edeler commited on
Commit
aefbd99
Β·
verified Β·
1 Parent(s): d910516

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -239
app.py CHANGED
@@ -7,305 +7,245 @@ from ultralytics import YOLO
7
  import supervision as sv
8
  from PIL import Image
9
  from huggingface_hub import snapshot_download
10
- from functools import lru_cache
11
- from typing import Tuple, Optional
12
  import spaces
 
 
13
 
14
- # Constants
15
- DEFAULT_CONFIDENCE_THRESHOLD = 0.1
16
- DEFAULT_NMS_THRESHOLD = 0.0
17
- DEFAULT_SLICE_WIDTH = 1024
18
- DEFAULT_SLICE_HEIGHT = 1024
19
- DEFAULT_OVERLAP_WIDTH = 0
20
- DEFAULT_OVERLAP_HEIGHT = 0
 
 
 
 
21
  ANNOTATION_COLOR = sv.Color.RED
22
  ANNOTATION_THICKNESS = 4
23
- DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
24
- REPO_ID = 'edeler/ICC'
25
-
26
- # Global model cache
27
- _model_cache = None
28
-
29
- @lru_cache(maxsize=1)
30
- def download_model() -> str:
31
- """Download and cache model from Hugging Face Hub."""
32
- model_dir = snapshot_download(REPO_ID, local_dir='./models/ICC')
33
- return os.path.join(model_dir, "best.pt")
34
-
35
- def load_model() -> YOLO:
36
- """Lazy-load and cache the YOLO model."""
37
- global _model_cache
38
- if _model_cache is None:
39
- model_path = download_model()
40
- _model_cache = YOLO(model_path).to(DEVICE)
41
- return _model_cache
42
-
43
- def validate_image(image: Optional[np.ndarray]) -> bool:
44
- """Validate input image."""
45
- if image is None:
46
- return False
47
- if not isinstance(image, np.ndarray):
48
- return False
49
- if image.size == 0:
50
- return False
51
- return True
52
 
53
- def preprocess_image(image: np.ndarray) -> np.ndarray:
54
- """Convert image to BGR format if needed."""
55
- # Gradio provides RGB images, convert to BGR for OpenCV/YOLO
56
- if len(image.shape) == 3 and image.shape[-1] == 3:
57
- # Check if already BGR or RGB by inspecting typical color distributions
58
- return cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
59
- return image
60
 
61
- def create_detection_callback(model: YOLO, confidence_threshold: float):
62
- """Factory function to create detection callback with closure over parameters."""
63
- def callback(image_slice: np.ndarray) -> sv.Detections:
64
- with torch.no_grad(): # Disable gradient computation for inference
65
- result = model(image_slice, verbose=False)[0]
66
- detections = sv.Detections.from_ultralytics(result)
67
- return detections[detections.confidence >= confidence_threshold]
68
- return callback
69
-
70
- def perform_detection(
71
- image: np.ndarray,
72
- confidence_threshold: float,
73
- nms_threshold: float,
74
- slice_width: int,
75
- slice_height: int,
76
- overlap_width: int,
77
- overlap_height: int
78
- ) -> Tuple[sv.Detections, int]:
79
- """Perform object detection with slicing and NMS."""
80
- model = load_model()
81
-
82
- # Create slicer with callback
83
- callback = create_detection_callback(model, confidence_threshold)
84
- slicer = sv.InferenceSlicer(
85
- callback=callback,
86
- slice_wh=(slice_width, slice_height),
87
- overlap_wh=(overlap_width, overlap_height),
88
- overlap_ratio_wh=None
89
- )
90
-
91
- # Perform slicing-based inference
92
- detections = slicer(image)
93
-
94
- # Apply NMS
95
- if nms_threshold > 0:
96
- detections = detections.with_nms(threshold=nms_threshold, class_agnostic=False)
97
-
98
- return detections, len(detections)
99
 
100
- def annotate_image(image: np.ndarray, detections: sv.Detections) -> np.ndarray:
101
- """Annotate image with detection bounding boxes."""
102
- box_annotator = sv.OrientedBoxAnnotator(
103
- color=ANNOTATION_COLOR,
104
- thickness=ANNOTATION_THICKNESS
105
- )
106
- return box_annotator.annotate(scene=image.copy(), detections=detections)
107
 
108
- @spaces.GPU
109
- def detect_objects(
110
- image: Optional[np.ndarray],
111
- confidence_threshold: float,
112
- nms_threshold: float,
113
- slice_width: int,
114
- slice_height: int,
115
- overlap_width: int,
116
- overlap_height: int
117
- ) -> Tuple[Optional[Image.Image], str]:
118
  """
119
- Main detection function for Gradio interface.
120
 
121
  Args:
122
- image: Input image as numpy array
123
- confidence_threshold: Minimum confidence for detections
124
- nms_threshold: IoU threshold for non-maximum suppression
125
- slice_width: Width of each detection slice
126
- slice_height: Height of each detection slice
127
- overlap_width: Overlap width between slices
128
- overlap_height: Overlap height between slices
129
-
130
  Returns:
131
- Tuple of (annotated image, detection summary)
132
  """
 
 
 
 
133
  try:
134
- # Validate input
135
- if not validate_image(image):
136
- return None, "⚠️ Please upload a valid image."
137
 
138
- # Preprocess image
139
- image_bgr = preprocess_image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- # Perform detection
142
- detections, total_count = perform_detection(
143
- image_bgr,
144
- confidence_threshold,
145
- nms_threshold,
146
- slice_width,
147
- slice_height,
148
- overlap_width,
149
- overlap_height
150
  )
151
 
 
 
 
 
152
  # Annotate image
153
- annotated_img = annotate_image(image_bgr, detections)
 
 
 
 
 
 
 
154
 
155
  # Convert back to RGB for display
156
- annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
157
 
158
- # Create summary with additional info
159
- summary = f"βœ… Total Detections: {total_count}"
160
- if total_count > 0:
161
- avg_confidence = np.mean(detections.confidence) if len(detections.confidence) > 0 else 0
162
- summary += f"\nπŸ“Š Average Confidence: {avg_confidence:.2%}"
163
 
164
- return Image.fromarray(annotated_img_rgb), summary
165
 
166
  except Exception as e:
167
- return None, f"❌ Error during detection: {str(e)}"
 
 
168
 
169
  def get_example_images() -> list:
170
  """Get list of example images from the current directory."""
171
- example_root = os.path.dirname(__file__)
172
- return [
173
- os.path.join(example_root, file)
174
- for file in os.listdir(example_root)
175
- if file.lower().endswith(('.jpg', '.jpeg', '.png'))
176
- ]
 
 
 
 
 
177
 
178
- def create_interface() -> gr.Blocks:
179
  """Create and configure the Gradio interface."""
180
- with gr.Blocks(title="ICC Detection Tool") as demo:
 
 
 
 
 
 
 
 
 
181
  gr.Markdown(
182
  """
183
  # πŸ”¬ Interstitial Cell of Cajal Detection and Quantification Tool
184
 
185
- Upload an image to detect and count Interstitial Cells of Cajal (ICC).
186
- Adjust the parameters below for fine-tuned detection.
187
- """
 
188
  )
189
 
190
  with gr.Row():
191
  with gr.Column(scale=1):
192
  input_img = gr.Image(
193
- label="πŸ“€ Upload an Image",
194
  type="numpy",
195
- interactive=True
196
  )
197
 
