File size: 13,663 Bytes
771674d
 
 
 
 
b1f8229
 
 
 
 
771674d
 
b1f8229
 
690f42c
 
5006b04
690f42c
771674d
 
 
c151ef0
771674d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690f42c
771674d
 
 
 
 
 
aefbd99
771674d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aefbd99
771674d
 
8c7fb6f
aefbd99
 
771674d
aefbd99
690f42c
771674d
 
aefbd99
771674d
 
 
 
aefbd99
771674d
 
aefbd99
771674d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690f42c
771674d
 
690f42c
771674d
690f42c
771674d
 
 
 
 
 
 
 
 
690f42c
771674d
 
 
 
 
 
 
 
690f42c
 
771674d
 
 
 
 
 
aefbd99
771674d
 
 
 
 
 
 
 
1245792
771674d
aefbd99
771674d
 
 
 
 
aefbd99
771674d
 
 
 
 
 
 
1245792
771674d
 
 
1245792
771674d
690f42c
771674d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690f42c
771674d
 
 
 
 
690f42c
771674d
 
 
 
 
 
 
e349241
 
 
 
771674d
e349241
771674d
e349241
 
771674d
 
1245792
 
 
 
 
 
 
 
 
 
 
 
 
e349241
1245792
 
e349241
 
 
 
 
 
 
 
 
 
771674d
 
 
 
 
aefbd99
771674d
 
 
 
 
 
 
 
 
 
 
690f42c
771674d
aefbd99
771674d
 
 
aefbd99
 
771674d
 
 
1245792
771674d
 
 
 
1245792
771674d
 
 
 
1245792
771674d
 
690f42c
771674d
 
 
 
 
 
 
 
 
 
 
 
690f42c
771674d
 
 
690f42c
771674d
 
 
 
aefbd99
771674d
aefbd99
 
 
 
 
771674d
aefbd99
771674d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
"""
Interstitial Cell of Cajal Detection and Quantification Tool
A Gradio app for detecting and counting cells in microscopy images using YOLO.
"""

import os
import cv2
import torch
import gradio as gr
import numpy as np
from typing import Tuple, Optional
from dataclasses import dataclass
from ultralytics import YOLO
import supervision as sv
from PIL import Image
from huggingface_hub import snapshot_download
import spaces

# ============================================================================
# Configuration
# ============================================================================

@dataclass
class DetectionConfig:
    """Configuration for detection parameters."""
    confidence_threshold: float = 0.1
    nms_threshold: float = 0.0
    slice_width: int = 1024
    slice_height: int = 1024
    overlap_width: int = 0
    overlap_height: int = 0
    annotation_color: sv.Color = sv.Color.RED
    annotation_thickness: int = 4


# ============================================================================
# Model Management
# ============================================================================

class ModelManager:
    """Manages model loading and inference."""
    
    def __init__(self, repo_id: str = 'edeler/ICC', model_filename: str = 'best.pt'):
        self.repo_id = repo_id
        self.model_filename = model_filename
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model = None
        self._load_model()
    
    def _load_model(self):
        """Download and load the YOLO model."""
        try:
            model_dir = snapshot_download(self.repo_id, local_dir='./models/ICC')
            model_path = os.path.join(model_dir, self.model_filename)
            
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Model file not found: {model_path}")
            
            self.model = YOLO(model_path).to(self.device)
            print(f"Model loaded successfully on {self.device}")
        except Exception as e:
            raise RuntimeError(f"Failed to load model: {str(e)}")
    
    def predict(self, image: np.ndarray) -> sv.Detections:
        """Run inference on an image."""
        if self.model is None:
            raise RuntimeError("Model not loaded")
        
        result = self.model(image)[0]
        return sv.Detections.from_ultralytics(result)


# ============================================================================
# Detection Pipeline
# ============================================================================

