Preyanshz commited on
Commit
b52b854
·
verified ·
1 Parent(s): a300f83

Update app.py

Browse files

added threshold slider

Files changed (1) hide show
  1. app.py +245 -32
app.py CHANGED
@@ -7,6 +7,8 @@ from PIL import Image
7
  from ultralytics import YOLO
8
  import requests
9
  from io import BytesIO
 
 
10
 
11
 
12
  def save_uploaded_file(uploaded_file):
@@ -15,6 +17,138 @@ def save_uploaded_file(uploaded_file):
15
  tmp_file.write(uploaded_file.getbuffer())
16
  return tmp_file.name
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def yolo_inference_tool():
19
  st.header("YOLO Model Inference Tool")
20
  st.write(
@@ -22,6 +156,12 @@ def yolo_inference_tool():
22
  "You can either upload images or provide an image URL."
23
  )
24
 
 
 
 
 
 
 
25
  # Allow multiple images upload
26
  uploaded_files = st.file_uploader(
27
  "Upload Images", type=["jpg", "jpeg", "png"], key="inference_images", accept_multiple_files=True
@@ -89,7 +229,8 @@ def yolo_inference_tool():
89
  continue
90
 
91
  try:
92
- result = model(np.array(pil_img))
 
93
  except Exception as e:
94
  st.error(f"Inference error on image {getattr(img_file, 'name', 'Unknown')}: {e}")
95
  continue
@@ -100,10 +241,14 @@ def yolo_inference_tool():
100
  # Get inference time from r.speed, if available
101
  inference_time = r.speed.get('inference', None) if isinstance(r.speed, dict) else None
102
  # Compute detection count and average confidence if detections exist
103
- if r.boxes is not None and len(r.boxes) > 0:
104
  detection_count = len(r.boxes)
105
- confs = r.boxes.conf.cpu().numpy() if hasattr(r.boxes.conf, "cpu") else r.boxes.conf
106
- avg_conf = float(np.mean(confs))
 
 
 
 
107
  else:
108
  detection_count = 0
109
  avg_conf = 0.0
@@ -117,20 +262,44 @@ def yolo_inference_tool():
117
 
118
  eta_placeholder.empty()
119
 
120
- # Display per-image metrics if collected
 
 
 
 
 
 
121
  st.subheader("Inference Metrics")
122
- if metrics:
123
- df_metrics = pd.DataFrame(metrics)
124
- st.dataframe(df_metrics, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
125
 
126
- # Display annotated images using pil=True (ensuring RGB)
127
  st.subheader("Annotated Images")
128
- for img_name, r in image_results.items():
129
  try:
130
- annotated_img = r.plot(conf=True, boxes=True, labels=True, pil=True)
131
- st.image(annotated_img, caption=img_name, use_container_width=True)
 
 
 
 
 
 
 
132
  except Exception as e:
133
  st.error(f"Error generating annotated image for {img_name}: {e}")
 
134
 
135
 
136
  def yolo_model_comparison_tool():
@@ -147,6 +316,16 @@ def yolo_model_comparison_tool():
147
  "and compute a Weighted Score that balances these factors.\n\n"
148
  "A progress bar and ETA are shown in real time after you click Submit."
149
  )
 
 
 
 
 
 
 
 
 
 
150
 
151
  images = st.file_uploader("Upload Images", type=["jpg", "jpeg", "png"], key="comparison_images", accept_multiple_files=True)
152
  model_files = st.file_uploader("Upload YOLO models (.pt)", type=["pt"], key="comparison_models", accept_multiple_files=True)
@@ -212,7 +391,8 @@ def yolo_model_comparison_tool():
212
 
213
  # Run inference
214
  try:
215
- result = model(np_img)
 
216
  except Exception as e:
217
  st.error(f"Inference error for model {model_file.name} on {img_file.name}: {e}")
218
  continue
@@ -225,11 +405,14 @@ def yolo_model_comparison_tool():
225
  total_inference_time += r.speed["inference"]
226
 
227
  # Count detections & confidence
228
- if r.boxes is not None:
229
  det_count = len(r.boxes)
230
  total_detections += det_count
231
  if det_count > 0:
232
- confs = r.boxes.conf.cpu().numpy() if hasattr(r.boxes.conf, "cpu") else r.boxes.conf
 
 
 
233
  sum_confidences += confs.sum()
234
  total_conf_count += det_count
235
 
@@ -254,9 +437,7 @@ def yolo_model_comparison_tool():
254
 
255
  # Display aggregated metrics
256
  df = pd.DataFrame(model_agg_data.values())
257
- st.subheader("Aggregated Metrics (Across All Images)")
258
- st.dataframe(df, use_container_width=True)
259
-
260
  # Weighted Scoring with reciprocal-based speed
261
  detection_max = df["Total Detections"].max()
262
  confidence_max = df["Average Confidence"].max()
@@ -281,6 +462,25 @@ def yolo_model_comparison_tool():
281
  gamma_speed * df["Speed Norm"]
282
  )
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  st.subheader("Weighted Score Analysis")
285
  st.write(f"Weights: Detection={alpha_detection}, Confidence={beta_confidence}, Speed={gamma_speed}")
