insomnia7 commited on
Commit
2b62d60
·
verified ·
1 Parent(s): f56fd54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -24
app.py CHANGED
@@ -548,6 +548,7 @@ def infer_seg(media, query):
548
  answer,
549
  gr.update(value=None, visible=False),
550
  gr.update(value=None, interactive=False, visible=False),
 
551
  )
552
 
553
  batch_size = len(quant_ids) // CODEBOOK_DEPTH
@@ -597,34 +598,18 @@ def infer_seg(media, query):
597
 
598
  answer = dict(text=output_text, entities=entities)
599
 
600
- # entities = []
601
- # unique_tags = list(set(tags))
602
- # entity_names = []
603
- # for i, tag in enumerate(unique_tags):
604
- # for m in re.finditer(re.escape(tag), output_text):
605
- # entities.append(dict(entity=f'Target {i + 1}', start=m.start(), end=m.end()))
606
- # entity_names.append(f'Target {i + 1}')
607
-
608
- # answer = dict(text=output_text, entities=entities)
609
-
610
  frames = torch.from_numpy(np.array(image)).unsqueeze(0)
611
  imgs = draw_mask(frames, _pred_masks, colors=colors)
612
 
613
  path = f"/tmp/{uuid.uuid4().hex}.png"
614
  iio.imwrite(path, imgs, duration=100, loop=0)
615
 
616
- base_img = Image.open(media).convert("RGB")
617
  entity_names = [f"Target {i+1}" for i in range(len(unique_tags))]
618
- masks_value = (
619
- base_img,
620
- [
621
- (
622
- _pred_masks[tag_to_mask_idx[tag]][0, 0].numpy().astype(bool),
623
- entity_names[i]
624
- )
625
- for i, tag in enumerate(unique_tags)
626
- ]
627
- )
628
 
629
  lines = []
630
  for i, tag in enumerate(unique_tags):
@@ -658,7 +643,6 @@ def build_demo():
658
  msk_1 = gr.AnnotatedImage(label='De-tokenized 2D masks', color_map=color_map, render=False)
659
  ans_1 = gr.HighlightedText(
660
  label='Model Response', color_map=color_map_light, show_inline_category=False, render=False)
661
- tag_map_md = gr.Markdown(value="", visible=False)
662
  with gr.Row():
663
  with gr.Column():
664
  media_1 = gr.Image(type='filepath')
@@ -677,7 +661,7 @@ def build_demo():
677
  with gr.Row():
678
  random_btn_1 = gr.Button(value='🔮 Random', visible=False)
679
 
680
- reset_btn_1 = gr.ClearButton([media_1, query_1, msk_1, ans_1, tag_map_md], value='🗑️ Reset')
681
  reset_btn_1.click(reset_seg, None, [sample_frames_1, download_btn_1])
682
 
683
  download_btn_1.render()
@@ -687,7 +671,7 @@ def build_demo():
687
  with gr.Column():
688
  msk_1.render()
689
  ans_1.render()
690
- tag_map_md
691
 
692
  ctx_1 = submit_btn_1.click(disable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
693
  ctx_1 = ctx_1.then(infer_seg, [media_1, query_1], [ans_1, msk_1, download_btn_1, tag_map_md])
 
548
  answer,
549
  gr.update(value=None, visible=False),
550
  gr.update(value=None, interactive=False, visible=False),
551
+ gr.update(value="", visible=False),
552
  )
553
 
554
  batch_size = len(quant_ids) // CODEBOOK_DEPTH
 
598
 
599
  answer = dict(text=output_text, entities=entities)
600
 
 
 
 
 
 
 
 
 
 
 
601
  frames = torch.from_numpy(np.array(image)).unsqueeze(0)
602
  imgs = draw_mask(frames, _pred_masks, colors=colors)
603
 
604
  path = f"/tmp/{uuid.uuid4().hex}.png"
605
  iio.imwrite(path, imgs, duration=100, loop=0)
606
 
607
+ mask_items = []
608
  entity_names = [f"Target {i+1}" for i in range(len(unique_tags))]
609
+ for i, tag in enumerate(unique_tags):
610
+ m = _pred_masks[tag_to_mask_idx[tag]][0, 0].numpy().astype(np.uint8) * 255 # (H,W), 0/255
611
+ mask_items.append((m, entity_names[i]))
612
+ masks_value = (media, mask_items)
 
 
 
 
 
 
613
 
614
  lines = []
615
  for i, tag in enumerate(unique_tags):
 
643
  msk_1 = gr.AnnotatedImage(label='De-tokenized 2D masks', color_map=color_map, render=False)
644
  ans_1 = gr.HighlightedText(
645
  label='Model Response', color_map=color_map_light, show_inline_category=False, render=False)
 
646
  with gr.Row():
647
  with gr.Column():
648
  media_1 = gr.Image(type='filepath')
 
661
  with gr.Row():
662
  random_btn_1 = gr.Button(value='🔮 Random', visible=False)
663
 
664
+ reset_btn_1 = gr.ClearButton([media_1, query_1, msk_1, ans_1], value='🗑️ Reset')
665
  reset_btn_1.click(reset_seg, None, [sample_frames_1, download_btn_1])
666
 
667
  download_btn_1.render()
 
671
  with gr.Column():
672
  msk_1.render()
673
  ans_1.render()
674
+ tag_map_md = gr.Markdown(value="", visible=False)
675
 
676
  ctx_1 = submit_btn_1.click(disable_btns, None, [random_btn_1, reset_btn_1, download_btn_1, submit_btn_1])
677
  ctx_1 = ctx_1.then(infer_seg, [media_1, query_1], [ans_1, msk_1, download_btn_1, tag_map_md])