Arulkumar03 commited on
Commit
2b09f60
·
1 Parent(s): c3b4316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -24
app.py CHANGED
@@ -93,28 +93,26 @@ def draw_mask(mask, image, random_color=True):
93
  return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
94
 
95
 
96
- def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
97
  init_image = input_image.convert("RGB")
98
  original_size = init_image.size
99
 
100
  _, image_tensor = image_transform_grounding(init_image)
101
  image_pil: Image = image_transform_grounding_for_vis(init_image)
102
 
103
- # run grounidng
104
- if task=='predict':
105
- boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
106
- annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
107
- image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
108
-
109
- return image_with_box
110
-
111
- elif task=='segment':
112
  boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
113
  segmented_frame_masks = segment(image_tensor, model, boxes=boxes)
114
  annotated_frame_with_mask = draw_mask(segmented_frame_masks[0][0], annotated_frame)
115
- seg_with_bbox=Image.fromarray(annotated_frame_with_mask)
116
-
117
- return seg_with_bbox
 
 
 
 
 
 
118
 
119
 
120
  if __name__ == "__main__":
@@ -136,9 +134,13 @@ if __name__ == "__main__":
136
  gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/Arulkumar03/SOTA-Grounding-DINO.ipynb'>Grounding DINO</a><h3><center>")
137
  gr.Markdown("<h3><center>Note the model runs on CPU, so it may take a while to run the model.<h3><center>")
138
 
 
139
  with gr.Row():
140
  with gr.Column():
141
  input_image = gr.Image(source='upload', type="pil")
 
 
 
142
  grounding_caption = gr.Textbox(label="Detection Prompt")
143
  run_button = gr.Button(label="Run")
144
  with gr.Accordion("Advanced options", open=False):
@@ -154,18 +156,15 @@ if __name__ == "__main__":
154
  type="pil",
155
  # label="grounding results"
156
  ).style(full_width=True, full_height=True)
157
- # gallery = gr.Gallery(label="Generated images", show_label=False).style(
158
- # grid=[1], height="auto", container=True, full_width=True, full_height=True)
159
 
160
  run_button.click(fn=run_grounding, inputs=[
161
- input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
162
  gr.Examples(
163
- [["watermelon.jpg", "watermelon", 0.25, 0.25]],
164
- inputs = [input_image, grounding_caption, box_threshold, text_threshold],
165
- outputs = [gallery,gr.Choice(["segment", "classify"], label="Select Task")],
166
- fn=run_grounding,
167
- cache_examples=True,
168
- label='Try this example input!'
169
- )
170
  block.launch(share=False, show_api=False, show_error=True)
171
-
 
93
  return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
94
 
95
 
96
+ def run_grounding(input_image,choice, grounding_caption, box_threshold, text_threshold,do_segmentation):
97
  init_image = input_image.convert("RGB")
98
  original_size = init_image.size
99
 
100
  _, image_tensor = image_transform_grounding(init_image)
101
  image_pil: Image = image_transform_grounding_for_vis(init_image)
102
 
103
+ if choice == 'segment':
 
 
 
 
 
 
 
 
104
  boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
105
  segmented_frame_masks = segment(image_tensor, model, boxes=boxes)
106
  annotated_frame_with_mask = draw_mask(segmented_frame_masks[0][0], annotated_frame)
107
+ else:
108
+ # run grounding
109
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
110
+ annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
111
+
112
+ image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
113
+
114
+ return image_with_box
115
+
116
 
117
 
118
  if __name__ == "__main__":
 
134
  gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/Arulkumar03/SOTA-Grounding-DINO.ipynb'>Grounding DINO</a><h3><center>")
135
  gr.Markdown("<h3><center>Note the model runs on CPU, so it may take a while to run the model.<h3><center>")
136
 
137
+
138
  with gr.Row():
139
  with gr.Column():
140
  input_image = gr.Image(source='upload', type="pil")
141
+ choice = gr.Radio(
142
+ ["segment", "classify"], default="segment", label="Choose Operation"
143
+ )
144
  grounding_caption = gr.Textbox(label="Detection Prompt")
145
  run_button = gr.Button(label="Run")
146
  with gr.Accordion("Advanced options", open=False):
 
156
  type="pil",
157
  # label="grounding results"
158
  ).style(full_width=True, full_height=True)
 
 
159
 
160
  run_button.click(fn=run_grounding, inputs=[
161
+ input_image, choice, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
162
  gr.Examples(
163
+ [["watermelon.jpg", "segment", "watermelon", 0.25, 0.25]],
164
+ inputs=[input_image, choice, grounding_caption, box_threshold, text_threshold],
165
+ outputs=[gallery],
166
+ fn=run_grounding,
167
+ cache_examples=True,
168
+ label='Try this example input!'
169
+ )
170
  block.launch(share=False, show_api=False, show_error=True)