286
  st.dataframe(df[[
@@ -295,11 +495,6 @@ def yolo_model_comparison_tool():
295
  "Weighted Score"
296
  ]], use_container_width=True)
297
 
298
- # Identify best overall model (highest Weighted Score)
299
- best_idx = df["Weighted Score"].idxmax()
300
- best_model = df.loc[best_idx, "Model File"]
301
- best_score = df.loc[best_idx, "Weighted Score"]
302
-
303
  st.markdown(f"""
304
  **Best Overall Model** based on Weighted Score:
305
  **{best_model}** (Score: {best_score:.3f}).
@@ -315,29 +510,47 @@ def yolo_model_comparison_tool():
315
  - Increase **Speed** weight if you need real‐time inference.
316
  """)
317
 
 
 
 
 
 
 
 
 
 
 
 
318
  # Display annotated images in a grid (row = image, column = model)
319
  st.subheader("Annotated Images Grid (Row = Image, Column = Model)")
320
- model_names_sorted = sorted(model_agg_data.keys())
321
-
322
- for img_file in images:
323
- st.markdown(f"### Image: {img_file.name}")
 
 
 
 
 
324
  columns = st.columns(len(model_names_sorted))
325
  for col, model_name in zip(columns, model_names_sorted):
326
- r = model_image_results.get(model_name, {}).get(img_file.name, None)
327
  if r is None:
328
  col.write(f"No results for {model_name}")
329
  continue
330
 
331
  try:
332
- # Return a PIL image in correct RGB color space
333
- annotated_img_pil = r.plot(conf=True, boxes=True, labels=True, pil=True)
 
334
  col.image(
335
- annotated_img_pil,
336
- caption=f"{model_name}",
337
  use_container_width=True
338
  )
339
  except Exception as e:
340
  col.error(f"Error annotating image for {model_name}: {e}")
 
341
 
342
  def main():
343
  st.sidebar.title("Navigation")
 
7
  from ultralytics import YOLO
8
  import requests
9
  from io import BytesIO
10
+ import copy
11
+ import cv2
12
 
13
 
14
  def save_uploaded_file(uploaded_file):
 
17
  tmp_file.write(uploaded_file.getbuffer())
18
  return tmp_file.name
19
 
