edeler commited on
Commit
5006b04
·
verified ·
1 Parent(s): a0ac02e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -57
app.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import io
3
  import cv2
4
  import sys
5
- import math
6
  import json
7
  import torch
8
  import gradio as gr
@@ -10,27 +9,31 @@ import numpy as np
10
  import pandas as pd
11
  from PIL import Image
12
  from typing import List, Tuple, Optional, Dict
13
-
14
  from ultralytics import YOLO
15
  import supervision as sv
16
  from huggingface_hub import hf_hub_download
17
- import spaces # ZeroGPU-compatible decorator
 
 
 
 
 
 
18
 
19
  # -----------------------------
20
  # Defaults / configuration
21
  # -----------------------------
22
- REPO_ID = "edeler/ICC" # <- your HF repo with weights
23
- WEIGHTS_FILENAME = "best.pt" # <- adjust if different
24
  LOCAL_MODEL_DIR = "./models/ICC"
25
- EXAMPLES_DIR = "." # scan current repo root for demo images
26
 
27
- # Reasonable defaults (tunable in UI)
28
  DEFAULT_CONF = 0.25
29
- DEFAULT_IOU = 0.50 # for both model-level and global cross-slice NMS
30
  DEFAULT_SLICE_WH = 1024
31
- DEFAULT_OVERLAP = 128 # helps stitching at tile borders
32
  DEFAULT_THICKNESS = 3
33
- DEFAULT_LONG_EDGE = 4096 # optional downscale for very large WSIs/crops
34
 
35
  # -----------------------------
36
  # Torch / device helpers
@@ -60,8 +63,7 @@ def load_model() -> Tuple[YOLO, Dict[int, str]]:
60
  weights_path = hf_hub_download(
61
  repo_id=REPO_ID,
62
  filename=WEIGHTS_FILENAME,
63
- local_dir=LOCAL_MODEL_DIR,
64
- local_dir_use_symlinks=False, # safer in Spaces
65
  )
66
  model = YOLO(weights_path)
67
  class_names = model.model.names if hasattr(model, "model") else model.names
@@ -72,7 +74,6 @@ def load_model() -> Tuple[YOLO, Dict[int, str]]:
72
  # Image utilities
73
  # -----------------------------
74
  def ensure_bgr(img: np.ndarray) -> np.ndarray:
75
- # Gradio provides RGB; OpenCV expects BGR
76
  if img.ndim == 3 and img.shape[2] == 3:
77
  return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
78
  return img
@@ -102,16 +103,9 @@ def run_sliced_inference(
102
  overlap_h: int,
103
  device: str,
104
  ) -> sv.Detections:
105
- """
106
- Uses supervision.InferenceSlicer to run model across tiles,
107
- then returns all detections (to be merged via cross-slice NMS).
108
- """
109
- # inner callback called by slicer
110
  @torch.inference_mode()
111
  def callback(tile_bgr: np.ndarray) -> sv.Detections:
112
- # Ultralytics expects RGB
113
  tile_rgb = cv2.cvtColor(tile_bgr, cv2.COLOR_BGR2RGB)
