""" Interstitial Cell of Cajal Detection and Quantification Tool A Gradio app for detecting and counting cells in microscopy images using YOLO. """ import os import cv2 import torch import gradio as gr import numpy as np from typing import Tuple, Optional from dataclasses import dataclass from ultralytics import YOLO import supervision as sv from PIL import Image from huggingface_hub import snapshot_download import spaces # ============================================================================ # Configuration # ============================================================================ @dataclass class DetectionConfig: """Configuration for detection parameters.""" confidence_threshold: float = 0.1 nms_threshold: float = 0.0 slice_width: int = 1024 slice_height: int = 1024 overlap_width: int = 0 overlap_height: int = 0 annotation_color: sv.Color = sv.Color.RED annotation_thickness: int = 4 # ============================================================================ # Model Management # ============================================================================ class ModelManager: """Manages model loading and inference.""" def __init__(self, repo_id: str = 'edeler/ICC', model_filename: str = 'best.pt'): self.repo_id = repo_id self.model_filename = model_filename self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model = None self._load_model() def _load_model(self): """Download and load the YOLO model.""" try: model_dir = snapshot_download(self.repo_id, local_dir='./models/ICC') model_path = os.path.join(model_dir, self.model_filename) if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found: {model_path}") self.model = YOLO(model_path).to(self.device) print(f"Model loaded successfully on {self.device}") except Exception as e: raise RuntimeError(f"Failed to load model: {str(e)}") def predict(self, image: np.ndarray) -> sv.Detections: """Run inference on an image.""" if self.model is None: raise RuntimeError("Model not loaded") result = self.model(image)[0] return sv.Detections.from_ultralytics(result) # ============================================================================ # Detection Pipeline # ============================================================================ class DetectionPipeline: """Handles the complete detection pipeline.""" def __init__(self, model_manager: ModelManager, config: DetectionConfig): self.model_manager = model_manager self.config = config def _preprocess_image(self, image: np.ndarray) -> np.ndarray: """Convert image to BGR format if needed.""" if len(image.shape) == 2: # Grayscale return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) elif image.shape[-1] == 3: # RGB return cv2.cvtColor(image, cv2.COLOR_RGB2BGR) elif image.shape[-1] == 4: # RGBA return cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) return image def _create_callback(self, confidence_threshold: float): """Create callback function for slice-based inference.""" def callback(image_slice: np.ndarray) -> sv.Detections: detections = self.model_manager.predict(image_slice) return detections[detections.confidence >= confidence_threshold] return callback def detect( self, image: np.ndarray, confidence_threshold: Optional[float] = None, nms_threshold: Optional[float] = None ) -> Tuple[sv.Detections, np.ndarray]: """ Perform detection on an image. Args: image: Input image as numpy array confidence_threshold: Override default confidence threshold nms_threshold: Override default NMS threshold Returns: Tuple of (detections, processed_image) """ # Use provided thresholds or defaults conf_thresh = confidence_threshold if confidence_threshold is not None else self.config.confidence_threshold nms_thresh = nms_threshold if nms_threshold is not None else self.config.nms_threshold # Preprocess image image_bgr = self._preprocess_image(image) # Initialize slicer with callback slicer = sv.InferenceSlicer( callback=self._create_callback(conf_thresh), slice_wh=(self.config.slice_width, self.config.slice_height), overlap_wh=(self.config.overlap_width, self.config.overlap_height) ) # Perform slicing-based inference detections = slicer(image_bgr) # Apply Non-Maximum Suppression detections = detections.with_nms(threshold=nms_thresh, class_agnostic=False) return detections, image_bgr def annotate(self, image: np.ndarray, detections: sv.Detections) -> np.ndarray: """Annotate image with detection results.""" box_annotator = sv.OrientedBoxAnnotator( color=self.config.annotation_color, thickness=self.config.annotation_thickness ) annotated = box_annotator.annotate(scene=image.copy(), detections=detections) return annotated # ============================================================================ # Statistics and Reporting # ============================================================================ class DetectionStats: """Generate statistics from detections.""" @staticmethod def generate_summary(detections: sv.Detections) -> str: """Generate a detailed summary of detections.""" total = len(detections) if total == 0: return "No detections found." summary = [f"**Total Detections:** {total}"] # Confidence statistics if detections.confidence is not None and len(detections.confidence) > 0: avg_conf = np.mean(detections.confidence) min_conf = np.min(detections.confidence) max_conf = np.max(detections.confidence) summary.append(f"\n**Confidence Statistics:**") summary.append(f"- Average: {avg_conf:.3f}") summary.append(f"- Min: {min_conf:.3f}") summary.append(f"- Max: {max_conf:.3f}") # Class distribution (if multiple classes) if detections.class_id is not None and len(np.unique(detections.class_id)) > 1: summary.append(f"\n**Class Distribution:**") unique, counts = np.unique(detections.class_id, return_counts=True) for cls_id, count in zip(unique, counts): summary.append(f"- Class {cls_id}: {count} detections") return "\n".join(summary) # ============================================================================ # Gradio Interface # ============================================================================ class GradioApp: """Gradio interface for the detection app.""" def __init__(self, model_manager: ModelManager, config: DetectionConfig): self.pipeline = DetectionPipeline(model_manager, config) self.stats = DetectionStats() @spaces.GPU def process_image( self, image: Optional[np.ndarray], confidence: float ) -> Tuple[Optional[Image.Image], str]: """ Process image and return annotated result. Args: image: Input image confidence: Confidence threshold Returns: Tuple of (annotated_image, summary_text) """ if image is None: return None, "Please upload an image." try: # Perform detection (NMS threshold from config) detections, image_bgr = self.pipeline.detect( image, confidence_threshold=confidence, nms_threshold=None # Use default from config ) # Annotate image annotated = self.pipeline.annotate(image_bgr, detections) # Convert to RGB for display annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) # Generate summary summary = self.stats.generate_summary(detections) return Image.fromarray(annotated_rgb), summary except Exception as e: return None, f"Error during detection: {str(e)}" def create_interface(self) -> gr.Blocks: """Create and configure the Gradio interface.""" with gr.Blocks(theme=gr.themes.Soft(), title="ICC Detection Tool") as demo: # Header gr.Markdown( """ # 🔬 Interstitial Cell of Cajal Detection and Quantification Tool Upload a microscopy image to automatically detect and count Interstitial Cells of Cajal (ICC). """ ) with gr.Row(): # Left column - Input and controls with gr.Column(scale=1): # Examples section - FIRST gr.Markdown("### 📋 Select an Example Image") # Create input_img component first (needed for Examples) input_img = gr.Image( label="", type="numpy", interactive=True, show_label=False ) example_root = os.path.dirname(__file__) example_images = [ os.path.join(example_root, file) for file in os.listdir(example_root) if file.lower().endswith(('.jpg', '.jpeg', '.png')) ] if example_images: gr.Examples( examples=example_images, inputs=[input_img], ) # Upload section - SECOND gr.Markdown("### 📤 Or Upload Your Own Image") # Detection parameters - COLLAPSIBLE with gr.Accordion("⚙️ Advanced Settings", open=False): confidence_slider = gr.Slider( minimum=0.01, maximum=1.0, value=0.1, step=0.01, label="Confidence Threshold", info="Minimum confidence for a detection" ) # Action buttons with gr.Row(): clear_btn = gr.Button("🗑️ Clear", variant="secondary") detect_btn = gr.Button("🔍 Detect", variant="primary", size="lg") # Right column - Output with gr.Column(scale=1): gr.Markdown("### 📊 Results") output_img = gr.Image( label="Detection Result", interactive=False ) detection_summary = gr.Markdown( value="Results will appear here...", label="Detection Summary" ) # Footer gr.Markdown( """ --- **Note:** This tool uses a YOLO-based model for cell detection with sliced inference for high-resolution images. """ ) # Event handlers def reset_interface(): return None, None, "Results will appear here...", 0.1 clear_btn.click( fn=reset_interface, inputs=None, outputs=[input_img, output_img, detection_summary, confidence_slider] ) detect_btn.click( fn=self.process_image, inputs=[input_img, confidence_slider], outputs=[output_img, detection_summary] ) return demo # ============================================================================ # Main Application Entry Point # ============================================================================ def main(): """Initialize and launch the application.""" try: # Initialize configuration config = DetectionConfig() # Initialize model manager print("Loading model...") model_manager = ModelManager(repo_id='edeler/ICC', model_filename='best.pt') # Create and launch Gradio app print("Initializing interface...") app = GradioApp(model_manager, config) demo = app.create_interface() print("Launching application...") demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True ) except Exception as e: print(f"Failed to start application: {str(e)}") raise if __name__ == "__main__": main()