198
- # Advanced settings in accordion
199
- with gr.Accordion("βš™οΈ Advanced Settings", open=False):
200
- confidence_slider = gr.Slider(
201
- minimum=0.01,
202
- maximum=1.0,
203
- value=DEFAULT_CONFIDENCE_THRESHOLD,
204
- step=0.01,
205
- label="Confidence Threshold",
206
- info="Minimum confidence for detections"
207
- )
208
- nms_slider = gr.Slider(
209
- minimum=0.0,
210
- maximum=1.0,
211
- value=DEFAULT_NMS_THRESHOLD,
212
- step=0.05,
213
- label="NMS Threshold",
214
- info="IoU threshold for non-maximum suppression"
215
- )
216
-
217
- with gr.Row():
218
- slice_width = gr.Number(
219
- value=DEFAULT_SLICE_WIDTH,
220
- label="Slice Width",
221
- precision=0
222
- )
223
- slice_height = gr.Number(
224
- value=DEFAULT_SLICE_HEIGHT,
225
- label="Slice Height",
226
- precision=0
227
- )
228
-
229
- with gr.Row():
230
- overlap_width = gr.Number(
231
- value=DEFAULT_OVERLAP_WIDTH,
232
- label="Overlap Width",
233
- precision=0
234
- )
235
- overlap_height = gr.Number(
236
- value=DEFAULT_OVERLAP_HEIGHT,
237
- label="Overlap Height",
238
- precision=0
239
- )
240
-
241
  with gr.Row():
242
- clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
243
- detect_btn = gr.Button("πŸ” Detect", variant="primary")
 
 
 
 
 
 
 
 
 
 
244
 
245
  with gr.Column(scale=1):
246
  output_img = gr.Image(
247
- label="πŸ“Š Detection Result",
248
- interactive=False
 
249
  )
250
  detection_count = gr.Textbox(
251
- label="πŸ“ˆ Detection Summary",
252
  interactive=False,
253
- lines=3
254
  )
255
 
256
- # Examples section
257
- example_images = get_example_images()
258
- if example_images:
259
- with gr.Accordion("πŸ“ Example Images", open=False):
260
- gr.Examples(
261
- examples=[[img] for img in example_images],
262
- inputs=[input_img],
263
- label="Click to load"
264
- )
 
 
 
265
 
266
- # Button actions
267
- detect_btn.click(
268
- fn=detect_objects,
269
- inputs=[
270
- input_img,
271
- confidence_slider,
272
- nms_slider,
273
- slice_width,
274
- slice_height,
275
- overlap_width,
276
- overlap_height
277
- ],
278
- outputs=[output_img, detection_count]
279
- )
280
 
281
  clear_btn.click(
282
- fn=lambda: (None, None, ""),
283
  inputs=None,
284
  outputs=[input_img, output_img, detection_count]
285
  )
286
 
287
- gr.Markdown(
288
- """
289
- ---
290
- ### πŸ’‘ Tips:
291
- - **Confidence Threshold**: Lower values detect more objects but may include false positives
292
- - **NMS Threshold**: Higher values keep more overlapping boxes
293
- - **Slice Settings**: Adjust for large images or when objects are at different scales
294
- """
 
 
 
 
295
  )
296
 
297
  return demo
298
 
 
299
  if __name__ == "__main__":
300
- # Preload model on startup for faster first inference
301
- print(f"πŸš€ Loading model on {DEVICE}...")
302
- load_model()
303
- print("βœ… Model loaded successfully!")
304
-
305
- # Create and launch interface
306
- demo = create_interface()
307
- demo.launch(
308
- server_name="0.0.0.0",
309
- server_port=7860,
310
- show_error=True
311
- )
 
7
  import supervision as sv
8
  from PIL import Image
9
  from huggingface_hub import snapshot_download
 
 
10
  import spaces
11
+ from typing import Tuple, Optional
12
+ import logging
13
 
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Detection parameters
19
+ CONFIDENCE_THRESHOLD = 0.1
20
+ NMS_THRESHOLD = 0.0
21
+ SLICE_WIDTH = 1024
22
+ SLICE_HEIGHT = 1024
23
+ OVERLAP_WIDTH = 0
24
+ OVERLAP_HEIGHT = 0
25
  ANNOTATION_COLOR = sv.Color.RED
26
  ANNOTATION_THICKNESS = 4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # Device configuration
29
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+ logger.info(f"Using device: {device}")
 
 
 
 
31
 