class DetectionPipeline:
    """Handles the complete detection pipeline."""
    
    def __init__(self, model_manager: ModelManager, config: DetectionConfig):
        self.model_manager = model_manager
        self.config = config
    
    def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
        """Convert image to BGR format if needed."""
        if len(image.shape) == 2:  # Grayscale
            return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
        elif image.shape[-1] == 3:  # RGB
            return cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        elif image.shape[-1] == 4:  # RGBA
            return cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
        return image
    
    def _create_callback(self, confidence_threshold: float):
        """Create callback function for slice-based inference."""
        def callback(image_slice: np.ndarray) -> sv.Detections:
            detections = self.model_manager.predict(image_slice)
            return detections[detections.confidence >= confidence_threshold]
        return callback
    
    def detect(
        self,
        image: np.ndarray,
        confidence_threshold: Optional[float] = None,
        nms_threshold: Optional[float] = None
    ) -> Tuple[sv.Detections, np.ndarray]:
        """
        Perform detection on an image.
        
        Args:
            image: Input image as numpy array
            confidence_threshold: Override default confidence threshold
            nms_threshold: Override default NMS threshold
        
        Returns:
            Tuple of (detections, processed_image)
        """
        # Use provided thresholds or defaults
        conf_thresh = confidence_threshold if confidence_threshold is not None else self.config.confidence_threshold
        nms_thresh = nms_threshold if nms_threshold is not None else self.config.nms_threshold
        
        # Preprocess image
        image_bgr = self._preprocess_image(image)
        
        # Initialize slicer with callback
        slicer = sv.InferenceSlicer(
            callback=self._create_callback(conf_thresh),
            slice_wh=(self.config.slice_width, self.config.slice_height),
            overlap_wh=(self.config.overlap_width, self.config.overlap_height)
        )
        
        # Perform slicing-based inference
        detections = slicer(image_bgr)
        
        # Apply Non-Maximum Suppression
        detections = detections.with_nms(threshold=nms_thresh, class_agnostic=False)
        
        return detections, image_bgr
    
    def annotate(self, image: np.ndarray, detections: sv.Detections) -> np.ndarray:
        """Annotate image with detection results."""
        box_annotator = sv.OrientedBoxAnnotator(
            color=self.config.annotation_color,
            thickness=self.config.annotation_thickness
        )
        annotated = box_annotator.annotate(scene=image.copy(), detections=detections)
        return annotated


# ============================================================================
# Statistics and Reporting
# ============================================================================

class DetectionStats:
    """Generate statistics from detections."""
    
    @staticmethod
    def generate_summary(detections: sv.Detections) -> str:
        """Generate a detailed summary of detections."""
        total = len(detections)
        
        if total == 0:
            return "No detections found."
        
        summary = [f"**Total Detections:** {total}"]
        
        # Confidence statistics
        if detections.confidence is not None and len(detections.confidence) > 0:
            avg_conf = np.mean(detections.confidence)
            min_conf = np.min(detections.confidence)
            max_conf = np.max(detections.confidence)
            summary.append(f"\n**Confidence Statistics:**")
            summary.append(f"- Average: {avg_conf:.3f}")
            summary.append(f"- Min: {min_conf:.3f}")
            summary.append(f"- Max: {max_conf:.3f}")
        
        # Class distribution (if multiple classes)
        if detections.class_id is not None and len(np.unique(detections.class_id)) > 1:
            summary.append(f"\n**Class Distribution:**")
            unique, counts = np.unique(detections.class_id, return_counts=True)
            for cls_id, count in zip(unique, counts):
                summary.append(f"- Class {cls_id}: {count} detections")
        
        return "\n".join(summary)


# ============================================================================
# Gradio Interface
# ============================================================================

