Preyanshz commited on
Commit
84d1822
·
verified ·
1 Parent(s): b52b854

Update app.py

Browse files

added overlapping threshold slider

Files changed (1) hide show
  1. app.py +108 -21
app.py CHANGED
@@ -17,7 +17,7 @@ def save_uploaded_file(uploaded_file):
17
  tmp_file.write(uploaded_file.getbuffer())
18
  return tmp_file.name
19
 
20
- def apply_confidence_threshold(result, conf_threshold):
21
  """Apply confidence threshold by modifying the result's boxes directly."""
22
  try:
23
  # If there are no boxes, or the boxes have no confidence values, just return the original image
@@ -38,14 +38,19 @@ def apply_confidence_threshold(result, conf_threshold):
38
  img_with_boxes = result.orig_img.copy()
39
  else:
40
  # Fallback to plot method if orig_img is not available
41
- return Image.fromarray(np.array(result.plot(conf=conf_threshold))), valid_detections
42
 
43
  # Only proceed with drawing if there are valid detections
44
  if valid_detections > 0:
45
  # Create mask of boxes to keep
46
  mask = confs >= conf_threshold
47
 
48
- # For each valid box, draw it on the image
 
 
 
 
 
49
  for i, is_valid in enumerate(mask):
50
  if is_valid:
51
  try:
@@ -71,7 +76,7 @@ def apply_confidence_threshold(result, conf_threshold):
71
  # If we can't get box coordinates, skip this box
72
  continue
73
 
74
- # Get class ID and name
75
  if hasattr(result.boxes, "cls"):
76
  if hasattr(result.boxes.cls, "cpu"):
77
  cls_id = int(result.boxes.cls[i].cpu().item())
@@ -83,12 +88,59 @@ def apply_confidence_threshold(result, conf_threshold):
83
  # Get confidence
84
  conf = confs[i]
85
 
86
- # Get class name
87
- if hasattr(result, 'names') and result.names and cls_id in result.names:
88
- cls_name = result.names[cls_id]
89
- else:
90
- cls_name = f"class_{cls_id}"
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # Make sure box coordinates are within image bounds
93
  h, w = img_with_boxes.shape[:2]
94
  box[0] = max(0, min(box[0], w-1))
@@ -96,6 +148,12 @@ def apply_confidence_threshold(result, conf_threshold):
96
  box[2] = max(0, min(box[2], w-1))
97
  box[3] = max(0, min(box[3], h-1))
98
 
 
 
 
 
 
 
99
  # Draw the box
100
  color = (0, 255, 0) # Green box
101
  cv2.rectangle(img_with_boxes, (box[0], box[1]), (box[2], box[3]), color, 2)
@@ -109,10 +167,15 @@ def apply_confidence_threshold(result, conf_threshold):
109
  text_x = box[0]
110
  text_y = max(box[1] - 10, text_size[1])
111
  cv2.putText(img_with_boxes, label, (text_x, text_y), font, 0.6, color, 2)
112
- except Exception as e:
113
- # If any error occurs for a specific box, just skip it
114
- st.error(f"Error processing a detection box: {str(e)}")
115
- continue
 
 
 
 
 
116
 
117
  # Convert back to PIL Image for streamlit display
118
  img_pil = Image.fromarray(img_with_boxes)
@@ -122,7 +185,7 @@ def apply_confidence_threshold(result, conf_threshold):
122
  # If anything fails in the custom drawing, fall back to the model's built-in plot method
123
  try:
124
  # Try using the built-in plot method with the threshold
125
- annotated_img = result.plot(conf=conf_threshold)
126
  if isinstance(annotated_img, np.ndarray):
127
  img_pil = Image.fromarray(annotated_img)
128
  else:
@@ -284,17 +347,29 @@ def yolo_inference_tool():
284
  key="single_model_conf_threshold"
285
  )
286
 
287
- # Display annotated images using the current threshold
 
 
 
 
 
 
 
 
 
 
 
 
288
  st.subheader("Annotated Images")
289
  for img_name, r in st.session_state.single_model_results.items():
290
  try:
291
- # Apply confidence threshold and get processed image
292
- processed_img, valid_detections = apply_confidence_threshold(r, conf_threshold)
293
 
294
  # Display the image
295
  st.image(
296
  processed_img,
297
- caption=f"{img_name} (Threshold: {conf_threshold:.2f}, Detections: {valid_detections})",
298
  use_container_width=True
299
  )
