Update app.py
Browse files
app.py
CHANGED
|
@@ -69,7 +69,6 @@ def fix_mt_format_comprehensive(text):
|
|
| 69 |
text = re.sub(pattern_too_few_no_end, replacement_too_few_no_end, text)
|
| 70 |
return text
|
| 71 |
|
| 72 |
-
|
| 73 |
MODEL = 'zhouyik/Qwen3-VL-8B-SAMTok'
|
| 74 |
|
| 75 |
TITLE = 'SAMTok: Representing Any Mask with Two Words'
|
|
@@ -153,7 +152,6 @@ def get_sam():
|
|
| 153 |
_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda").eval()
|
| 154 |
return _sam
|
| 155 |
|
| 156 |
-
|
| 157 |
colors = sample_color()
|
| 158 |
color_map = {f'Target {i + 1}': f'#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}' for i, c in enumerate(colors * 255)}
|
| 159 |
color_map_light = {
|
|
@@ -164,15 +162,12 @@ color_map_light = {
|
|
| 164 |
def enable_btns():
|
| 165 |
return (gr.update(interactive=True), ) * 4
|
| 166 |
|
| 167 |
-
|
| 168 |
def disable_btns():
|
| 169 |
return (gr.update(interactive=False), ) * 4
|
| 170 |
|
| 171 |
-
|
| 172 |
def reset_seg():
|
| 173 |
return 16, gr.update(interactive=False)
|
| 174 |
|
| 175 |
-
|
| 176 |
def reset_reg():
|
| 177 |
return 1, gr.update(interactive=False)
|
| 178 |
|
|
@@ -249,14 +244,13 @@ def mu_predict_mask_from_state(mu_state):
|
|
| 249 |
# postprocess needs lists/tensors on CPU
|
| 250 |
original_sizes = torch.tensor([mu_state["original_sizes"]], dtype=torch.long)
|
| 251 |
reshaped_sizes = torch.tensor([mu_state["reshaped_input_sizes"]], dtype=torch.long)
|
| 252 |
-
|
| 253 |
masks = sam_processor.post_process_masks(
|
| 254 |
outputs.pred_masks.detach().cpu(),
|
| 255 |
original_sizes,
|
| 256 |
reshaped_sizes,
|
| 257 |
)
|
| 258 |
-
mask = masks[0][0].numpy()
|
| 259 |
-
mask = (mask > 0).astype(np.
|
| 260 |
return mask
|
| 261 |
|
| 262 |
@spaces.GPU
|
|
@@ -272,13 +266,33 @@ def mu_add_point(evt: gr.SelectData, mu_state, is_positive: bool):
|
|
| 272 |
mu_state["cur_mask"] = mask
|
| 273 |
return mu_state, mask
|
| 274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
def mu_clear_prompts(mu_state):
|
| 276 |
mu_state["points"] = []
|
| 277 |
mu_state["labels"] = []
|
| 278 |
mu_state["cur_mask"] = None
|
| 279 |
return mu_state, None
|
| 280 |
|
| 281 |
-
|
| 282 |
@spaces.GPU
|
| 283 |
def mu_save_region(mu_state):
|
| 284 |
if mu_state["cur_mask"] is None:
|
|
@@ -468,9 +482,17 @@ def infer_understanding(mu_media, mu_query, mu_state):
|
|
| 468 |
inputs = processor.apply_chat_template(
|
| 469 |
messages, tokenize=True, add_generation_prompt=True,
|
| 470 |
return_dict=True, return_tensors="pt"
|
| 471 |
-
).to(device)
|
| 472 |
|
| 473 |
-
generated_ids = model.generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
generated_ids_trimmed = [
|
| 475 |
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 476 |
]
|
|
@@ -485,11 +507,11 @@ def infer_seg(media, query):
|
|
| 485 |
|
| 486 |
if not media:
|
| 487 |
gr.Warning('Please upload an image')
|
| 488 |
-
return None, None, None
|
| 489 |
|
| 490 |
if not query:
|
| 491 |
gr.Warning('Please provide a text prompt.')
|
| 492 |
-
return None, None, None
|
| 493 |
|
| 494 |
image = Image.open(media).convert('RGB')
|
| 495 |
ori_width, ori_height = image.size
|
|
@@ -518,8 +540,11 @@ def infer_seg(media, query):
|
|
| 518 |
generated_ids = model.generate(
|
| 519 |
**inputs,
|
| 520 |
max_new_tokens=1024,
|
| 521 |
-
|
| 522 |
-
|
|
|
|
|
|
|
|
|
|
| 523 |
)
|
| 524 |
generated_ids_trimmed = [
|
| 525 |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
@@ -536,7 +561,6 @@ def infer_seg(media, query):
|
|
| 536 |
answer,
|
| 537 |
gr.update(value=None, visible=False), # hide AnnotatedImage
|
| 538 |
gr.update(value=None, interactive=False, visible=False), # hide DownloadButton
|
| 539 |
-
gr.update(value="", visible=True),
|
| 540 |
)
|
| 541 |
|
| 542 |
if len(quant_ids) % CODEBOOK_DEPTH != 0:
|
|
@@ -549,17 +573,14 @@ def infer_seg(media, query):
|
|
| 549 |
answer,
|
| 550 |
gr.update(value=None, visible=False),
|
| 551 |
gr.update(value=None, interactive=False, visible=False),
|
| 552 |
-
gr.update(value="", visible=False),
|
| 553 |
)
|
| 554 |
|
| 555 |
batch_size = len(quant_ids) // CODEBOOK_DEPTH
|
| 556 |
remap_quant_ids = []
|
| 557 |
tags = []
|
| 558 |
-
short_tags = []
|
| 559 |
for bs_id in range(batch_size):
|
| 560 |
chunk_quant_ids = quant_ids[bs_id*CODEBOOK_DEPTH:(bs_id+1)*CODEBOOK_DEPTH]
|
| 561 |
-
tags.append(f'<|
|
| 562 |
-
short_tags.append(short_tag_from_codes(chunk_quant_ids[0], chunk_quant_ids[1]))
|
| 563 |
remap_chunk_quant_ids = [quant_id - book_id*CODEBOOK_SIZE for book_id, quant_id in enumerate(chunk_quant_ids)]
|
| 564 |
code1 = remap_chunk_quant_ids[0]
|
| 565 |
code2 = remap_chunk_quant_ids[1]
|
|
@@ -581,15 +602,12 @@ def infer_seg(media, query):
|
|
| 581 |
_pred_masks = vq_sam2.forward_with_codes(sam2_pixel_values, quant_ids)
|
| 582 |
_pred_masks = torch.nn.functional.interpolate(_pred_masks, size=(ori_height, ori_width), mode='bilinear')
|
| 583 |
_pred_masks = _pred_masks > 0.5
|
| 584 |
-
# _pred_masks = _pred_masks[:, 0, :, :].cpu().numpy().astype(np.uint8)
|
| 585 |
_pred_masks = _pred_masks.long().unsqueeze(2).cpu() # n, 1, 1, h, w
|
| 586 |
|
| 587 |
tag_to_mask_idx = {}
|
| 588 |
-
|
| 589 |
-
for i, (tag, stag) in enumerate(zip(tags, short_tags)):
|
| 590 |
if tag not in tag_to_mask_idx:
|
| 591 |
tag_to_mask_idx[tag] = i
|
| 592 |
-
tag_to_short[tag] = stag
|
| 593 |
unique_tags = list(tag_to_mask_idx.keys())
|
| 594 |
|
| 595 |
entities = []
|
|
@@ -606,35 +624,19 @@ def infer_seg(media, query):
|
|
| 606 |
iio.imwrite(path, imgs, duration=100, loop=0)
|
| 607 |
|
| 608 |
mask_items = []
|
| 609 |
-
entity_names =
|
| 610 |
for i, tag in enumerate(unique_tags):
|
| 611 |
-
m = _pred_masks[tag_to_mask_idx[tag]][0, 0].numpy()
|
| 612 |
mask_items.append((m, entity_names[i]))
|
| 613 |
masks_value = (media, mask_items)
|
| 614 |
|
| 615 |
-
lines = []
|
| 616 |
-
for i, tag in enumerate(unique_tags):
|
| 617 |
-
short_tag = tag_to_short[tag]
|
| 618 |
-
lines.append(f"- **{entity_names[i]}** → `{short_tag}`")
|
| 619 |
-
tag_map_text = "### Mask-Token Mapping\n" + "\n".join(lines)
|
| 620 |
-
|
| 621 |
-
# dynamic color maps keyed by tag
|
| 622 |
-
dyn_color_map = {}
|
| 623 |
-
dyn_color_map_light = {}
|
| 624 |
-
for i, tag in enumerate(unique_tags):
|
| 625 |
-
c = colors[i % len(colors)]
|
| 626 |
-
dyn_color_map[entity_names[i]] = f'#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}'
|
| 627 |
-
dyn_color_map_light[tag] = f'#{int(c[0] * 127.5 + 127.5):02x}{int(c[1] * 127.5 + 127.5):02x}{int(c[2] * 127.5 + 127.5):02x}'
|
| 628 |
-
|
| 629 |
# return answer, masks, path
|
| 630 |
return (
|
| 631 |
-
gr.update(value=answer,
|
| 632 |
gr.update(value=masks_value, visible=True), # msk_1
|
| 633 |
gr.update(value=path, interactive=True, visible=True), # download
|
| 634 |
-
gr.update(value=tag_map_text, visible=True)
|
| 635 |
)
|
| 636 |
|
| 637 |
-
|
| 638 |
def build_demo():
|
| 639 |
with gr.Blocks(title=TITLE, js=JS, theme=gr.themes.Soft()) as demo:
|
| 640 |
gr.HTML(HEADER)
|
|
@@ -650,7 +652,6 @@ def build_demo():
|
|
| 650 |
|
| 651 |
sample_frames_1 = gr.Slider(1, 32, value=16, step=1, visible=False)
|
| 652 |
|
| 653 |
-
# query_1 = gr.Textbox(label='Text Prompt', placeholder='Please segment the...', elem_id='query_1')
|
| 654 |
query_1 = gr.Textbox(
|
| 655 |
label='Text Prompt',
|
| 656 |
placeholder='Please segment the...',
|
|
@@ -672,10 +673,9 @@ def build_demo():
|
|
| 672 |
with gr.Column():
|
| 673 |
msk_1.render()
|
| 674 |
ans_1.render()
|
| 675 |
-
tag_map_md = gr.Markdown(value="", visible=False)
|
| 676 |
|
| 677 |
ctx_1 = submit_btn_1.click(disable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
|
| 678 |
-
ctx_1 = ctx_1.then(infer_seg, [media_1, query_1], [ans_1, msk_1, download_btn_1
|
| 679 |
ctx_1.then(enable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
|
| 680 |
|
| 681 |
EXAMPLES = [
|
|
@@ -684,7 +684,7 @@ def build_demo():
|
|
| 684 |
["examples/example3.png", "Find all the people who are currently standing and response with segmentation masks."],
|
| 685 |
["examples/example4.jpg", "Segment every instance that belongs to the following categories: person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic light, fire hydrant, stop sign, parking meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports ball, kite, baseball bat, baseball glove, skateboard, surfboard, tennis racket, bottle, wine glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot dog, pizza, donut, cake, chair, couch, potted plant, bed, dining table, toilet, tv, laptop, mouse, remote, keyboard, cell phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy bear, hair drier, toothbrush, banner, blanket, bridge, cardboard, counter, curtain, door-stuff, floor-wood, flower, fruit, gravel, house, light, mirror-stuff, net, pillow, platform, playingfield, railroad, river, road, roof, sand, sea, shelf, snow, stairs, tent, towel, wall-brick, wall-stone, wall-tile, wall-wood, water-other, window-blind, window-other, tree-merged, fence-merged, ceiling-merged, sky-other-merged, cabinet-merged, table-merged, floor-other-merged, pavement-merged, mountain-merged, grass-merged, dirt-merged, paper-merged, food-other-merged, building-other-merged, rock-merged, wall-other-merged, rug-merged"],
|
| 686 |
["examples/example5.jpg", "Generate a scene graph for this image. Identify the main objects and describe their relationships to each other."],
|
| 687 |
-
["examples/example6.jpg", "
|
| 688 |
]
|
| 689 |
gr.Markdown("## Examples")
|
| 690 |
gr.Examples(
|
|
@@ -693,6 +693,24 @@ def build_demo():
|
|
| 693 |
label="Click an example to load the image and prompt",
|
| 694 |
)
|
| 695 |
with gr.Tab("Mask Understanding"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
mu_state = gr.State(new_mu_state())
|
| 697 |
mu_point_is_pos = gr.State(True)
|
| 698 |
|
|
@@ -725,9 +743,20 @@ def build_demo():
|
|
| 725 |
mu_pos_btn.click(lambda: True, None, mu_point_is_pos)
|
| 726 |
mu_neg_btn.click(lambda: False, None, mu_point_is_pos)
|
| 727 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
mu_click_img.select(
|
| 729 |
-
fn=
|
| 730 |
-
inputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
outputs=[mu_state, mu_mask_preview],
|
| 732 |
)
|
| 733 |
|
|
@@ -763,5 +792,4 @@ if __name__ == '__main__':
|
|
| 763 |
demo = build_demo()
|
| 764 |
|
| 765 |
demo.queue()
|
| 766 |
-
# demo.launch(server_name='0.0.0.0')
|
| 767 |
demo.launch()
|
|
|
|
| 69 |
text = re.sub(pattern_too_few_no_end, replacement_too_few_no_end, text)
|
| 70 |
return text
|
| 71 |
|
|
|
|
| 72 |
MODEL = 'zhouyik/Qwen3-VL-8B-SAMTok'
|
| 73 |
|
| 74 |
TITLE = 'SAMTok: Representing Any Mask with Two Words'
|
|
|
|
| 152 |
_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda").eval()
|
| 153 |
return _sam
|
| 154 |
|
|
|
|
| 155 |
colors = sample_color()
|
| 156 |
color_map = {f'Target {i + 1}': f'#{int(c[0]):02x}{int(c[1]):02x}{int(c[2]):02x}' for i, c in enumerate(colors * 255)}
|
| 157 |
color_map_light = {
|
|
|
|
| 162 |
def enable_btns():
|
| 163 |
return (gr.update(interactive=True), ) * 4
|
| 164 |
|
|
|
|
| 165 |
def disable_btns():
|
| 166 |
return (gr.update(interactive=False), ) * 4
|
| 167 |
|
|
|
|
| 168 |
def reset_seg():
|
| 169 |
return 16, gr.update(interactive=False)
|
| 170 |
|
|
|
|
| 171 |
def reset_reg():
|
| 172 |
return 1, gr.update(interactive=False)
|
| 173 |
|
|
|
|
| 244 |
# postprocess needs lists/tensors on CPU
|
| 245 |
original_sizes = torch.tensor([mu_state["original_sizes"]], dtype=torch.long)
|
| 246 |
reshaped_sizes = torch.tensor([mu_state["reshaped_input_sizes"]], dtype=torch.long)
|
|
|
|
| 247 |
masks = sam_processor.post_process_masks(
|
| 248 |
outputs.pred_masks.detach().cpu(),
|
| 249 |
original_sizes,
|
| 250 |
reshaped_sizes,
|
| 251 |
)
|
| 252 |
+
mask = masks[0][0][0].numpy()
|
| 253 |
+
mask = (mask > 0).astype(np.float32)
|
| 254 |
return mask
|
| 255 |
|
| 256 |
@spaces.GPU
|
|
|
|
| 266 |
mu_state["cur_mask"] = mask
|
| 267 |
return mu_state, mask
|
| 268 |
|
| 269 |
+
@spaces.GPU
|
| 270 |
+
def mu_add_point_xy(xy, mu_state, is_positive: bool):
|
| 271 |
+
if mu_state["image_path"] is None:
|
| 272 |
+
return mu_state, None
|
| 273 |
+
|
| 274 |
+
if xy is None:
|
| 275 |
+
return mu_state, mu_state.get("cur_mask")
|
| 276 |
+
|
| 277 |
+
x, y = xy # xy is a tuple/list of two ints
|
| 278 |
+
mu_state["points"].append([float(x), float(y)])
|
| 279 |
+
mu_state["labels"].append(1 if is_positive else 0)
|
| 280 |
+
|
| 281 |
+
mask = mu_predict_mask_from_state(mu_state)
|
| 282 |
+
mu_state["cur_mask"] = mask
|
| 283 |
+
return mu_state, mask
|
| 284 |
+
|
| 285 |
+
def mu_evt_to_xy(evt: gr.SelectData):
|
| 286 |
+
# return plain python types only (picklable)
|
| 287 |
+
x, y = evt.index
|
| 288 |
+
return (int(x), int(y))
|
| 289 |
+
|
| 290 |
def mu_clear_prompts(mu_state):
|
| 291 |
mu_state["points"] = []
|
| 292 |
mu_state["labels"] = []
|
| 293 |
mu_state["cur_mask"] = None
|
| 294 |
return mu_state, None
|
| 295 |
|
|
|
|
| 296 |
@spaces.GPU
|
| 297 |
def mu_save_region(mu_state):
|
| 298 |
if mu_state["cur_mask"] is None:
|
|
|
|
| 482 |
inputs = processor.apply_chat_template(
|
| 483 |
messages, tokenize=True, add_generation_prompt=True,
|
| 484 |
return_dict=True, return_tensors="pt"
|
| 485 |
+
).to(model.device)
|
| 486 |
|
| 487 |
+
generated_ids = model.generate(
|
| 488 |
+
**inputs,
|
| 489 |
+
max_new_tokens=1024,
|
| 490 |
+
do_sample=True,
|
| 491 |
+
top_p=0.8,
|
| 492 |
+
top_k=20,
|
| 493 |
+
temperature=0.7,
|
| 494 |
+
repetition_penalty=1.0,
|
| 495 |
+
)
|
| 496 |
generated_ids_trimmed = [
|
| 497 |
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
| 498 |
]
|
|
|
|
| 507 |
|
| 508 |
if not media:
|
| 509 |
gr.Warning('Please upload an image')
|
| 510 |
+
return None, None, None
|
| 511 |
|
| 512 |
if not query:
|
| 513 |
gr.Warning('Please provide a text prompt.')
|
| 514 |
+
return None, None, None
|
| 515 |
|
| 516 |
image = Image.open(media).convert('RGB')
|
| 517 |
ori_width, ori_height = image.size
|
|
|
|
| 540 |
generated_ids = model.generate(
|
| 541 |
**inputs,
|
| 542 |
max_new_tokens=1024,
|
| 543 |
+
do_sample=True,
|
| 544 |
+
top_p=0.8,
|
| 545 |
+
top_k=20,
|
| 546 |
+
temperature=0.7,
|
| 547 |
+
repetition_penalty=1.0,
|
| 548 |
)
|
| 549 |
generated_ids_trimmed = [
|
| 550 |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
|
|
| 561 |
answer,
|
| 562 |
gr.update(value=None, visible=False), # hide AnnotatedImage
|
| 563 |
gr.update(value=None, interactive=False, visible=False), # hide DownloadButton
|
|
|
|
| 564 |
)
|
| 565 |
|
| 566 |
if len(quant_ids) % CODEBOOK_DEPTH != 0:
|
|
|
|
| 573 |
answer,
|
| 574 |
gr.update(value=None, visible=False),
|
| 575 |
gr.update(value=None, interactive=False, visible=False),
|
|
|
|
| 576 |
)
|
| 577 |
|
| 578 |
batch_size = len(quant_ids) // CODEBOOK_DEPTH
|
| 579 |
remap_quant_ids = []
|
| 580 |
tags = []
|
|
|
|
| 581 |
for bs_id in range(batch_size):
|
| 582 |
chunk_quant_ids = quant_ids[bs_id*CODEBOOK_DEPTH:(bs_id+1)*CODEBOOK_DEPTH]
|
| 583 |
+
tags.append(f'<|mt_{str(chunk_quant_ids[0]).zfill(4)}|><|mt_{str(chunk_quant_ids[1]).zfill(4)}|>')
|
|
|
|
| 584 |
remap_chunk_quant_ids = [quant_id - book_id*CODEBOOK_SIZE for book_id, quant_id in enumerate(chunk_quant_ids)]
|
| 585 |
code1 = remap_chunk_quant_ids[0]
|
| 586 |
code2 = remap_chunk_quant_ids[1]
|
|
|
|
| 602 |
_pred_masks = vq_sam2.forward_with_codes(sam2_pixel_values, quant_ids)
|
| 603 |
_pred_masks = torch.nn.functional.interpolate(_pred_masks, size=(ori_height, ori_width), mode='bilinear')
|
| 604 |
_pred_masks = _pred_masks > 0.5
|
|
|
|
| 605 |
_pred_masks = _pred_masks.long().unsqueeze(2).cpu() # n, 1, 1, h, w
|
| 606 |
|
| 607 |
tag_to_mask_idx = {}
|
| 608 |
+
for i, tag in enumerate(tags):
|
|
|
|
| 609 |
if tag not in tag_to_mask_idx:
|
| 610 |
tag_to_mask_idx[tag] = i
|
|
|
|
| 611 |
unique_tags = list(tag_to_mask_idx.keys())
|
| 612 |
|
| 613 |
entities = []
|
|
|
|
| 624 |
iio.imwrite(path, imgs, duration=100, loop=0)
|
| 625 |
|
| 626 |
mask_items = []
|
| 627 |
+
entity_names = unique_tags
|
| 628 |
for i, tag in enumerate(unique_tags):
|
| 629 |
+
m = _pred_masks[tag_to_mask_idx[tag]][0, 0].numpy()
|
| 630 |
mask_items.append((m, entity_names[i]))
|
| 631 |
masks_value = (media, mask_items)
|
| 632 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
# return answer, masks, path
|
| 634 |
return (
|
| 635 |
+
gr.update(value=answer, visible=True), # ans_1
|
| 636 |
gr.update(value=masks_value, visible=True), # msk_1
|
| 637 |
gr.update(value=path, interactive=True, visible=True), # download
|
|
|
|
| 638 |
)
|
| 639 |
|
|
|
|
| 640 |
def build_demo():
|
| 641 |
with gr.Blocks(title=TITLE, js=JS, theme=gr.themes.Soft()) as demo:
|
| 642 |
gr.HTML(HEADER)
|
|
|
|
| 652 |
|
| 653 |
sample_frames_1 = gr.Slider(1, 32, value=16, step=1, visible=False)
|
| 654 |
|
|
|
|
| 655 |
query_1 = gr.Textbox(
|
| 656 |
label='Text Prompt',
|
| 657 |
placeholder='Please segment the...',
|
|
|
|
| 673 |
with gr.Column():
|
| 674 |
msk_1.render()
|
| 675 |
ans_1.render()
|
|
|
|
| 676 |
|
| 677 |
ctx_1 = submit_btn_1.click(disable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
|
| 678 |
+
ctx_1 = ctx_1.then(infer_seg, [media_1, query_1], [ans_1, msk_1, download_btn_1])
|
| 679 |
ctx_1.then(enable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
|
| 680 |
|
| 681 |
EXAMPLES = [
|
|
|
|
| 684 |
["examples/example3.png", "Find all the people who are currently standing and response with segmentation masks."],
|
| 685 |
["examples/example4.jpg", "Segment every instance that belongs to the following categories: person, bicycle, car, motorcycle, airplane, bus, train, truck, boat, traffic light, fire hydrant, stop sign, parking meter, bench, bird, cat, dog, horse, sheep, cow, elephant, bear, zebra, giraffe, backpack, umbrella, handbag, tie, suitcase, frisbee, skis, snowboard, sports ball, kite, baseball bat, baseball glove, skateboard, surfboard, tennis racket, bottle, wine glass, cup, fork, knife, spoon, bowl, banana, apple, sandwich, orange, broccoli, carrot, hot dog, pizza, donut, cake, chair, couch, potted plant, bed, dining table, toilet, tv, laptop, mouse, remote, keyboard, cell phone, microwave, oven, toaster, sink, refrigerator, book, clock, vase, scissors, teddy bear, hair drier, toothbrush, banner, blanket, bridge, cardboard, counter, curtain, door-stuff, floor-wood, flower, fruit, gravel, house, light, mirror-stuff, net, pillow, platform, playingfield, railroad, river, road, roof, sand, sea, shelf, snow, stairs, tent, towel, wall-brick, wall-stone, wall-tile, wall-wood, water-other, window-blind, window-other, tree-merged, fence-merged, ceiling-merged, sky-other-merged, cabinet-merged, table-merged, floor-other-merged, pavement-merged, mountain-merged, grass-merged, dirt-merged, paper-merged, food-other-merged, building-other-merged, rock-merged, wall-other-merged, rug-merged"],
|
| 686 |
["examples/example5.jpg", "Generate a scene graph for this image. Identify the main objects and describe their relationships to each other."],
|
| 687 |
+
["examples/example6.jpg", "What item for sale indicates that the primary product is also offered in a ready-to-eat form? A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>"]
|
| 688 |
]
|
| 689 |
gr.Markdown("## Examples")
|
| 690 |
gr.Examples(
|
|
|
|
| 693 |
label="Click an example to load the image and prompt",
|
| 694 |
)
|
| 695 |
with gr.Tab("Mask Understanding"):
|
| 696 |
+
MU_INSTRUCTIONS = """
|
| 697 |
+
### Mask Understanding — Instructions
|
| 698 |
+
|
| 699 |
+
1. **Upload an image.**
|
| 700 |
+
2. **Create a region mask**
|
| 701 |
+
- Click **Clear Prompts**
|
| 702 |
+
- Click **Positive Point**, then click on the target region in the image.
|
| 703 |
+
- The **Current Mask** preview updates after each click. Add more clicks to refine the mask.
|
| 704 |
+
- Click **Save Region** to store the current mask. A new region ID (e.g., `region1`) will be created.
|
| 705 |
+
3. *(Optional)* Repeat Step 2 to add more regions.
|
| 706 |
+
4. **Enter a text prompt.** When referring to a saved region, use its exact auto-generated ID (e.g., `region1`), e.g. `Given a detailed description of region1.`
|
| 707 |
+
You can reference multiple regions, e.g. `Compare region1 and region2 and describe their differences.`
|
| 708 |
+
|
| 709 |
+
**Tips:** Use **Negative Point** to remove unwanted parts; use **Clear Prompts** to reset points.
|
| 710 |
+
"""
|
| 711 |
+
with gr.Accordion("Instructions (click to expand)", open=False):
|
| 712 |
+
gr.Markdown(MU_INSTRUCTIONS)
|
| 713 |
+
mu_click_xy = gr.State(None)
|
| 714 |
mu_state = gr.State(new_mu_state())
|
| 715 |
mu_point_is_pos = gr.State(True)
|
| 716 |
|
|
|
|
| 743 |
mu_pos_btn.click(lambda: True, None, mu_point_is_pos)
|
| 744 |
mu_neg_btn.click(lambda: False, None, mu_point_is_pos)
|
| 745 |
|
| 746 |
+
# mu_click_img.select(
|
| 747 |
+
# fn=mu_add_point,
|
| 748 |
+
# inputs=[mu_state, mu_point_is_pos],
|
| 749 |
+
# outputs=[mu_state, mu_mask_preview],
|
| 750 |
+
# )
|
| 751 |
+
|
| 752 |
mu_click_img.select(
|
| 753 |
+
fn=mu_evt_to_xy,
|
| 754 |
+
inputs=None,
|
| 755 |
+
outputs=mu_click_xy,
|
| 756 |
+
queue=False,
|
| 757 |
+
).then(
|
| 758 |
+
fn=mu_add_point_xy,
|
| 759 |
+
inputs=[mu_click_xy, mu_state, mu_point_is_pos],
|
| 760 |
outputs=[mu_state, mu_mask_preview],
|
| 761 |
)
|
| 762 |
|
|
|
|
| 792 |
demo = build_demo()
|
| 793 |
|
| 794 |
demo.queue()
|
|
|
|
| 795 |
demo.launch()
|