20
+ def apply_confidence_threshold(result, conf_threshold):
21
+ """Apply confidence threshold by modifying the result's boxes directly."""
22
+ try:
23
+ # If there are no boxes, or the boxes have no confidence values, just return the original image
24
+ if not hasattr(result, 'boxes') or result.boxes is None or len(result.boxes) == 0:
25
+ return Image.fromarray(result.orig_img), 0
26
+
27
+ # Get the confidence values
28
+ if hasattr(result.boxes.conf, "cpu"):
29
+ confs = result.boxes.conf.cpu().numpy()
30
+ else:
31
+ confs = result.boxes.conf
32
+
33
+ # Count valid detections for display purposes
34
+ valid_detections = sum(confs >= conf_threshold)
35
+
36
+ # Create a completely new plot with only the boxes that meet the threshold
37
+ if hasattr(result, 'orig_img'):
38
+ img_with_boxes = result.orig_img.copy()
39
+ else:
40
+ # Fallback to plot method if orig_img is not available
41
+ return Image.fromarray(np.array(result.plot(conf=conf_threshold))), valid_detections
42
+
43
+ # Only proceed with drawing if there are valid detections
44
+ if valid_detections > 0:
45
+ # Create mask of boxes to keep
46
+ mask = confs >= conf_threshold
47
+
48
+ # For each valid box, draw it on the image
49
+ for i, is_valid in enumerate(mask):
50
+ if is_valid:
51
+ try:
52
+ # Get the box coordinates (handle different formats)
53
+ if hasattr(result.boxes, "xyxy"):
54
+ if hasattr(result.boxes.xyxy, "cpu"):
55
+ box = result.boxes.xyxy[i].cpu().numpy().astype(int)
56
+ else:
57
+ box = result.boxes.xyxy[i].astype(int)
58
+ elif hasattr(result.boxes, "xywh"): # Handle xywh format if that's what's available
59
+ if hasattr(result.boxes.xywh, "cpu"):
60
+ xywh = result.boxes.xywh[i].cpu().numpy().astype(int)
61
+ else:
62
+ xywh = result.boxes.xywh[i].astype(int)
63
+ # Convert xywh to xyxy: [x, y, w, h] -> [x1, y1, x2, y2]
64
+ box = np.array([
65
+ xywh[0] - xywh[2]//2, # x1 = x - w/2
66
+ xywh[1] - xywh[3]//2, # y1 = y - h/2
67
+ xywh[0] + xywh[2]//2, # x2 = x + w/2
68
+ xywh[1] + xywh[3]//2 # y2 = y + h/2
69
+ ]).astype(int)
70
+ else:
71
+ # If we can't get box coordinates, skip this box
72
+ continue
73
+
74
+ # Get class ID and name
75
+ if hasattr(result.boxes, "cls"):
76
+ if hasattr(result.boxes.cls, "cpu"):
77
+ cls_id = int(result.boxes.cls[i].cpu().item())
78
+ else:
79
+ cls_id = int(result.boxes.cls[i])
80
+ else:
81
+ cls_id = 0 # Default class ID if not available
82
+
83
+ # Get confidence
84
+ conf = confs[i]
85
+
86
+ # Get class name
87
+ if hasattr(result, 'names') and result.names and cls_id in result.names:
88
+ cls_name = result.names[cls_id]
89
+ else:
90
+ cls_name = f"class_{cls_id}"
91
+
92
+ # Make sure box coordinates are within image bounds
93
+ h, w = img_with_boxes.shape[:2]
94
+ box[0] = max(0, min(box[0], w-1))
95
+ box[1] = max(0, min(box[1], h-1))
96
+ box[2] = max(0, min(box[2], w-1))
97
+ box[3] = max(0, min(box[3], h-1))
98
+
99
+ # Draw the box
100
+ color = (0, 255, 0) # Green box
101
+ cv2.rectangle(img_with_boxes, (box[0], box[1]), (box[2], box[3]), color, 2)
102
+
103
+ # Add label with confidence
104
+ label = f"{cls_name} {conf:.2f}"
105
+ font = cv2.FONT_HERSHEY_SIMPLEX
106
+ # Calculate text size to place it properly
107
+ text_size = cv2.getTextSize(label, font, 0.6, 2)[0]
108
+ # Ensure label is drawn within image bounds
109
+ text_x = box[0]
110
+ text_y = max(box[1] - 10, text_size[1])
111
+ cv2.putText(img_with_boxes, label, (text_x, text_y), font, 0.6, color, 2)
112
+ except Exception as e:
113
+ # If any error occurs for a specific box, just skip it
114
+ st.error(f"Error processing a detection box: {str(e)}")
115
+ continue
116
+
117
+ # Convert back to PIL Image for streamlit display
118
+ img_pil = Image.fromarray(img_with_boxes)
119
+ return img_pil, valid_detections
120
+
121
+ except Exception as e:
122
+ # If anything fails in the custom drawing, fall back to the model's built-in plot method
123
+ try:
124
+ # Try using the built-in plot method with the threshold
125
+ annotated_img = result.plot(conf=conf_threshold)
126
+ if isinstance(annotated_img, np.ndarray):
127
+ img_pil = Image.fromarray(annotated_img)
128
+ else:
129
+ img_pil = annotated_img
130
+
131
+ # Count detections meeting threshold
132
+ if hasattr(result, 'boxes') and result.boxes is not None and len(result.boxes) > 0:
133
+ if hasattr(result.boxes.conf, "cpu"):
134
+ confs = result.boxes.conf.cpu().numpy()
135
+ else:
136
+ confs = result.boxes.conf
137
+ valid_detections = sum(confs >= conf_threshold)
138
+ else:
139
+ valid_detections = 0
140
+
141
+ return img_pil, valid_detections
142
+ except Exception as nested_e:
143
+ # If even the fallback fails, return the original image without annotations
144
+ if hasattr(result, 'orig_img'):
145
+ return Image.fromarray(result.orig_img), 0
146
+ # If we can't even get the original image, create a blank one with error text
147
+ blank_img = np.zeros((400, 600, 3), dtype=np.uint8)
148
+ cv2.putText(blank_img, f"Error: {str(e)}", (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
149
+ cv2.putText(blank_img, "Could not render annotations", (20, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
150
+ return Image.fromarray(blank_img), 0
151
+
152
  def yolo_inference_tool():
153
  st.header("YOLO Model Inference Tool")
154
  st.write(
 
156
  "You can either upload images or provide an image URL."
157
  )
158
 
159
+ # Initialize session state for storing inference results
160
+ if 'single_model_results' not in st.session_state:
161
+ st.session_state.single_model_results = None
162
+ if 'single_model_metrics' not in st.session_state:
163
+ st.session_state.single_model_metrics = None
164
+
165
  # Allow multiple images upload
166
  uploaded_files = st.file_uploader(
167
  "Upload Images", type=["jpg", "jpeg", "png"], key="inference_images", accept_multiple_files=True
 
229
  continue
230
 
231
  try:
232
+ # Run inference with the lowest possible confidence to capture all detections
233
+ result = model(np.array(pil_img), conf=0.01)
234
  except Exception as e:
235
  st.error(f"Inference error on image {getattr(img_file, 'name', 'Unknown')}: {e}")
236
  continue
 
241
  # Get inference time from r.speed, if available
242
  inference_time = r.speed.get('inference', None) if isinstance(r.speed, dict) else None
243
  # Compute detection count and average confidence if detections exist
244
+ if hasattr(r, 'boxes') and r.boxes is not None and len(r.boxes) > 0:
245
  detection_count = len(r.boxes)
246
+ if hasattr(r.boxes.conf, "cpu"):
247
+ confs = r.boxes.conf.cpu().numpy()
248
+ avg_conf = float(np.mean(confs))
249
+ else:
250
+ confs = r.boxes.conf
251
+ avg_conf = float(np.mean(confs))
252
  else:
253
  detection_count = 0
254
  avg_conf = 0.0
 
262
 
263
  eta_placeholder.empty()
264
 
265
+ # Store results in session state for persistence
266
+ st.session_state.single_model_results = image_results
267
+ st.session_state.single_model_metrics = metrics
268
+
269
+ # Display results if available in session state (either from button click or slider change)
270
+ if st.session_state.single_model_metrics is not None:
271
+ # Display per-image metrics
272
  st.subheader("Inference Metrics")
273
+ df_metrics = pd.DataFrame(st.session_state.single_model_metrics)
274
+ st.dataframe(df_metrics, use_container_width=True)
275
+
276
+ # Add a confidence threshold slider
277
+ st.subheader("Confidence Threshold")
278
+ conf_threshold = st.slider(
279
+ "Adjust confidence threshold",
280
+ min_value=0.0,
281
+ max_value=1.0,
282
+ value=0.25, # Default value
283
+ step=0.05,
284
+ key="single_model_conf_threshold"
285
+ )
286
 
287
+ # Display annotated images using the current threshold
288
  st.subheader("Annotated Images")
289
+ for img_name, r in st.session_state.single_model_results.items():
290
  try:
291
+ # Apply confidence threshold and get processed image
292
+ processed_img, valid_detections = apply_confidence_threshold(r, conf_threshold)
293
+
294
+ # Display the image
295
+ st.image(
296
+ processed_img,
297
+ caption=f"{img_name} (Threshold: {conf_threshold:.2f}, Detections: {valid_detections})",
298
+ use_container_width=True
299
+ )
300
  except Exception as e:
301
  st.error(f"Error generating annotated image for {img_name}: {e}")
302
+ st.error(str(e))
303
 
304
 
305
  def yolo_model_comparison_tool():
 
316
  "and compute a Weighted Score that balances these factors.\n\n"
317
  "A progress bar and ETA are shown in real time after you click Submit."
318
  )
319
+
320
+ # Initialize session state for storing model comparison results
321
+ if 'model_agg_data' not in st.session_state:
322
+ st.session_state.model_agg_data = None
323
+ if 'model_image_results' not in st.session_state:
324
+ st.session_state.model_image_results = None
325
+ if 'model_metrics_df' not in st.session_state:
326
+ st.session_state.model_metrics_df = None
327
+ if 'best_model_info' not in st.session_state:
328
+ st.session_state.best_model_info = None
329
 
330
  images = st.file_uploader("Upload Images", type=["jpg", "jpeg", "png"], key="comparison_images", accept_multiple_files=True)
331
  model_files = st.file_uploader("Upload YOLO models (.pt)", type=["pt"], key="comparison_models", accept_multiple_files=True)
 
391
 
392
  # Run inference
393
  try:
394
+ # Use low confidence to capture all detections
395
+ result = model(np_img, conf=0.01)
396
  except Exception as e:
397
  st.error(f"Inference error for model {model_file.name} on {img_file.name}: {e}")
398
  continue
 
405
  total_inference_time += r.speed["inference"]
406
 
407
  # Count detections & confidence
408
+ if hasattr(r, 'boxes') and r.boxes is not None and len(r.boxes) > 0:
409
  det_count = len(r.boxes)
410
  total_detections += det_count
411
  if det_count > 0:
412
+ if hasattr(r.boxes.conf, "cpu"):
413
+ confs = r.boxes.conf.cpu().numpy()
414
+ else:
415
+ confs = r.boxes.conf
416
  sum_confidences += confs.sum()
417
  total_conf_count += det_count
418
 
 
437
 
438
  # Display aggregated metrics
439
  df = pd.DataFrame(model_agg_data.values())
440
+
 
 
441
  # Weighted Scoring with reciprocal-based speed
442
  detection_max = df["Total Detections"].max()
443
  confidence_max = df["Average Confidence"].max()
 
462
  gamma_speed * df["Speed Norm"]
463
  )
464
 
465
+ # Identify best overall model (highest Weighted Score)
466
+ best_idx = df["Weighted Score"].idxmax()
467
+ best_model = df.loc[best_idx, "Model File"]
468
+ best_score = df.loc[best_idx, "Weighted Score"]
469
+
470
+ # Store results in session state
471
+ st.session_state.model_agg_data = model_agg_data
472
+ st.session_state.model_image_results = model_image_results
473
+ st.session_state.model_metrics_df = df
474
+ st.session_state.best_model_info = (best_model, best_score)
475
+
476
+ # Display results if available in session state
477
+ if st.session_state.model_metrics_df is not None:
478
+ df = st.session_state.model_metrics_df
479
+ best_model, best_score = st.session_state.best_model_info
480
+
481
+ st.subheader("Aggregated Metrics (Across All Images)")
482
+ st.dataframe(df, use_container_width=True)
483
+
484
  st.subheader("Weighted Score Analysis")
485
  st.write(f"Weights: Detection={alpha_detection}, Confidence={beta_confidence}, Speed={gamma_speed}")
486
  st.dataframe(df[[
 
495
  "Weighted Score"
496
  ]], use_container_width=True)
497
 
 
 
 
 
 
498
  st.markdown(f"""
499
  **Best Overall Model** based on Weighted Score:
500
  **{best_model}** (Score: {best_score:.3f}).
 
510
  - Increase **Speed** weight if you need real‐time inference.
511
  """)
512
 
513
+ # Add a confidence threshold slider
514
+ st.subheader("Confidence Threshold")
515
+ comp_conf_threshold = st.slider(
516
+ "Adjust confidence threshold for all models",
517
+ min_value=0.0,
518
+ max_value=1.0,
519
+ value=0.25, # Default value
520
+ step=0.05,
521
+ key="multi_model_conf_threshold"
522
+ )
523
+
524
  # Display annotated images in a grid (row = image, column = model)
525
  st.subheader("Annotated Images Grid (Row = Image, Column = Model)")
526
+ model_names_sorted = sorted(st.session_state.model_agg_data.keys())
527
+
528
+ # Extract the image file names from the stored results
529
+ image_names = set()
530
+ for model_results in st.session_state.model_image_results.values():
531
+ image_names.update(model_results.keys())
532
+
533
+ for img_name in sorted(image_names):
534
+ st.markdown(f"### Image: {img_name}")
535
  columns = st.columns(len(model_names_sorted))
536
  for col, model_name in zip(columns, model_names_sorted):
537
+ r = st.session_state.model_image_results.get(model_name, {}).get(img_name, None)
538
  if r is None:
539
  col.write(f"No results for {model_name}")
540
  continue
541
 
542
  try:
543
+ # Apply confidence threshold and get processed image
544
+ processed_img, valid_detections = apply_confidence_threshold(r, comp_conf_threshold)
545
+
546
  col.image(
547
+ processed_img,
548
+ caption=f"{model_name} (Threshold: {comp_conf_threshold:.2f}, Detections: {valid_detections})",
549
  use_container_width=True
550
  )
551
  except Exception as e:
552
  col.error(f"Error annotating image for {model_name}: {e}")
553
+ col.error(str(e))
554
 
555
  def main():
556
  st.sidebar.title("Navigation")