300
  except Exception as e:
@@ -520,6 +595,18 @@ def yolo_model_comparison_tool():
520
  step=0.05,
521
  key="multi_model_conf_threshold"
522
  )
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
  # Display annotated images in a grid (row = image, column = model)
525
  st.subheader("Annotated Images Grid (Row = Image, Column = Model)")
@@ -540,12 +627,12 @@ def yolo_model_comparison_tool():
540
  continue
541
 
542
  try:
543
- # Apply confidence threshold and get processed image
544
- processed_img, valid_detections = apply_confidence_threshold(r, comp_conf_threshold)
545
 
546
  col.image(
547
  processed_img,
548
- caption=f"{model_name} (Threshold: {comp_conf_threshold:.2f}, Detections: {valid_detections})",
549
  use_container_width=True
550
  )
551
  except Exception as e:
 
17
  tmp_file.write(uploaded_file.getbuffer())
18
  return tmp_file.name
19
 
20
+ def apply_confidence_threshold(result, conf_threshold, iou_threshold=0.45):
21
  """Apply confidence threshold by modifying the result's boxes directly."""
22
  try:
23
  # If there are no boxes, or the boxes have no confidence values, just return the original image
 
38
  img_with_boxes = result.orig_img.copy()
39
  else:
40
  # Fallback to plot method if orig_img is not available
41
+ return Image.fromarray(np.array(result.plot(conf=conf_threshold, iou=iou_threshold))), valid_detections
42
 
43
  # Only proceed with drawing if there are valid detections
44
  if valid_detections > 0:
45
  # Create mask of boxes to keep
46
  mask = confs >= conf_threshold
47
 
48
+ # Apply non-maximum suppression if we have xyxy boxes
49
+ boxes_to_draw = []
50
+ class_ids = []
51
+ confidences = []
52
+
53
+ # Collect all valid boxes
54
  for i, is_valid in enumerate(mask):
55
  if is_valid:
56
  try:
 
76
  # If we can't get box coordinates, skip this box
77
  continue
78
 
79
+ # Get class ID
80
  if hasattr(result.boxes, "cls"):
81
  if hasattr(result.boxes.cls, "cpu"):
82
  cls_id = int(result.boxes.cls[i].cpu().item())
 
88
  # Get confidence
89
  conf = confs[i]
90
 
91
+ # Add to our collection
92
+ boxes_to_draw.append(box)
93
+ class_ids.append(cls_id)
94
+ confidences.append(conf)
 
95
 
96
+ except Exception as e:
97
+ # If any error occurs for a specific box, just skip it
98
+ st.error(f"Error processing a detection box: {str(e)}")
99
+ continue
100
+
101
+ # Apply non-maximum suppression if we have collected boxes
102
+ if boxes_to_draw:
103
+ try:
104
+ # Convert to numpy arrays
105
+ boxes_np = np.array(boxes_to_draw)
106
+ class_ids_np = np.array(class_ids)
107
+ confidences_np = np.array(confidences)
108
+
109
+ # Apply NMS by class
110
+ unique_classes = np.unique(class_ids_np)
111
+ final_boxes = []
112
+ final_classes = []
113
+ final_confs = []
114
+
115
+ for cls in unique_classes:
116
+ # Get indices for this class
117
+ indices = np.where(class_ids_np == cls)[0]
118
+ if len(indices) == 0:
119
+ continue
120
+
121
+ class_boxes = boxes_np[indices]
122
+ class_confs = confidences_np[indices]
123
+
124
+ # Apply NMS
125
+ keep_indices = cv2.dnn.NMSBoxes(
126
+ class_boxes.tolist(),
127
+ class_confs.tolist(),
128
+ score_threshold=conf_threshold,
129
+ nms_threshold=iou_threshold
130
+ )
131
+
132
+ # Add kept boxes to final lists
133
+ if len(keep_indices) > 0:
134
+ if isinstance(keep_indices[0], np.ndarray): # Handle different return formats
135
+ keep_indices = keep_indices.flatten()
136
+
137
+ for idx in keep_indices:
138
+ final_boxes.append(class_boxes[idx])
139
+ final_classes.append(cls)
140
+ final_confs.append(class_confs[idx])
141
+
142
+ # Now draw only the final boxes
143
+ for box, cls_id, conf in zip(final_boxes, final_classes, final_confs):
144
  # Make sure box coordinates are within image bounds
