Spaces:
Sleeping
Sleeping
Update app.py
Browse filesadded overlapping threshold slider
app.py
CHANGED
|
@@ -17,7 +17,7 @@ def save_uploaded_file(uploaded_file):
|
|
| 17 |
tmp_file.write(uploaded_file.getbuffer())
|
| 18 |
return tmp_file.name
|
| 19 |
|
| 20 |
-
def apply_confidence_threshold(result, conf_threshold):
|
| 21 |
"""Apply confidence threshold by modifying the result's boxes directly."""
|
| 22 |
try:
|
| 23 |
# If there are no boxes, or the boxes have no confidence values, just return the original image
|
|
@@ -38,14 +38,19 @@ def apply_confidence_threshold(result, conf_threshold):
|
|
| 38 |
img_with_boxes = result.orig_img.copy()
|
| 39 |
else:
|
| 40 |
# Fallback to plot method if orig_img is not available
|
| 41 |
-
return Image.fromarray(np.array(result.plot(conf=conf_threshold))), valid_detections
|
| 42 |
|
| 43 |
# Only proceed with drawing if there are valid detections
|
| 44 |
if valid_detections > 0:
|
| 45 |
# Create mask of boxes to keep
|
| 46 |
mask = confs >= conf_threshold
|
| 47 |
|
| 48 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
for i, is_valid in enumerate(mask):
|
| 50 |
if is_valid:
|
| 51 |
try:
|
|
@@ -71,7 +76,7 @@ def apply_confidence_threshold(result, conf_threshold):
|
|
| 71 |
# If we can't get box coordinates, skip this box
|
| 72 |
continue
|
| 73 |
|
| 74 |
-
# Get class ID
|
| 75 |
if hasattr(result.boxes, "cls"):
|
| 76 |
if hasattr(result.boxes.cls, "cpu"):
|
| 77 |
cls_id = int(result.boxes.cls[i].cpu().item())
|
|
@@ -83,12 +88,59 @@ def apply_confidence_threshold(result, conf_threshold):
|
|
| 83 |
# Get confidence
|
| 84 |
conf = confs[i]
|
| 85 |
|
| 86 |
-
#
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
cls_name = f"class_{cls_id}"
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
# Make sure box coordinates are within image bounds
|
| 93 |
h, w = img_with_boxes.shape[:2]
|
| 94 |
box[0] = max(0, min(box[0], w-1))
|
|
@@ -96,6 +148,12 @@ def apply_confidence_threshold(result, conf_threshold):
|
|
| 96 |
box[2] = max(0, min(box[2], w-1))
|
| 97 |
box[3] = max(0, min(box[3], h-1))
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
# Draw the box
|
| 100 |
color = (0, 255, 0) # Green box
|
| 101 |
cv2.rectangle(img_with_boxes, (box[0], box[1]), (box[2], box[3]), color, 2)
|
|
@@ -109,10 +167,15 @@ def apply_confidence_threshold(result, conf_threshold):
|
|
| 109 |
text_x = box[0]
|
| 110 |
text_y = max(box[1] - 10, text_size[1])
|
| 111 |
cv2.putText(img_with_boxes, label, (text_x, text_y), font, 0.6, color, 2)
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# Convert back to PIL Image for streamlit display
|
| 118 |
img_pil = Image.fromarray(img_with_boxes)
|
|
@@ -122,7 +185,7 @@ def apply_confidence_threshold(result, conf_threshold):
|
|
| 122 |
# If anything fails in the custom drawing, fall back to the model's built-in plot method
|
| 123 |
try:
|
| 124 |
# Try using the built-in plot method with the threshold
|
| 125 |
-
annotated_img = result.plot(conf=conf_threshold)
|
| 126 |
if isinstance(annotated_img, np.ndarray):
|
| 127 |
img_pil = Image.fromarray(annotated_img)
|
| 128 |
else:
|
|
@@ -284,17 +347,29 @@ def yolo_inference_tool():
|
|
| 284 |
key="single_model_conf_threshold"
|
| 285 |
)
|
| 286 |
|
| 287 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
st.subheader("Annotated Images")
|
| 289 |
for img_name, r in st.session_state.single_model_results.items():
|
| 290 |
try:
|
| 291 |
-
# Apply confidence
|
| 292 |
-
processed_img, valid_detections = apply_confidence_threshold(r, conf_threshold)
|
| 293 |
|
| 294 |
# Display the image
|
| 295 |
st.image(
|
| 296 |
processed_img,
|
| 297 |
-
caption=f"{img_name} (
|
| 298 |
use_container_width=True
|
| 299 |
)
|
| 300 |
except Exception as e:
|
|
@@ -520,6 +595,18 @@ def yolo_model_comparison_tool():
|
|
| 520 |
step=0.05,
|
| 521 |
key="multi_model_conf_threshold"
|
| 522 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
# Display annotated images in a grid (row = image, column = model)
|
| 525 |
st.subheader("Annotated Images Grid (Row = Image, Column = Model)")
|
|
@@ -540,12 +627,12 @@ def yolo_model_comparison_tool():
|
|
| 540 |
continue
|
| 541 |
|
| 542 |
try:
|
| 543 |
-
# Apply confidence
|
| 544 |
-
processed_img, valid_detections = apply_confidence_threshold(r, comp_conf_threshold)
|
| 545 |
|
| 546 |
col.image(
|
| 547 |
processed_img,
|
| 548 |
-
caption=f"{model_name} (
|
| 549 |
use_container_width=True
|
| 550 |
)
|
| 551 |
except Exception as e:
|
|
|
|
| 17 |
tmp_file.write(uploaded_file.getbuffer())
|
| 18 |
return tmp_file.name
|
| 19 |
|
| 20 |
+
def apply_confidence_threshold(result, conf_threshold, iou_threshold=0.45):
|
| 21 |
"""Apply confidence threshold by modifying the result's boxes directly."""
|
| 22 |
try:
|
| 23 |
# If there are no boxes, or the boxes have no confidence values, just return the original image
|
|
|
|
| 38 |
img_with_boxes = result.orig_img.copy()
|
| 39 |
else:
|
| 40 |
# Fallback to plot method if orig_img is not available
|
| 41 |
+
return Image.fromarray(np.array(result.plot(conf=conf_threshold, iou=iou_threshold))), valid_detections
|
| 42 |
|
| 43 |
# Only proceed with drawing if there are valid detections
|
| 44 |
if valid_detections > 0:
|
| 45 |
# Create mask of boxes to keep
|
| 46 |
mask = confs >= conf_threshold
|
| 47 |
|
| 48 |
+
# Apply non-maximum suppression if we have xyxy boxes
|
| 49 |
+
boxes_to_draw = []
|
| 50 |
+
class_ids = []
|
| 51 |
+
confidences = []
|
| 52 |
+
|
| 53 |
+
# Collect all valid boxes
|
| 54 |
for i, is_valid in enumerate(mask):
|
| 55 |
if is_valid:
|
| 56 |
try:
|
|
|
|
| 76 |
# If we can't get box coordinates, skip this box
|
| 77 |
continue
|
| 78 |
|
| 79 |
+
# Get class ID
|
| 80 |
if hasattr(result.boxes, "cls"):
|
| 81 |
if hasattr(result.boxes.cls, "cpu"):
|
| 82 |
cls_id = int(result.boxes.cls[i].cpu().item())
|
|
|
|
| 88 |
# Get confidence
|
| 89 |
conf = confs[i]
|
| 90 |
|
| 91 |
+
# Add to our collection
|
| 92 |
+
boxes_to_draw.append(box)
|
| 93 |
+
class_ids.append(cls_id)
|
| 94 |
+
confidences.append(conf)
|
|
|
|
| 95 |
|
| 96 |
+
except Exception as e:
|
| 97 |
+
# If any error occurs for a specific box, just skip it
|
| 98 |
+
st.error(f"Error processing a detection box: {str(e)}")
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
# Apply non-maximum suppression if we have collected boxes
|
| 102 |
+
if boxes_to_draw:
|
| 103 |
+
try:
|
| 104 |
+
# Convert to numpy arrays
|
| 105 |
+
boxes_np = np.array(boxes_to_draw)
|
| 106 |
+
class_ids_np = np.array(class_ids)
|
| 107 |
+
confidences_np = np.array(confidences)
|
| 108 |
+
|
| 109 |
+
# Apply NMS by class
|
| 110 |
+
unique_classes = np.unique(class_ids_np)
|
| 111 |
+
final_boxes = []
|
| 112 |
+
final_classes = []
|
| 113 |
+
final_confs = []
|
| 114 |
+
|
| 115 |
+
for cls in unique_classes:
|
| 116 |
+
# Get indices for this class
|
| 117 |
+
indices = np.where(class_ids_np == cls)[0]
|
| 118 |
+
if len(indices) == 0:
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
class_boxes = boxes_np[indices]
|
| 122 |
+
class_confs = confidences_np[indices]
|
| 123 |
+
|
| 124 |
+
# Apply NMS
|
| 125 |
+
keep_indices = cv2.dnn.NMSBoxes(
|
| 126 |
+
class_boxes.tolist(),
|
| 127 |
+
class_confs.tolist(),
|
| 128 |
+
score_threshold=conf_threshold,
|
| 129 |
+
nms_threshold=iou_threshold
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Add kept boxes to final lists
|
| 133 |
+
if len(keep_indices) > 0:
|
| 134 |
+
if isinstance(keep_indices[0], np.ndarray): # Handle different return formats
|
| 135 |
+
keep_indices = keep_indices.flatten()
|
| 136 |
+
|
| 137 |
+
for idx in keep_indices:
|
| 138 |
+
final_boxes.append(class_boxes[idx])
|
| 139 |
+
final_classes.append(cls)
|
| 140 |
+
final_confs.append(class_confs[idx])
|
| 141 |
+
|
| 142 |
+
# Now draw only the final boxes
|
| 143 |
+
for box, cls_id, conf in zip(final_boxes, final_classes, final_confs):
|
| 144 |
# Make sure box coordinates are within image bounds
|
| 145 |
h, w = img_with_boxes.shape[:2]
|
| 146 |
box[0] = max(0, min(box[0], w-1))
|
|
|
|
| 148 |
box[2] = max(0, min(box[2], w-1))
|
| 149 |
box[3] = max(0, min(box[3], h-1))
|
| 150 |
|
| 151 |
+
# Get class name
|
| 152 |
+
if hasattr(result, 'names') and result.names and cls_id in result.names:
|
| 153 |
+
cls_name = result.names[cls_id]
|
| 154 |
+
else:
|
| 155 |
+
cls_name = f"class_{cls_id}"
|
| 156 |
+
|
| 157 |
# Draw the box
|
| 158 |
color = (0, 255, 0) # Green box
|
| 159 |
cv2.rectangle(img_with_boxes, (box[0], box[1]), (box[2], box[3]), color, 2)
|
|
|
|
| 167 |
text_x = box[0]
|
| 168 |
text_y = max(box[1] - 10, text_size[1])
|
| 169 |
cv2.putText(img_with_boxes, label, (text_x, text_y), font, 0.6, color, 2)
|
| 170 |
+
|
| 171 |
+
# Update valid_detections to reflect NMS results
|
| 172 |
+
valid_detections = len(final_boxes)
|
| 173 |
+
|
| 174 |
+
except Exception as nms_error:
|
| 175 |
+
# If NMS fails, fall back to original drawing code
|
| 176 |
+
st.warning(f"NMS processing failed, falling back to simple filtering: {str(nms_error)}")
|
| 177 |
+
# The original boxes will be drawn in the fallback code
|
| 178 |
+
pass
|
| 179 |
|
| 180 |
# Convert back to PIL Image for streamlit display
|
| 181 |
img_pil = Image.fromarray(img_with_boxes)
|
|
|
|
| 185 |
# If anything fails in the custom drawing, fall back to the model's built-in plot method
|
| 186 |
try:
|
| 187 |
# Try using the built-in plot method with the threshold
|
| 188 |
+
annotated_img = result.plot(conf=conf_threshold, iou=iou_threshold)
|
| 189 |
if isinstance(annotated_img, np.ndarray):
|
| 190 |
img_pil = Image.fromarray(annotated_img)
|
| 191 |
else:
|
|
|
|
| 347 |
key="single_model_conf_threshold"
|
| 348 |
)
|
| 349 |
|
| 350 |
+
# Add IoU threshold slider for NMS
|
| 351 |
+
st.subheader("Overlapping (IoU) Threshold")
|
| 352 |
+
iou_threshold = st.slider(
|
| 353 |
+
"Adjust IoU threshold for non-maximum suppression",
|
| 354 |
+
min_value=0.1,
|
| 355 |
+
max_value=1.0,
|
| 356 |
+
value=0.45, # Default NMS value
|
| 357 |
+
step=0.05,
|
| 358 |
+
key="single_model_iou_threshold",
|
| 359 |
+
help="Higher values allow more overlapping boxes. Lower values keep only the most confident box in overlapping groups."
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Display annotated images using the current thresholds
|
| 363 |
st.subheader("Annotated Images")
|
| 364 |
for img_name, r in st.session_state.single_model_results.items():
|
| 365 |
try:
|
| 366 |
+
# Apply confidence and IoU thresholds and get processed image
|
| 367 |
+
processed_img, valid_detections = apply_confidence_threshold(r, conf_threshold, iou_threshold)
|
| 368 |
|
| 369 |
# Display the image
|
| 370 |
st.image(
|
| 371 |
processed_img,
|
| 372 |
+
caption=f"{img_name} (Conf: {conf_threshold:.2f}, IoU: {iou_threshold:.2f}, Detections: {valid_detections})",
|
| 373 |
use_container_width=True
|
| 374 |
)
|
| 375 |
except Exception as e:
|
|
|
|
| 595 |
step=0.05,
|
| 596 |
key="multi_model_conf_threshold"
|
| 597 |
)
|
| 598 |
+
|
| 599 |
+
# Add IoU threshold slider for NMS
|
| 600 |
+
st.subheader("Overlapping (IoU) Threshold")
|
| 601 |
+
comp_iou_threshold = st.slider(
|
| 602 |
+
"Adjust IoU threshold for non-maximum suppression across all models",
|
| 603 |
+
min_value=0.1,
|
| 604 |
+
max_value=1.0,
|
| 605 |
+
value=0.45, # Default NMS value
|
| 606 |
+
step=0.05,
|
| 607 |
+
key="multi_model_iou_threshold",
|
| 608 |
+
help="Higher values allow more overlapping boxes. Lower values keep only the most confident box in overlapping groups."
|
| 609 |
+
)
|
| 610 |
|
| 611 |
# Display annotated images in a grid (row = image, column = model)
|
| 612 |
st.subheader("Annotated Images Grid (Row = Image, Column = Model)")
|
|
|
|
| 627 |
continue
|
| 628 |
|
| 629 |
try:
|
| 630 |
+
# Apply confidence and IoU thresholds and get processed image
|
| 631 |
+
processed_img, valid_detections = apply_confidence_threshold(r, comp_conf_threshold, comp_iou_threshold)
|
| 632 |
|
| 633 |
col.image(
|
| 634 |
processed_img,
|
| 635 |
+
caption=f"{model_name} (Conf: {comp_conf_threshold:.2f}, IoU: {comp_iou_threshold:.2f}, Det: {valid_detections})",
|
| 636 |
use_container_width=True
|
| 637 |
)
|
| 638 |
except Exception as e:
|