vikhyatk commited on
Commit
8d0b547
·
verified ·
1 Parent(s): 7e87351

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -151
app.py CHANGED
@@ -42,165 +42,17 @@ os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True
42
  moondream = AutoModelForCausalLM.from_pretrained(
43
  "vikhyatk/moondream-next",
44
  trust_remote_code=True,
45
- torch_dtype=torch.bfloat16,
46
  device_map={"": "cuda"},
47
  revision=REVISION
48
  )
49
  moondream.eval()
50
 
51
 
52
- def convert_to_entities(text, coords):
53
- """
54
- Converts a string with special markers into an entity representation.
55
- Markers:
56
- - <|coord|> pairs indicate coordinate markers
57
- - <|start_ground_points|> indicates the start of grounding
58
- - <|start_ground_text|> indicates the start of a ground term
59
- - <|end_ground|> indicates the end of a ground term
60
-
61
- Returns:
62
- - Dictionary with cleaned text and entities with their character positions
63
- """
64
- # Initialize variables
65
- cleaned_text = ""
66
- entities = []
67
- entity = []
68
-
69
- # Track current position in cleaned text
70
- current_pos = 0
71
- # Track if we're currently processing an entity
72
- in_entity = False
73
- entity_start = 0
74
-
75
- i = 0
76
- while i < len(text):
77
- # Check for markers
78
- if text[i : i + 9] == "<|coord|>":
79
- i += 9
80
- entity.append(coords.pop(0))
81
- continue
82
-
83
- elif text[i : i + 23] == "<|start_ground_points|>":
84
- in_entity = True
85
- entity_start = current_pos
86
- i += 23
87
- continue
88
-
89
- elif text[i : i + 21] == "<|start_ground_text|>":
90
- entity_start = current_pos
91
- i += 21
92
- continue
93
-
94
- elif text[i : i + 14] == "<|end_ground|>":
95
- # Store entity position
96
- entities.append(
97
- {
98
- "entity": json.dumps(entity),
99
- "start": entity_start,
100
- "end": current_pos,
101
- }
102
- )
103
- entity = []
104
- in_entity = False
105
- i += 14
106
- continue
107
-
108
- # Add character to cleaned text
109
- cleaned_text += text[i]
110
- current_pos += 1
111
- i += 1
112
-
113
- return {"text": cleaned_text, "entities": entities}
114
-
115
-
116
- @spaces.GPU(duration=30)
117
- def answer_question(img, prompt, reasoning):
118
- buffer = ""
119
- resp = moondream.query(img, prompt, stream=True, reasoning=reasoning)
120
- reasoning_text = resp["reasoning"]["text"] if reasoning else "[reasoning disabled]"
121
- entities = [
122
- {"start": g["start_idx"], "end": g["end_idx"], "entity": json.dumps(g["points"])}
123
- for g in resp["reasoning"]["grounding"]
124
- ] if reasoning else []
125
- for new_text in resp["answer"]:
126
- buffer += new_text
127
- yield buffer.strip(), {"text": reasoning_text, "entities": entities}
128
-
129
-
130
- @spaces.GPU(duration=10)
131
- def caption(img, mode):
132
- if img is None:
133
- yield ""
134
- return
135
-
136
- buffer = ""
137
- if mode == "Short":
138
- l = "short"
139
- elif mode == "Long":
140
- l = "long"
141
- else:
142
- l = "normal"
143
- for t in moondream.caption(img, length=l, stream=True)["caption"]:
144
- buffer += t
145
- yield buffer.strip()
146
-
147
- @spaces.GPU(duration=10)
148
- def detect(img, object, eos_bias):
149
- if img is None:
150
- yield "", gr.update(visible=False, value=None)
151
- return
152
-
153
- eos_bias = float(eos_bias)
154
-
155
- objs = moondream.detect(img, object, settings={"eos_bias": eos_bias})["objects"]
156
-
157
- w, h = img.size
158
- if w > 768 or h > 768:
159
- img = Resize(768)(img)
160
- w, h = img.size
161
-
162
- draw_image = ImageDraw.Draw(img)
163
- for o in objs:
164
- draw_image.rectangle(
165
- (o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h),
166
- outline="red",
167
- width=3,
168
- )
169
-
170
- yield {"text": f"{len(objs)} detected", "entities": []}, gr.update(
171
- visible=True, value=img
172
- )
173
-
174
-
175
- @spaces.GPU(duration=10)
176
- def point(img, object):
177
- if img is None:
178
- yield "", gr.update(visible=False, value=None)
179
- return
180
-
181
- w, h = img.size
182
- if w > 768 or h > 768:
183
- img = Resize(768)(img)
184
- w, h = img.size
185
-
186
- objs = moondream.point(img, object, settings={"max_objects": 200})["points"]
187
- draw_image = ImageDraw.Draw(img)
188
- for o in objs:
189
- draw_image.ellipse(
190
- (o["x"] * w - 5, o["y"] * h - 5, o["x"] * w + 5, o["y"] * h + 5),
191
- fill="red",
192
- outline="blue",
193
- width=2,
194
- )
195
-
196
- yield {"text": f"{len(objs)} detected", "entities": []}, gr.update(
197
- visible=True, value=img
198
- )
199
-
200
  @spaces.GPU(duration=10)
201
  def localized_query(img, x, y, question):
202
  if img is None:
203
- yield "", {"text": "", "entities": []}, gr.update(visible=False, value=None)
204
  return
205
 
206
  answer = moondream.query(img, question, spatial_refs=[(x, y)])["answer"]
@@ -277,7 +129,7 @@ with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo:
277
 
278
  with gr.Column():
279
  output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True)
280
- ann = gr.Image(visible=False)
281
 
282
 
283
  demo.queue().launch()
 
42
  moondream = AutoModelForCausalLM.from_pretrained(
43
  "vikhyatk/moondream-next",
44
  trust_remote_code=True,
45
+ dtype=torch.bfloat16,
46
  device_map={"": "cuda"},
47
  revision=REVISION
48
  )
49
  moondream.eval()
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  @spaces.GPU(duration=10)
53
  def localized_query(img, x, y, question):
54
  if img is None:
55
+ yield "", gr.update(visible=False, value=None)
56
  return
57
 
58
  answer = moondream.query(img, question, spatial_refs=[(x, y)])["answer"]
 
129
 
130
  with gr.Column():
131
  output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True)
132
+ ann = gr.Image(visible=False, watermark="Click on the image on the right, not here")
133
 
134
 
135
  demo.queue().launch()