Preyanshz commited on
Commit
b34eb0c
·
verified ·
1 Parent(s): 84d1822

Update app.py

Browse files

fixed the overlapping threshold slider.

Files changed (1) hide show
  1. app.py +207 -152
app.py CHANGED
@@ -23,175 +23,228 @@ def apply_confidence_threshold(result, conf_threshold, iou_threshold=0.45):
23
  # If there are no boxes, or the boxes have no confidence values, just return the original image
24
  if not hasattr(result, 'boxes') or result.boxes is None or len(result.boxes) == 0:
25
  return Image.fromarray(result.orig_img), 0
26
-
27
  # Get the confidence values
28
  if hasattr(result.boxes.conf, "cpu"):
29
  confs = result.boxes.conf.cpu().numpy()
30
  else:
31
  confs = result.boxes.conf
32
-
33
- # Count valid detections for display purposes
34
- valid_detections = sum(confs >= conf_threshold)
35
 
36
  # Create a completely new plot with only the boxes that meet the threshold
37
  if hasattr(result, 'orig_img'):
38
  img_with_boxes = result.orig_img.copy()
39
  else:
40
  # Fallback to plot method if orig_img is not available
41
- return Image.fromarray(np.array(result.plot(conf=conf_threshold, iou=iou_threshold))), valid_detections
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Only proceed with drawing if there are valid detections
44
- if valid_detections > 0:
45
- # Create mask of boxes to keep
46
- mask = confs >= conf_threshold
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # Apply non-maximum suppression if we have xyxy boxes
49
- boxes_to_draw = []
50
- class_ids = []
51
- confidences = []
52
 
53
- # Collect all valid boxes
54
- for i, is_valid in enumerate(mask):
55
- if is_valid:
56
- try:
57
- # Get the box coordinates (handle different formats)
58
- if hasattr(result.boxes, "xyxy"):
59
- if hasattr(result.boxes.xyxy, "cpu"):
60
- box = result.boxes.xyxy[i].cpu().numpy().astype(int)
61
- else:
62
- box = result.boxes.xyxy[i].astype(int)
63
- elif hasattr(result.boxes, "xywh"): # Handle xywh format if that's what's available
64
- if hasattr(result.boxes.xywh, "cpu"):
65
- xywh = result.boxes.xywh[i].cpu().numpy().astype(int)
66
- else:
67
- xywh = result.boxes.xywh[i].astype(int)
68
- # Convert xywh to xyxy: [x, y, w, h] -> [x1, y1, x2, y2]
69
- box = np.array([
70
- xywh[0] - xywh[2]//2, # x1 = x - w/2
71
- xywh[1] - xywh[3]//2, # y1 = y - h/2
72
- xywh[0] + xywh[2]//2, # x2 = x + w/2
73
- xywh[1] + xywh[3]//2 # y2 = y + h/2
74
- ]).astype(int)
75
- else:
76
- # If we can't get box coordinates, skip this box
77
- continue
78
-
79
- # Get class ID
80
- if hasattr(result.boxes, "cls"):
81
- if hasattr(result.boxes.cls, "cpu"):
82
- cls_id = int(result.boxes.cls[i].cpu().item())
83
- else:
84
- cls_id = int(result.boxes.cls[i])
85
- else:
86
- cls_id = 0 # Default class ID if not available
87
-
88
- # Get confidence
89
- conf = confs[i]
90
-
91
- # Add to our collection
92
- boxes_to_draw.append(box)
93
- class_ids.append(cls_id)
94
- confidences.append(conf)
95
-
96
- except Exception as e:
97
- # If any error occurs for a specific box, just skip it
98
- st.error(f"Error processing a detection box: {str(e)}")
99
- continue
100
 
