| | """ |
| | Interstitial Cell of Cajal Detection and Quantification Tool |
| | A Gradio app for detecting and counting cells in microscopy images using YOLO. |
| | """ |
| |
|
| | import os |
| | import cv2 |
| | import torch |
| | import gradio as gr |
| | import numpy as np |
| | from typing import Tuple, Optional |
| | from dataclasses import dataclass |
| | from ultralytics import YOLO |
| | import supervision as sv |
| | from PIL import Image |
| | from huggingface_hub import snapshot_download |
| | import spaces |
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class DetectionConfig: |
| | """Configuration for detection parameters.""" |
| | confidence_threshold: float = 0.1 |
| | nms_threshold: float = 0.0 |
| | slice_width: int = 1024 |
| | slice_height: int = 1024 |
| | overlap_width: int = 0 |
| | overlap_height: int = 0 |
| | annotation_color: sv.Color = sv.Color.RED |
| | annotation_thickness: int = 4 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ModelManager: |
| | """Manages model loading and inference.""" |
| | |
| | def __init__(self, repo_id: str = 'edeler/ICC', model_filename: str = 'best.pt'): |
| | self.repo_id = repo_id |
| | self.model_filename = model_filename |
| | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | self.model = None |
| | self._load_model() |
| | |
| | def _load_model(self): |
| | """Download and load the YOLO model.""" |
| | try: |
| | model_dir = snapshot_download(self.repo_id, local_dir='./models/ICC') |
| | model_path = os.path.join(model_dir, self.model_filename) |
| | |
| | if not os.path.exists(model_path): |
| | raise FileNotFoundError(f"Model file not found: {model_path}") |
| | |
| | self.model = YOLO(model_path).to(self.device) |
| | print(f"Model loaded successfully on {self.device}") |
| | except Exception as e: |
| | raise RuntimeError(f"Failed to load model: {str(e)}") |
| | |
| | def predict(self, image: np.ndarray) -> sv.Detections: |
| | """Run inference on an image.""" |
| | if self.model is None: |
| | raise RuntimeError("Model not loaded") |
| | |
| | result = self.model(image)[0] |
| | return sv.Detections.from_ultralytics(result) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class DetectionPipeline: |
| | """Handles the complete detection pipeline.""" |
| | |
| | def __init__(self, model_manager: ModelManager, config: DetectionConfig): |
| | self.model_manager = model_manager |
| | self.config = config |
| | |
| | def _preprocess_image(self, image: np.ndarray) -> np.ndarray: |
| | """Convert image to BGR format if needed.""" |
| | if len(image.shape) == 2: |
| | return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) |
| | elif image.shape[-1] == 3: |
| | return cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
| | elif image.shape[-1] == 4: |
| | return cv2.cvtColor(image, cv2.COLOR_RGBA2BGR) |
| | return image |
| | |
| | def _create_callback(self, confidence_threshold: float): |
| | """Create callback function for slice-based inference.""" |
| | def callback(image_slice: np.ndarray) -> sv.Detections: |
| | detections = self.model_manager.predict(image_slice) |
| | return detections[detections.confidence >= confidence_threshold] |
| | return callback |
| | |
| | def detect( |
| | self, |
| | image: np.ndarray, |
| | confidence_threshold: Optional[float] = None, |
| | nms_threshold: Optional[float] = None |
| | ) -> Tuple[sv.Detections, np.ndarray]: |
| | """ |
| | Perform detection on an image. |
| | |
| | Args: |
| | image: Input image as numpy array |
| | confidence_threshold: Override default confidence threshold |
| | nms_threshold: Override default NMS threshold |
| | |
| | Returns: |
| | Tuple of (detections, processed_image) |
| | """ |
| | |
| | conf_thresh = confidence_threshold if confidence_threshold is not None else self.config.confidence_threshold |
| | nms_thresh = nms_threshold if nms_threshold is not None else self.config.nms_threshold |
| | |
| | |
| | image_bgr = self._preprocess_image(image) |
| | |
| | |
| | slicer = sv.InferenceSlicer( |
| | callback=self._create_callback(conf_thresh), |
| | slice_wh=(self.config.slice_width, self.config.slice_height), |
| | overlap_wh=(self.config.overlap_width, self.config.overlap_height) |
| | ) |
| | |
| | |
| | detections = slicer(image_bgr) |
| | |
| | |
| | detections = detections.with_nms(threshold=nms_thresh, class_agnostic=False) |
| | |
| | return detections, image_bgr |
| | |
| | def annotate(self, image: np.ndarray, detections: sv.Detections) -> np.ndarray: |
| | """Annotate image with detection results.""" |
| | box_annotator = sv.OrientedBoxAnnotator( |
| | color=self.config.annotation_color, |
| | thickness=self.config.annotation_thickness |
| | ) |
| | annotated = box_annotator.annotate(scene=image.copy(), detections=detections) |
| | return annotated |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class DetectionStats: |
| | """Generate statistics from detections.""" |
| | |
| | @staticmethod |
| | def generate_summary(detections: sv.Detections) -> str: |
| | """Generate a detailed summary of detections.""" |
| | total = len(detections) |
| | |
| | if total == 0: |
| | return "No detections found." |
| | |
| | summary = [f"**Total Detections:** {total}"] |
| | |
| | |
| | if detections.confidence is not None and len(detections.confidence) > 0: |
| | avg_conf = np.mean(detections.confidence) |
| | min_conf = np.min(detections.confidence) |
| | max_conf = np.max(detections.confidence) |
| | summary.append(f"\n**Confidence Statistics:**") |
| | summary.append(f"- Average: {avg_conf:.3f}") |
| | summary.append(f"- Min: {min_conf:.3f}") |
| | summary.append(f"- Max: {max_conf:.3f}") |
| | |
| | |
| | if detections.class_id is not None and len(np.unique(detections.class_id)) > 1: |
| | summary.append(f"\n**Class Distribution:**") |
| | unique, counts = np.unique(detections.class_id, return_counts=True) |
| | for cls_id, count in zip(unique, counts): |
| | summary.append(f"- Class {cls_id}: {count} detections") |
| | |
| | return "\n".join(summary) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GradioApp: |
| | """Gradio interface for the detection app.""" |
| | |
| | def __init__(self, model_manager: ModelManager, config: DetectionConfig): |
| | self.pipeline = DetectionPipeline(model_manager, config) |
| | self.stats = DetectionStats() |
| | |
| | @spaces.GPU |
| | def process_image( |
| | self, |
| | image: Optional[np.ndarray], |
| | confidence: float |
| | ) -> Tuple[Optional[Image.Image], str]: |
| | """ |
| | Process image and return annotated result. |
| | |
| | Args: |
| | image: Input image |
| | confidence: Confidence threshold |
| | |
| | Returns: |
| | Tuple of (annotated_image, summary_text) |
| | """ |
| | if image is None: |
| | return None, "Please upload an image." |
| | |
| | try: |
| | |
| | detections, image_bgr = self.pipeline.detect( |
| | image, |
| | confidence_threshold=confidence, |
| | nms_threshold=None |
| | ) |
| | |
| | |
| | annotated = self.pipeline.annotate(image_bgr, detections) |
| | |
| | |
| | annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) |
| | |
| | |
| | summary = self.stats.generate_summary(detections) |
| | |
| | return Image.fromarray(annotated_rgb), summary |
| | |
| | except Exception as e: |
| | return None, f"Error during detection: {str(e)}" |
| | |
| | def create_interface(self) -> gr.Blocks: |
| | """Create and configure the Gradio interface.""" |
| | |
| | with gr.Blocks(theme=gr.themes.Soft(), title="ICC Detection Tool") as demo: |
| | |
| | gr.Markdown( |
| | """ |
| | # π¬ Interstitial Cell of Cajal Detection and Quantification Tool |
| | |
| | Upload a microscopy image to automatically detect and count Interstitial Cells of Cajal (ICC). |
| | """ |
| | ) |
| | |
| | with gr.Row(): |
| | |
| | with gr.Column(scale=1): |
| | |
| | gr.Markdown("### π Select an Example Image") |
| | |
| | |
| | input_img = gr.Image( |
| | label="", |
| | type="numpy", |
| | interactive=True, |
| | show_label=False |
| | ) |
| | |
| | example_root = os.path.dirname(__file__) |
| | example_images = [ |
| | os.path.join(example_root, file) |
| | for file in os.listdir(example_root) |
| | if file.lower().endswith(('.jpg', '.jpeg', '.png')) |
| | ] |
| | |
| | if example_images: |
| | gr.Examples( |
| | examples=example_images, |
| | inputs=[input_img], |
| | ) |
| | |
| | |
| | gr.Markdown("### π€ Or Upload Your Own Image") |
| | |
| | |
| | with gr.Accordion("βοΈ Advanced Settings", open=False): |
| | confidence_slider = gr.Slider( |
| | minimum=0.01, |
| | maximum=1.0, |
| | value=0.1, |
| | step=0.01, |
| | label="Confidence Threshold", |
| | info="Minimum confidence for a detection" |
| | ) |
| | |
| | |
| | with gr.Row(): |
| | clear_btn = gr.Button("ποΈ Clear", variant="secondary") |
| | detect_btn = gr.Button("π Detect", variant="primary", size="lg") |
| | |
| | |
| | with gr.Column(scale=1): |
| | gr.Markdown("### π Results") |
| | output_img = gr.Image( |
| | label="Detection Result", |
| | interactive=False |
| | ) |
| | detection_summary = gr.Markdown( |
| | value="Results will appear here...", |
| | label="Detection Summary" |
| | ) |
| | |
| | |
| | gr.Markdown( |
| | """ |
| | --- |
| | **Note:** This tool uses a YOLO-based model for cell detection with sliced inference for high-resolution images. |
| | """ |
| | ) |
| | |
| | |
| | def reset_interface(): |
| | return None, None, "Results will appear here...", 0.1 |
| | |
| | clear_btn.click( |
| | fn=reset_interface, |
| | inputs=None, |
| | outputs=[input_img, output_img, detection_summary, confidence_slider] |
| | ) |
| | |
| | detect_btn.click( |
| | fn=self.process_image, |
| | inputs=[input_img, confidence_slider], |
| | outputs=[output_img, detection_summary] |
| | ) |
| | |
| | return demo |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | """Initialize and launch the application.""" |
| | try: |
| | |
| | config = DetectionConfig() |
| | |
| | |
| | print("Loading model...") |
| | model_manager = ModelManager(repo_id='edeler/ICC', model_filename='best.pt') |
| | |
| | |
| | print("Initializing interface...") |
| | app = GradioApp(model_manager, config) |
| | demo = app.create_interface() |
| | |
| | print("Launching application...") |
| | demo.launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | show_error=True |
| | ) |
| | |
| | except Exception as e: |
| | print(f"Failed to start application: {str(e)}") |
| | raise |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |