Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import tempfile | |
| import pandas as pd | |
| import numpy as np | |
| import time | |
| from PIL import Image | |
| from ultralytics import YOLO | |
| import requests | |
| from io import BytesIO | |
| import copy | |
| import cv2 | |
| def save_uploaded_file(uploaded_file): | |
| """Save an uploaded file to a temporary file and return its path.""" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=uploaded_file.name) as tmp_file: | |
| tmp_file.write(uploaded_file.getbuffer()) | |
| return tmp_file.name | |
| def apply_confidence_threshold(result, conf_threshold, iou_threshold=0.45): | |
| """Apply confidence threshold by modifying the result's boxes directly.""" | |
| try: | |
| # If there are no boxes, or the boxes have no confidence values, just return the original image | |
| if not hasattr(result, 'boxes') or result.boxes is None or len(result.boxes) == 0: | |
| return Image.fromarray(result.orig_img), 0 | |
| # Get the confidence values | |
| if hasattr(result.boxes.conf, "cpu"): | |
| confs = result.boxes.conf.cpu().numpy() | |
| else: | |
| confs = result.boxes.conf | |
| # First filter by confidence threshold | |
| conf_mask = confs >= conf_threshold | |
| # Create a completely new plot with only the boxes that meet the threshold | |
| if hasattr(result, 'orig_img'): | |
| img_with_boxes = result.orig_img.copy() | |
| else: | |
| # Fallback to plot method if orig_img is not available | |
| try: | |
| # First try the combined approach | |
| return Image.fromarray(np.array(result.plot(conf=conf_threshold, iou=iou_threshold))), sum(conf_mask) | |
| except: | |
| # Fallback to just confidence if iou param is not supported | |
| return Image.fromarray(np.array(result.plot(conf=conf_threshold))), sum(conf_mask) | |
| # Collect all boxes that meet confidence threshold | |
| filtered_boxes = [] | |
| filtered_classes = [] | |
| filtered_confs = [] | |
| for i in range(len(confs)): | |
| if confs[i] < conf_threshold: | |
| continue | |
| try: | |
| # Get the box coordinates (handle different formats) | |
| if hasattr(result.boxes, "xyxy"): | |
| if hasattr(result.boxes.xyxy, "cpu"): | |
| box = result.boxes.xyxy[i].cpu().numpy().astype(float) | |
| else: | |
| box = result.boxes.xyxy[i].astype(float) | |
| elif hasattr(result.boxes, "xywh"): | |
| if hasattr(result.boxes.xywh, "cpu"): | |
| xywh = result.boxes.xywh[i].cpu().numpy().astype(float) | |
| else: | |
| xywh = result.boxes.xywh[i].astype(float) | |
| box = np.array([ | |
| xywh[0] - xywh[2]/2, # x1 = x - w/2 | |
| xywh[1] - xywh[3]/2, # y1 = y - h/2 | |
| xywh[0] + xywh[2]/2, # x2 = x + w/2 | |
| xywh[1] + xywh[3]/2 # y2 = y + h/2 | |
| ]).astype(float) | |
| else: | |
| continue # Skip if no box format available | |
| # Get class ID | |
| if hasattr(result.boxes, "cls"): | |
| if hasattr(result.boxes.cls, "cpu"): | |
| cls_id = int(result.boxes.cls[i].cpu().item()) | |
| else: | |
| cls_id = int(result.boxes.cls[i]) | |
| else: | |
| cls_id = 0 # Default class ID if not available | |
| # Store the box, class, and confidence | |
| filtered_boxes.append(box) | |
| filtered_classes.append(cls_id) | |
| filtered_confs.append(confs[i]) | |
| except Exception as e: | |
| st.error(f"Error processing detection box: {str(e)}") | |
| continue | |
| if not filtered_boxes: | |
| # No boxes passed the confidence threshold | |
| return Image.fromarray(img_with_boxes), 0 | |
| # Convert to numpy arrays for processing | |
| boxes_array = np.array(filtered_boxes) | |
| classes_array = np.array(filtered_classes) | |
| confs_array = np.array(filtered_confs) | |
| # Get unique classes for per-class NMS | |
| unique_classes = np.unique(classes_array) | |
| # Final boxes to draw after NMS | |
| final_boxes = [] | |
| final_classes = [] | |
| final_confs = [] | |
| # Helper function to calculate IoU between two boxes | |
| def calculate_iou(box1, box2): | |
| # Calculate intersection area | |
| x1 = max(box1[0], box2[0]) | |
| y1 = max(box1[1], box2[1]) | |
| x2 = min(box1[2], box2[2]) | |
| y2 = min(box1[3], box2[3]) | |
| if x2 < x1 or y2 < y1: | |
| return 0.0 # No intersection | |
| intersection_area = (x2 - x1) * (y2 - y1) | |
| # Calculate union area | |
| box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| union_area = box1_area + box2_area - intersection_area | |
| # Return IoU | |
| if union_area <= 0: | |
| return 0.0 | |
| return intersection_area / union_area | |
| # Apply NMS per class as shown in the diagram | |
| for cls in unique_classes: | |
| # Get all boxes for this class | |
| class_indices = np.where(classes_array == cls)[0] | |
| if len(class_indices) == 0: | |
| continue | |
| # Get boxes and scores for this class | |
| class_boxes = boxes_array[class_indices] | |
| class_scores = confs_array[class_indices] | |
| # We'll keep track of which boxes to keep | |
| keep_boxes = [] | |
| # While we still have boxes to process | |
| while len(class_indices) > 0: | |
| # Find the box with highest confidence | |
| max_conf_idx = np.argmax(class_scores) | |
| max_conf_box = class_boxes[max_conf_idx] | |
| max_conf = class_scores[max_conf_idx] | |
| # Add this box to our final list | |
| keep_boxes.append(class_indices[max_conf_idx]) | |
| # Remove this box from consideration | |
| class_boxes = np.delete(class_boxes, max_conf_idx, axis=0) | |
| class_scores = np.delete(class_scores, max_conf_idx) | |
| class_indices = np.delete(class_indices, max_conf_idx) | |
| # If no boxes left, we're done with this class | |
| if len(class_indices) == 0: | |
| break | |
| # Calculate IoU of the saved box with the rest | |
| ious = np.array([calculate_iou(max_conf_box, box) for box in class_boxes]) | |
| # Remove boxes with IoU > threshold | |
| boxes_to_keep = ious <= iou_threshold | |
| class_boxes = class_boxes[boxes_to_keep] | |
| class_scores = class_scores[boxes_to_keep] | |
| class_indices = class_indices[boxes_to_keep] | |
| # Add all kept boxes for this class to our final lists | |
| for idx in keep_boxes: | |
| final_boxes.append(filtered_boxes[idx]) | |
| final_classes.append(filtered_classes[idx]) | |
| final_confs.append(filtered_confs[idx]) | |
| # Count valid detections after NMS | |
| valid_detections = len(final_boxes) | |
| # Draw all final boxes | |
| for i, (box, cls_id, conf) in enumerate(zip(final_boxes, final_classes, final_confs)): | |
| # Make sure box coordinates are within image bounds | |
| h, w = img_with_boxes.shape[:2] | |
| box[0] = max(0, min(box[0], w-1)) | |
| box[1] = max(0, min(box[1], h-1)) | |
| box[2] = max(0, min(box[2], w-1)) | |
| box[3] = max(0, min(box[3], h-1)) | |
| # Convert to integers for drawing | |
| box = box.astype(int) | |
| # Get class name | |
| if hasattr(result, 'names') and result.names and cls_id in result.names: | |
| cls_name = result.names[cls_id] | |
| else: | |
| cls_name = f"class_{cls_id}" | |
| # Create a deterministic color based on class ID | |
| # Fixed color per class for consistency | |
| color_r = (cls_id * 100 + 50) % 255 | |
| color_g = (cls_id * 50 + 170) % 255 | |
| color_b = (cls_id * 80 + 90) % 255 | |
| color = (color_b, color_g, color_r) # BGR format for OpenCV | |
| # Draw rectangle | |
| cv2.rectangle(img_with_boxes, (box[0], box[1]), (box[2], box[3]), color, 2) | |
| # Add label with confidence | |
| label = f"{cls_name} {conf:.2f}" | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| text_size = cv2.getTextSize(label, font, 0.5, 2)[0] | |
| # Create filled rectangle for text background | |
| rect_y1 = max(0, box[1] - text_size[1] - 10) | |
| cv2.rectangle(img_with_boxes, (box[0], rect_y1), | |
| (box[0] + text_size[0], box[1]), color, -1) | |
| # Draw text with white color | |
| cv2.putText(img_with_boxes, label, (box[0], box[1] - 5), | |
| font, 0.5, (255, 255, 255), 1) | |
| # Return the annotated image and detection count | |
| return Image.fromarray(img_with_boxes), valid_detections | |
| except Exception as e: | |
| # If our custom implementation fails, try using the model's built-in plot method | |
| try: | |
| try: | |
| # Try with both parameters if supported | |
| annotated_img = result.plot(conf=conf_threshold, iou=iou_threshold) | |
| except: | |
| # Fallback to just confidence parameter | |
| annotated_img = result.plot(conf=conf_threshold) | |
| if isinstance(annotated_img, np.ndarray): | |
| img_pil = Image.fromarray(annotated_img) | |
| else: | |
| img_pil = annotated_img | |
| # Count detections meeting the confidence threshold | |
| if hasattr(result, 'boxes') and result.boxes is not None and len(result.boxes) > 0: | |
| if hasattr(result.boxes.conf, "cpu"): | |
| confs = result.boxes.conf.cpu().numpy() | |
| else: | |
| confs = result.boxes.conf | |
| valid_detections = sum(confs >= conf_threshold) | |
| else: | |
| valid_detections = 0 | |
| return img_pil, valid_detections | |
| except Exception as nested_e: | |
| # Last resort: return the original image | |
| if hasattr(result, 'orig_img'): | |
| return Image.fromarray(result.orig_img), 0 | |
| # If even that fails, create a blank image with error message | |
| blank_img = np.zeros((400, 600, 3), dtype=np.uint8) | |
| cv2.putText(blank_img, f"Error: {str(e)}", (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) | |
| cv2.putText(blank_img, "Could not render annotations", (20, 100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) | |
| return Image.fromarray(blank_img), 0 | |
| def yolo_inference_tool(): | |
| st.header("YOLO Model Inference Tool") | |
| st.write( | |
| "Upload one or more images and a YOLO model (.pt) file to run inference and view detailed results. " | |
| "You can either upload images or provide an image URL." | |
| ) | |
| # Initialize session state for storing inference results | |
| if 'single_model_results' not in st.session_state: | |
| st.session_state.single_model_results = None | |
| if 'single_model_metrics' not in st.session_state: | |
| st.session_state.single_model_metrics = None | |
| # Allow multiple images upload | |
| uploaded_files = st.file_uploader( | |
| "Upload Images", type=["jpg", "jpeg", "png"], key="inference_images", accept_multiple_files=True | |
| ) | |
| # Text input for a single image URL (you could expand this to multiple URLs if needed) | |
| url_input = st.text_input("Enter image URL (optional)", key="inference_url") | |
| # Combine uploaded files and URL image into a single list | |
| images = [] | |
| if uploaded_files: | |
| images.extend(uploaded_files) | |
| if url_input and url_input.strip(): | |
| try: | |
| response = requests.get(url_input) | |
| if response.status_code == 200: | |
| image_bytes = BytesIO(response.content) | |
| # Assign a name attribute for consistency | |
| image_bytes.name = url_input | |
| images.append(image_bytes) | |
| else: | |
| st.error("Failed to fetch image from URL.") | |
| except Exception as e: | |
| st.error(f"Error fetching image from URL: {e}") | |
| model_file = st.file_uploader("Upload YOLO model (.pt)", type=["pt"], key="inference_model") | |
| if st.button("Submit (Single-Model Inference)"): | |
| if not images or not model_file: | |
| st.error("Please upload at least one image (or provide an image URL) and a model.") | |
| return | |
| # Save and load the model file | |
| model_path = save_uploaded_file(model_file) | |
| try: | |
| model = YOLO(model_path) | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| return | |
| total_images = len(images) | |
| progress_bar = st.progress(0) | |
| eta_placeholder = st.empty() | |
| start_time = time.time() | |
| steps_done = 0 | |
| # Dictionaries to store inference results and metrics | |
| image_results = {} | |
| metrics = [] | |
| for img_file in images: | |
| steps_done += 1 | |
| fraction_done = steps_done / total_images | |
| progress_bar.progress(fraction_done) | |
| elapsed_time = time.time() - start_time | |
| time_per_step = elapsed_time / steps_done | |
| remaining_steps = total_images - steps_done | |
| eta_seconds = remaining_steps * time_per_step | |
| eta_placeholder.info(f"Progress: {fraction_done:.1%}. ETA: ~{eta_seconds:.1f} s") | |
| try: | |
| pil_img = Image.open(img_file).convert("RGB") | |
| except Exception as e: | |
| st.error(f"Error reading image {getattr(img_file, 'name', 'Unknown')}: {e}") | |
| continue | |
| try: | |
| # Run inference with the lowest possible confidence to capture all detections | |
| result = model(np.array(pil_img), conf=0.01) | |
| except Exception as e: | |
| st.error(f"Inference error on image {getattr(img_file, 'name', 'Unknown')}: {e}") | |
| continue | |
| r = result[0] | |
| image_results[getattr(img_file, 'name', 'Unknown')] = r | |
| # Get inference time from r.speed, if available | |
| inference_time = r.speed.get('inference', None) if isinstance(r.speed, dict) else None | |
| # Compute detection count and average confidence if detections exist | |
| if hasattr(r, 'boxes') and r.boxes is not None and len(r.boxes) > 0: | |
| detection_count = len(r.boxes) | |
| if hasattr(r.boxes.conf, "cpu"): | |
| confs = r.boxes.conf.cpu().numpy() | |
| avg_conf = float(np.mean(confs)) | |
| else: | |
| confs = r.boxes.conf | |
| avg_conf = float(np.mean(confs)) | |
| else: | |
| detection_count = 0 | |
| avg_conf = 0.0 | |
| metrics.append({ | |
| "Image": getattr(img_file, 'name', 'Unknown'), | |
| "Inference Time (ms)": inference_time if inference_time is not None else "N/A", | |
| "Detections": detection_count, | |
| "Average Confidence": f"{avg_conf:.2f}" | |
| }) | |
| eta_placeholder.empty() | |
| # Store results in session state for persistence | |
| st.session_state.single_model_results = image_results | |
| st.session_state.single_model_metrics = metrics | |
| # Display results if available in session state (either from button click or slider change) | |
| if st.session_state.single_model_metrics is not None: | |
| # Display per-image metrics | |
| st.subheader("Inference Metrics") | |
| df_metrics = pd.DataFrame(st.session_state.single_model_metrics) | |
| st.dataframe(df_metrics, use_container_width=True) | |
| # Add a confidence threshold slider | |
| st.subheader("Confidence Threshold") | |
| conf_threshold = st.slider( | |
| "Adjust confidence threshold", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.25, # Default value | |
| step=0.05, | |
| key="single_model_conf_threshold" | |
| ) | |
| # Add IoU threshold slider for NMS | |
| st.subheader("Overlapping (IoU) Threshold") | |
| iou_threshold = st.slider( | |
| "Adjust IoU threshold for non-maximum suppression", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.45, # Default NMS value | |
| step=0.05, | |
| key="single_model_iou_threshold", | |
| help="Controls how overlapping boxes are filtered. Lower values (0.1-0.3) remove more overlapping boxes, higher values (0.7-0.9) allow more overlaps. The standard YOLO default is 0.45." | |
| ) | |
| # Display annotated images using the current thresholds | |
| st.subheader("Annotated Images") | |
| for img_name, r in st.session_state.single_model_results.items(): | |
| try: | |
| # Apply confidence and IoU thresholds and get processed image | |
| processed_img, valid_detections = apply_confidence_threshold(r, conf_threshold, iou_threshold) | |
| # Display the image | |
| st.image( | |
| processed_img, | |
| caption=f"{img_name} (Conf: {conf_threshold:.2f}, IoU: {iou_threshold:.2f}, Detections: {valid_detections})", | |
| use_container_width=True | |
| ) | |
| except Exception as e: | |
| st.error(f"Error generating annotated image for {img_name}: {e}") | |
| st.error(str(e)) | |
| def yolo_model_comparison_tool(): | |
| """ | |
| Multi-model, multi-image comparison subpage, | |
| with Weighted Scoring that uses a reciprocal-based speed metric | |
| and a real-time progress bar + ETA display. | |
| """ | |
| st.header("YOLO Models Comparison Tool (Multi-Image, Weighted Score + Progress Bar)") | |
| st.write( | |
| "Upload **one or more images** and **multiple YOLO model (.pt) files**. " | |
| "Then click **Submit** to run inference across all images with each model. " | |
| "We aggregate metrics (Avg Inference Time, Total Detections, Avg Confidence) " | |
| "and compute a Weighted Score that balances these factors.\n\n" | |
| "A progress bar and ETA are shown in real time after you click Submit." | |
| ) | |
| # Initialize session state for storing model comparison results | |
| if 'model_agg_data' not in st.session_state: | |
| st.session_state.model_agg_data = None | |
| if 'model_image_results' not in st.session_state: | |
| st.session_state.model_image_results = None | |
| if 'model_metrics_df' not in st.session_state: | |
| st.session_state.model_metrics_df = None | |
| if 'best_model_info' not in st.session_state: | |
| st.session_state.best_model_info = None | |
| images = st.file_uploader("Upload Images", type=["jpg", "jpeg", "png"], key="comparison_images", accept_multiple_files=True) | |
| model_files = st.file_uploader("Upload YOLO models (.pt)", type=["pt"], key="comparison_models", accept_multiple_files=True) | |
| # Example weights. You can expose them as sliders if you want user customization. | |
| alpha_detection = 0.4 | |
| beta_confidence = 0.3 | |
| gamma_speed = 0.3 # speed = reciprocal of time | |
| if st.button("Submit (Multi-Model Comparison)"): | |
| if not images or not model_files: | |
| st.error("Please upload at least one image and at least one model.") | |
| return | |
| # Initialize progress tracking | |
| total_inferences = len(images) * len(model_files) | |
| if total_inferences == 0: | |
| st.error("No valid images or models to process.") | |
| return | |
| progress_bar = st.progress(0) | |
| eta_placeholder = st.empty() | |
| start_time = time.time() | |
| steps_done = 0 | |
| # We'll store aggregated metrics here | |
| model_agg_data = {} | |
| # We'll store results for each (model, image) so we can display side-by-side | |
| model_image_results = {m.name: {} for m in model_files} | |
| for model_file in model_files: | |
| model_path = save_uploaded_file(model_file) | |
| try: | |
| model = YOLO(model_path) | |
| except Exception as e: | |
| st.error(f"Error loading model {model_file.name}: {e}") | |
| continue | |
| total_inference_time = 0.0 | |
| total_detections = 0 | |
| sum_confidences = 0.0 | |
| total_conf_count = 0 | |
| for img_file in images: | |
| # Update progress/ETA before processing next image | |
| steps_done += 1 | |
| fraction_done = steps_done / total_inferences | |
| progress_bar.progress(fraction_done) | |
| elapsed_time = time.time() - start_time | |
| time_per_step = elapsed_time / steps_done | |
| remaining_steps = total_inferences - steps_done | |
| eta_seconds = remaining_steps * time_per_step | |
| eta_placeholder.info(f"Progress: {fraction_done:.1%}. ETA: ~{eta_seconds:.1f} s") | |
| # Load image | |
| try: | |
| pil_img = Image.open(img_file).convert("RGB") | |
| np_img = np.array(pil_img) | |
| except Exception as e: | |
| st.error(f"Error reading image {img_file.name}: {e}") | |
| continue | |
| # Run inference | |
| try: | |
| # Use low confidence to capture all detections | |
| result = model(np_img, conf=0.01) | |
| except Exception as e: | |
| st.error(f"Inference error for model {model_file.name} on {img_file.name}: {e}") | |
| continue | |
| r = result[0] | |
| model_image_results[model_file.name][img_file.name] = r | |
| # Accumulate inference time | |
| if isinstance(r.speed, dict) and "inference" in r.speed: | |
| total_inference_time += r.speed["inference"] | |
| # Count detections & confidence | |
| if hasattr(r, 'boxes') and r.boxes is not None and len(r.boxes) > 0: | |
| det_count = len(r.boxes) | |
| total_detections += det_count | |
| if det_count > 0: | |
| if hasattr(r.boxes.conf, "cpu"): | |
| confs = r.boxes.conf.cpu().numpy() | |
| else: | |
| confs = r.boxes.conf | |
| sum_confidences += confs.sum() | |
| total_conf_count += det_count | |
| # After all images for this model | |
| image_count = len(images) | |
| avg_inference_time = total_inference_time / image_count if image_count > 0 else float("inf") | |
| avg_confidence = sum_confidences / total_conf_count if total_conf_count > 0 else 0.0 | |
| model_agg_data[model_file.name] = { | |
| "Model File": model_file.name, | |
| "Avg Inference Time (ms)": avg_inference_time, | |
| "Total Detections": total_detections, | |
| "Average Confidence": avg_confidence | |
| } | |
| if not model_agg_data: | |
| st.write("No valid models processed.") | |
| return | |
| # Now that all inferences are done, remove the ETA info | |
| eta_placeholder.empty() | |
| # Display aggregated metrics | |
| df = pd.DataFrame(model_agg_data.values()) | |
| # Weighted Scoring with reciprocal-based speed | |
| detection_max = df["Total Detections"].max() | |
| confidence_max = df["Average Confidence"].max() | |
| if detection_max == 0: detection_max = 1 | |
| if confidence_max == 0: confidence_max = 1 | |
| df["Detection Norm"] = df["Total Detections"] / detection_max | |
| df["Confidence Norm"] = df["Average Confidence"] / confidence_max | |
| # Convert time to speed = 1 / time, then normalize | |
| eps = 1e-9 | |
| df["Speed Val"] = 1.0 / (df["Avg Inference Time (ms)"] + eps) | |
| max_speed_val = df["Speed Val"].max() if not df["Speed Val"].isnull().all() else 1 | |
| if max_speed_val == 0: | |
| max_speed_val = 1 | |
| df["Speed Norm"] = df["Speed Val"] / max_speed_val | |
| df["Weighted Score"] = ( | |
| alpha_detection * df["Detection Norm"] + | |
| beta_confidence * df["Confidence Norm"] + | |
| gamma_speed * df["Speed Norm"] | |
| ) | |
| # Identify best overall model (highest Weighted Score) | |
| best_idx = df["Weighted Score"].idxmax() | |
| best_model = df.loc[best_idx, "Model File"] | |
| best_score = df.loc[best_idx, "Weighted Score"] | |
| # Store results in session state | |
| st.session_state.model_agg_data = model_agg_data | |
| st.session_state.model_image_results = model_image_results | |
| st.session_state.model_metrics_df = df | |
| st.session_state.best_model_info = (best_model, best_score) | |
| # Display results if available in session state | |
| if st.session_state.model_metrics_df is not None: | |
| df = st.session_state.model_metrics_df | |
| best_model, best_score = st.session_state.best_model_info | |
| st.subheader("Aggregated Metrics (Across All Images)") | |
| st.dataframe(df, use_container_width=True) | |
| st.subheader("Weighted Score Analysis") | |
| st.write(f"Weights: Detection={alpha_detection}, Confidence={beta_confidence}, Speed={gamma_speed}") | |
| st.dataframe(df[[ | |
| "Model File", | |
| "Avg Inference Time (ms)", | |
| "Total Detections", | |
| "Average Confidence", | |
| "Detection Norm", | |
| "Confidence Norm", | |
| "Speed Val", | |
| "Speed Norm", | |
| "Weighted Score" | |
| ]], use_container_width=True) | |
| st.markdown(f""" | |
| **Best Overall Model** based on Weighted Score: | |
| **{best_model}** (Score: {best_score:.3f}). | |
| ### Interpretation: | |
| - **Detection Norm** → fraction of the best detection count. | |
| - **Confidence Norm** → fraction of the highest average confidence. | |
| - **Speed Norm** → fraction of the highest (1/time). The fastest model is near 1; others are a fraction of that speed. | |
| If you find one factor more important, adjust the weights: | |
| - Increase **Detection** weight if you care about finding as many objects as possible. | |
| - Increase **Confidence** weight if you only trust high‐confidence detections. | |
| - Increase **Speed** weight if you need real‐time inference. | |
| """) | |
| # Add a confidence threshold slider | |
| st.subheader("Confidence Threshold") | |
| comp_conf_threshold = st.slider( | |
| "Adjust confidence threshold for all models", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.25, # Default value | |
| step=0.05, | |
| key="multi_model_conf_threshold" | |
| ) | |
| # Add IoU threshold slider for NMS | |
| st.subheader("Overlapping (IoU) Threshold") | |
| comp_iou_threshold = st.slider( | |
| "Adjust IoU threshold for non-maximum suppression across all models", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.45, # Default NMS value | |
| step=0.05, | |
| key="multi_model_iou_threshold", | |
| help="Controls how overlapping boxes are filtered. Lower values (0.1-0.3) remove more overlapping boxes, higher values (0.7-0.9) allow more overlaps. The standard YOLO default is 0.45." | |
| ) | |
| # Display annotated images in a grid (row = image, column = model) | |
| st.subheader("Annotated Images Grid (Row = Image, Column = Model)") | |
| model_names_sorted = sorted(st.session_state.model_agg_data.keys()) | |
| # Extract the image file names from the stored results | |
| image_names = set() | |
| for model_results in st.session_state.model_image_results.values(): | |
| image_names.update(model_results.keys()) | |
| for img_name in sorted(image_names): | |
| st.markdown(f"### Image: {img_name}") | |
| columns = st.columns(len(model_names_sorted)) | |
| for col, model_name in zip(columns, model_names_sorted): | |
| r = st.session_state.model_image_results.get(model_name, {}).get(img_name, None) | |
| if r is None: | |
| col.write(f"No results for {model_name}") | |
| continue | |
| try: | |
| # Apply confidence and IoU thresholds and get processed image | |
| processed_img, valid_detections = apply_confidence_threshold(r, comp_conf_threshold, comp_iou_threshold) | |
| col.image( | |
| processed_img, | |
| caption=f"{model_name} (Conf: {comp_conf_threshold:.2f}, IoU: {comp_iou_threshold:.2f}, Det: {valid_detections})", | |
| use_container_width=True | |
| ) | |
| except Exception as e: | |
| col.error(f"Error annotating image for {model_name}: {e}") | |
| col.error(str(e)) | |
| def main(): | |
| st.sidebar.title("Navigation") | |
| page = st.sidebar.radio("Go to", ("YOLO Model Inference Tool", "YOLO Models Comparison Tool")) | |
| if page == "YOLO Model Inference Tool": | |
| yolo_inference_tool() | |
| else: | |
| yolo_model_comparison_tool() | |
| if __name__ == "__main__": | |
| main() |