101
- # Apply non-maximum suppression if we have collected boxes
102
- if boxes_to_draw:
103
- try:
104
- # Convert to numpy arrays
105
- boxes_np = np.array(boxes_to_draw)
106
- class_ids_np = np.array(class_ids)
107
- confidences_np = np.array(confidences)
108
-
109
- # Apply NMS by class
110
- unique_classes = np.unique(class_ids_np)
111
- final_boxes = []
112
- final_classes = []
113
- final_confs = []
114
-
115
- for cls in unique_classes:
116
- # Get indices for this class
117
- indices = np.where(class_ids_np == cls)[0]
118
- if len(indices) == 0:
119
- continue
120
-
121
- class_boxes = boxes_np[indices]
122
- class_confs = confidences_np[indices]
123
-
124
- # Apply NMS
125
- keep_indices = cv2.dnn.NMSBoxes(
126
- class_boxes.tolist(),
127
- class_confs.tolist(),
128
- score_threshold=conf_threshold,
129
- nms_threshold=iou_threshold
130
- )
131
-
132
- # Add kept boxes to final lists
133
- if len(keep_indices) > 0:
134
- if isinstance(keep_indices[0], np.ndarray): # Handle different return formats
135
- keep_indices = keep_indices.flatten()
136
-
137
- for idx in keep_indices:
138
- final_boxes.append(class_boxes[idx])
139
- final_classes.append(cls)
140
- final_confs.append(class_confs[idx])
141
-
142
- # Now draw only the final boxes
143
- for box, cls_id, conf in zip(final_boxes, final_classes, final_confs):
144
- # Make sure box coordinates are within image bounds
145
- h, w = img_with_boxes.shape[:2]
146
- box[0] = max(0, min(box[0], w-1))
147
- box[1] = max(0, min(box[1], h-1))
148
- box[2] = max(0, min(box[2], w-1))
149
- box[3] = max(0, min(box[3], h-1))
150
-
151
- # Get class name
152
- if hasattr(result, 'names') and result.names and cls_id in result.names:
153
- cls_name = result.names[cls_id]
154
- else:
155
- cls_name = f"class_{cls_id}"
156
-
157
- # Draw the box
158
- color = (0, 255, 0) # Green box
159
- cv2.rectangle(img_with_boxes, (box[0], box[1]), (box[2], box[3]), color, 2)
160
-
161
- # Add label with confidence
162
- label = f"{cls_name} {conf:.2f}"
163
- font = cv2.FONT_HERSHEY_SIMPLEX
164
- # Calculate text size to place it properly
165
- text_size = cv2.getTextSize(label, font, 0.6, 2)[0]
166
- # Ensure label is drawn within image bounds
167
- text_x = box[0]
168
- text_y = max(box[1] - 10, text_size[1])
169
- cv2.putText(img_with_boxes, label, (text_x, text_y), font, 0.6, color, 2)
170
-
171
- # Update valid_detections to reflect NMS results
172
- valid_detections = len(final_boxes)
173
-
174
- except Exception as nms_error:
175
- # If NMS fails, fall back to original drawing code
176
- st.warning(f"NMS processing failed, falling back to simple filtering: {str(nms_error)}")
177
- # The original boxes will be drawn in the fallback code
178
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- # Convert back to PIL Image for streamlit display
181
- img_pil = Image.fromarray(img_with_boxes)
182
- return img_pil, valid_detections
183
-
184
  except Exception as e:
185
- # If anything fails in the custom drawing, fall back to the model's built-in plot method
186
  try:
187
- # Try using the built-in plot method with the threshold
188
- annotated_img = result.plot(conf=conf_threshold, iou=iou_threshold)
 
 
 
 
 
189
  if isinstance(annotated_img, np.ndarray):
190
  img_pil = Image.fromarray(annotated_img)
191
  else:
192
  img_pil = annotated_img
193
 
194
- # Count detections meeting threshold
195
  if hasattr(result, 'boxes') and result.boxes is not None and len(result.boxes) > 0:
196
  if hasattr(result.boxes.conf, "cpu"):
197
  confs = result.boxes.conf.cpu().numpy()
@@ -202,11 +255,13 @@ def apply_confidence_threshold(result, conf_threshold, iou_threshold=0.45):
202
  valid_detections = 0
203
 
204
  return img_pil, valid_detections
 
205
  except Exception as nested_e:
206
- # If even the fallback fails, return the original image without annotations
207
  if hasattr(result, 'orig_img'):
208
  return Image.fromarray(result.orig_img), 0
209
- # If we can't even get the original image, create a blank one with error text
 
210
  blank_img = np.zeros((400, 600, 3), dtype=np.uint8)