114
- # Predict with thresholds at model level (faster than filtering post-hoc)
115
  results = model.predict(
116
  source=tile_rgb,
117
  conf=conf,
@@ -120,9 +114,7 @@ def run_sliced_inference(
120
  verbose=False,
121
  half=half_precision_available(device),
122
  )
123
- res = results[0]
124
- det = sv.Detections.from_ultralytics(res)
125
- return det
126
 
127
  slicer = sv.InferenceSlicer(
128
  callback=callback,
@@ -131,7 +123,6 @@ def run_sliced_inference(
131
  overlap_ratio_wh=None,
132
  )
133
  detections = slicer(image_bgr)
134
- # Cross-slice NMS to merge duplicates at tile seams
135
  detections = detections.with_nms(threshold=iou, class_agnostic=False)
136
  return detections
137
 
@@ -150,11 +141,11 @@ def make_labels(det: sv.Detections, names: Dict[int, str], show_labels: bool) ->
150
  def detections_to_dataframe(det: sv.Detections, names: Dict[int, str]) -> pd.DataFrame:
151
  if len(det) == 0:
152
  return pd.DataFrame(columns=["class_id", "class_name", "confidence", "x_min", "y_min", "x_max", "y_max"])
153
- xyxy = det.xyxy # (N,4)
154
- data = []
155
  for i in range(len(det)):
156
  cls = int(det.class_id[i])
157
- data.append({
158
  "class_id": cls,
159
  "class_name": names.get(cls, str(cls)),
160
  "confidence": float(det.confidence[i]),
@@ -163,7 +154,7 @@ def detections_to_dataframe(det: sv.Detections, names: Dict[int, str]) -> pd.Dat
163
  "x_max": float(xyxy[i, 2]),
164
  "y_max": float(xyxy[i, 3]),
165
  })
166
- return pd.DataFrame(data)
167
 
168
  def per_class_summary(df: pd.DataFrame) -> str:
169
  if df.empty:
@@ -176,7 +167,7 @@ def per_class_summary(df: pd.DataFrame) -> str:
176
  # -----------------------------
177
  # Gradio inference function
178
  # -----------------------------
179
- @spaces.GPU # ensures ZeroGPU allocation during this call
180
  def detect_objects(
181
  image: np.ndarray,
182
  conf: float,
@@ -195,16 +186,13 @@ def detect_objects(
195
  if image is None:
196
  raise ValueError("Please upload or select an image.")
197
 
198
- # Prepare image (BGR) and optional downscale
199
  image_bgr = ensure_bgr(image)
200
- image_bgr, scale = maybe_downscale_long_edge(image_bgr, long_edge)
201
 
202
- # Load model + names lazily
203
  progress(0.05, desc="Loading model…")
204
  model, names = load_model()
205
  device = get_device()
206
 
207
- # Inference (sliced)
208
  progress(0.35, desc="Running sliced inference…")
209
  with torch.inference_mode():
210
  detections = run_sliced_inference(
@@ -219,59 +207,55 @@ def detect_objects(
219
  device=device,
220
  )
221
 
222
- # Optional class filtering (by names)
223
  if selected_classes:
224
  allow_ids = {cid for cid, cname in names.items() if cname in set(selected_classes)}
225
  if len(detections) > 0:
226
  mask = np.array([int(c) in allow_ids for c in detections.class_id], dtype=bool)
227
  detections = detections[mask]
228
 
229
- # Create labels (optional)
230
  labels = make_labels(detections, names, show_labels)
231
 
232
- # Annotate
233
  progress(0.65, desc="Annotating…")
234
  annotator = sv.BoxAnnotator(thickness=thickness)
235
  annotated = annotator.annotate(scene=image_bgr.copy(), detections=detections, labels=labels)
236
-
237
- # Convert to RGB for display
238
  annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
239
  annotated_pil = Image.fromarray(annotated_rgb)
240
 
241
- # Tabular output + downloadable CSV
242
  df = detections_to_dataframe(detections, names)
243
- csv_bytes = df.to_csv(index=False).encode("utf-8")
244
  summary = per_class_summary(df)
245
 
 
 
 
 
 
246
  progress(1.0, desc="Done.")
247
- return annotated_pil, summary, df, csv_bytes
248
 
249
  except Exception as e:
250
- # Return clean error in summary field
251
  empty_img = None
252
  empty_df = pd.DataFrame(columns=["class_id", "class_name", "confidence", "x_min", "y_min", "x_max", "y_max"])
253
- return empty_img, f"Error: {repr(e)}", empty_df, b""
254
 
255
  # -----------------------------
256
  # UI / App
257
  # -----------------------------
258
  def discover_examples(root: str) -> List[str]:
259
  exts = {".jpg", ".jpeg", ".png"}
260
- paths = []
261
  try:
262
  for fname in os.listdir(root):
263
  if os.path.splitext(fname.lower())[1] in exts:
264
- paths.append(os.path.join(root, fname))
265
  except Exception:
266
  pass
267
- # Keep at most a handful to keep UI tidy
268
- return sorted(paths)[:8]
269
 
270
  def reset_all():
271
  # image, output_img, summary, table, download
272
  return gr.update(value=None), gr.update(value=None), gr.update(value=""), pd.DataFrame(), None
273
 
274
- with gr.Blocks(title="Interstitial Cell of Cajal Detection and Quantification Tool", fill_height=True) as demo:
275
  gr.Markdown("<h1>Interstitial Cell of Cajal (ICC) Detection and Quantification</h1>"
276
  "<p>YOLO-based tiled inference with cross-slice NMS. Adjust parameters under <em>Advanced Settings</em>.</p>")
277
 
@@ -296,8 +280,6 @@ with gr.Blocks(title="Interstitial Cell of Cajal Detection and Quantification To
296
  thickness = gr.Slider(1, 8, value=DEFAULT_THICKNESS, step=1, label="Bounding box thickness")
297
  show_labels = gr.Checkbox(value=True, label="Show class + confidence labels")
298
  long_edge = gr.Slider(512, 8192, value=DEFAULT_LONG_EDGE, step=64, label="Optional downscale — max long edge (px)")
299
- # Dynamic class list (populated on load)
300
- # We attempt to load names now; if it fails (cold start), we show an empty multiselect.
301
  try:
302
  _, _names = load_model()
303
  class_list = [v for _, v in sorted(_names.items())]
@@ -312,28 +294,25 @@ with gr.Blocks(title="Interstitial Cell of Cajal Detection and Quantification To
312
  with gr.Column(scale=1):
313
  output_img = gr.Image(label="Detection Result", interactive=False)
314
  detection_summary = gr.Textbox(label="Detection Summary", interactive=False)
 
315
  detections_table = gr.Dataframe(
316
- headers=["class_id", "class_name", "confidence", "x_min", "y_min", "x_max", "y_max"],
317
  label="Detections (table)",
318
  interactive=False,
319
- wrap=True,
320
- height=240,
321
  )
322
- download_csv = gr.DownloadButton(label="Download detections as CSV", value=None, file_name="detections.csv")
 
323
 
324
- # Wire buttons
325
  predict.click(
326
  detect_objects,
327
  inputs=[input_img, conf, iou, slice_w, slice_h, overlap_w, overlap_h, thickness, show_labels, selected_classes, long_edge],
328
  outputs=[output_img, detection_summary, detections_table, download_csv],
329
  )
 
330
  clear.click(
331
  reset_all,
332
  inputs=None,
333
  outputs=[input_img, output_img, detection_summary, detections_table, download_csv],
334
  )
335
 
336
- # Recommended for Spaces stability / concurrency
337
  demo.queue(max_size=16, concurrency_count=1)
338
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=False)
339
-
 
2
  import io
3
  import cv2
4
  import sys
 
5
  import json
6
  import torch
7
  import gradio as gr
 
9
  import pandas as pd
10
  from PIL import Image
11
  from typing import List, Tuple, Optional, Dict
 
12
  from ultralytics import YOLO
13
  import supervision as sv
14
  from huggingface_hub import hf_hub_download
15
+ import spaces
16
+ import tempfile
17
+
18
+ # ------------------------------------------------------------------
19
+ # Silence Ultralytics config dir warning in read-only home directories
20
+ # ------------------------------------------------------------------
21
+ os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
22
 
23
  # -----------------------------
24
  # Defaults / configuration
25
  # -----------------------------
26
+ REPO_ID = "edeler/ICC" # your HF repo with weights
27
+ WEIGHTS_FILENAME = "best.pt" # adjust if different
28
  LOCAL_MODEL_DIR = "./models/ICC"
29
+ EXAMPLES_DIR = "." # scan repo root for demo images
30
 
 
31
  DEFAULT_CONF = 0.25
32
+ DEFAULT_IOU = 0.50
33
  DEFAULT_SLICE_WH = 1024
34
+ DEFAULT_OVERLAP = 128
35
  DEFAULT_THICKNESS = 3
36
+ DEFAULT_LONG_EDGE = 4096
37
 
38
  # -----------------------------
39
  # Torch / device helpers
 
63
  weights_path = hf_hub_download(
64
  repo_id=REPO_ID,
65
  filename=WEIGHTS_FILENAME,
66
+ local_dir=LOCAL_MODEL_DIR, # no symlink arg (deprecated)
 
67
  )
68
  model = YOLO(weights_path)
69
  class_names = model.model.names if hasattr(model, "model") else model.names
 
74
  # Image utilities
75
  # -----------------------------
76
  def ensure_bgr(img: np.ndarray) -> np.ndarray:
 
77
  if img.ndim == 3 and img.shape[2] == 3:
78
  return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
79
  return img
 
103
  overlap_h: int,
104
  device: str,
105
  ) -> sv.Detections:
 
 
 
 
 
106
  @torch.inference_mode()
107
  def callback(tile_bgr: np.ndarray) -> sv.Detections:
 
108
  tile_rgb = cv2.cvtColor(tile_bgr, cv2.COLOR_BGR2RGB)
 
109
  results = model.predict(
110
  source=tile_rgb,
111
  conf=conf,
 
114
  verbose=False,
115
  half=half_precision_available(device),
116
  )
117
+ return sv.Detections.from_ultralytics(results[0])
 
 
118
 
119
  slicer = sv.InferenceSlicer(
120
  callback=callback,
 
123
  overlap_ratio_wh=None,
124
  )
125
  detections = slicer(image_bgr)
 
126
  detections = detections.with_nms(threshold=iou, class_agnostic=False)
127
  return detections
128
 
 
141
  def detections_to_dataframe(det: sv.Detections, names: Dict[int, str]) -> pd.DataFrame:
142
  if len(det) == 0:
143
  return pd.DataFrame(columns=["class_id", "class_name", "confidence", "x_min", "y_min", "x_max", "y_max"])
144
+ xyxy = det.xyxy
145
+ rows = []
146
  for i in range(len(det)):
147
  cls = int(det.class_id[i])
148
+ rows.append({
149
  "class_id": cls,
150
  "class_name": names.get(cls, str(cls)),
151
  "confidence": float(det.confidence[i]),
 
154
  "x_max": float(xyxy[i, 2]),
155
  "y_max": float(xyxy[i, 3]),
156
  })
157
+ return pd.DataFrame(rows)
158
 
159
  def per_class_summary(df: pd.DataFrame) -> str:
160
  if df.empty:
 
167
  # -----------------------------
168
  # Gradio inference function
169
  # -----------------------------
170
+ @spaces.GPU
171
  def detect_objects(
172
  image: np.ndarray,
173
  conf: float,
 
186
  if image is None:
187
  raise ValueError("Please upload or select an image.")
188
 
 
189
  image_bgr = ensure_bgr(image)
190
+ image_bgr, _ = maybe_downscale_long_edge(image_bgr, long_edge)
191
 
 
192
  progress(0.05, desc="Loading model…")
193
  model, names = load_model()
194
  device = get_device()
195
 
 
196
  progress(0.35, desc="Running sliced inference…")
197
  with torch.inference_mode():
198
  detections = run_sliced_inference(
 
207
  device=device,
208
  )
209
 
 
210
  if selected_classes:
211
  allow_ids = {cid for cid, cname in names.items() if cname in set(selected_classes)}
212
  if len(detections) > 0:
213
  mask = np.array([int(c) in allow_ids for c in detections.class_id], dtype=bool)
214
  detections = detections[mask]
215
 
 
216
  labels = make_labels(detections, names, show_labels)
217
 
 
218
  progress(0.65, desc="Annotating…")
219
  annotator = sv.BoxAnnotator(thickness=thickness)
220
  annotated = annotator.annotate(scene=image_bgr.copy(), detections=detections, labels=labels)
 
 
221
  annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
222
  annotated_pil = Image.fromarray(annotated_rgb)
223
 
 
224
  df = detections_to_dataframe(detections, names)
 
225
  summary = per_class_summary(df)
226
 
227
+ # Create a temporary CSV file for robust downloads on older Gradio builds
228
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
229
+ df.to_csv(tmp.name, index=False)
230
+ csv_path = tmp.name
231
+
232
  progress(1.0, desc="Done.")
233
+ return annotated_pil, summary, df, csv_path
234
 
235
  except Exception as e:
 
236
  empty_img = None
237
  empty_df = pd.DataFrame(columns=["class_id", "class_name", "confidence", "x_min", "y_min", "x_max", "y_max"])
238
+ return empty_img, f"Error: {repr(e)}", empty_df, None
239
 
240
  # -----------------------------
241
  # UI / App
242
  # -----------------------------
243
  def discover_examples(root: str) -> List[str]:
244
  exts = {".jpg", ".jpeg", ".png"}
245
+ out = []
246
  try:
247
  for fname in os.listdir(root):
248
  if os.path.splitext(fname.lower())[1] in exts:
249
+ out.append(os.path.join(root, fname))
250
  except Exception:
251
  pass
252
+ return sorted(out)[:8]
 
253
 
254
  def reset_all():
255
  # image, output_img, summary, table, download
256
  return gr.update(value=None), gr.update(value=None), gr.update(value=""), pd.DataFrame(), None
257
 
258
+ with gr.Blocks(title="Interstitial Cell of Cajal Detection and Quantification Tool") as demo:
259
  gr.Markdown("<h1>Interstitial Cell of Cajal (ICC) Detection and Quantification</h1>"
260
  "<p>YOLO-based tiled inference with cross-slice NMS. Adjust parameters under <em>Advanced Settings</em>.</p>")
261
 
 
280
  thickness = gr.Slider(1, 8, value=DEFAULT_THICKNESS, step=1, label="Bounding box thickness")
281
  show_labels = gr.Checkbox(value=True, label="Show class + confidence labels")
282
  long_edge = gr.Slider(512, 8192, value=DEFAULT_LONG_EDGE, step=64, label="Optional downscale — max long edge (px)")
 
 
283
  try:
284
  _, _names = load_model()
285
  class_list = [v for _, v in sorted(_names.items())]
 
294
  with gr.Column(scale=1):
295
  output_img = gr.Image(label="Detection Result", interactive=False)
296
  detection_summary = gr.Textbox(label="Detection Summary", interactive=False)
297
+ # NOTE: remove unsupported 'height' kwarg for older Gradio
298
  detections_table = gr.Dataframe(
 
299
  label="Detections (table)",
300
  interactive=False,
 
 
301
  )
302
+ # Use gr.File for robust downloads across Gradio versions
303
+ download_csv = gr.File(label="Download detections as CSV")
304
 
 
305
  predict.click(
306
  detect_objects,
307
  inputs=[input_img, conf, iou, slice_w, slice_h, overlap_w, overlap_h, thickness, show_labels, selected_classes, long_edge],
308
  outputs=[output_img, detection_summary, detections_table, download_csv],
309
  )
310
+
311
  clear.click(
312
  reset_all,
313
  inputs=None,
314
  outputs=[input_img, output_img, detection_summary, detections_table, download_csv],
315
  )
316
 
 
317
  demo.queue(max_size=16, concurrency_count=1)
318
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=False)