Spaces:
Build error
Build error
Update app.py
Browse filesadded threshold slider
app.py
CHANGED
|
@@ -7,6 +7,8 @@ from PIL import Image
|
|
| 7 |
from ultralytics import YOLO
|
| 8 |
import requests
|
| 9 |
from io import BytesIO
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def save_uploaded_file(uploaded_file):
|
|
@@ -15,6 +17,138 @@ def save_uploaded_file(uploaded_file):
|
|
| 15 |
tmp_file.write(uploaded_file.getbuffer())
|
| 16 |
return tmp_file.name
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def yolo_inference_tool():
|
| 19 |
st.header("YOLO Model Inference Tool")
|
| 20 |
st.write(
|
|
@@ -22,6 +156,12 @@ def yolo_inference_tool():
|
|
| 22 |
"You can either upload images or provide an image URL."
|
| 23 |
)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Allow multiple images upload
|
| 26 |
uploaded_files = st.file_uploader(
|
| 27 |
"Upload Images", type=["jpg", "jpeg", "png"], key="inference_images", accept_multiple_files=True
|
|
@@ -89,7 +229,8 @@ def yolo_inference_tool():
|
|
| 89 |
continue
|
| 90 |
|
| 91 |
try:
|
| 92 |
-
|
|
|
|
| 93 |
except Exception as e:
|
| 94 |
st.error(f"Inference error on image {getattr(img_file, 'name', 'Unknown')}: {e}")
|
| 95 |
continue
|
|
@@ -100,10 +241,14 @@ def yolo_inference_tool():
|
|
| 100 |
# Get inference time from r.speed, if available
|
| 101 |
inference_time = r.speed.get('inference', None) if isinstance(r.speed, dict) else None
|
| 102 |
# Compute detection count and average confidence if detections exist
|
| 103 |
-
if r.boxes is not None and len(r.boxes) > 0:
|
| 104 |
detection_count = len(r.boxes)
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
else:
|
| 108 |
detection_count = 0
|
| 109 |
avg_conf = 0.0
|
|
@@ -117,20 +262,44 @@ def yolo_inference_tool():
|
|
| 117 |
|
| 118 |
eta_placeholder.empty()
|
| 119 |
|
| 120 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
st.subheader("Inference Metrics")
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
-
# Display annotated images using
|
| 127 |
st.subheader("Annotated Images")
|
| 128 |
-
for img_name, r in
|
| 129 |
try:
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
except Exception as e:
|
| 133 |
st.error(f"Error generating annotated image for {img_name}: {e}")
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
def yolo_model_comparison_tool():
|
|
@@ -147,6 +316,16 @@ def yolo_model_comparison_tool():
|
|
| 147 |
"and compute a Weighted Score that balances these factors.\n\n"
|
| 148 |
"A progress bar and ETA are shown in real time after you click Submit."
|
| 149 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
images = st.file_uploader("Upload Images", type=["jpg", "jpeg", "png"], key="comparison_images", accept_multiple_files=True)
|
| 152 |
model_files = st.file_uploader("Upload YOLO models (.pt)", type=["pt"], key="comparison_models", accept_multiple_files=True)
|
|
@@ -212,7 +391,8 @@ def yolo_model_comparison_tool():
|
|
| 212 |
|
| 213 |
# Run inference
|
| 214 |
try:
|
| 215 |
-
|
|
|
|
| 216 |
except Exception as e:
|
| 217 |
st.error(f"Inference error for model {model_file.name} on {img_file.name}: {e}")
|
| 218 |
continue
|
|
@@ -225,11 +405,14 @@ def yolo_model_comparison_tool():
|
|
| 225 |
total_inference_time += r.speed["inference"]
|
| 226 |
|
| 227 |
# Count detections & confidence
|
| 228 |
-
if r.boxes is not None:
|
| 229 |
det_count = len(r.boxes)
|
| 230 |
total_detections += det_count
|
| 231 |
if det_count > 0:
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
| 233 |
sum_confidences += confs.sum()
|
| 234 |
total_conf_count += det_count
|
| 235 |
|
|
@@ -254,9 +437,7 @@ def yolo_model_comparison_tool():
|
|
| 254 |
|
| 255 |
# Display aggregated metrics
|
| 256 |
df = pd.DataFrame(model_agg_data.values())
|
| 257 |
-
|
| 258 |
-
st.dataframe(df, use_container_width=True)
|
| 259 |
-
|
| 260 |
# Weighted Scoring with reciprocal-based speed
|
| 261 |
detection_max = df["Total Detections"].max()
|
| 262 |
confidence_max = df["Average Confidence"].max()
|
|
@@ -281,6 +462,25 @@ def yolo_model_comparison_tool():
|
|
| 281 |
gamma_speed * df["Speed Norm"]
|
| 282 |
)
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
st.subheader("Weighted Score Analysis")
|
| 285 |
st.write(f"Weights: Detection={alpha_detection}, Confidence={beta_confidence}, Speed={gamma_speed}")
|
| 286 |
st.dataframe(df[[
|
|
@@ -295,11 +495,6 @@ def yolo_model_comparison_tool():
|
|
| 295 |
"Weighted Score"
|
| 296 |
]], use_container_width=True)
|
| 297 |
|
| 298 |
-
# Identify best overall model (highest Weighted Score)
|
| 299 |
-
best_idx = df["Weighted Score"].idxmax()
|
| 300 |
-
best_model = df.loc[best_idx, "Model File"]
|
| 301 |
-
best_score = df.loc[best_idx, "Weighted Score"]
|
| 302 |
-
|
| 303 |
st.markdown(f"""
|
| 304 |
**Best Overall Model** based on Weighted Score:
|
| 305 |
**{best_model}** (Score: {best_score:.3f}).
|
|
@@ -315,29 +510,47 @@ def yolo_model_comparison_tool():
|
|
| 315 |
- Increase **Speed** weight if you need real‐time inference.
|
| 316 |
""")
|
| 317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
# Display annotated images in a grid (row = image, column = model)
|
| 319 |
st.subheader("Annotated Images Grid (Row = Image, Column = Model)")
|
| 320 |
-
model_names_sorted = sorted(model_agg_data.keys())
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
columns = st.columns(len(model_names_sorted))
|
| 325 |
for col, model_name in zip(columns, model_names_sorted):
|
| 326 |
-
r = model_image_results.get(model_name, {}).get(
|
| 327 |
if r is None:
|
| 328 |
col.write(f"No results for {model_name}")
|
| 329 |
continue
|
| 330 |
|
| 331 |
try:
|
| 332 |
-
#
|
| 333 |
-
|
|
|
|
| 334 |
col.image(
|
| 335 |
-
|
| 336 |
-
caption=f"{model_name}",
|
| 337 |
use_container_width=True
|
| 338 |
)
|
| 339 |
except Exception as e:
|
| 340 |
col.error(f"Error annotating image for {model_name}: {e}")
|
|
|
|
| 341 |
|
| 342 |
def main():
|
| 343 |
st.sidebar.title("Navigation")
|
|
|
|
| 7 |
from ultralytics import YOLO
|
| 8 |
import requests
|
| 9 |
from io import BytesIO
|
| 10 |
+
import copy
|
| 11 |
+
import cv2
|
| 12 |
|
| 13 |
|
| 14 |
def save_uploaded_file(uploaded_file):
|
|
|
|
| 17 |
tmp_file.write(uploaded_file.getbuffer())
|
| 18 |
return tmp_file.name
|
| 19 |
|
| 20 |
+
def apply_confidence_threshold(result, conf_threshold):
|
| 21 |
+
"""Apply confidence threshold by modifying the result's boxes directly."""
|
| 22 |
+
try:
|
| 23 |
+
# If there are no boxes, or the boxes have no confidence values, just return the original image
|
| 24 |
+
if not hasattr(result, 'boxes') or result.boxes is None or len(result.boxes) == 0:
|
| 25 |
+
return Image.fromarray(result.orig_img), 0
|
| 26 |
+
|
| 27 |
+
# Get the confidence values
|
| 28 |
+
if hasattr(result.boxes.conf, "cpu"):
|
| 29 |
+
confs = result.boxes.conf.cpu().numpy()
|
| 30 |
+
else:
|
| 31 |
+
confs = result.boxes.conf
|
| 32 |
+
|
| 33 |
+
# Count valid detections for display purposes
|
| 34 |
+
valid_detections = sum(confs >= conf_threshold)
|
| 35 |
+
|
| 36 |
+
# Create a completely new plot with only the boxes that meet the threshold
|
| 37 |
+
if hasattr(result, 'orig_img'):
|
| 38 |
+
img_with_boxes = result.orig_img.copy()
|
| 39 |
+
else:
|
| 40 |
+
# Fallback to plot method if orig_img is not available
|
| 41 |
+
return Image.fromarray(np.array(result.plot(conf=conf_threshold))), valid_detections
|
| 42 |
+
|
| 43 |
+
# Only proceed with drawing if there are valid detections
|
| 44 |
+
if valid_detections > 0:
|
| 45 |
+
# Create mask of boxes to keep
|
| 46 |
+
mask = confs >= conf_threshold
|
| 47 |
+
|
| 48 |
+
# For each valid box, draw it on the image
|
| 49 |
+
for i, is_valid in enumerate(mask):
|
| 50 |
+
if is_valid:
|
| 51 |
+
try:
|
| 52 |
+
# Get the box coordinates (handle different formats)
|
| 53 |
+
if hasattr(result.boxes, "xyxy"):
|
| 54 |
+
if hasattr(result.boxes.xyxy, "cpu"):
|
| 55 |
+
box = result.boxes.xyxy[i].cpu().numpy().astype(int)
|
| 56 |
+
else:
|
| 57 |
+
box = result.boxes.xyxy[i].astype(int)
|
| 58 |
+
elif hasattr(result.boxes, "xywh"): # Handle xywh format if that's what's available
|
| 59 |
+
if hasattr(result.boxes.xywh, "cpu"):
|
| 60 |
+
xywh = result.boxes.xywh[i].cpu().numpy().astype(int)
|
| 61 |
+
else:
|
| 62 |
+
xywh = result.boxes.xywh[i].astype(int)
|
| 63 |
+
# Convert xywh to xyxy: [x, y, w, h] -> [x1, y1, x2, y2]
|
| 64 |
+
box = np.array([
|
| 65 |
+
xywh[0] - xywh[2]//2, # x1 = x - w/2
|
| 66 |
+
xywh[1] - xywh[3]//2, # y1 = y - h/2
|
| 67 |
+
xywh[0] + xywh[2]//2, # x2 = x + w/2
|
| 68 |
+
xywh[1] + xywh[3]//2 # y2 = y + h/2
|
| 69 |
+
]).astype(int)
|
| 70 |
+
else:
|
| 71 |
+
# If we can't get box coordinates, skip this box
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
# Get class ID and name
|
| 75 |
+
if hasattr(result.boxes, "cls"):
|
| 76 |
+
if hasattr(result.boxes.cls, "cpu"):
|
| 77 |
+
cls_id = int(result.boxes.cls[i].cpu().item())
|
| 78 |
+
else:
|
| 79 |
+
cls_id = int(result.boxes.cls[i])
|
| 80 |
+
else:
|
| 81 |
+
cls_id = 0 # Default class ID if not available
|
| 82 |
+
|
| 83 |
+
# Get confidence
|
| 84 |
+
conf = confs[i]
|
| 85 |
+
|
| 86 |
+
# Get class name
|
| 87 |
+
if hasattr(result, 'names') and result.names and cls_id in result.names:
|
| 88 |
+
cls_name = result.names[cls_id]
|
| 89 |
+
else:
|
| 90 |
+
cls_name = f"class_{cls_id}"
|
| 91 |
+
|
| 92 |
+
# Make sure box coordinates are within image bounds
|
| 93 |
+
h, w = img_with_boxes.shape[:2]
|
| 94 |
+
box[0] = max(0, min(box[0], w-1))
|
| 95 |
+
box[1] = max(0, min(box[1], h-1))
|
| 96 |
+
box[2] = max(0, min(box[2], w-1))
|
| 97 |
+
box[3] = max(0, min(box[3], h-1))
|
| 98 |
+
|
| 99 |
+
# Draw the box
|
| 100 |
+
color = (0, 255, 0) # Green box
|
| 101 |
+
cv2.rectangle(img_with_boxes, (box[0], box[1]), (box[2], box[3]), color, 2)
|
| 102 |
+
|
| 103 |
+
# Add label with confidence
|
| 104 |
+
label = f"{cls_name} {conf:.2f}"
|
| 105 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 106 |
+
# Calculate text size to place it properly
|
| 107 |
+
text_size = cv2.getTextSize(label, font, 0.6, 2)[0]
|
| 108 |
+
# Ensure label is drawn within image bounds
|
| 109 |
+
text_x = box[0]
|
| 110 |
+
text_y = max(box[1] - 10, text_size[1])
|
| 111 |
+
cv2.putText(img_with_boxes, label, (text_x, text_y), font, 0.6, color, 2)
|
| 112 |
+
except Exception as e:
|
| 113 |
+
# If any error occurs for a specific box, just skip it
|
| 114 |
+
st.error(f"Error processing a detection box: {str(e)}")
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
# Convert back to PIL Image for streamlit display
|
| 118 |
+
img_pil = Image.fromarray(img_with_boxes)
|
| 119 |
+
return img_pil, valid_detections
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
# If anything fails in the custom drawing, fall back to the model's built-in plot method
|
| 123 |
+
try:
|
| 124 |
+
# Try using the built-in plot method with the threshold
|
| 125 |
+
annotated_img = result.plot(conf=conf_threshold)
|
| 126 |
+
if isinstance(annotated_img, np.ndarray):
|
| 127 |
+
img_pil = Image.fromarray(annotated_img)
|
| 128 |
+
else:
|
| 129 |
+
img_pil = annotated_img
|
| 130 |
+
|
| 131 |
+
# Count detections meeting threshold
|
| 132 |
+
if hasattr(result, 'boxes') and result.boxes is not None and len(result.boxes) > 0:
|
| 133 |
+
if hasattr(result.boxes.conf, "cpu"):
|
| 134 |
+
confs = result.boxes.conf.cpu().numpy()
|
| 135 |
+
else:
|
| 136 |
+
confs = result.boxes.conf
|
| 137 |
+
valid_detections = sum(confs >= conf_threshold)
|
| 138 |
+
else:
|
| 139 |
+
valid_detections = 0
|
| 140 |
+
|
| 141 |
+
return img_pil, valid_detections
|
| 142 |
+
except Exception as nested_e:
|
| 143 |
+
# If even the fallback fails, return the original image without annotations
|
| 144 |
+
if hasattr(result, 'orig_img'):
|
| 145 |
+
return Image.fromarray(result.orig_img), 0
|
| 146 |
+
# If we can't even get the original image, create a blank one with error text
|
| 147 |
+
blank_img = np.zeros((400, 600, 3), dtype=np.uint8)
|
| 148 |
+
cv2.putText(blank_img, f"Error: {str(e)}", (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
| 149 |
+
cv2.putText(blank_img, "Could not render annotations", (20, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
| 150 |
+
return Image.fromarray(blank_img), 0
|
| 151 |
+
|
| 152 |
def yolo_inference_tool():
|
| 153 |
st.header("YOLO Model Inference Tool")
|
| 154 |
st.write(
|
|
|
|
| 156 |
"You can either upload images or provide an image URL."
|
| 157 |
)
|
| 158 |
|
| 159 |
+
# Initialize session state for storing inference results
|
| 160 |
+
if 'single_model_results' not in st.session_state:
|
| 161 |
+
st.session_state.single_model_results = None
|
| 162 |
+
if 'single_model_metrics' not in st.session_state:
|
| 163 |
+
st.session_state.single_model_metrics = None
|
| 164 |
+
|
| 165 |
# Allow multiple images upload
|
| 166 |
uploaded_files = st.file_uploader(
|
| 167 |
"Upload Images", type=["jpg", "jpeg", "png"], key="inference_images", accept_multiple_files=True
|
|
|
|
| 229 |
continue
|
| 230 |
|
| 231 |
try:
|
| 232 |
+
# Run inference with the lowest possible confidence to capture all detections
|
| 233 |
+
result = model(np.array(pil_img), conf=0.01)
|
| 234 |
except Exception as e:
|
| 235 |
st.error(f"Inference error on image {getattr(img_file, 'name', 'Unknown')}: {e}")
|
| 236 |
continue
|
|
|
|
| 241 |
# Get inference time from r.speed, if available
|
| 242 |
inference_time = r.speed.get('inference', None) if isinstance(r.speed, dict) else None
|
| 243 |
# Compute detection count and average confidence if detections exist
|
| 244 |
+
if hasattr(r, 'boxes') and r.boxes is not None and len(r.boxes) > 0:
|
| 245 |
detection_count = len(r.boxes)
|
| 246 |
+
if hasattr(r.boxes.conf, "cpu"):
|
| 247 |
+
confs = r.boxes.conf.cpu().numpy()
|
| 248 |
+
avg_conf = float(np.mean(confs))
|
| 249 |
+
else:
|
| 250 |
+
confs = r.boxes.conf
|
| 251 |
+
avg_conf = float(np.mean(confs))
|
| 252 |
else:
|
| 253 |
detection_count = 0
|
| 254 |
avg_conf = 0.0
|
|
|
|
| 262 |
|
| 263 |
eta_placeholder.empty()
|
| 264 |
|
| 265 |
+
# Store results in session state for persistence
|
| 266 |
+
st.session_state.single_model_results = image_results
|
| 267 |
+
st.session_state.single_model_metrics = metrics
|
| 268 |
+
|
| 269 |
+
# Display results if available in session state (either from button click or slider change)
|
| 270 |
+
if st.session_state.single_model_metrics is not None:
|
| 271 |
+
# Display per-image metrics
|
| 272 |
st.subheader("Inference Metrics")
|
| 273 |
+
df_metrics = pd.DataFrame(st.session_state.single_model_metrics)
|
| 274 |
+
st.dataframe(df_metrics, use_container_width=True)
|
| 275 |
+
|
| 276 |
+
# Add a confidence threshold slider
|
| 277 |
+
st.subheader("Confidence Threshold")
|
| 278 |
+
conf_threshold = st.slider(
|
| 279 |
+
"Adjust confidence threshold",
|
| 280 |
+
min_value=0.0,
|
| 281 |
+
max_value=1.0,
|
| 282 |
+
value=0.25, # Default value
|
| 283 |
+
step=0.05,
|
| 284 |
+
key="single_model_conf_threshold"
|
| 285 |
+
)
|
| 286 |
|
| 287 |
+
# Display annotated images using the current threshold
|
| 288 |
st.subheader("Annotated Images")
|
| 289 |
+
for img_name, r in st.session_state.single_model_results.items():
|
| 290 |
try:
|
| 291 |
+
# Apply confidence threshold and get processed image
|
| 292 |
+
processed_img, valid_detections = apply_confidence_threshold(r, conf_threshold)
|
| 293 |
+
|
| 294 |
+
# Display the image
|
| 295 |
+
st.image(
|
| 296 |
+
processed_img,
|
| 297 |
+
caption=f"{img_name} (Threshold: {conf_threshold:.2f}, Detections: {valid_detections})",
|
| 298 |
+
use_container_width=True
|
| 299 |
+
)
|
| 300 |
except Exception as e:
|
| 301 |
st.error(f"Error generating annotated image for {img_name}: {e}")
|
| 302 |
+
st.error(str(e))
|
| 303 |
|
| 304 |
|
| 305 |
def yolo_model_comparison_tool():
|
|
|
|
| 316 |
"and compute a Weighted Score that balances these factors.\n\n"
|
| 317 |
"A progress bar and ETA are shown in real time after you click Submit."
|
| 318 |
)
|
| 319 |
+
|
| 320 |
+
# Initialize session state for storing model comparison results
|
| 321 |
+
if 'model_agg_data' not in st.session_state:
|
| 322 |
+
st.session_state.model_agg_data = None
|
| 323 |
+
if 'model_image_results' not in st.session_state:
|
| 324 |
+
st.session_state.model_image_results = None
|
| 325 |
+
if 'model_metrics_df' not in st.session_state:
|
| 326 |
+
st.session_state.model_metrics_df = None
|
| 327 |
+
if 'best_model_info' not in st.session_state:
|
| 328 |
+
st.session_state.best_model_info = None
|
| 329 |
|
| 330 |
images = st.file_uploader("Upload Images", type=["jpg", "jpeg", "png"], key="comparison_images", accept_multiple_files=True)
|
| 331 |
model_files = st.file_uploader("Upload YOLO models (.pt)", type=["pt"], key="comparison_models", accept_multiple_files=True)
|
|
|
|
| 391 |
|
| 392 |
# Run inference
|
| 393 |
try:
|
| 394 |
+
# Use low confidence to capture all detections
|
| 395 |
+
result = model(np_img, conf=0.01)
|
| 396 |
except Exception as e:
|
| 397 |
st.error(f"Inference error for model {model_file.name} on {img_file.name}: {e}")
|
| 398 |
continue
|
|
|
|
| 405 |
total_inference_time += r.speed["inference"]
|
| 406 |
|
| 407 |
# Count detections & confidence
|
| 408 |
+
if hasattr(r, 'boxes') and r.boxes is not None and len(r.boxes) > 0:
|
| 409 |
det_count = len(r.boxes)
|
| 410 |
total_detections += det_count
|
| 411 |
if det_count > 0:
|
| 412 |
+
if hasattr(r.boxes.conf, "cpu"):
|
| 413 |
+
confs = r.boxes.conf.cpu().numpy()
|
| 414 |
+
else:
|
| 415 |
+
confs = r.boxes.conf
|
| 416 |
sum_confidences += confs.sum()
|
| 417 |
total_conf_count += det_count
|
| 418 |
|
|
|
|
| 437 |
|
| 438 |
# Display aggregated metrics
|
| 439 |
df = pd.DataFrame(model_agg_data.values())
|
| 440 |
+
|
|
|
|
|
|
|
| 441 |
# Weighted Scoring with reciprocal-based speed
|
| 442 |
detection_max = df["Total Detections"].max()
|
| 443 |
confidence_max = df["Average Confidence"].max()
|
|
|
|
| 462 |
gamma_speed * df["Speed Norm"]
|
| 463 |
)
|
| 464 |
|
| 465 |
+
# Identify best overall model (highest Weighted Score)
|
| 466 |
+
best_idx = df["Weighted Score"].idxmax()
|
| 467 |
+
best_model = df.loc[best_idx, "Model File"]
|
| 468 |
+
best_score = df.loc[best_idx, "Weighted Score"]
|
| 469 |
+
|
| 470 |
+
# Store results in session state
|
| 471 |
+
st.session_state.model_agg_data = model_agg_data
|
| 472 |
+
st.session_state.model_image_results = model_image_results
|
| 473 |
+
st.session_state.model_metrics_df = df
|
| 474 |
+
st.session_state.best_model_info = (best_model, best_score)
|
| 475 |
+
|
| 476 |
+
# Display results if available in session state
|
| 477 |
+
if st.session_state.model_metrics_df is not None:
|
| 478 |
+
df = st.session_state.model_metrics_df
|
| 479 |
+
best_model, best_score = st.session_state.best_model_info
|
| 480 |
+
|
| 481 |
+
st.subheader("Aggregated Metrics (Across All Images)")
|
| 482 |
+
st.dataframe(df, use_container_width=True)
|
| 483 |
+
|
| 484 |
st.subheader("Weighted Score Analysis")
|
| 485 |
st.write(f"Weights: Detection={alpha_detection}, Confidence={beta_confidence}, Speed={gamma_speed}")
|
| 486 |
st.dataframe(df[[
|
|
|
|
| 495 |
"Weighted Score"
|
| 496 |
]], use_container_width=True)
|
| 497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
st.markdown(f"""
|
| 499 |
**Best Overall Model** based on Weighted Score:
|
| 500 |
**{best_model}** (Score: {best_score:.3f}).
|
|
|
|
| 510 |
- Increase **Speed** weight if you need real‐time inference.
|
| 511 |
""")
|
| 512 |
|
| 513 |
+
# Add a confidence threshold slider
|
| 514 |
+
st.subheader("Confidence Threshold")
|
| 515 |
+
comp_conf_threshold = st.slider(
|
| 516 |
+
"Adjust confidence threshold for all models",
|
| 517 |
+
min_value=0.0,
|
| 518 |
+
max_value=1.0,
|
| 519 |
+
value=0.25, # Default value
|
| 520 |
+
step=0.05,
|
| 521 |
+
key="multi_model_conf_threshold"
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
# Display annotated images in a grid (row = image, column = model)
|
| 525 |
st.subheader("Annotated Images Grid (Row = Image, Column = Model)")
|
| 526 |
+
model_names_sorted = sorted(st.session_state.model_agg_data.keys())
|
| 527 |
+
|
| 528 |
+
# Extract the image file names from the stored results
|
| 529 |
+
image_names = set()
|
| 530 |
+
for model_results in st.session_state.model_image_results.values():
|
| 531 |
+
image_names.update(model_results.keys())
|
| 532 |
+
|
| 533 |
+
for img_name in sorted(image_names):
|
| 534 |
+
st.markdown(f"### Image: {img_name}")
|
| 535 |
columns = st.columns(len(model_names_sorted))
|
| 536 |
for col, model_name in zip(columns, model_names_sorted):
|
| 537 |
+
r = st.session_state.model_image_results.get(model_name, {}).get(img_name, None)
|
| 538 |
if r is None:
|
| 539 |
col.write(f"No results for {model_name}")
|
| 540 |
continue
|
| 541 |
|
| 542 |
try:
|
| 543 |
+
# Apply confidence threshold and get processed image
|
| 544 |
+
processed_img, valid_detections = apply_confidence_threshold(r, comp_conf_threshold)
|
| 545 |
+
|
| 546 |
col.image(
|
| 547 |
+
processed_img,
|
| 548 |
+
caption=f"{model_name} (Threshold: {comp_conf_threshold:.2f}, Detections: {valid_detections})",
|
| 549 |
use_container_width=True
|
| 550 |
)
|
| 551 |
except Exception as e:
|
| 552 |
col.error(f"Error annotating image for {model_name}: {e}")
|
| 553 |
+
col.error(str(e))
|
| 554 |
|
| 555 |
def main():
|
| 556 |
st.sidebar.title("Navigation")
|