145
  h, w = img_with_boxes.shape[:2]
146
  box[0] = max(0, min(box[0], w-1))
 
148
  box[2] = max(0, min(box[2], w-1))
149
  box[3] = max(0, min(box[3], h-1))
150
 
151
+ # Get class name
152
+ if hasattr(result, 'names') and result.names and cls_id in result.names:
153
+ cls_name = result.names[cls_id]
154
+ else:
155
+ cls_name = f"class_{cls_id}"
156
+
157
  # Draw the box
158
  color = (0, 255, 0) # Green box
159
  cv2.rectangle(img_with_boxes, (box[0], box[1]), (box[2], box[3]), color, 2)
 
167
  text_x = box[0]
168
  text_y = max(box[1] - 10, text_size[1])
169
  cv2.putText(img_with_boxes, label, (text_x, text_y), font, 0.6, color, 2)
170
+
171
+ # Update valid_detections to reflect NMS results
172
+ valid_detections = len(final_boxes)
173
+
174
+ except Exception as nms_error:
175
+ # If NMS fails, fall back to original drawing code
176
+ st.warning(f"NMS processing failed, falling back to simple filtering: {str(nms_error)}")
177
+ # The original boxes will be drawn in the fallback code
178
+ pass
179
 
180
  # Convert back to PIL Image for streamlit display
181
  img_pil = Image.fromarray(img_with_boxes)
 
185
  # If anything fails in the custom drawing, fall back to the model's built-in plot method
186
  try:
187
  # Try using the built-in plot method with the threshold
188
+ annotated_img = result.plot(conf=conf_threshold, iou=iou_threshold)
189
  if isinstance(annotated_img, np.ndarray):
190
  img_pil = Image.fromarray(annotated_img)
191
  else:
 
347
  key="single_model_conf_threshold"
348
  )
349
 
350
+ # Add IoU threshold slider for NMS
351
+ st.subheader("Overlapping (IoU) Threshold")
352
+ iou_threshold = st.slider(
353
+ "Adjust IoU threshold for non-maximum suppression",
354
+ min_value=0.1,
355
+ max_value=1.0,
356
+ value=0.45, # Default NMS value
357
+ step=0.05,
358
+ key="single_model_iou_threshold",
359
+ help="Higher values allow more overlapping boxes. Lower values keep only the most confident box in overlapping groups."
360
+ )
361
+
362
+ # Display annotated images using the current thresholds
363
  st.subheader("Annotated Images")
364
  for img_name, r in st.session_state.single_model_results.items():
365
  try:
366
+ # Apply confidence and IoU thresholds and get processed image
367
+ processed_img, valid_detections = apply_confidence_threshold(r, conf_threshold, iou_threshold)
368
 
369
  # Display the image
370
  st.image(
371
  processed_img,
372
+ caption=f"{img_name} (Conf: {conf_threshold:.2f}, IoU: {iou_threshold:.2f}, Detections: {valid_detections})",
373
  use_container_width=True
374
  )
375
  except Exception as e:
 
595
  step=0.05,
596
  key="multi_model_conf_threshold"
597
  )
598
+
599
+ # Add IoU threshold slider for NMS
600
+ st.subheader("Overlapping (IoU) Threshold")
601
+ comp_iou_threshold = st.slider(
602
+ "Adjust IoU threshold for non-maximum suppression across all models",
603
+ min_value=0.1,
604
+ max_value=1.0,
605
+ value=0.45, # Default NMS value
606
+ step=0.05,
607
+ key="multi_model_iou_threshold",
608
+ help="Higher values allow more overlapping boxes. Lower values keep only the most confident box in overlapping groups."
609
+ )
610
 
611
  # Display annotated images in a grid (row = image, column = model)
612
  st.subheader("Annotated Images Grid (Row = Image, Column = Model)")
 
627
  continue
628
 
629
  try:
630
+ # Apply confidence and IoU thresholds and get processed image
631
+ processed_img, valid_detections = apply_confidence_threshold(r, comp_conf_threshold, comp_iou_threshold)
632
 
633
  col.image(
634
  processed_img,
635
+ caption=f"{model_name} (Conf: {comp_conf_threshold:.2f}, IoU: {comp_iou_threshold:.2f}, Det: {valid_detections})",
636
  use_container_width=True
637
  )
638
  except Exception as e: