import cv2 import numpy as np import torch from PIL import Image from sklearn.metrics import (jaccard_score, f1_score, accuracy_score, precision_score, recall_score) from scipy.spatial.distance import directed_hausdorff # MUST BE FIRST STREAMLIT COMMAND import streamlit as st st.set_page_config( page_title="Advanced Segmentation Metrics Analyzer", page_icon="🧪", layout="wide" ) # Model loading with enhanced error handling @st.cache_resource def load_model(): try: # First try official ultralytics package from ultralytics import YOLO return YOLO('yolov8x-seg.pt') except ImportError: try: # Fallback to torch hub model = torch.hub.load('ultralytics/yolov8', 'yolov8x-seg', pretrained=True) return model.to('cuda' if torch.cuda.is_available() else 'cpu') except Exception as e: st.error(f"⚠️ Model loading failed: {str(e)}") st.info("Please check your internet connection and try again") return None model = load_model() def validate_image(img_array): """Ensure 3-channel RGB format""" if len(img_array.shape) == 2: # Grayscale img_array = cv2.cvtColor(img_array, cv2.COLOR_GRAY2RGB) elif img_array.shape[2] == 4: # RGBA img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2RGB) elif img_array.shape[2] > 3: # Extra channels img_array = img_array[:, :, :3] return img_array def calculate_boundary_iou(mask1, mask2, boundary_width=2): """Calculate Boundary IoU with error handling""" try: kernel = np.ones((boundary_width, boundary_width), np.uint8) boundary1 = cv2.morphologyEx(mask1, cv2.MORPH_GRADIENT, kernel) boundary2 = cv2.morphologyEx(mask2, cv2.MORPH_GRADIENT, kernel) return jaccard_score(boundary1.flatten(), boundary2.flatten()) except Exception: return 0.0 # Graceful degradation def calculate_metrics(results, img_shape): """Robust metric calculation""" if not model: return {"error": "Model not loaded"} if not results or results[0].masks is None: return {"error": "No objects detected"} try: # Process predictions pred_masks = torch.stack([m.data[0] for m in results[0].masks]).cpu().numpy() pred_masks = (pred_masks > 0.5).astype(np.uint8) # Generate mock ground truth gt_masks = np.zeros_like(pred_masks) h, w = img_shape[:2] gt_masks[:, h//4:3*h//4, w//4:3*w//4] = 1 # Initialize metrics metrics = { 'IoU': {'mean': 0, 'per_instance': [], 'class_wise': {}}, 'Dice': 0, 'Pixel_Accuracy': 0, 'Boundary_IoU': 0, 'Object_Counts': {}, 'Class_Distribution': {} } # Calculate per-mask metrics valid_masks = 0 for i, (pred_mask, gt_mask) in enumerate(zip(pred_masks, gt_masks)): try: pred_flat = pred_mask.flatten() gt_flat = gt_mask.flatten() if np.sum(gt_flat) == 0: continue # Core metrics metrics['IoU']['per_instance'].append(jaccard_score(gt_flat, pred_flat)) metrics['Dice'] += f1_score(gt_flat, pred_flat) metrics['Pixel_Accuracy'] += accuracy_score(gt_flat, pred_flat) metrics['Boundary_IoU'] += calculate_boundary_iou(gt_mask, pred_mask) # Class tracking cls = int(results[0].boxes.cls[i]) cls_name = model.names[cls] metrics['Object_Counts'][cls_name] = metrics['Object_Counts'].get(cls_name, 0) + 1 metrics['Class_Distribution'][cls_name] = metrics['Class_Distribution'].get(cls_name, 0) + 1 valid_masks += 1 except Exception: continue # Finalize metrics if valid_masks > 0: metrics['IoU']['mean'] = np.mean(metrics['IoU']['per_instance']) metrics['Dice'] /= valid_masks metrics['Pixel_Accuracy'] /= valid_masks metrics['Boundary_IoU'] /= valid_masks # Class-wise metrics total = sum(metrics['Object_Counts'].values()) metrics['IoU']['class_wise'] = {k: v/total for k, v in metrics['Object_Counts'].items()} return metrics except Exception as e: return {"error": f"Metric calculation failed: {str(e)}"} def visualize_results(img, results): """Generate visualizations with error handling""" try: # Segmentation overlay seg_img = img.copy() if results[0].masks is not None: for mask in results[0].masks: mask_points = mask.xy[0].astype(int) cv2.fillPoly(seg_img, [mask_points], (0, 0, 255, 100)) # Bounding boxes det_img = img.copy() for box, cls, conf in zip(results[0].boxes.xyxy, results[0].boxes.cls, results[0].boxes.conf): x1, y1, x2, y2 = map(int, box) cv2.rectangle(det_img, (x1, y1), (x2, y2), (255, 0, 0), 2) cv2.putText(det_img, f"{model.names[int(cls)]} {conf:.2f}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2) return seg_img, det_img except Exception: return img, img # Fallback to original images def process_image(input_img): """Main processing pipeline""" try: img = np.array(input_img) img = validate_image(img) results = model(img) seg_img, det_img = visualize_results(img, results) metrics = calculate_metrics(results, img.shape) return Image.fromarray(seg_img), Image.fromarray(det_img), metrics except Exception as e: st.error(f"Processing failed: {str(e)}") return None, None, {"error": str(e)} # Main UI def main(): st.title("🧪 Advanced Segmentation Metrics Analyzer") st.markdown(""" Upload an image to analyze object segmentation performance using YOLOv8. The system provides detailed metrics and visualizations. """) with st.sidebar: st.header("Configuration") conf_threshold = st.slider("Confidence Threshold", 0.1, 1.0, 0.5) boundary_width = st.slider("Boundary Width (pixels)", 1, 10, 2) st.markdown("---") st.markdown(f"**Device:** {'GPU 🔥' if torch.cuda.is_available() else 'CPU 🐢'}") uploaded_file = st.file_uploader( "Choose an image", type=["jpg", "jpeg", "png", "bmp"], help="Supports JPG, PNG, BMP formats" ) if uploaded_file: try: img = Image.open(uploaded_file) col1, col2 = st.columns(2) with col1: st.image(img, caption="Original Image", use_column_width=True) if st.button("Analyze", type="primary"): with st.spinner("Processing..."): seg_img, det_img, metrics = process_image(img) if metrics and "error" not in metrics: tabs = st.tabs(["Visual Results", "Metrics Dashboard", "Raw Data"]) with tabs[0]: st.subheader("Segmentation Analysis") cols = st.columns(2) cols[0].image(seg_img, caption="Segmentation Mask", use_column_width=True) cols[1].image(det_img, caption="Detected Objects", use_column_width=True) with tabs[1]: st.subheader("Performance Metrics") st.metric("Mean IoU", f"{metrics['IoU']['mean']:.2%}", help="Intersection over Union") st.metric("Dice Coefficient", f"{metrics['Dice']:.2%}", help="F1 Score for segmentation") st.metric("Pixel Accuracy", f"{metrics['Pixel_Accuracy']:.2%}") st.plotly_chart({ 'data': [{ 'x': list(metrics['Class_Distribution'].keys()), 'y': list(metrics['Class_Distribution'].values()), 'type': 'bar' }], 'layout': {'title': 'Class Distribution'} }) with tabs[2]: st.download_button( "Download Metrics", str(metrics), "metrics.json", "application/json" ) st.json(metrics) elif metrics and "error" in metrics: st.error(metrics["error"]) except Exception as e: st.error(f"Error loading image: {str(e)}") if __name__ == "__main__": main()