32
+ # Model initialization
33
+ def load_model():
34
+ """Load YOLO model from Hugging Face Hub with error handling."""
35
+ try:
36
+ repo_id = 'edeler/ICC'
37
+ logger.info(f"Downloading model from {repo_id}...")
38
+ model_dir = snapshot_download(repo_id, local_dir='./models/ICC')
39
+ model_path = os.path.join(model_dir, "best.pt")
40
+
41
+ if not os.path.exists(model_path):
42
+ raise FileNotFoundError(f"Model file not found at {model_path}")
43
+
44
+ logger.info(f"Loading model from {model_path}...")
45
+ model = YOLO(model_path)
46
+ model.to(device)
47
+ logger.info("Model loaded successfully")
48
+ return model
49
+ except Exception as e:
50
+ logger.error(f"Error loading model: {str(e)}")
51
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # Load model once at startup
54
+ model = load_model()
 
 
 
 
 
55
 
56
+ @spaces.GPU(duration=60) # Allocate GPU for up to 60 seconds
57
+ def detect_objects(image: Optional[np.ndarray]) -> Tuple[Optional[Image.Image], str]:
 
 
 
 
 
 
 
 
58
  """
59
+ Detect objects in the input image using YOLO model with sliced inference.
60
 
61
  Args:
62
+ image: Input image as numpy array (RGB format from Gradio)
63
+
 
 
 
 
 
 
64
  Returns:
65
+ Tuple of (annotated PIL Image, detection summary string)
66
  """
67
+ # Validate input
68
+ if image is None:
69
+ return None, "⚠️ Please upload an image first."
70
+
71
  try:
72
+ # Convert RGB (from Gradio) to BGR (for OpenCV/YOLO)
73
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
 
74
 
75
+ # Define callback for sliced inference
76
+ def inference_callback(image_slice: np.ndarray) -> sv.Detections:
77
+ """Process each image slice."""
78
+ result = model(image_slice, verbose=False)[0]
79
+ detections = sv.Detections.from_ultralytics(result)
80
+ # Filter by confidence threshold
81
+ return detections[detections.confidence >= CONFIDENCE_THRESHOLD]
82
+
83
+ # Initialize slicer
84
+ slicer = sv.InferenceSlicer(
85
+ callback=inference_callback,
86
+ slice_wh=(SLICE_WIDTH, SLICE_HEIGHT),
87
+ overlap_wh=(OVERLAP_WIDTH, OVERLAP_HEIGHT),
88
+ overlap_ratio_wh=None
89
+ )
90
+
91
+ # Perform inference
92
+ logger.info("Running detection...")
93
+ detections = slicer(image_bgr)
94
 
95
+ # Apply NMS to remove duplicate detections
96
+ detections = detections.with_nms(
97
+ threshold=NMS_THRESHOLD,
98
+ class_agnostic=False
 
 
 
 
 
99
  )
100
 
101
+ # Count detections
102
+ total_detections = len(detections)
103
+ logger.info(f"Found {total_detections} detections")
104
+
105
  # Annotate image
106
+ box_annotator = sv.OrientedBoxAnnotator(
107
+ color=ANNOTATION_COLOR,
108
+ thickness=ANNOTATION_THICKNESS
109
+ )
110
+ annotated_img_bgr = box_annotator.annotate(
111
+ scene=image_bgr.copy(),
112
+ detections=detections
113
+ )
114
 
115
  # Convert back to RGB for display
116
+ annotated_img_rgb = cv2.cvtColor(annotated_img_bgr, cv2.COLOR_BGR2RGB)
117
 
118
+ # Create result message
119
+ result_msg = f"βœ… Total Detections: {total_detections}"
 
 
 
120
 
121
+ return Image.fromarray(annotated_img_rgb), result_msg
122
 
123
  except Exception as e:
124
+ error_msg = f"❌ Error during detection: {str(e)}"
125
+ logger.error(error_msg)
126
+ return None, error_msg
127
 
128
  def get_example_images() -> list:
129
  """Get list of example images from the current directory."""
