jovian commited on
Commit
7c84458
·
1 Parent(s): 409fbec

slicing input

Browse files
Files changed (1) hide show
  1. app.py +32 -13
app.py CHANGED
@@ -79,15 +79,15 @@ class Detection:
79
  device='cuda:0'
80
  )
81
 
82
- def detect_from_image(self, image):
83
  # Perform sliced prediction with both models
84
  results1 = get_sliced_prediction(
85
  image=image,
86
  detection_model=self.model1,
87
- slice_height=256,
88
- slice_width=256,
89
- overlap_height_ratio=0.5,
90
- overlap_width_ratio=0.5,
91
  postprocess_type='NMS',
92
  postprocess_match_metric='IOU',
93
  postprocess_match_threshold=0.1,
@@ -97,10 +97,10 @@ class Detection:
97
  results2 = get_sliced_prediction(
98
  image=image,
99
  detection_model=self.model2,
100
- slice_height=256,
101
- slice_width=256,
102
- overlap_height_ratio=0.5,
103
- overlap_width_ratio=0.5,
104
  postprocess_type='NMS',
105
  postprocess_match_metric='IOU',
106
  postprocess_match_threshold=0.1,
@@ -321,13 +321,13 @@ def upload_image(image):
321
  return image
322
 
323
  @spaces.GPU
324
- def apply_detection(image):
325
  """Run object detection on the uploaded image and return the annotated image."""
326
  # Convert image from PIL to NumPy array
327
  img = np.array(image)
328
 
329
  # Perform detection and get COCO annotations
330
- annotations = detection.detect_from_image(img)
331
 
332
  # Draw the annotations on the image using OpenCV
333
  annotated_image = detection.draw_annotations(img, annotations)
@@ -614,6 +614,19 @@ with gr.Blocks() as demo:
614
 
615
  """)
616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
 
618
  with gr.Row(visible=False) as input_row:
619
  # Image Upload and Display in two columns
@@ -626,7 +639,7 @@ with gr.Blocks() as demo:
626
  output_image_component = gr.Image(type="pil", label="Annotated Image")
627
  apply_detection_btn = gr.Button("Apply Detection", variant='primary')
628
  output_annotations = gr.State() # Store annotations
629
- apply_detection_btn.click(apply_detection, inputs=upload_image_component, outputs=[output_image_component, output_annotations])
630
 
631
 
632
 
@@ -753,6 +766,9 @@ with gr.Blocks() as demo:
753
  ).then(
754
  lambda login_state: (
755
  gr.update(visible=login_state), # Show header_row
 
 
 
756
  gr.update(visible=login_state), # Show input_row
757
  gr.update(visible=login_state), # Show area_graph_row
758
  gr.update(visible=login_state), # Show area_btn_row
@@ -768,7 +784,10 @@ with gr.Blocks() as demo:
768
  gr.update(visible=not login_state) # for login
769
  ),
770
  inputs=login_successful,
771
- outputs=[header_row,
 
 
 
772
  input_row,
773
  area_graph_row,
774
  area_btn_row,
 
79
  device='cuda:0'
80
  )
81
 
82
+ def detect_from_image(self, image,slice_width_input,slice_height_input,overlap_width_input,overlap_height_input):
83
  # Perform sliced prediction with both models
84
  results1 = get_sliced_prediction(
85
  image=image,
86
  detection_model=self.model1,
87
+ slice_height=slice_height_input,
88
+ slice_width=slice_width_input,
89
+ overlap_height_ratio=overlap_height_input,
90
+ overlap_width_ratio=overlap_width_input,
91
  postprocess_type='NMS',
92
  postprocess_match_metric='IOU',
93
  postprocess_match_threshold=0.1,
 
97
  results2 = get_sliced_prediction(
98
  image=image,
99
  detection_model=self.model2,
100
+ slice_height=slice_height_input,
101
+ slice_width=slice_width_input,
102
+ overlap_height_ratio=overlap_height_input,
103
+ overlap_width_ratio=overlap_width_input,
104
  postprocess_type='NMS',
105
  postprocess_match_metric='IOU',
106
  postprocess_match_threshold=0.1,
 
321
  return image
322
 
323
  @spaces.GPU
324
+ def apply_detection(image,slice_width_input,slice_height_input,overlap_width_input,overlap_height_input):
325
  """Run object detection on the uploaded image and return the annotated image."""
326
  # Convert image from PIL to NumPy array
327
  img = np.array(image)
328
 
329
  # Perform detection and get COCO annotations
330
+ annotations = detection.detect_from_image(img,slice_width_input,slice_height_input,overlap_width_input,overlap_height_input)
331
 
332
  # Draw the annotations on the image using OpenCV
333
  annotated_image = detection.draw_annotations(img, annotations)
 
614
 
615
  """)
616
 
617
+ with gr.Row(visible=False) as slicing_text:
618
+ gr.Markdown("### Choose the width and height dimension and the overlapping ratio of the slice to determine how small the model can detect")
619
+
620
+
621
+ with gr.Row(visible=False) as slicing_dim_input:
622
+ # Add inputs for width and height
623
+ slice_width_input = gr.Number(label="Slice Width (pixels)", value=256)
624
+ slice_height_input = gr.Number(label="Slice Height (pixels)", value=256)
625
+
626
+ with gr.Row(visible=False) as slicing_overlap_input:
627
+ overlap_width_input = gr.Slider(0, 1, step=0.01, label="Overlap Width Ratio", value=0.5)
628
+ overlap_height_input = gr.Slider(0, 1, step=0.01, label="Overlap Height Ratio", value=0.5)
629
+
630
 
631
  with gr.Row(visible=False) as input_row:
632
  # Image Upload and Display in two columns
 
639
  output_image_component = gr.Image(type="pil", label="Annotated Image")
640
  apply_detection_btn = gr.Button("Apply Detection", variant='primary')
641
  output_annotations = gr.State() # Store annotations
642
+ apply_detection_btn.click(apply_detection, inputs=[upload_image_component,slice_width_input,slice_height_input,overlap_width_input,overlap_height_input], outputs=[output_image_component, output_annotations])
643
 
644
 
645
 
 
766
  ).then(
767
  lambda login_state: (
768
  gr.update(visible=login_state), # Show header_row
769
+ gr.update(visible=login_state), # Show slicing text
770
+ gr.update(visible=login_state), # Show slicing_dim_input
771
+ gr.update(visible=login_state), # Show slicing_overlap_input
772
  gr.update(visible=login_state), # Show input_row
773
  gr.update(visible=login_state), # Show area_graph_row
774
  gr.update(visible=login_state), # Show area_btn_row
 
784
  gr.update(visible=not login_state) # for login
785
  ),
786
  inputs=login_successful,
787
+ outputs=[header_row,
788
+ slicing_text,
789
+ slicing_dim_input,
790
+ slicing_overlap_input,
791
  input_row,
792
  area_graph_row,
793
  area_btn_row,