class GradioApp:
    """Gradio interface for the detection app."""
    
    def __init__(self, model_manager: ModelManager, config: DetectionConfig):
        self.pipeline = DetectionPipeline(model_manager, config)
        self.stats = DetectionStats()
    
    @spaces.GPU
    def process_image(
        self,
        image: Optional[np.ndarray],
        confidence: float
    ) -> Tuple[Optional[Image.Image], str]:
        """
        Process image and return annotated result.
        
        Args:
            image: Input image
            confidence: Confidence threshold
        
        Returns:
            Tuple of (annotated_image, summary_text)
        """
        if image is None:
            return None, "Please upload an image."
        
        try:
            # Perform detection (NMS threshold from config)
            detections, image_bgr = self.pipeline.detect(
                image,
                confidence_threshold=confidence,
                nms_threshold=None  # Use default from config
            )
            
            # Annotate image
            annotated = self.pipeline.annotate(image_bgr, detections)
            
            # Convert to RGB for display
            annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
            
            # Generate summary
            summary = self.stats.generate_summary(detections)
            
            return Image.fromarray(annotated_rgb), summary
            
        except Exception as e:
            return None, f"Error during detection: {str(e)}"
    
    def create_interface(self) -> gr.Blocks:
        """Create and configure the Gradio interface."""
        
        with gr.Blocks(theme=gr.themes.Soft(), title="ICC Detection Tool") as demo:
            # Header
            gr.Markdown(
                """
                # πŸ”¬ Interstitial Cell of Cajal Detection and Quantification Tool
                
                Upload a microscopy image to automatically detect and count Interstitial Cells of Cajal (ICC).
                """
            )
            
            with gr.Row():
                # Left column - Input and controls
                with gr.Column(scale=1):
                    # Examples section - FIRST
                    gr.Markdown("### πŸ“‹ Select an Example Image")
                    
                    # Create input_img component first (needed for Examples)
                    input_img = gr.Image(
                        label="",
                        type="numpy",
                        interactive=True,
                        show_label=False
                    )
                    
                    example_root = os.path.dirname(__file__)
                    example_images = [
                        os.path.join(example_root, file)
                        for file in os.listdir(example_root)
                        if file.lower().endswith(('.jpg', '.jpeg', '.png'))
                    ]
                    
                    if example_images:
                        gr.Examples(
                            examples=example_images,
                            inputs=[input_img],
                        )
                    
                    # Upload section - SECOND
                    gr.Markdown("### πŸ“€ Or Upload Your Own Image")
                    
                    # Detection parameters - COLLAPSIBLE
                    with gr.Accordion("βš™οΈ Advanced Settings", open=False):
                        confidence_slider = gr.Slider(
                            minimum=0.01,
                            maximum=1.0,
                            value=0.1,
                            step=0.01,
                            label="Confidence Threshold",
                            info="Minimum confidence for a detection"
                        )
                    
                    # Action buttons
                    with gr.Row():
                        clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
                        detect_btn = gr.Button("πŸ” Detect", variant="primary", size="lg")
                
                # Right column - Output
                with gr.Column(scale=1):
                    gr.Markdown("### πŸ“Š Results")
                    output_img = gr.Image(
                        label="Detection Result",
                        interactive=False
                    )
                    detection_summary = gr.Markdown(
                        value="Results will appear here...",
                        label="Detection Summary"
                    )
            
            # Footer
            gr.Markdown(
                """
                ---
                **Note:** This tool uses a YOLO-based model for cell detection with sliced inference for high-resolution images.
                """
            )
            
            # Event handlers
            def reset_interface():
                return None, None, "Results will appear here...", 0.1
            
            clear_btn.click(
                fn=reset_interface,
                inputs=None,
                outputs=[input_img, output_img, detection_summary, confidence_slider]
            )
            
            detect_btn.click(
                fn=self.process_image,
                inputs=[input_img, confidence_slider],
                outputs=[output_img, detection_summary]
            )
        
        return demo


# ============================================================================
# Main Application Entry Point
# ============================================================================

def main():
    """Initialize and launch the application."""
    try:
        # Initialize configuration
        config = DetectionConfig()
        
        # Initialize model manager
        print("Loading model...")
        model_manager = ModelManager(repo_id='edeler/ICC', model_filename='best.pt')
        
        # Create and launch Gradio app
        print("Initializing interface...")
        app = GradioApp(model_manager, config)
        demo = app.create_interface()
        
        print("Launching application...")
        demo.launch(
            server_name="0.0.0.0",
            server_port=7860,
            show_error=True
        )
        
    except Exception as e:
        print(f"Failed to start application: {str(e)}")
        raise


if __name__ == "__main__":
    main()