insomnia7 commited on
Commit
cd3fc46
·
verified ·
1 Parent(s): ecb3410

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +556 -82
app.py CHANGED
@@ -9,22 +9,19 @@ import matplotlib as mpl
9
  import numpy as np
10
  import uuid
11
  import imageio.v3 as iio
 
 
 
12
 
13
  import torch
 
14
  from torchvision.transforms.functional import to_pil_image
15
  from huggingface_hub import hf_hub_download
16
 
17
  import spaces
18
  import gradio as gr
19
 
20
- # GRADIO_TMP = os.path.join(os.path.dirname(__file__), ".gradio_tmp")
21
- # Path(GRADIO_TMP).mkdir(parents=True, exist_ok=True)
22
-
23
- # os.environ["GRADIO_TEMP_DIR"] = GRADIO_TMP
24
- # os.environ["TMPDIR"] = GRADIO_TMP
25
- # os.environ["TEMP"] = GRADIO_TMP
26
- # os.environ["TMP"] = GRADIO_TMP
27
-
28
  from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
29
  from sam2 import VQ_SAM2, VQ_SAM2Config, SAM2Config
30
  from visualizer import sample_color, draw_mask
@@ -78,15 +75,14 @@ MODEL = 'zhouyik/Qwen3-VL-8B-SAMTok'
78
  TITLE = 'SAMTok: Representing Any Mask with Two Words'
79
 
80
  HEADER = """
81
- <p align="center" style="margin: 1em 0 2em;"><img width="260" src="https://github.com/bytedance/Sa2VA/blob/main/projects/samtok/figs/logo.png"></p>
82
- <h3 align="center">SAMTok: Representing Any Mask with Two Words</h3>
83
  <div style="display: flex; justify-content: center; gap: 5px;">
84
  <a href="https://github.com/bytedance/Sa2VA/tree/main/projects/samtok" target="_blank"><img src="https://img.shields.io/badge/arXiv-2509.18094-red"></a>
85
  <a href="https://github.com/bytedance/Sa2VA/tree/main/projects/samtok" target="_blank"><img src="https://img.shields.io/badge/Project-Page-brightgreen"></a>
86
  <a href="https://huggingface.co/collections/zhouyik/samtok" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue"></a>
87
  <a href="https://github.com/bytedance/Sa2VA" target="_blank"><img src="https://img.shields.io/github/stars/bytedance/Sa2VA"></a>
88
  </div>
89
- <p style="margin-top: 1em;">SAMTok provides a unified mask-token interface for MLLMs. (1) SAMTok compresses region masks into two discrete tokens and faithfully reconstructs them across diverse visual domains. (2) Injecting these mask tokens into MLLMs enables a wide range of region-level mask generation and understanding tasks. (3) The text-based representation of region masks allows a purely textual answer-matching reward for the GRPO of the mask generation task.</p>
90
  """
91
 
92
  JS = """
@@ -99,33 +95,63 @@ function init() {
99
  window.addEventListener('load', init);
100
  """
101
 
102
- device = torch.device('cuda')
103
-
104
- model = Qwen3VLForConditionalGeneration.from_pretrained(
105
- MODEL, torch_dtype="auto"
106
- ).cuda().eval()
107
-
108
- processor = AutoProcessor.from_pretrained(MODEL)
109
 
110
  # build vq-sam2 model
 
 
111
  sam2_ckpt_local = hf_hub_download(repo_id=MODEL, filename="sam2.1_hiera_large.pt")
112
  mask_tokenizer_local = hf_hub_download(repo_id=MODEL, filename="mask_tokenizer_256x2.pth")
113
  CODEBOOK_SIZE = 256
114
  CODEBOOK_DEPTH = 2
115
- sam2_config = SAM2Config(
116
- ckpt_path=sam2_ckpt_local,
117
- )
118
- vq_sam2_config = VQ_SAM2Config(
119
- sam2_config=sam2_config,
120
- codebook_size=CODEBOOK_SIZE,
121
- codebook_depth=CODEBOOK_DEPTH,
122
- shared_codebook=False,
123
- latent_dim=256,
124
- )
125
- vq_sam2 = VQ_SAM2(vq_sam2_config).cuda().eval()
126
- state = torch.load(mask_tokenizer_local, map_location="cpu")
127
- vq_sam2.load_state_dict(state)
128
- sam2_image_processor = DirectResize(1024)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
 
131
  colors = sample_color()
@@ -150,10 +176,312 @@ def reset_seg():
150
  def reset_reg():
151
  return 1, gr.update(interactive=False)
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  @spaces.GPU
154
  def infer_seg(media, query):
155
- print("=======>>>enter infer seg")
156
- global model
157
 
158
  if not media:
159
  gr.Warning('Please upload an image')
@@ -185,15 +513,13 @@ def infer_seg(media, query):
185
  return_tensors="pt"
186
  )
187
 
188
- model = model.to(device)
189
-
190
  inputs = inputs.to(model.device)
191
 
192
  generated_ids = model.generate(
193
  **inputs,
194
  max_new_tokens=1024,
195
- do_sample=False,
196
- top_p=1.0,
197
  )
198
  generated_ids_trimmed = [
199
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
@@ -201,21 +527,37 @@ def infer_seg(media, query):
201
  output_text = processor.batch_decode(
202
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
203
  )[0]
204
-
205
- print("========>>>>output_text", output_text)
206
- exit(0)
207
 
208
  quant_ids = extract_mt_token_ids_v1(output_text)
 
 
 
 
 
 
 
 
 
209
  if len(quant_ids) % CODEBOOK_DEPTH != 0:
210
  output_text = fix_mt_format_comprehensive(output_text)
211
  quant_ids = extract_mt_token_ids_v2(output_text)
212
 
 
 
 
 
 
 
 
 
213
  batch_size = len(quant_ids) // CODEBOOK_DEPTH
214
  remap_quant_ids = []
215
  tags = []
 
216
  for bs_id in range(batch_size):
217
  chunk_quant_ids = quant_ids[bs_id*CODEBOOK_DEPTH:(bs_id+1)*CODEBOOK_DEPTH]
218
  tags.append(f'<|mt_start|><|mt_{str(chunk_quant_ids[0]).zfill(4)}|><|mt_{str(chunk_quant_ids[1]).zfill(4)}|><|mt_end|>')
 
219
  remap_chunk_quant_ids = [quant_id - book_id*CODEBOOK_SIZE for book_id, quant_id in enumerate(chunk_quant_ids)]
220
  code1 = remap_chunk_quant_ids[0]
221
  code2 = remap_chunk_quant_ids[1]
@@ -240,63 +582,194 @@ def infer_seg(media, query):
240
  # _pred_masks = _pred_masks[:, 0, :, :].cpu().numpy().astype(np.uint8)
241
  _pred_masks = _pred_masks.long().unsqueeze(2).cpu() # n, 1, 1, h, w
242
 
 
 
 
 
 
 
 
 
243
  entities = []
244
- unique_tags = list(set(tags))
245
- entity_names = []
246
- for i, tag in enumerate(unique_tags):
247
  for m in re.finditer(re.escape(tag), output_text):
248
- entities.append(dict(entity=f'Target {i + 1}', start=m.start(), end=m.end()))
249
- entity_names.append(f'Target {i + 1}')
250
-
251
  answer = dict(text=output_text, entities=entities)
252
 
 
 
 
 
 
 
 
 
 
 
253
  frames = torch.from_numpy(np.array(image)).unsqueeze(0)
254
  imgs = draw_mask(frames, _pred_masks, colors=colors)
255
 
256
  path = f"/tmp/{uuid.uuid4().hex}.png"
257
  iio.imwrite(path, imgs, duration=100, loop=0)
258
 
259
- masks = media, [(m[0, 0].numpy(), entity_names[i]) for i, m in enumerate(_pred_masks)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
- return answer, masks, path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
 
264
  def build_demo():
265
  with gr.Blocks(title=TITLE, js=JS, theme=gr.themes.Soft()) as demo:
266
  gr.HTML(HEADER)
267
 
268
- # with gr.Tab('Mask Generation'):
269
- download_btn_1 = gr.DownloadButton(label='📦 Download', interactive=False, render=False)
270
- msk_1 = gr.AnnotatedImage(label='De-tokenized 2D masks', color_map=color_map, render=False)
271
- ans_1 = gr.HighlightedText(
272
- label='Model Response', color_map=color_map_light, show_inline_category=False, render=False)
273
- with gr.Row():
274
- with gr.Column():
275
- media_1 = gr.Image(type='filepath')
276
-
277
- sample_frames_1 = gr.Slider(1, 32, value=16, step=1, visible=False)
278
-
279
- query_1 = gr.Textbox(label='Text Prompt', placeholder='Please segment the...', elem_id='query_1')
280
-
281
- with gr.Row():
282
- random_btn_1 = gr.Button(value='🔮 Random', visible=False)
283
-
284
- reset_btn_1 = gr.ClearButton([media_1, query_1, msk_1, ans_1], value='🗑️ Reset')
285
- reset_btn_1.click(reset_seg, None, [sample_frames_1, download_btn_1])
286
-
287
- download_btn_1.render()
288
-
289
- submit_btn_1 = gr.Button(value='🚀 Submit', variant='primary', elem_id='submit_1')
290
-
291
- with gr.Column():
292
- msk_1.render()
293
- ans_1.render()
294
-
295
- ctx_1 = submit_btn_1.click(disable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
296
- ctx_1 = ctx_1.then(infer_seg, [media_1, query_1], [ans_1, msk_1, download_btn_1])
297
- ctx_1.then(enable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
298
- # with gr.Tab('Mask Understanding'):
299
- # pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
  return demo
302
 
@@ -304,4 +777,5 @@ if __name__ == '__main__':
304
  demo = build_demo()
305
 
306
  demo.queue()
307
- demo.launch(server_name='0.0.0.0')
 
 
9
  import numpy as np
10
  import uuid
11
  import imageio.v3 as iio
12
+ import base64
13
+ import io
14
+ import re
15
 
16
  import torch
17
+ import torchvision
18
  from torchvision.transforms.functional import to_pil_image
19
  from huggingface_hub import hf_hub_download
20
 
21
  import spaces
22
  import gradio as gr
23
 
24
+ from transformers import SamModel, SamProcessor
 
 
 
 
 
 
 
25
  from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
26
  from sam2 import VQ_SAM2, VQ_SAM2Config, SAM2Config
27
  from visualizer import sample_color, draw_mask
 
75
  TITLE = 'SAMTok: Representing Any Mask with Two Words'
76
 
77
  HEADER = """
78
+ <h2 align="center">SAMTok: Representing Any Mask with Two Words</h3>
 
79
  <div style="display: flex; justify-content: center; gap: 5px;">
80
  <a href="https://github.com/bytedance/Sa2VA/tree/main/projects/samtok" target="_blank"><img src="https://img.shields.io/badge/arXiv-2509.18094-red"></a>
81
  <a href="https://github.com/bytedance/Sa2VA/tree/main/projects/samtok" target="_blank"><img src="https://img.shields.io/badge/Project-Page-brightgreen"></a>
82
  <a href="https://huggingface.co/collections/zhouyik/samtok" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue"></a>
83
  <a href="https://github.com/bytedance/Sa2VA" target="_blank"><img src="https://img.shields.io/github/stars/bytedance/Sa2VA"></a>
84
  </div>
85
+ <p style="margin-top: 1em;">SAMTok provides a unified mask-token interface for MLLMs.</p>
86
  """
87
 
88
  JS = """
 
95
  window.addEventListener('load', init);
96
  """
97
 
98
+ MT_START_TOKEN = '<|mt_start|>'
99
+ MT_END_TOKEN = '<|mt_end|>'
100
+ MT_CONTEXT_TOKEN = '<|mt_{}|>'
 
 
 
 
101
 
102
  # build vq-sam2 model
103
+ vq_sam2 = None
104
+ sam2_image_processor = DirectResize(1024)
105
  sam2_ckpt_local = hf_hub_download(repo_id=MODEL, filename="sam2.1_hiera_large.pt")
106
  mask_tokenizer_local = hf_hub_download(repo_id=MODEL, filename="mask_tokenizer_256x2.pth")
107
  CODEBOOK_SIZE = 256
108
  CODEBOOK_DEPTH = 2
109
+ def load_vq_sam2():
110
+ global vq_sam2
111
+
112
+ if vq_sam2 is not None:
113
+ return vq_sam2
114
+
115
+ if hasattr(torch, "set_default_device"):
116
+ torch.set_default_device("cpu")
117
+
118
+ sam2_config = SAM2Config(
119
+ ckpt_path=sam2_ckpt_local,
120
+ )
121
+ vq_sam2_config = VQ_SAM2Config(
122
+ sam2_config=sam2_config,
123
+ codebook_size=CODEBOOK_SIZE,
124
+ codebook_depth=CODEBOOK_DEPTH,
125
+ shared_codebook=False,
126
+ latent_dim=256,
127
+ )
128
+
129
+ vq_sam2 = VQ_SAM2(vq_sam2_config)
130
+ state = torch.load(mask_tokenizer_local, map_location="cpu")
131
+ vq_sam2.load_state_dict(state)
132
+
133
+ vq_sam2 = vq_sam2.cuda().eval()
134
+ return vq_sam2
135
+
136
+ processor = AutoProcessor.from_pretrained(MODEL)
137
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
138
+
139
+ _qwen = None
140
+ _sam = None
141
+
142
+ def get_qwen():
143
+ """Must be called only inside @spaces.GPU function."""
144
+ global _qwen
145
+ if _qwen is None:
146
+ _qwen = Qwen3VLForConditionalGeneration.from_pretrained(MODEL, torch_dtype="auto").to("cuda").eval()
147
+ return _qwen
148
+
149
+ def get_sam():
150
+ """Must be called only inside @spaces.GPU function."""
151
+ global _sam
152
+ if _sam is None:
153
+ _sam = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda").eval()
154
+ return _sam
155
 
156
 
157
  colors = sample_color()
 
176
  def reset_reg():
177
  return 1, gr.update(interactive=False)
178
 
179
+ def new_mu_state():
180
+ return {
181
+ "image_path": None,
182
+ "ori_size": None, # (w, h)
183
+ "original_sizes": None, # e.g. [h, w]
184
+ "reshaped_input_sizes": None, # e.g. [h', w']
185
+ "image_embeddings": None, # numpy array on CPU
186
+ "points": [],
187
+ "labels": [],
188
+ "cur_mask": None, # np.uint8 (H,W)
189
+ "regions": {},
190
+ "next_region_id": 1,
191
+ }
192
+
193
+ @spaces.GPU
194
+ def mu_on_upload_image(media_path, mu_state):
195
+ if not media_path:
196
+ return new_mu_state(), None, None
197
+
198
+ sam_model = get_sam() # GPU-side
199
+
200
+ img = Image.open(media_path).convert("RGB")
201
+ w, h = img.size
202
+
203
+ inputs = sam_processor(img, return_tensors="pt").to("cuda")
204
+ with torch.no_grad():
205
+ emb = sam_model.get_image_embeddings(inputs["pixel_values"]) # CUDA tensor
206
+
207
+ st = new_mu_state()
208
+ st["image_path"] = media_path
209
+ st["ori_size"] = (w, h)
210
+
211
+ # store sizes as python lists (not tensors)
212
+ st["original_sizes"] = inputs["original_sizes"][0].detach().cpu().tolist()
213
+ st["reshaped_input_sizes"] = inputs["reshaped_input_sizes"][0].detach().cpu().tolist()
214
+
215
+ # store embeddings as CPU numpy (picklable)
216
+ st["image_embeddings"] = emb[0].detach().cpu().to(torch.float16).numpy() # (256,64,64)
217
+
218
+ return st, media_path, None
219
+
220
+ def mu_predict_mask_from_state(mu_state):
221
+ if mu_state["image_embeddings"] is None or mu_state["image_path"] is None:
222
+ return None
223
+ if len(mu_state["points"]) == 0:
224
+ return None
225
+
226
+ sam_model = get_sam()
227
+
228
+ img = Image.open(mu_state["image_path"]).convert("RGB")
229
+
230
+ prompt_inputs = sam_processor(
231
+ img,
232
+ input_points=[mu_state["points"]],
233
+ input_labels=[mu_state["labels"]],
234
+ return_tensors="pt",
235
+ ).to("cuda")
236
+
237
+ # restore embedding to CUDA tensor, shape (1,256,64,64)
238
+ emb = torch.from_numpy(mu_state["image_embeddings"]).to("cuda")
239
+ emb = emb.unsqueeze(0)
240
+
241
+ with torch.no_grad():
242
+ outputs = sam_model(
243
+ image_embeddings=emb,
244
+ input_points=prompt_inputs["input_points"],
245
+ input_labels=prompt_inputs["input_labels"],
246
+ multimask_output=False,
247
+ )
248
+
249
+ # postprocess needs lists/tensors on CPU
250
+ original_sizes = torch.tensor([mu_state["original_sizes"]], dtype=torch.long)
251
+ reshaped_sizes = torch.tensor([mu_state["reshaped_input_sizes"]], dtype=torch.long)
252
+
253
+ masks = sam_processor.post_process_masks(
254
+ outputs.pred_masks.detach().cpu(),
255
+ original_sizes,
256
+ reshaped_sizes,
257
+ )
258
+ mask = masks[0][0].numpy()
259
+ mask = (mask > 0).astype(np.uint8)
260
+ return mask
261
+
262
+ @spaces.GPU
263
+ def mu_add_point(evt: gr.SelectData, mu_state, is_positive: bool):
264
+ if mu_state["image_path"] is None:
265
+ return mu_state, None
266
+
267
+ x, y = evt.index
268
+ mu_state["points"].append([float(x), float(y)])
269
+ mu_state["labels"].append(1 if is_positive else 0)
270
+
271
+ mask = mu_predict_mask_from_state(mu_state)
272
+ mu_state["cur_mask"] = mask
273
+ return mu_state, mask
274
+
275
+ def mu_clear_prompts(mu_state):
276
+ mu_state["points"] = []
277
+ mu_state["labels"] = []
278
+ mu_state["cur_mask"] = None
279
+ return mu_state, None
280
+
281
+
282
+ @spaces.GPU
283
+ def mu_save_region(mu_state):
284
+ if mu_state["cur_mask"] is None:
285
+ return mu_state, gr.update(choices=[], value=None)
286
+
287
+ rid = f"region{mu_state['next_region_id']}"
288
+ mu_state["next_region_id"] += 1
289
+
290
+ reg = {"mask": mu_state["cur_mask"], "token_str": None, "zoom_in_token_str": None, "zoom_in_image": None}
291
+
292
+ vq_sam2 = load_vq_sam2()
293
+
294
+ image = Image.open(mu_state["image_path"]).convert('RGB')
295
+ ori_width, ori_height = image.size
296
+
297
+ sam2_image = np.array(image)
298
+ sam2_image = sam2_image_processor.apply_image(sam2_image)
299
+ sam2_pixel_values = torch.from_numpy(sam2_image).permute(2, 0, 1).contiguous()
300
+ sam2_pixel_values = sam2_pixel_values.unsqueeze(0).to(vq_sam2.dtype).to(vq_sam2.device)
301
+
302
+ masks = torch.stack([torch.from_numpy(np.ascontiguousarray(mu_state["cur_mask"].copy()))])
303
+
304
+ boxes = torchvision.ops.masks_to_boxes(masks)
305
+ x1, y1, x2, y2 = boxes.squeeze().cpu().numpy().tolist()
306
+ boxes_w = x2 - x1
307
+ boxes_h = y2 - y1
308
+ boxes_area = boxes_h * boxes_w
309
+ image_area = ori_height * ori_width
310
+ boxes_occupied_ratio = boxes_area / image_area
311
+
312
+ whwh = torch.as_tensor([[ori_width, ori_height, ori_width, ori_height]])
313
+ boxes = boxes / whwh
314
+ boxes = boxes.to(vq_sam2.device)
315
+ masks = [m.unsqueeze(0).to(vq_sam2.device) for m in masks]
316
+
317
+ with torch.no_grad():
318
+ vq_sam2_output = vq_sam2(
319
+ sam2_pixel_values,
320
+ masks,
321
+ boxes,
322
+ reconstruct_mask=False,
323
+ )
324
+
325
+ quant_codes = vq_sam2_output.quant_codes.squeeze().cpu().numpy().astype(np.int32).tolist()
326
+ remap_quant_codes = [depth_idx*CODEBOOK_SIZE+quant_code for depth_idx, quant_code in enumerate(quant_codes)]
327
+ quant_codes = remap_quant_codes
328
+ global_mask_tokens_str = MT_START_TOKEN + ''.join([MT_CONTEXT_TOKEN.format(str(code).zfill(4)) for code in quant_codes]) + MT_END_TOKEN
329
+
330
+ reg["token_str"] = global_mask_tokens_str
331
+
332
+ if boxes_occupied_ratio < 0.3:
333
+ bbox_w = x2 - x1
334
+ bbox_h = y2 - y1
335
+ if bbox_w < 140:
336
+ x1 = x1 - (140 - bbox_w) // 2
337
+ x2 = x2 + (140 - bbox_w) // 2
338
+ if bbox_h < 140:
339
+ y1 = y1 - (140 - bbox_h) // 2
340
+ y2 = y2 + (140 - bbox_h) // 2
341
+ x1 = int(max(0, x1))
342
+ x2 = int(min(ori_width, x2))
343
+ y1 = int(max(0, y1))
344
+ y2 = int(min(ori_height, y2))
345
+
346
+ cropped_image = image.crop((x1, y1, x2, y2))
347
+ crop_width, crop_height = cropped_image.size
348
+
349
+ if crop_width > crop_height and crop_width < 280:
350
+ ratio = 280 / crop_height
351
+ new_height = 280
352
+ new_width = int(crop_width * ratio)
353
+ resized_crop_image = cropped_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
354
+ elif crop_height > crop_width and crop_height < 280:
355
+ ratio = 280 / crop_width
356
+ new_width = 280
357
+ new_height = int(crop_height * ratio)
358
+ resized_crop_image = cropped_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
359
+ elif crop_height == crop_width and crop_width < 280:
360
+ ratio = 280 / crop_height
361
+ new_height = 280
362
+ new_width = int(crop_width * ratio)
363
+ resized_crop_image = cropped_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
364
+ else:
365
+ new_height = new_width = None
366
+ resized_crop_image = None
367
+
368
+ if resized_crop_image is None:
369
+ cropped_sam2_image = np.array(cropped_image)
370
+ cropped_sam2_image = sam2_image_processor.apply_image(cropped_sam2_image)
371
+ cropped_sam2_pixel_values = torch.from_numpy(cropped_sam2_image).permute(2, 0, 1).contiguous()
372
+ cropped_sam2_pixel_values = cropped_sam2_pixel_values.unsqueeze(0).to(vq_sam2.dtype).to(vq_sam2.device)
373
+ else:
374
+ cropped_sam2_image = np.array(resized_crop_image)
375
+ cropped_sam2_image = sam2_image_processor.apply_image(cropped_sam2_image)
376
+ cropped_sam2_pixel_values = torch.from_numpy(cropped_sam2_image).permute(2, 0, 1).contiguous()
377
+ cropped_sam2_pixel_values = cropped_sam2_pixel_values.unsqueeze(0).to(vq_sam2.dtype).to(vq_sam2.device)
378
+
379
+ cropped_masks = torch.stack([torch.from_numpy(np.ascontiguousarray(mu_state["cur_mask"].copy()[y1:y2, x1:x2]))])
380
+ assert cropped_masks.shape[-2] == crop_height and cropped_masks.shape[-1] == crop_width
381
+
382
+ if resized_crop_image is not None:
383
+ resized_crop_masks = torch.nn.functional.interpolate(cropped_masks.unsqueeze(0), size=(new_height, new_width), mode='bilinear')
384
+ resized_crop_masks = resized_crop_masks[0] > 0.5
385
+ cropped_masks = resized_crop_masks
386
+ crop_height, crop_width = cropped_masks.shape[-2:]
387
+ cropped_boxes = torchvision.ops.masks_to_boxes(cropped_masks)
388
+ crop_whwh = torch.as_tensor([[crop_width, crop_height, crop_width, crop_height]])
389
+ cropped_boxes = cropped_boxes / crop_whwh
390
+ cropped_boxes = cropped_boxes.to(vq_sam2.device)
391
+ cropped_masks = [m.unsqueeze(0).to(vq_sam2.device) for m in cropped_masks]
392
+
393
+ with torch.no_grad():
394
+ cropped_vq_sam2_output = vq_sam2(
395
+ cropped_sam2_pixel_values,
396
+ cropped_masks,
397
+ cropped_boxes,
398
+ reconstruct_mask=True,
399
+ )
400
+
401
+ crop_quant_codes = cropped_vq_sam2_output.quant_codes.squeeze().detach().cpu().numpy().astype(np.int32).tolist()
402
+ remap_crop_quant_codes = [depth_idx*CODEBOOK_SIZE+quant_code for depth_idx, quant_code in enumerate(crop_quant_codes)]
403
+ crop_quant_codes = remap_crop_quant_codes
404
+ zoom_in_mask_tokens_str = MT_START_TOKEN + ''.join([MT_CONTEXT_TOKEN.format(str(code).zfill(4)) for code in crop_quant_codes]) + MT_END_TOKEN
405
+
406
+ buffer = io.BytesIO()
407
+ if resized_crop_image is None:
408
+ cropped_image.save(buffer, format='JPEG')
409
+ else:
410
+ resized_crop_image.save(buffer, format='JPEG')
411
+ buffer.seek(0)
412
+ b64 = base64.b64encode(buffer.read()).decode("utf-8")
413
+
414
+ reg["zoom_in_token_str"] = zoom_in_mask_tokens_str
415
+ reg["zoom_in_image"] = b64
416
+
417
+ mu_state["regions"][rid] = reg
418
+ choices = list(mu_state["regions"].keys())
419
+ return mu_state, gr.update(choices=choices, value=rid)
420
+
421
+ def replace_region_all(text: str, rid: str, token_str: str) -> str:
422
+ pattern = re.compile(rf"(?<![A-Za-z0-9_]){re.escape(rid)}(?![A-Za-z0-9_])")
423
+ return pattern.sub(f"{rid} {token_str}", text)
424
+
425
+ def short_tag_from_codes(code_a: int, code_b: int) -> str:
426
+ return f"<{code_a:04d}-{code_b:04d}>"
427
+
428
+ @spaces.GPU
429
+ def infer_understanding(mu_media, mu_query, mu_state):
430
+ model = get_qwen()
431
+
432
+ if not mu_media:
433
+ gr.Warning("Please upload an image")
434
+ return ""
435
+ if not mu_query:
436
+ gr.Warning("Please provide a text prompt.")
437
+ return ""
438
+
439
+ raw_query = mu_query
440
+
441
+ # 1) find which regions are referenced in the ORIGINAL query
442
+ used = []
443
+ for rid in mu_state["regions"].keys():
444
+ if re.search(rf"(?<![A-Za-z0-9_]){re.escape(rid)}(?![A-Za-z0-9_])", raw_query):
445
+ used.append(rid)
446
+
447
+ # 2) replace ALL occurrences for each used rid
448
+ for rid in used:
449
+ reg = mu_state["regions"][rid]
450
+ token_str = reg.get("token_str")
451
+ if token_str:
452
+ mu_query = replace_region_all(mu_query, rid, token_str)
453
+
454
+ content = [{"type": "image", "image": mu_media}]
455
+ content.append({"type": "text", "text": mu_query})
456
+
457
+ # 3) zoom-in blocks only for used regions
458
+ for rid in used:
459
+ reg = mu_state["regions"][rid]
460
+ zoom_in_image = reg.get("zoom_in_image")
461
+ zoom_in_token_str = reg.get("zoom_in_token_str")
462
+ if zoom_in_image and zoom_in_token_str:
463
+ content.append({"type": "text", "text": f" Zoom in {rid}: "})
464
+ content.append({"type": "image", "image": f"data:image/jpeg;base64,{zoom_in_image}"})
465
+ content.append({"type": "text", "text": f", {zoom_in_token_str}."})
466
+
467
+ messages = [{"role": "user", "content": content}]
468
+ inputs = processor.apply_chat_template(
469
+ messages, tokenize=True, add_generation_prompt=True,
470
+ return_dict=True, return_tensors="pt"
471
+ ).to(device)
472
+
473
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
474
+ generated_ids_trimmed = [
475
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
476
+ ]
477
+ return processor.batch_decode(
478
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
479
+ )[0]
480
+
481
  @spaces.GPU
482
  def infer_seg(media, query):
483
+ model = get_qwen()
484
+ vq_sam2 = load_vq_sam2()
485
 
486
  if not media:
487
  gr.Warning('Please upload an image')
 
513
  return_tensors="pt"
514
  )
515
 
 
 
516
  inputs = inputs.to(model.device)
517
 
518
  generated_ids = model.generate(
519
  **inputs,
520
  max_new_tokens=1024,
521
+ # do_sample=False,
522
+ # top_p=1.0,
523
  )
524
  generated_ids_trimmed = [
525
  out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
527
  output_text = processor.batch_decode(
528
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
529
  )[0]
 
 
 
530
 
531
  quant_ids = extract_mt_token_ids_v1(output_text)
532
+ if len(quant_ids) == 0:
533
+ # only show model response; hide masks & download
534
+ answer = dict(text=output_text, entities=[])
535
+ return (
536
+ answer,
537
+ gr.update(value=None, visible=False), # hide AnnotatedImage
538
+ gr.update(value=None, interactive=False, visible=False), # hide DownloadButton
539
+ )
540
+
541
  if len(quant_ids) % CODEBOOK_DEPTH != 0:
542
  output_text = fix_mt_format_comprehensive(output_text)
543
  quant_ids = extract_mt_token_ids_v2(output_text)
544
 
545
+ if len(quant_ids) == 0 or (len(quant_ids) % CODEBOOK_DEPTH != 0):
546
+ answer = dict(text=output_text, entities=[])
547
+ return (
548
+ answer,
549
+ gr.update(value=None, visible=False),
550
+ gr.update(value=None, interactive=False, visible=False),
551
+ )
552
+
553
  batch_size = len(quant_ids) // CODEBOOK_DEPTH
554
  remap_quant_ids = []
555
  tags = []
556
+ short_tags = []
557
  for bs_id in range(batch_size):
558
  chunk_quant_ids = quant_ids[bs_id*CODEBOOK_DEPTH:(bs_id+1)*CODEBOOK_DEPTH]
559
  tags.append(f'<|mt_start|><|mt_{str(chunk_quant_ids[0]).zfill(4)}|><|mt_{str(chunk_quant_ids[1]).zfill(4)}|><|mt_end|>')
560
+ short_tags.append(short_tag_from_codes(chunk_quant_ids[0], chunk_quant_ids[1]))
561
  remap_chunk_quant_ids = [quant_id - book_id*CODEBOOK_SIZE for book_id, quant_id in enumerate(chunk_quant_ids)]
562
  code1 = remap_chunk_quant_ids[0]
563
  code2 = remap_chunk_quant_ids[1]
 
582
  # _pred_masks = _pred_masks[:, 0, :, :].cpu().numpy().astype(np.uint8)
583
  _pred_masks = _pred_masks.long().unsqueeze(2).cpu() # n, 1, 1, h, w
584
 
585
+ tag_to_mask_idx = {}
586
+ tag_to_short = {}
587
+ for i, (tag, stag) in enumerate(zip(tags, short_tags)):
588
+ if tag not in tag_to_mask_idx:
589
+ tag_to_mask_idx[tag] = i
590
+ tag_to_short[tag] = stag
591
+ unique_tags = list(tag_to_mask_idx.keys())
592
+
593
  entities = []
594
+ for tag in unique_tags:
 
 
595
  for m in re.finditer(re.escape(tag), output_text):
596
+ entities.append(dict(entity=tag, start=m.start(), end=m.end()))
597
+
 
598
  answer = dict(text=output_text, entities=entities)
599
 
600
+ # entities = []
601
+ # unique_tags = list(set(tags))
602
+ # entity_names = []
603
+ # for i, tag in enumerate(unique_tags):
604
+ # for m in re.finditer(re.escape(tag), output_text):
605
+ # entities.append(dict(entity=f'Target {i + 1}', start=m.start(), end=m.end()))
606
+ # entity_names.append(f'Target {i + 1}')
607
+
608
+ # answer = dict(text=output_text, entities=entities)
609
+
610
  frames = torch.from_numpy(np.array(image)).unsqueeze(0)
611
  imgs = draw_mask(frames, _pred_masks, colors=colors)
612
 
613
  path = f"/tmp/{uuid.uuid4().hex}.png"
614
  iio.imwrite(path, imgs, duration=100, loop=0)
615
 
616
+ # masks_value = (media, [(m[0, 0].numpy(), entity_names[i]) for i, m in enumerate(_pred_masks)])
617
+ # masks_value = (
618
+ # media,
619
+ # [( _pred_masks[tag_to_mask_idx[tag]][0, 0].numpy(), tag ) for tag in unique_tags]
620
+ # )
621
+
622
+ entity_names = [f"Target {i+1}" for i in range(len(unique_tags))]
623
+ masks_value = (
624
+ media,
625
+ [(_pred_masks[tag_to_mask_idx[tag]][0, 0].numpy().astype(np.uint8) * 255, entity_names[i]) for i, tag in enumerate(unique_tags)]
626
+ )
627
+
628
+ lines = []
629
+ for i, tag in enumerate(unique_tags):
630
+ short_tag = tag_to_short[tag]
631
+ lines.append(f"- **{entity_names[i]}** → `{short_tag}`")
632
+ tag_map_text = "### Mask-Token Mapping\n" + "\n".join(lines)
633
 
634
+ # dynamic color maps keyed by tag
635
+ dyn_color_map = {}
636
+ dyn_color_map_light = {}
637
+ for i, tag in enumerate(unique_tags):
638
+ c = colors[i % len(colors)]
639
+ dyn_color_map[tag] = f'#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}'
640
+ dyn_color_map_light[tag] = f'#{int(c[0] * 127.5 + 127.5):02x}{int(c[1] * 127.5 + 127.5):02x}{int(c[2] * 127.5 + 127.5):02x}'
641
+
642
+ # return answer, masks, path
643
+ return (
644
+ gr.update(value=answer, color_map=dyn_color_map_light, visible=True), # ans_1
645
+ gr.update(value=masks_value, visible=True), # msk_1
646
+ gr.update(value=path, interactive=True, visible=True), # download
647
+ gr.update(value=tag_map_text, visible=True)
648
+ )
649
 
650
 
651
  def build_demo():
652
  with gr.Blocks(title=TITLE, js=JS, theme=gr.themes.Soft()) as demo:
653
  gr.HTML(HEADER)
654
 
655
+ with gr.Tab('Mask Generation'):
656
+ download_btn_1 = gr.DownloadButton(label='📦 Download', interactive=False, render=False)
657
+ msk_1 = gr.AnnotatedImage(label='De-tokenized 2D masks', color_map=color_map, render=False)
658
+ ans_1 = gr.HighlightedText(
659
+ label='Model Response', color_map=color_map_light, show_inline_category=False, render=False)
660
+ tag_map_md = gr.Markdown(label="Mask-Token Mapping", value="", visible=False)
661
+ with gr.Row():
662
+ with gr.Column():
663
+ media_1 = gr.Image(type='filepath')
664
+
665
+ sample_frames_1 = gr.Slider(1, 32, value=16, step=1, visible=False)
666
+
667
+ # query_1 = gr.Textbox(label='Text Prompt', placeholder='Please segment the...', elem_id='query_1')
668
+ query_1 = gr.Textbox(
669
+ label='Text Prompt',
670
+ placeholder='Please segment the...',
671
+ lines=3,
672
+ max_lines=12,
673
+ elem_id='query_1'
674
+ )
675
+
676
+ with gr.Row():
677
+ random_btn_1 = gr.Button(value='🔮 Random', visible=False)
678
+
679
+ reset_btn_1 = gr.ClearButton([media_1, query_1, msk_1, ans_1, tag_map_md], value='🗑️ Reset')
680
+ reset_btn_1.click(reset_seg, None, [sample_frames_1, download_btn_1])
681
+
682
+ download_btn_1.render()
683
+
684
+ submit_btn_1 = gr.Button(value='🚀 Submit', variant='primary', elem_id='submit_1')
685
+
686
+ with gr.Column():
687
+ msk_1.render()
688
+ tag_map_md
689
+ ans_1.render()
690
+
691
+ ctx_1 = submit_btn_1.click(disable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
692
+ ctx_1 = ctx_1.then(infer_seg, [media_1, query_1], [ans_1, msk_1, download_btn_1, tag_map_md])
693
+ ctx_1.then(enable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
694
+
695
+ EXAMPLES = [
696
+ ["examples/example1.jpeg", "Locate the tissue box in this image and response with its segmentation mask."],
697
+ ["examples/example2.jpg", "Could you please give me a detail description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer."],
698
+ ["examples/example3.png", "Find all the people who are currently standing and response with segmentation masks."],
699
+ ["examples/example4.jpg", "Segment every instance that belongs to the following categories: person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic light, fire hydrant, stop sign, parking meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports ball, kite, baseball bat, baseball glove, skateboard, surfboard, tennis racket, bottle, wine glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot dog, pizza, donut, cake, chair, couch, potted plant, bed, dining table, toilet, tv, laptop, mouse, remote, keyboard, cell phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy bear, hair drier, toothbrush, banner, blanket, bridge, cardboard, counter, curtain, door-stuff, floor-wood, flower, fruit, gravel, house, light, mirror-stuff, net, pillow, platform, playingfield, railroad, river, road, roof, sand, sea, shelf, snow, stairs, tent, towel, wall-brick, wall-stone, wall-tile, wall-wood, water-other, window-blind, window-other, tree-merged, fence-merged, ceiling-merged, sky-other-merged, cabinet-merged, table-merged, floor-other-merged, pavement-merged, mountain-merged, grass-merged, dirt-merged, paper-merged, food-other-merged, building-other-merged, rock-merged, wall-other-merged, rug-merged"],
700
+ ["examples/example5.jpg", "Generate a scene graph for this image. Identify the main objects and describe their relationships to each other."],
701
+ ["examples/example6.jpg", "Which person, wearing a shirt of a primary color, is positioned between the individual in athletic attire and the one in a uniform? A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>"]
702
+ ]
703
+ gr.Markdown("## Examples")
704
+ gr.Examples(
705
+ examples=EXAMPLES,
706
+ inputs=[media_1, query_1],
707
+ label="Click an example to load the image and prompt",
708
+ )
709
+ with gr.Tab("Mask Understanding"):
710
+ mu_state = gr.State(new_mu_state())
711
+ mu_point_is_pos = gr.State(True)
712
+
713
+ with gr.Row():
714
+ with gr.Column():
715
+ mu_media = gr.Image(type="filepath", label="Upload Image")
716
+ mu_click_img = gr.Image(label="Click to add points", interactive=True)
717
+
718
+ with gr.Row():
719
+ mu_pos_btn = gr.Button("Positive Point")
720
+ mu_neg_btn = gr.Button("Negative Point")
721
+ mu_clear_btn = gr.Button("Clear Prompts")
722
+ mu_save_btn = gr.Button("Save Region")
723
+
724
+ mu_region_dd = gr.Dropdown(label="Saved Regions", choices=[], interactive=True)
725
+
726
+ mu_query = gr.Textbox(label="Text Prompt", lines=3, max_lines=12)
727
+ mu_submit = gr.Button("Submit", variant="primary")
728
+
729
+ with gr.Column():
730
+ mu_mask_preview = gr.Image(label="Current Mask")
731
+ mu_answer = gr.Textbox(label="Model Response", lines=12)
732
+
733
+ mu_media.change(
734
+ fn=mu_on_upload_image,
735
+ inputs=[mu_media, mu_state],
736
+ outputs=[mu_state, mu_click_img, mu_mask_preview],
737
+ )
738
+
739
+ mu_pos_btn.click(lambda: True, None, mu_point_is_pos)
740
+ mu_neg_btn.click(lambda: False, None, mu_point_is_pos)
741
+
742
+ mu_click_img.select(
743
+ fn=mu_add_point,
744
+ inputs=[mu_state, mu_point_is_pos],
745
+ outputs=[mu_state, mu_mask_preview],
746
+ )
747
+
748
+ mu_clear_btn.click(mu_clear_prompts, [mu_state], [mu_state, mu_mask_preview])
749
+
750
+ mu_save_btn.click(mu_save_region, [mu_state], [mu_state, mu_region_dd])
751
+
752
+ mu_submit.click(
753
+ fn=infer_understanding,
754
+ inputs=[mu_media, mu_query, mu_state],
755
+ outputs=[mu_answer],
756
+ )
757
+
758
+ EXAMPLES_MU = [
759
+ ["examples/example1.jpeg"],
760
+ ["examples/example2.jpg"],
761
+ ["examples/example3.png"],
762
+ ["examples/example4.jpg"],
763
+ ["examples/example5.jpg"],
764
+ ["examples/example6.jpg"],
765
+ ]
766
+
767
+ gr.Markdown("## Examples")
768
+ gr.Examples(
769
+ examples=EXAMPLES_MU,
770
+ inputs=[mu_media], # only load image
771
+ label="Click an example to load the image",
772
+ )
773
 
774
  return demo
775
 
 
777
  demo = build_demo()
778
 
779
  demo.queue()
780
+ # demo.launch(server_name='0.0.0.0')
781
+ demo.launch()