Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,6 @@ import os
|
|
| 2 |
import io
|
| 3 |
import cv2
|
| 4 |
import sys
|
| 5 |
-
import math
|
| 6 |
import json
|
| 7 |
import torch
|
| 8 |
import gradio as gr
|
|
@@ -10,27 +9,31 @@ import numpy as np
|
|
| 10 |
import pandas as pd
|
| 11 |
from PIL import Image
|
| 12 |
from typing import List, Tuple, Optional, Dict
|
| 13 |
-
|
| 14 |
from ultralytics import YOLO
|
| 15 |
import supervision as sv
|
| 16 |
from huggingface_hub import hf_hub_download
|
| 17 |
-
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
# -----------------------------
|
| 20 |
# Defaults / configuration
|
| 21 |
# -----------------------------
|
| 22 |
-
REPO_ID = "edeler/ICC" #
|
| 23 |
-
WEIGHTS_FILENAME = "best.pt" #
|
| 24 |
LOCAL_MODEL_DIR = "./models/ICC"
|
| 25 |
-
EXAMPLES_DIR = "." # scan
|
| 26 |
|
| 27 |
-
# Reasonable defaults (tunable in UI)
|
| 28 |
DEFAULT_CONF = 0.25
|
| 29 |
-
DEFAULT_IOU = 0.50
|
| 30 |
DEFAULT_SLICE_WH = 1024
|
| 31 |
-
DEFAULT_OVERLAP = 128
|
| 32 |
DEFAULT_THICKNESS = 3
|
| 33 |
-
DEFAULT_LONG_EDGE = 4096
|
| 34 |
|
| 35 |
# -----------------------------
|
| 36 |
# Torch / device helpers
|
|
@@ -60,8 +63,7 @@ def load_model() -> Tuple[YOLO, Dict[int, str]]:
|
|
| 60 |
weights_path = hf_hub_download(
|
| 61 |
repo_id=REPO_ID,
|
| 62 |
filename=WEIGHTS_FILENAME,
|
| 63 |
-
local_dir=LOCAL_MODEL_DIR,
|
| 64 |
-
local_dir_use_symlinks=False, # safer in Spaces
|
| 65 |
)
|
| 66 |
model = YOLO(weights_path)
|
| 67 |
class_names = model.model.names if hasattr(model, "model") else model.names
|
|
@@ -72,7 +74,6 @@ def load_model() -> Tuple[YOLO, Dict[int, str]]:
|
|
| 72 |
# Image utilities
|
| 73 |
# -----------------------------
|
| 74 |
def ensure_bgr(img: np.ndarray) -> np.ndarray:
|
| 75 |
-
# Gradio provides RGB; OpenCV expects BGR
|
| 76 |
if img.ndim == 3 and img.shape[2] == 3:
|
| 77 |
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 78 |
return img
|
|
@@ -102,16 +103,9 @@ def run_sliced_inference(
|
|
| 102 |
overlap_h: int,
|
| 103 |
device: str,
|
| 104 |
) -> sv.Detections:
|
| 105 |
-
"""
|
| 106 |
-
Uses supervision.InferenceSlicer to run model across tiles,
|
| 107 |
-
then returns all detections (to be merged via cross-slice NMS).
|
| 108 |
-
"""
|
| 109 |
-
# inner callback called by slicer
|
| 110 |
@torch.inference_mode()
|
| 111 |
def callback(tile_bgr: np.ndarray) -> sv.Detections:
|
| 112 |
-
# Ultralytics expects RGB
|
| 113 |
tile_rgb = cv2.cvtColor(tile_bgr, cv2.COLOR_BGR2RGB)
|
| 114 |
-
# Predict with thresholds at model level (faster than filtering post-hoc)
|
| 115 |
results = model.predict(
|
| 116 |
source=tile_rgb,
|
| 117 |
conf=conf,
|
|
@@ -120,9 +114,7 @@ def run_sliced_inference(
|
|
| 120 |
verbose=False,
|
| 121 |
half=half_precision_available(device),
|
| 122 |
)
|
| 123 |
-
|
| 124 |
-
det = sv.Detections.from_ultralytics(res)
|
| 125 |
-
return det
|
| 126 |
|
| 127 |
slicer = sv.InferenceSlicer(
|
| 128 |
callback=callback,
|
|
@@ -131,7 +123,6 @@ def run_sliced_inference(
|
|
| 131 |
overlap_ratio_wh=None,
|
| 132 |
)
|
| 133 |
detections = slicer(image_bgr)
|
| 134 |
-
# Cross-slice NMS to merge duplicates at tile seams
|
| 135 |
detections = detections.with_nms(threshold=iou, class_agnostic=False)
|
| 136 |
return detections
|
| 137 |
|
|
@@ -150,11 +141,11 @@ def make_labels(det: sv.Detections, names: Dict[int, str], show_labels: bool) ->
|
|
| 150 |
def detections_to_dataframe(det: sv.Detections, names: Dict[int, str]) -> pd.DataFrame:
|
| 151 |
if len(det) == 0:
|
| 152 |
return pd.DataFrame(columns=["class_id", "class_name", "confidence", "x_min", "y_min", "x_max", "y_max"])
|
| 153 |
-
xyxy = det.xyxy
|
| 154 |
-
|
| 155 |
for i in range(len(det)):
|
| 156 |
cls = int(det.class_id[i])
|
| 157 |
-
|
| 158 |
"class_id": cls,
|
| 159 |
"class_name": names.get(cls, str(cls)),
|
| 160 |
"confidence": float(det.confidence[i]),
|
|
@@ -163,7 +154,7 @@ def detections_to_dataframe(det: sv.Detections, names: Dict[int, str]) -> pd.Dat
|
|
| 163 |
"x_max": float(xyxy[i, 2]),
|
| 164 |
"y_max": float(xyxy[i, 3]),
|
| 165 |
})
|
| 166 |
-
return pd.DataFrame(
|
| 167 |
|
| 168 |
def per_class_summary(df: pd.DataFrame) -> str:
|
| 169 |
if df.empty:
|
|
@@ -176,7 +167,7 @@ def per_class_summary(df: pd.DataFrame) -> str:
|
|
| 176 |
# -----------------------------
|
| 177 |
# Gradio inference function
|
| 178 |
# -----------------------------
|
| 179 |
-
@spaces.GPU
|
| 180 |
def detect_objects(
|
| 181 |
image: np.ndarray,
|
| 182 |
conf: float,
|
|
@@ -195,16 +186,13 @@ def detect_objects(
|
|
| 195 |
if image is None:
|
| 196 |
raise ValueError("Please upload or select an image.")
|
| 197 |
|
| 198 |
-
# Prepare image (BGR) and optional downscale
|
| 199 |
image_bgr = ensure_bgr(image)
|
| 200 |
-
image_bgr,
|
| 201 |
|
| 202 |
-
# Load model + names lazily
|
| 203 |
progress(0.05, desc="Loading model…")
|
| 204 |
model, names = load_model()
|
| 205 |
device = get_device()
|
| 206 |
|
| 207 |
-
# Inference (sliced)
|
| 208 |
progress(0.35, desc="Running sliced inference…")
|
| 209 |
with torch.inference_mode():
|
| 210 |
detections = run_sliced_inference(
|
|
@@ -219,59 +207,55 @@ def detect_objects(
|
|
| 219 |
device=device,
|
| 220 |
)
|
| 221 |
|
| 222 |
-
# Optional class filtering (by names)
|
| 223 |
if selected_classes:
|
| 224 |
allow_ids = {cid for cid, cname in names.items() if cname in set(selected_classes)}
|
| 225 |
if len(detections) > 0:
|
| 226 |
mask = np.array([int(c) in allow_ids for c in detections.class_id], dtype=bool)
|
| 227 |
detections = detections[mask]
|
| 228 |
|
| 229 |
-
# Create labels (optional)
|
| 230 |
labels = make_labels(detections, names, show_labels)
|
| 231 |
|
| 232 |
-
# Annotate
|
| 233 |
progress(0.65, desc="Annotating…")
|
| 234 |
annotator = sv.BoxAnnotator(thickness=thickness)
|
| 235 |
annotated = annotator.annotate(scene=image_bgr.copy(), detections=detections, labels=labels)
|
| 236 |
-
|
| 237 |
-
# Convert to RGB for display
|
| 238 |
annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
|
| 239 |
annotated_pil = Image.fromarray(annotated_rgb)
|
| 240 |
|
| 241 |
-
# Tabular output + downloadable CSV
|
| 242 |
df = detections_to_dataframe(detections, names)
|
| 243 |
-
csv_bytes = df.to_csv(index=False).encode("utf-8")
|
| 244 |
summary = per_class_summary(df)
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
progress(1.0, desc="Done.")
|
| 247 |
-
return annotated_pil, summary, df,
|
| 248 |
|
| 249 |
except Exception as e:
|
| 250 |
-
# Return clean error in summary field
|
| 251 |
empty_img = None
|
| 252 |
empty_df = pd.DataFrame(columns=["class_id", "class_name", "confidence", "x_min", "y_min", "x_max", "y_max"])
|
| 253 |
-
return empty_img, f"Error: {repr(e)}", empty_df,
|
| 254 |
|
| 255 |
# -----------------------------
|
| 256 |
# UI / App
|
| 257 |
# -----------------------------
|
| 258 |
def discover_examples(root: str) -> List[str]:
|
| 259 |
exts = {".jpg", ".jpeg", ".png"}
|
| 260 |
-
|
| 261 |
try:
|
| 262 |
for fname in os.listdir(root):
|
| 263 |
if os.path.splitext(fname.lower())[1] in exts:
|
| 264 |
-
|
| 265 |
except Exception:
|
| 266 |
pass
|
| 267 |
-
|
| 268 |
-
return sorted(paths)[:8]
|
| 269 |
|
| 270 |
def reset_all():
|
| 271 |
# image, output_img, summary, table, download
|
| 272 |
return gr.update(value=None), gr.update(value=None), gr.update(value=""), pd.DataFrame(), None
|
| 273 |
|
| 274 |
-
with gr.Blocks(title="Interstitial Cell of Cajal Detection and Quantification Tool"
|
| 275 |
gr.Markdown("<h1>Interstitial Cell of Cajal (ICC) Detection and Quantification</h1>"
|
| 276 |
"<p>YOLO-based tiled inference with cross-slice NMS. Adjust parameters under <em>Advanced Settings</em>.</p>")
|
| 277 |
|
|
@@ -296,8 +280,6 @@ with gr.Blocks(title="Interstitial Cell of Cajal Detection and Quantification To
|
|
| 296 |
thickness = gr.Slider(1, 8, value=DEFAULT_THICKNESS, step=1, label="Bounding box thickness")
|
| 297 |
show_labels = gr.Checkbox(value=True, label="Show class + confidence labels")
|
| 298 |
long_edge = gr.Slider(512, 8192, value=DEFAULT_LONG_EDGE, step=64, label="Optional downscale — max long edge (px)")
|
| 299 |
-
# Dynamic class list (populated on load)
|
| 300 |
-
# We attempt to load names now; if it fails (cold start), we show an empty multiselect.
|
| 301 |
try:
|
| 302 |
_, _names = load_model()
|
| 303 |
class_list = [v for _, v in sorted(_names.items())]
|
|
@@ -312,28 +294,25 @@ with gr.Blocks(title="Interstitial Cell of Cajal Detection and Quantification To
|
|
| 312 |
with gr.Column(scale=1):
|
| 313 |
output_img = gr.Image(label="Detection Result", interactive=False)
|
| 314 |
detection_summary = gr.Textbox(label="Detection Summary", interactive=False)
|
|
|
|
| 315 |
detections_table = gr.Dataframe(
|
| 316 |
-
headers=["class_id", "class_name", "confidence", "x_min", "y_min", "x_max", "y_max"],
|
| 317 |
label="Detections (table)",
|
| 318 |
interactive=False,
|
| 319 |
-
wrap=True,
|
| 320 |
-
height=240,
|
| 321 |
)
|
| 322 |
-
|
|
|
|
| 323 |
|
| 324 |
-
# Wire buttons
|
| 325 |
predict.click(
|
| 326 |
detect_objects,
|
| 327 |
inputs=[input_img, conf, iou, slice_w, slice_h, overlap_w, overlap_h, thickness, show_labels, selected_classes, long_edge],
|
| 328 |
outputs=[output_img, detection_summary, detections_table, download_csv],
|
| 329 |
)
|
|
|
|
| 330 |
clear.click(
|
| 331 |
reset_all,
|
| 332 |
inputs=None,
|
| 333 |
outputs=[input_img, output_img, detection_summary, detections_table, download_csv],
|
| 334 |
)
|
| 335 |
|
| 336 |
-
# Recommended for Spaces stability / concurrency
|
| 337 |
demo.queue(max_size=16, concurrency_count=1)
|
| 338 |
demo.launch(server_name="0.0.0.0", server_port=7860, debug=False)
|
| 339 |
-
|
|
|
|
| 2 |
import io
|
| 3 |
import cv2
|
| 4 |
import sys
|
|
|
|
| 5 |
import json
|
| 6 |
import torch
|
| 7 |
import gradio as gr
|
|
|
|
| 9 |
import pandas as pd
|
| 10 |
from PIL import Image
|
| 11 |
from typing import List, Tuple, Optional, Dict
|
|
|
|
| 12 |
from ultralytics import YOLO
|
| 13 |
import supervision as sv
|
| 14 |
from huggingface_hub import hf_hub_download
|
| 15 |
+
import spaces
|
| 16 |
+
import tempfile
|
| 17 |
+
|
| 18 |
+
# ------------------------------------------------------------------
|
| 19 |
+
# Silence Ultralytics config dir warning in read-only home directories
|
| 20 |
+
# ------------------------------------------------------------------
|
| 21 |
+
os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
|
| 22 |
|
| 23 |
# -----------------------------
|
| 24 |
# Defaults / configuration
|
| 25 |
# -----------------------------
|
| 26 |
+
REPO_ID = "edeler/ICC" # your HF repo with weights
|
| 27 |
+
WEIGHTS_FILENAME = "best.pt" # adjust if different
|
| 28 |
LOCAL_MODEL_DIR = "./models/ICC"
|
| 29 |
+
EXAMPLES_DIR = "." # scan repo root for demo images
|
| 30 |
|
|
|
|
| 31 |
DEFAULT_CONF = 0.25
|
| 32 |
+
DEFAULT_IOU = 0.50
|
| 33 |
DEFAULT_SLICE_WH = 1024
|
| 34 |
+
DEFAULT_OVERLAP = 128
|
| 35 |
DEFAULT_THICKNESS = 3
|
| 36 |
+
DEFAULT_LONG_EDGE = 4096
|
| 37 |
|
| 38 |
# -----------------------------
|
| 39 |
# Torch / device helpers
|
|
|
|
| 63 |
weights_path = hf_hub_download(
|
| 64 |
repo_id=REPO_ID,
|
| 65 |
filename=WEIGHTS_FILENAME,
|
| 66 |
+
local_dir=LOCAL_MODEL_DIR, # no symlink arg (deprecated)
|
|
|
|
| 67 |
)
|
| 68 |
model = YOLO(weights_path)
|
| 69 |
class_names = model.model.names if hasattr(model, "model") else model.names
|
|
|
|
| 74 |
# Image utilities
|
| 75 |
# -----------------------------
|
| 76 |
def ensure_bgr(img: np.ndarray) -> np.ndarray:
|
|
|
|
| 77 |
if img.ndim == 3 and img.shape[2] == 3:
|
| 78 |
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 79 |
return img
|
|
|
|
| 103 |
overlap_h: int,
|
| 104 |
device: str,
|
| 105 |
) -> sv.Detections:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
@torch.inference_mode()
|
| 107 |
def callback(tile_bgr: np.ndarray) -> sv.Detections:
|
|
|
|
| 108 |
tile_rgb = cv2.cvtColor(tile_bgr, cv2.COLOR_BGR2RGB)
|
|
|
|
| 109 |
results = model.predict(
|
| 110 |
source=tile_rgb,
|
| 111 |
conf=conf,
|
|
|
|
| 114 |
verbose=False,
|
| 115 |
half=half_precision_available(device),
|
| 116 |
)
|
| 117 |
+
return sv.Detections.from_ultralytics(results[0])
|
|
|
|
|
|
|
| 118 |
|
| 119 |
slicer = sv.InferenceSlicer(
|
| 120 |
callback=callback,
|
|
|
|
| 123 |
overlap_ratio_wh=None,
|
| 124 |
)
|
| 125 |
detections = slicer(image_bgr)
|
|
|
|
| 126 |
detections = detections.with_nms(threshold=iou, class_agnostic=False)
|
| 127 |
return detections
|
| 128 |
|
|
|
|
| 141 |
def detections_to_dataframe(det: sv.Detections, names: Dict[int, str]) -> pd.DataFrame:
|
| 142 |
if len(det) == 0:
|
| 143 |
return pd.DataFrame(columns=["class_id", "class_name", "confidence", "x_min", "y_min", "x_max", "y_max"])
|
| 144 |
+
xyxy = det.xyxy
|
| 145 |
+
rows = []
|
| 146 |
for i in range(len(det)):
|
| 147 |
cls = int(det.class_id[i])
|
| 148 |
+
rows.append({
|
| 149 |
"class_id": cls,
|
| 150 |
"class_name": names.get(cls, str(cls)),
|
| 151 |
"confidence": float(det.confidence[i]),
|
|
|
|
| 154 |
"x_max": float(xyxy[i, 2]),
|
| 155 |
"y_max": float(xyxy[i, 3]),
|
| 156 |
})
|
| 157 |
+
return pd.DataFrame(rows)
|
| 158 |
|
| 159 |
def per_class_summary(df: pd.DataFrame) -> str:
|
| 160 |
if df.empty:
|
|
|
|
| 167 |
# -----------------------------
|
| 168 |
# Gradio inference function
|
| 169 |
# -----------------------------
|
| 170 |
+
@spaces.GPU
|
| 171 |
def detect_objects(
|
| 172 |
image: np.ndarray,
|
| 173 |
conf: float,
|
|
|
|
| 186 |
if image is None:
|
| 187 |
raise ValueError("Please upload or select an image.")
|
| 188 |
|
|
|
|
| 189 |
image_bgr = ensure_bgr(image)
|
| 190 |
+
image_bgr, _ = maybe_downscale_long_edge(image_bgr, long_edge)
|
| 191 |
|
|
|
|
| 192 |
progress(0.05, desc="Loading model…")
|
| 193 |
model, names = load_model()
|
| 194 |
device = get_device()
|
| 195 |
|
|
|
|
| 196 |
progress(0.35, desc="Running sliced inference…")
|
| 197 |
with torch.inference_mode():
|
| 198 |
detections = run_sliced_inference(
|
|
|
|
| 207 |
device=device,
|
| 208 |
)
|
| 209 |
|
|
|
|
| 210 |
if selected_classes:
|
| 211 |
allow_ids = {cid for cid, cname in names.items() if cname in set(selected_classes)}
|
| 212 |
if len(detections) > 0:
|
| 213 |
mask = np.array([int(c) in allow_ids for c in detections.class_id], dtype=bool)
|
| 214 |
detections = detections[mask]
|
| 215 |
|
|
|
|
| 216 |
labels = make_labels(detections, names, show_labels)
|
| 217 |
|
|
|
|
| 218 |
progress(0.65, desc="Annotating…")
|
| 219 |
annotator = sv.BoxAnnotator(thickness=thickness)
|
| 220 |
annotated = annotator.annotate(scene=image_bgr.copy(), detections=detections, labels=labels)
|
|
|
|
|
|
|
| 221 |
annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
|
| 222 |
annotated_pil = Image.fromarray(annotated_rgb)
|
| 223 |
|
|
|
|
| 224 |
df = detections_to_dataframe(detections, names)
|
|
|
|
| 225 |
summary = per_class_summary(df)
|
| 226 |
|
| 227 |
+
# Create a temporary CSV file for robust downloads on older Gradio builds
|
| 228 |
+
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
|
| 229 |
+
df.to_csv(tmp.name, index=False)
|
| 230 |
+
csv_path = tmp.name
|
| 231 |
+
|
| 232 |
progress(1.0, desc="Done.")
|
| 233 |
+
return annotated_pil, summary, df, csv_path
|
| 234 |
|
| 235 |
except Exception as e:
|
|
|
|
| 236 |
empty_img = None
|
| 237 |
empty_df = pd.DataFrame(columns=["class_id", "class_name", "confidence", "x_min", "y_min", "x_max", "y_max"])
|
| 238 |
+
return empty_img, f"Error: {repr(e)}", empty_df, None
|
| 239 |
|
| 240 |
# -----------------------------
|
| 241 |
# UI / App
|
| 242 |
# -----------------------------
|
| 243 |
def discover_examples(root: str) -> List[str]:
|
| 244 |
exts = {".jpg", ".jpeg", ".png"}
|
| 245 |
+
out = []
|
| 246 |
try:
|
| 247 |
for fname in os.listdir(root):
|
| 248 |
if os.path.splitext(fname.lower())[1] in exts:
|
| 249 |
+
out.append(os.path.join(root, fname))
|
| 250 |
except Exception:
|
| 251 |
pass
|
| 252 |
+
return sorted(out)[:8]
|
|
|
|
| 253 |
|
| 254 |
def reset_all():
|
| 255 |
# image, output_img, summary, table, download
|
| 256 |
return gr.update(value=None), gr.update(value=None), gr.update(value=""), pd.DataFrame(), None
|
| 257 |
|
| 258 |
+
with gr.Blocks(title="Interstitial Cell of Cajal Detection and Quantification Tool") as demo:
|
| 259 |
gr.Markdown("<h1>Interstitial Cell of Cajal (ICC) Detection and Quantification</h1>"
|
| 260 |
"<p>YOLO-based tiled inference with cross-slice NMS. Adjust parameters under <em>Advanced Settings</em>.</p>")
|
| 261 |
|
|
|
|
| 280 |
thickness = gr.Slider(1, 8, value=DEFAULT_THICKNESS, step=1, label="Bounding box thickness")
|
| 281 |
show_labels = gr.Checkbox(value=True, label="Show class + confidence labels")
|
| 282 |
long_edge = gr.Slider(512, 8192, value=DEFAULT_LONG_EDGE, step=64, label="Optional downscale — max long edge (px)")
|
|
|
|
|
|
|
| 283 |
try:
|
| 284 |
_, _names = load_model()
|
| 285 |
class_list = [v for _, v in sorted(_names.items())]
|
|
|
|
| 294 |
with gr.Column(scale=1):
|
| 295 |
output_img = gr.Image(label="Detection Result", interactive=False)
|
| 296 |
detection_summary = gr.Textbox(label="Detection Summary", interactive=False)
|
| 297 |
+
# NOTE: remove unsupported 'height' kwarg for older Gradio
|
| 298 |
detections_table = gr.Dataframe(
|
|
|
|
| 299 |
label="Detections (table)",
|
| 300 |
interactive=False,
|
|
|
|
|
|
|
| 301 |
)
|
| 302 |
+
# Use gr.File for robust downloads across Gradio versions
|
| 303 |
+
download_csv = gr.File(label="Download detections as CSV")
|
| 304 |
|
|
|
|
| 305 |
predict.click(
|
| 306 |
detect_objects,
|
| 307 |
inputs=[input_img, conf, iou, slice_w, slice_h, overlap_w, overlap_h, thickness, show_labels, selected_classes, long_edge],
|
| 308 |
outputs=[output_img, detection_summary, detections_table, download_csv],
|
| 309 |
)
|
| 310 |
+
|
| 311 |
clear.click(
|
| 312 |
reset_all,
|
| 313 |
inputs=None,
|
| 314 |
outputs=[input_img, output_img, detection_summary, detections_table, download_csv],
|
| 315 |
)
|
| 316 |
|
|
|
|
| 317 |
demo.queue(max_size=16, concurrency_count=1)
|
| 318 |
demo.launch(server_name="0.0.0.0", server_port=7860, debug=False)
|
|
|