Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,10 +5,10 @@ Gradio app to compare object‑detection models:
|
|
| 5 |
• Roboflow RF‑DETR (Base, Large)
|
| 6 |
• Custom fine‑tuned checkpoints (.pt/.pth upload)
|
| 7 |
|
| 8 |
-
Revision 2025‑04‑19‑
|
| 9 |
-
•
|
| 10 |
-
•
|
| 11 |
-
•
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
|
@@ -31,7 +31,7 @@ from rfdetr.util.coco_classes import COCO_CLASSES
|
|
| 31 |
###############################################################################
|
| 32 |
|
| 33 |
YOLO_MODEL_MAP: Dict[str, str] = {
|
| 34 |
-
#
|
| 35 |
"YOLOv12‑n": "yolo12n.pt",
|
| 36 |
"YOLOv12‑s": "yolo12s.pt",
|
| 37 |
"YOLOv12‑m": "yolo12m.pt",
|
|
@@ -44,7 +44,6 @@ YOLO_MODEL_MAP: Dict[str, str] = {
|
|
| 44 |
"YOLOv11‑x": "yolo11x.pt",
|
| 45 |
}
|
| 46 |
|
| 47 |
-
|
| 48 |
RFDETR_MODEL_MAP = {
|
| 49 |
"RF‑DETR‑Base (29M)": "base",
|
| 50 |
"RF‑DETR‑Large (128M)": "large",
|
|
@@ -58,12 +57,11 @@ ALL_MODELS = list(YOLO_MODEL_MAP.keys()) + list(RFDETR_MODEL_MAP.keys()) + [
|
|
| 58 |
_loaded: Dict[str, object] = {}
|
| 59 |
|
| 60 |
def load_model(choice: str, custom_file: Optional[Path] = None):
|
| 61 |
-
"""Fetch and cache a detector instance for *choice*."""
|
| 62 |
if choice in _loaded:
|
| 63 |
return _loaded[choice]
|
| 64 |
|
| 65 |
if choice in YOLO_MODEL_MAP:
|
| 66 |
-
model = YOLO(YOLO_MODEL_MAP[choice])
|
| 67 |
elif choice in RFDETR_MODEL_MAP:
|
| 68 |
model = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
|
| 69 |
elif choice.startswith("Custom YOLO"):
|
|
@@ -84,8 +82,8 @@ def load_model(choice: str, custom_file: Optional[Path] = None):
|
|
| 84 |
# Inference helpers
|
| 85 |
###############################################################################
|
| 86 |
|
| 87 |
-
BOX_THICKNESS = 2
|
| 88 |
-
BOX_ALPHA = 0.6
|
| 89 |
|
| 90 |
box_annotator = sv.BoxAnnotator(thickness=BOX_THICKNESS)
|
| 91 |
label_annotator = sv.LabelAnnotator()
|
|
@@ -118,7 +116,7 @@ def run_single_inference(model, image: Image.Image, threshold: float) -> Tuple[I
|
|
| 118 |
return Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)), runtime
|
| 119 |
|
| 120 |
###############################################################################
|
| 121 |
-
#
|
| 122 |
###############################################################################
|
| 123 |
|
| 124 |
def compare_models(
|
|
@@ -132,50 +130,54 @@ def compare_models(
|
|
| 132 |
if img.mode != "RGB":
|
| 133 |
img = img.convert("RGB")
|
| 134 |
|
| 135 |
-
total_steps = len(models) * 2
|
| 136 |
progress = gr.Progress()
|
| 137 |
|
| 138 |
-
# ----- Phase 1: preload weights -----
|
| 139 |
detectors: Dict[str, object] = {}
|
| 140 |
for i, name in enumerate(models, 1):
|
| 141 |
try:
|
| 142 |
detectors[name] = load_model(name, custom_file)
|
| 143 |
except Exception as exc:
|
| 144 |
-
detectors[name] = exc
|
| 145 |
progress(i, total=total_steps, desc=f"Loading {name}")
|
| 146 |
|
| 147 |
-
|
| 148 |
-
results: List[Image.Image] = []
|
| 149 |
legends: Dict[str, str] = {}
|
| 150 |
|
| 151 |
for j, name in enumerate(models, 1):
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
if isinstance(
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
| 160 |
continue
|
| 161 |
try:
|
| 162 |
-
annotated, latency = run_single_inference(
|
| 163 |
-
|
|
|
|
| 164 |
legends[name] = f"{latency*1000:.1f} ms"
|
| 165 |
except Exception as exc:
|
| 166 |
-
|
|
|
|
|
|
|
| 167 |
legends[name] = f"ERROR: {str(exc).splitlines()[0][:120]}"
|
| 168 |
-
progress(
|
| 169 |
|
| 170 |
-
yield results, legends
|
| 171 |
|
| 172 |
###############################################################################
|
| 173 |
-
#
|
| 174 |
###############################################################################
|
| 175 |
|
| 176 |
def build_demo():
|
| 177 |
with gr.Blocks(title="CV Model Comparison") as demo:
|
| 178 |
-
gr.Markdown(
|
|
|
|
|
|
|
| 179 |
|
| 180 |
with gr.Row():
|
| 181 |
sel_models = gr.CheckboxGroup(ALL_MODELS, value=["YOLOv12‑n"], label="Models")
|
|
@@ -188,8 +190,9 @@ def build_demo():
|
|
| 188 |
gallery = gr.Gallery(label="Results", columns=2, height="auto")
|
| 189 |
legend_out = gr.JSON(label="Latency / status by model")
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
|
|
|
| 193 |
|
| 194 |
return demo
|
| 195 |
|
|
|
|
| 5 |
• Roboflow RF‑DETR (Base, Large)
|
| 6 |
• Custom fine‑tuned checkpoints (.pt/.pth upload)
|
| 7 |
|
| 8 |
+
Revision 2025‑04‑19‑e:
|
| 9 |
+
• Gallery items now carry captions so you can see which model produced which image (and latency).
|
| 10 |
+
• Captions display as "Model (xx ms)" or error status.
|
| 11 |
+
• No other behaviour changed: pre‑loading, progress bar, thin semi‑transparent boxes, concise error labels.
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
|
|
|
| 31 |
###############################################################################
|
| 32 |
|
| 33 |
YOLO_MODEL_MAP: Dict[str, str] = {
|
| 34 |
+
# Ultralytics filenames omit the "v"
|
| 35 |
"YOLOv12‑n": "yolo12n.pt",
|
| 36 |
"YOLOv12‑s": "yolo12s.pt",
|
| 37 |
"YOLOv12‑m": "yolo12m.pt",
|
|
|
|
| 44 |
"YOLOv11‑x": "yolo11x.pt",
|
| 45 |
}
|
| 46 |
|
|
|
|
| 47 |
RFDETR_MODEL_MAP = {
|
| 48 |
"RF‑DETR‑Base (29M)": "base",
|
| 49 |
"RF‑DETR‑Large (128M)": "large",
|
|
|
|
| 57 |
_loaded: Dict[str, object] = {}
|
| 58 |
|
| 59 |
def load_model(choice: str, custom_file: Optional[Path] = None):
|
|
|
|
| 60 |
if choice in _loaded:
|
| 61 |
return _loaded[choice]
|
| 62 |
|
| 63 |
if choice in YOLO_MODEL_MAP:
|
| 64 |
+
model = YOLO(YOLO_MODEL_MAP[choice])
|
| 65 |
elif choice in RFDETR_MODEL_MAP:
|
| 66 |
model = RFDETRBase() if RFDETR_MODEL_MAP[choice] == "base" else RFDETRLarge()
|
| 67 |
elif choice.startswith("Custom YOLO"):
|
|
|
|
| 82 |
# Inference helpers
|
| 83 |
###############################################################################
|
| 84 |
|
| 85 |
+
BOX_THICKNESS = 2
|
| 86 |
+
BOX_ALPHA = 0.6
|
| 87 |
|
| 88 |
box_annotator = sv.BoxAnnotator(thickness=BOX_THICKNESS)
|
| 89 |
label_annotator = sv.LabelAnnotator()
|
|
|
|
| 116 |
return Image.fromarray(cv2.cvtColor(blended, cv2.COLOR_BGR2RGB)), runtime
|
| 117 |
|
| 118 |
###############################################################################
|
| 119 |
+
# Callback with progress & captions
|
| 120 |
###############################################################################
|
| 121 |
|
| 122 |
def compare_models(
|
|
|
|
| 130 |
if img.mode != "RGB":
|
| 131 |
img = img.convert("RGB")
|
| 132 |
|
| 133 |
+
total_steps = len(models) * 2
|
| 134 |
progress = gr.Progress()
|
| 135 |
|
|
|
|
| 136 |
detectors: Dict[str, object] = {}
|
| 137 |
for i, name in enumerate(models, 1):
|
| 138 |
try:
|
| 139 |
detectors[name] = load_model(name, custom_file)
|
| 140 |
except Exception as exc:
|
| 141 |
+
detectors[name] = exc
|
| 142 |
progress(i, total=total_steps, desc=f"Loading {name}")
|
| 143 |
|
| 144 |
+
results: List[Tuple[Image.Image, str]] = []
|
|
|
|
| 145 |
legends: Dict[str, str] = {}
|
| 146 |
|
| 147 |
for j, name in enumerate(models, 1):
|
| 148 |
+
item = detectors[name]
|
| 149 |
+
step = len(models) + j
|
| 150 |
+
if isinstance(item, Exception):
|
| 151 |
+
placeholder = Image.new("RGB", img.size, (40, 40, 40))
|
| 152 |
+
emsg = str(item)
|
| 153 |
+
caption = f"{name} – Unavailable" if "No such file" in emsg or "not found" in emsg else f"{name} – ERROR"
|
| 154 |
+
results.append((placeholder, caption))
|
| 155 |
+
legends[name] = caption
|
| 156 |
+
progress(step, total=total_steps, desc=f"Skipped {name}")
|
| 157 |
continue
|
| 158 |
try:
|
| 159 |
+
annotated, latency = run_single_inference(item, img, threshold)
|
| 160 |
+
caption = f"{name} ({latency*1000:.1f} ms)"
|
| 161 |
+
results.append((annotated, caption))
|
| 162 |
legends[name] = f"{latency*1000:.1f} ms"
|
| 163 |
except Exception as exc:
|
| 164 |
+
placeholder = Image.new("RGB", img.size, (40, 40, 40))
|
| 165 |
+
caption = f"{name} – ERROR"
|
| 166 |
+
results.append((placeholder, caption))
|
| 167 |
legends[name] = f"ERROR: {str(exc).splitlines()[0][:120]}"
|
| 168 |
+
progress(step, total=total_steps, desc=f"Inference {name}")
|
| 169 |
|
| 170 |
+
yield results, legends
|
| 171 |
|
| 172 |
###############################################################################
|
| 173 |
+
# UI
|
| 174 |
###############################################################################
|
| 175 |
|
| 176 |
def build_demo():
|
| 177 |
with gr.Blocks(title="CV Model Comparison") as demo:
|
| 178 |
+
gr.Markdown(
|
| 179 |
+
"""# 🔍 Compare Object‑Detection Models\nUpload an image, select detectors, and click **Run Inference**.\nCaptions beneath each result show which model (and latency) generated it."""
|
| 180 |
+
)
|
| 181 |
|
| 182 |
with gr.Row():
|
| 183 |
sel_models = gr.CheckboxGroup(ALL_MODELS, value=["YOLOv12‑n"], label="Models")
|
|
|
|
| 190 |
gallery = gr.Gallery(label="Results", columns=2, height="auto")
|
| 191 |
legend_out = gr.JSON(label="Latency / status by model")
|
| 192 |
|
| 193 |
+
gr.Button("Run Inference", variant="primary").click(
|
| 194 |
+
compare_models, [sel_models, img_in, conf_slider, ckpt_file], [gallery, legend_out]
|
| 195 |
+
)
|
| 196 |
|
| 197 |
return demo
|
| 198 |
|