211
  cv2.putText(blank_img, f"Error: {str(e)}", (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
212
  cv2.putText(blank_img, "Could not render annotations", (20, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
@@ -351,12 +406,12 @@ def yolo_inference_tool():
351
  st.subheader("Overlapping (IoU) Threshold")
352
  iou_threshold = st.slider(
353
  "Adjust IoU threshold for non-maximum suppression",
354
- min_value=0.1,
355
  max_value=1.0,
356
  value=0.45, # Default NMS value
357
  step=0.05,
358
  key="single_model_iou_threshold",
359
- help="Higher values allow more overlapping boxes. Lower values keep only the most confident box in overlapping groups."
360
  )
361
 
362
  # Display annotated images using the current thresholds
@@ -600,12 +655,12 @@ def yolo_model_comparison_tool():
600
  st.subheader("Overlapping (IoU) Threshold")
601
  comp_iou_threshold = st.slider(
602
  "Adjust IoU threshold for non-maximum suppression across all models",
603
- min_value=0.1,
604
  max_value=1.0,
605
  value=0.45, # Default NMS value
606
  step=0.05,
607
  key="multi_model_iou_threshold",
608
- help="Higher values allow more overlapping boxes. Lower values keep only the most confident box in overlapping groups."
609
  )
610
 
611
  # Display annotated images in a grid (row = image, column = model)
 
23
  # If there are no boxes, or the boxes have no confidence values, just return the original image
24
  if not hasattr(result, 'boxes') or result.boxes is None or len(result.boxes) == 0:
25
  return Image.fromarray(result.orig_img), 0
26
+
27
  # Get the confidence values
28
  if hasattr(result.boxes.conf, "cpu"):
29
  confs = result.boxes.conf.cpu().numpy()
30
  else:
31
  confs = result.boxes.conf
32
+
33
+ # First filter by confidence threshold
34
+ conf_mask = confs >= conf_threshold
35
 
36
  # Create a completely new plot with only the boxes that meet the threshold
37
  if hasattr(result, 'orig_img'):
38
  img_with_boxes = result.orig_img.copy()
39
  else:
40
  # Fallback to plot method if orig_img is not available
41
+ try:
42
+ # First try the combined approach
43
+ return Image.fromarray(np.array(result.plot(conf=conf_threshold, iou=iou_threshold))), sum(conf_mask)
44
+ except:
45
+ # Fallback to just confidence if iou param is not supported
46
+ return Image.fromarray(np.array(result.plot(conf=conf_threshold))), sum(conf_mask)
47
+
48
+ # Collect all boxes that meet confidence threshold
49
+ filtered_boxes = []
50
+ filtered_classes = []
51
+ filtered_confs = []
52
+
53
+ for i in range(len(confs)):
54
+ if confs[i] < conf_threshold:
55
+ continue
56
+
57
+ try:
58
+ # Get the box coordinates (handle different formats)
59
+ if hasattr(result.boxes, "xyxy"):
60
+ if hasattr(result.boxes.xyxy, "cpu"):
61
+ box = result.boxes.xyxy[i].cpu().numpy().astype(float)
62
+ else:
63
+ box = result.boxes.xyxy[i].astype(float)
64
+ elif hasattr(result.boxes, "xywh"):
65
+ if hasattr(result.boxes.xywh, "cpu"):
66
+ xywh = result.boxes.xywh[i].cpu().numpy().astype(float)
67
+ else:
68
+ xywh = result.boxes.xywh[i].astype(float)
69
+ box = np.array([
70
+ xywh[0] - xywh[2]/2, # x1 = x - w/2
71
+ xywh[1] - xywh[3]/2, # y1 = y - h/2
72
+ xywh[0] + xywh[2]/2, # x2 = x + w/2
73
+ xywh[1] + xywh[3]/2 # y2 = y + h/2
74
+ ]).astype(float)
75
+ else:
76
+ continue # Skip if no box format available
77
+
78
+ # Get class ID
79
+ if hasattr(result.boxes, "cls"):
80
+ if hasattr(result.boxes.cls, "cpu"):
81
+ cls_id = int(result.boxes.cls[i].cpu().item())
82
+ else:
83
+ cls_id = int(result.boxes.cls[i])
84
+ else:
85
+ cls_id = 0 # Default class ID if not available
86
+
87
+ # Store the box, class, and confidence
88
+ filtered_boxes.append(box)
89
+ filtered_classes.append(cls_id)
90
+ filtered_confs.append(confs[i])
91
+
92
+ except Exception as e:
93
+ st.error(f"Error processing detection box: {str(e)}")
94
+ continue
95
+
96
+ if not filtered_boxes:
97
+ # No boxes passed the confidence threshold
98
+ return Image.fromarray(img_with_boxes), 0
99
+
100
+ # Convert to numpy arrays for processing
101
+ boxes_array = np.array(filtered_boxes)
102
+ classes_array = np.array(filtered_classes)
103
+ confs_array = np.array(filtered_confs)
104
 
105
+ # Get unique classes for per-class NMS
106
+ unique_classes = np.unique(classes_array)
107
+
108
+ # Final boxes to draw after NMS
109
+ final_boxes = []
110
+ final_classes = []
111
+ final_confs = []
112
+
113
+ # Helper function to calculate IoU between two boxes
114
+ def calculate_iou(box1, box2):
115
+ # Calculate intersection area
116
+ x1 = max(box1[0], box2[0])
117
+ y1 = max(box1[1], box2[1])
118
+ x2 = min(box1[2], box2[2])
119
+ y2 = min(box1[3], box2[3])
120
 
121
+ if x2 < x1 or y2 < y1:
122
+ return 0.0 # No intersection
 
 
123
 
124
+ intersection_area = (x2 - x1) * (y2 - y1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # Calculate union area
127
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
128
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
129
+ union_area = box1_area + box2_area - intersection_area
130
+
131
+ # Return IoU
132
+ if union_area <= 0:
133
+ return 0.0
134
+ return intersection_area / union_area
135
+
136
+ # Apply NMS per class as shown in the diagram
137
+ for cls in unique_classes:
138
+ # Get all boxes for this class
139
+ class_indices = np.where(classes_array == cls)[0]
140
+ if len(class_indices) == 0:
141
+ continue
142
+
143
+ # Get boxes and scores for this class
144
+ class_boxes = boxes_array[class_indices]
145
+ class_scores = confs_array[class_indices]
146
+
147
+ # We'll keep track of which boxes to keep
148
+ keep_boxes = []
149
+
150
+ # While we still have boxes to process
151
+ while len(class_indices) > 0:
152
+ # Find the box with highest confidence
153
+ max_conf_idx = np.argmax(class_scores)
154
+ max_conf_box = class_boxes[max_conf_idx]
155
+ max_conf = class_scores[max_conf_idx]
156
+
157
+ # Add this box to our final list
158
+ keep_boxes.append(class_indices[max_conf_idx])
159
+
160
+ # Remove this box from consideration
161
+ class_boxes = np.delete(class_boxes, max_conf_idx, axis=0)
162
+ class_scores = np.delete(class_scores, max_conf_idx)
163
+ class_indices = np.delete(class_indices, max_conf_idx)
164
+
165
+ # If no boxes left, we're done with this class
166
+ if len(class_indices) == 0:
167
+ break
168
+
169
+ # Calculate IoU of the saved box with the rest
170
+ ious = np.array([calculate_iou(max_conf_box, box) for box in class_boxes])
171
+
172
+ # Remove boxes with IoU > threshold
173
+ boxes_to_keep = ious <= iou_threshold
174
+ class_boxes = class_boxes[boxes_to_keep]
175
+ class_scores = class_scores[boxes_to_keep]
176
+ class_indices = class_indices[boxes_to_keep]
177
+
178
+ # Add all kept boxes for this class to our final lists
179
+ for idx in keep_boxes:
180
+ final_boxes.append(filtered_boxes[idx])
181
+ final_classes.append(filtered_classes[idx])
182
+ final_confs.append(filtered_confs[idx])
183
+
184
+ # Count valid detections after NMS
185
+ valid_detections = len(final_boxes)
186
+
187
+ # Draw all final boxes
188
+ for i, (box, cls_id, conf) in enumerate(zip(final_boxes, final_classes, final_confs)):
189
+ # Make sure box coordinates are within image bounds
190
+ h, w = img_with_boxes.shape[:2]
191
+ box[0] = max(0, min(box[0], w-1))
192
+ box[1] = max(0, min(box[1], h-1))
193
+ box[2] = max(0, min(box[2], w-1))
194
+ box[3] = max(0, min(box[3], h-1))
195
+
196
+ # Convert to integers for drawing
197
+ box = box.astype(int)
198
+
199
+ # Get class name
200
+ if hasattr(result, 'names') and result.names and cls_id in result.names:
201
+ cls_name = result.names[cls_id]
202
+ else:
203
+ cls_name = f"class_{cls_id}"
204
+
205
+ # Create a deterministic color based on class ID
206
+ # Fixed color per class for consistency
207
+ color_r = (cls_id * 100 + 50) % 255
208
+ color_g = (cls_id * 50 + 170) % 255
209
+ color_b = (cls_id * 80 + 90) % 255
210
+ color = (color_b, color_g, color_r) # BGR format for OpenCV
211
+
212
+ # Draw rectangle
213
+ cv2.rectangle(img_with_boxes, (box[0], box[1]), (box[2], box[3]), color, 2)
214
+
215
+ # Add label with confidence
216
+ label = f"{cls_name} {conf:.2f}"
217
+ font = cv2.FONT_HERSHEY_SIMPLEX
218
+ text_size = cv2.getTextSize(label, font, 0.5, 2)[0]
219
+
220
+ # Create filled rectangle for text background
221
+ rect_y1 = max(0, box[1] - text_size[1] - 10)
222
+ cv2.rectangle(img_with_boxes, (box[0], rect_y1),
223
+ (box[0] + text_size[0], box[1]), color, -1)
224
+
225
+ # Draw text with white color
226
+ cv2.putText(img_with_boxes, label, (box[0], box[1] - 5),
227
+ font, 0.5, (255, 255, 255), 1)
228
+
229
+ # Return the annotated image and detection count
230
+ return Image.fromarray(img_with_boxes), valid_detections
231
 
 
 
 
 
232
  except Exception as e:
233
+ # If our custom implementation fails, try using the model's built-in plot method
234
  try:
235
+ try:
236
+ # Try with both parameters if supported
237
+ annotated_img = result.plot(conf=conf_threshold, iou=iou_threshold)
238
+ except:
239
+ # Fallback to just confidence parameter
240
+ annotated_img = result.plot(conf=conf_threshold)
241
+
242
  if isinstance(annotated_img, np.ndarray):
243
  img_pil = Image.fromarray(annotated_img)
244
  else:
245
  img_pil = annotated_img
246
 
247
+ # Count detections meeting the confidence threshold
248
  if hasattr(result, 'boxes') and result.boxes is not None and len(result.boxes) > 0:
249
  if hasattr(result.boxes.conf, "cpu"):
250
  confs = result.boxes.conf.cpu().numpy()
 
255
  valid_detections = 0
256
 
257
  return img_pil, valid_detections
258
+
259
  except Exception as nested_e:
260
+ # Last resort: return the original image
261
  if hasattr(result, 'orig_img'):
262
  return Image.fromarray(result.orig_img), 0
263
+
264
+ # If even that fails, create a blank image with error message
265
  blank_img = np.zeros((400, 600, 3), dtype=np.uint8)
266
  cv2.putText(blank_img, f"Error: {str(e)}", (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
267
  cv2.putText(blank_img, "Could not render annotations", (20, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
 
406
  st.subheader("Overlapping (IoU) Threshold")
407
  iou_threshold = st.slider(
408
  "Adjust IoU threshold for non-maximum suppression",
409
+ min_value=0.0,
410
  max_value=1.0,
411
  value=0.45, # Default NMS value
412
  step=0.05,
413
  key="single_model_iou_threshold",
414
+ help="Controls how overlapping boxes are filtered. Lower values (0.1-0.3) remove more overlapping boxes, higher values (0.7-0.9) allow more overlaps. The standard YOLO default is 0.45."
415
  )
416
 
417
  # Display annotated images using the current thresholds
 
655
  st.subheader("Overlapping (IoU) Threshold")
656
  comp_iou_threshold = st.slider(
657
  "Adjust IoU threshold for non-maximum suppression across all models",
658
+ min_value=0.0,
659
  max_value=1.0,
660
  value=0.45, # Default NMS value
661
  step=0.05,
662
  key="multi_model_iou_threshold",
663
+ help="Controls how overlapping boxes are filtered. Lower values (0.1-0.3) remove more overlapping boxes, higher values (0.7-0.9) allow more overlaps. The standard YOLO default is 0.45."
664
  )
665
 
666
  # Display annotated images in a grid (row = image, column = model)