130
+ try:
131
+ example_root = os.path.dirname(__file__) or "."
132
+ example_images = [
133
+ os.path.join(example_root, file)
134
+ for file in os.listdir(example_root)
135
+ if file.lower().endswith((".jpg", ".jpeg", ".png"))
136
+ ]
137
+ return example_images[:10] # Limit to 10 examples
138
+ except Exception as e:
139
+ logger.warning(f"Could not load example images: {str(e)}")
140
+ return []
141
 
142
+ def create_interface():
143
  """Create and configure the Gradio interface."""
144
+
145
+ with gr.Blocks(
146
+ title="ICC Detection Tool",
147
+ theme=gr.themes.Soft(),
148
+ css="""
149
+ .gradio-container {max-width: 1200px !important}
150
+ #title {text-align: center; color: #2563eb;}
151
+ """
152
+ ) as demo:
153
+
154
  gr.Markdown(
155
  """
156
  # πŸ”¬ Interstitial Cell of Cajal Detection and Quantification Tool
157
 
158
+ Upload an image to detect and quantify Interstitial Cells of Cajal (ICC).
159
+ The model uses advanced YOLO detection with sliced inference for accurate results.
160
+ """,
161
+ elem_id="title"
162
  )
163
 
164
  with gr.Row():
165
  with gr.Column(scale=1):
166
  input_img = gr.Image(
167
+ label="πŸ“€ Upload Image",
168
  type="numpy",
169
+ height=400
170
  )
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  with gr.Row():
173
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary", scale=1)
174
+ detect_btn = gr.Button("πŸ” Detect", variant="primary", scale=2)
175
+
176
+ # Example images
177
+ example_images = get_example_images()
178
+ if example_images:
179
+ with gr.Accordion("πŸ“ Example Images", open=False):
180
+ gr.Examples(
181
+ examples=[[img] for img in example_images],
182
+ inputs=[input_img],
183
+ label=None
184
+ )
185
 
186
  with gr.Column(scale=1):
187
  output_img = gr.Image(
188
+ label="✨ Detection Result",
189
+ type="pil",
190
+ height=400
191
  )
192
  detection_count = gr.Textbox(
193
+ label="πŸ“Š Detection Summary",
194
  interactive=False,
195
+ lines=2
196
  )
197
 
198
+ # Model information
199
+ with gr.Accordion("ℹ️ Model Information", open=False):
200
+ gr.Markdown(
201
+ f"""
202
+ **Configuration:**
203
+ - Confidence Threshold: {CONFIDENCE_THRESHOLD}
204
+ - NMS Threshold: {NMS_THRESHOLD}
205
+ - Slice Size: {SLICE_WIDTH}x{SLICE_HEIGHT}
206
+ - Device: {device.upper()}
207
+ - Model: edeler/ICC
208
+ """
209
+ )
210
 
211
+ # Event handlers
212
+ def reset_interface():
213
+ """Reset all components to initial state."""
214
+ return None, None, ""
 
 
 
 
 
 
 
 
 
 
215
 
216
  clear_btn.click(
217
+ fn=reset_interface,
218
  inputs=None,
219
  outputs=[input_img, output_img, detection_count]
220
  )
221
 
222
+ detect_btn.click(
223
+ fn=detect_objects,
224
+ inputs=[input_img],
225
+ outputs=[output_img, detection_count],
226
+ api_name="detect"
227
+ )
228
+
229
+ # Allow Enter key to trigger detection
230
+ input_img.upload(
231
+ fn=lambda: "⏳ Image uploaded. Click 'Detect' to start...",
232
+ inputs=None,
233
+ outputs=detection_count
234
  )
235
 
236
  return demo
237
 
238
+ # Main execution
239
  if __name__ == "__main__":
240
+ try:
241
+ demo = create_interface()
242
+ demo.queue(max_size=20) # Enable queuing for better handling of concurrent requests
243
+ demo.launch(
244
+ server_name="0.0.0.0",
245
+ server_port=7860,
246
+ share=False, # Set to True if you want a public link
247
+ show_error=True
248
+ )
249
+ except Exception as e:
250
+ logger.error(f"Failed to launch app: {str(e)}")
251
+ raise