Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
-
RF-DETR Object Counter — Gradio app for Hugging Face Spaces.
|
| 3 |
-
Counts people,
|
| 4 |
-
RF-DETR Medium + ByteTrack
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
import os
|
|
@@ -12,47 +13,47 @@ import cv2
|
|
| 12 |
import gradio as gr
|
| 13 |
import numpy as np
|
| 14 |
import supervision as sv
|
|
|
|
| 15 |
from rfdetr import RFDETRMedium
|
| 16 |
-
from rfdetr.assets.coco_classes import COCO_CLASSES
|
| 17 |
|
| 18 |
# ---------------------------------------------------------------------------
|
| 19 |
-
# Target classes (COCO indices)
|
| 20 |
# ---------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
TARGET_CLASSES = {
|
| 22 |
0: "person",
|
| 23 |
-
1: "bicycle",
|
| 24 |
2: "car",
|
| 25 |
7: "truck",
|
| 26 |
-
# animals
|
| 27 |
-
14: "bird",
|
| 28 |
-
15: "cat",
|
| 29 |
16: "dog",
|
| 30 |
-
17: "horse",
|
| 31 |
-
18: "sheep",
|
| 32 |
19: "cow",
|
| 33 |
-
20: "elephant",
|
| 34 |
-
21: "bear",
|
| 35 |
-
22: "zebra",
|
| 36 |
-
23: "giraffe",
|
| 37 |
}
|
| 38 |
TARGET_IDS = list(TARGET_CLASSES.keys())
|
| 39 |
|
| 40 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
CLASS_COLORS = {
|
| 42 |
-
"person":
|
| 43 |
-
"
|
| 44 |
-
"
|
| 45 |
-
"
|
| 46 |
-
"
|
| 47 |
-
"
|
| 48 |
-
"
|
| 49 |
-
"horse": (245, 120, 120),
|
| 50 |
-
"sheep": (220, 220, 220),
|
| 51 |
-
"cow": (140, 90, 60),
|
| 52 |
-
"elephant": (160, 160, 200),
|
| 53 |
-
"bear": (90, 60, 30),
|
| 54 |
-
"zebra": (40, 40, 40),
|
| 55 |
-
"giraffe": (220, 180, 90),
|
| 56 |
}
|
| 57 |
|
| 58 |
# Example video lives next to app.py
|
|
@@ -60,15 +61,26 @@ APP_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
| 60 |
EXAMPLE_VIDEO = os.path.join(APP_DIR, "example.mp4")
|
| 61 |
|
| 62 |
# ---------------------------------------------------------------------------
|
| 63 |
-
# Load model
|
| 64 |
# ---------------------------------------------------------------------------
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
try:
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
print("Model ready.")
|
| 73 |
|
| 74 |
# Annotators
|
|
@@ -77,12 +89,12 @@ LABEL_ANNOTATOR = sv.LabelAnnotator(text_scale=0.45, text_thickness=1, text_padd
|
|
| 77 |
|
| 78 |
|
| 79 |
def draw_counter_panel(frame: np.ndarray, counts: dict) -> np.ndarray:
|
| 80 |
-
"""Translucent
|
| 81 |
active = [(name, n) for name, n in counts.items() if n > 0]
|
| 82 |
if not active:
|
| 83 |
active = [("No targets yet", 0)]
|
| 84 |
|
| 85 |
-
panel_w =
|
| 86 |
panel_h = 40 + 22 * len(active)
|
| 87 |
overlay = frame.copy()
|
| 88 |
cv2.rectangle(overlay, (12, 12), (12 + panel_w, 12 + panel_h), (20, 20, 20), -1)
|
|
@@ -95,13 +107,15 @@ def draw_counter_panel(frame: np.ndarray, counts: dict) -> np.ndarray:
|
|
| 95 |
for name, n in active:
|
| 96 |
color = CLASS_COLORS.get(name, (200, 200, 200))
|
| 97 |
cv2.circle(frame, (28, y - 5), 5, color, -1)
|
| 98 |
-
|
|
|
|
| 99 |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (240, 240, 240), 1, cv2.LINE_AA)
|
| 100 |
y += 22
|
| 101 |
return frame
|
| 102 |
|
| 103 |
|
| 104 |
-
def process_video(video_path, confidence, frame_stride,
|
|
|
|
| 105 |
if video_path is None:
|
| 106 |
return None, "⚠️ Please upload a video first.", []
|
| 107 |
|
|
@@ -109,15 +123,19 @@ def process_video(video_path, confidence, frame_stride, progress=gr.Progress(tra
|
|
| 109 |
frame_gen = sv.get_video_frames_generator(video_path)
|
| 110 |
tracker = sv.ByteTrack(frame_rate=int(video_info.fps))
|
| 111 |
|
| 112 |
-
unique_ids = defaultdict(set)
|
| 113 |
last_detections = sv.Detections.empty()
|
| 114 |
|
| 115 |
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
| 116 |
|
| 117 |
with sv.VideoSink(target_path=out_path, video_info=video_info) as sink:
|
| 118 |
-
for i, frame in enumerate(progress.tqdm(
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
if i % frame_stride == 0:
|
| 122 |
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 123 |
detections = MODEL.predict(rgb, threshold=confidence)
|
|
@@ -130,7 +148,7 @@ def process_video(video_path, confidence, frame_stride, progress=gr.Progress(tra
|
|
| 130 |
detections = tracker.update_with_detections(detections)
|
| 131 |
last_detections = detections
|
| 132 |
|
| 133 |
-
# Register unique IDs per class
|
| 134 |
for cid, tid in zip(detections.class_id, detections.tracker_id):
|
| 135 |
if tid is None:
|
| 136 |
continue
|
|
@@ -140,17 +158,18 @@ def process_video(video_path, confidence, frame_stride, progress=gr.Progress(tra
|
|
| 140 |
else:
|
| 141 |
detections = last_detections
|
| 142 |
|
| 143 |
-
# Annotate
|
| 144 |
if len(detections) > 0:
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
|
|
|
| 154 |
frame = BOX_ANNOTATOR.annotate(frame, detections)
|
| 155 |
frame = LABEL_ANNOTATOR.annotate(frame, detections, labels)
|
| 156 |
|
|
@@ -158,21 +177,26 @@ def process_video(video_path, confidence, frame_stride, progress=gr.Progress(tra
|
|
| 158 |
frame = draw_counter_panel(frame, counts_now)
|
| 159 |
sink.write_frame(frame)
|
| 160 |
|
| 161 |
-
# Build summary outputs
|
| 162 |
total = sum(len(ids) for ids in unique_ids.values())
|
| 163 |
if total == 0:
|
| 164 |
-
summary_md = "### ℹ️ No target objects detected.\
|
|
|
|
| 165 |
else:
|
| 166 |
lines = [f"### ✅ Total unique objects detected: **{total}**", ""]
|
| 167 |
for name in TARGET_CLASSES.values():
|
| 168 |
n = len(unique_ids.get(name, set()))
|
| 169 |
if n > 0:
|
| 170 |
-
|
|
|
|
| 171 |
summary_md = "\n".join(lines)
|
| 172 |
|
| 173 |
-
table = [
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
| 176 |
if not table:
|
| 177 |
table = [["—", 0]]
|
| 178 |
|
|
@@ -192,15 +216,14 @@ CUSTOM_CSS = """
|
|
| 192 |
footer {visibility: hidden;}
|
| 193 |
"""
|
| 194 |
|
| 195 |
-
with gr.Blocks(
|
| 196 |
-
css=CUSTOM_CSS, title="RF-DETR Object Counter") as demo:
|
| 197 |
|
| 198 |
with gr.Row(elem_id="title-row"):
|
| 199 |
gr.Markdown(
|
| 200 |
"""
|
| 201 |
-
#
|
| 202 |
-
Count **people,
|
| 203 |
-
Powered by [RF-DETR Medium](https://github.com/roboflow/rf-detr)
|
| 204 |
each object is counted **only once** as it moves across frames.
|
| 205 |
"""
|
| 206 |
)
|
|
@@ -218,14 +241,14 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate")
|
|
| 218 |
|
| 219 |
with gr.Accordion("⚙️ Advanced settings", open=False):
|
| 220 |
confidence = gr.Slider(
|
| 221 |
-
minimum=0.1, maximum=0.9, value=0.
|
| 222 |
label="Confidence threshold",
|
| 223 |
info="Higher = fewer but more certain detections.",
|
| 224 |
)
|
| 225 |
frame_stride = gr.Slider(
|
| 226 |
-
minimum=1, maximum=
|
| 227 |
-
label="Frame stride",
|
| 228 |
-
info="Process every Nth frame.
|
| 229 |
)
|
| 230 |
|
| 231 |
submit_btn = gr.Button("🔍 Count Objects", variant="primary", size="lg")
|
|
@@ -253,12 +276,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate")
|
|
| 253 |
|
| 254 |
gr.Markdown(
|
| 255 |
"""
|
| 256 |
-
|
| 257 |
-
**Detected categories:** person · bicycle · car · truck · bird · cat · dog · horse ·
|
| 258 |
-
sheep · cow · elephant · bear · zebra · giraffe
|
| 259 |
-
|
| 260 |
-
**Tip:** the first run loads the model (≈45–90 s for Medium). Subsequent runs are much faster.
|
| 261 |
-
Use *Frame stride* if processing is slow on CPU.
|
| 262 |
"""
|
| 263 |
)
|
| 264 |
|
|
@@ -270,4 +288,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate")
|
|
| 270 |
|
| 271 |
|
| 272 |
if __name__ == "__main__":
|
| 273 |
-
demo.queue(max_size=
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
RF-DETR Object Counter — CPU-optimized Gradio app for Hugging Face Spaces.
|
| 3 |
+
Counts people, cars, trucks, and farm animals (cow, sheep/goat, horse/donkey,
|
| 4 |
+
dog) in video using RF-DETR Medium + ByteTrack so each object is counted
|
| 5 |
+
only once across the whole video.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import os
|
|
|
|
| 13 |
import gradio as gr
|
| 14 |
import numpy as np
|
| 15 |
import supervision as sv
|
| 16 |
+
import torch
|
| 17 |
from rfdetr import RFDETRMedium
|
|
|
|
| 18 |
|
| 19 |
# ---------------------------------------------------------------------------
|
| 20 |
+
# Target classes (COCO indices)
|
| 21 |
# ---------------------------------------------------------------------------
|
| 22 |
+
# Note: COCO does NOT contain "goat" or "donkey". We approximate:
|
| 23 |
+
# goat ~ sheep (closest 4-legged ruminant in COCO)
|
| 24 |
+
# donkey ~ horse (closest equid in COCO)
|
| 25 |
+
# Counts for these will be roughly right; labels will say sheep/horse.
|
| 26 |
TARGET_CLASSES = {
|
| 27 |
0: "person",
|
|
|
|
| 28 |
2: "car",
|
| 29 |
7: "truck",
|
|
|
|
|
|
|
|
|
|
| 30 |
16: "dog",
|
| 31 |
+
17: "horse", # also catches donkeys
|
| 32 |
+
18: "sheep", # also catches goats
|
| 33 |
19: "cow",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
}
|
| 35 |
TARGET_IDS = list(TARGET_CLASSES.keys())
|
| 36 |
|
| 37 |
+
# Friendly UI labels
|
| 38 |
+
DISPLAY_NAMES = {
|
| 39 |
+
"person": "person",
|
| 40 |
+
"car": "car",
|
| 41 |
+
"truck": "truck",
|
| 42 |
+
"dog": "dog",
|
| 43 |
+
"horse": "horse / donkey",
|
| 44 |
+
"sheep": "sheep / goat",
|
| 45 |
+
"cow": "cow",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
# Per-class colours for the live overlay panel (BGR)
|
| 49 |
CLASS_COLORS = {
|
| 50 |
+
"person": (66, 135, 245),
|
| 51 |
+
"car": (66, 245, 167),
|
| 52 |
+
"truck": (245, 66, 161),
|
| 53 |
+
"dog": (120, 245, 200),
|
| 54 |
+
"horse": (245, 120, 120),
|
| 55 |
+
"sheep": (220, 220, 220),
|
| 56 |
+
"cow": (140, 90, 60),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
}
|
| 58 |
|
| 59 |
# Example video lives next to app.py
|
|
|
|
| 61 |
EXAMPLE_VIDEO = os.path.join(APP_DIR, "example.mp4")
|
| 62 |
|
| 63 |
# ---------------------------------------------------------------------------
|
| 64 |
+
# Load model — pinned to CPU
|
| 65 |
# ---------------------------------------------------------------------------
|
| 66 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 67 |
+
print(f"Loading RF-DETR Medium on {DEVICE}…")
|
| 68 |
+
MODEL = RFDETRMedium(device=DEVICE)
|
| 69 |
+
|
| 70 |
+
# optimize_for_inference is GPU-only (TensorRT-style ops). Skip on CPU.
|
| 71 |
+
if DEVICE == "cuda":
|
| 72 |
+
try:
|
| 73 |
+
MODEL.optimize_for_inference()
|
| 74 |
+
print("Optimized for GPU inference.")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"GPU optimization skipped: {e}")
|
| 77 |
+
|
| 78 |
+
# Use a few threads for torch CPU inference; tune to your Space's vCPU count
|
| 79 |
try:
|
| 80 |
+
torch.set_num_threads(max(1, (os.cpu_count() or 2) - 1))
|
| 81 |
+
except Exception:
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
print("Model ready.")
|
| 85 |
|
| 86 |
# Annotators
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
def draw_counter_panel(frame: np.ndarray, counts: dict) -> np.ndarray:
|
| 92 |
+
"""Translucent live-count panel in the top-left corner of the frame."""
|
| 93 |
active = [(name, n) for name, n in counts.items() if n > 0]
|
| 94 |
if not active:
|
| 95 |
active = [("No targets yet", 0)]
|
| 96 |
|
| 97 |
+
panel_w = 280
|
| 98 |
panel_h = 40 + 22 * len(active)
|
| 99 |
overlay = frame.copy()
|
| 100 |
cv2.rectangle(overlay, (12, 12), (12 + panel_w, 12 + panel_h), (20, 20, 20), -1)
|
|
|
|
| 107 |
for name, n in active:
|
| 108 |
color = CLASS_COLORS.get(name, (200, 200, 200))
|
| 109 |
cv2.circle(frame, (28, y - 5), 5, color, -1)
|
| 110 |
+
display = DISPLAY_NAMES.get(name, name)
|
| 111 |
+
cv2.putText(frame, f"{display}: {n}", (44, y),
|
| 112 |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (240, 240, 240), 1, cv2.LINE_AA)
|
| 113 |
y += 22
|
| 114 |
return frame
|
| 115 |
|
| 116 |
|
| 117 |
+
def process_video(video_path, confidence, frame_stride,
|
| 118 |
+
progress=gr.Progress(track_tqdm=True)):
|
| 119 |
if video_path is None:
|
| 120 |
return None, "⚠️ Please upload a video first.", []
|
| 121 |
|
|
|
|
| 123 |
frame_gen = sv.get_video_frames_generator(video_path)
|
| 124 |
tracker = sv.ByteTrack(frame_rate=int(video_info.fps))
|
| 125 |
|
| 126 |
+
unique_ids = defaultdict(set) # class_name -> {tracker_id, ...}
|
| 127 |
last_detections = sv.Detections.empty()
|
| 128 |
|
| 129 |
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
| 130 |
|
| 131 |
with sv.VideoSink(target_path=out_path, video_info=video_info) as sink:
|
| 132 |
+
for i, frame in enumerate(progress.tqdm(
|
| 133 |
+
frame_gen,
|
| 134 |
+
total=video_info.total_frames,
|
| 135 |
+
desc="Analyzing video")):
|
| 136 |
+
|
| 137 |
+
# Detect every Nth frame; reuse previous detections in-between
|
| 138 |
+
# so the output video stays smooth even with high stride.
|
| 139 |
if i % frame_stride == 0:
|
| 140 |
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 141 |
detections = MODEL.predict(rgb, threshold=confidence)
|
|
|
|
| 148 |
detections = tracker.update_with_detections(detections)
|
| 149 |
last_detections = detections
|
| 150 |
|
| 151 |
+
# Register unique tracker IDs per class
|
| 152 |
for cid, tid in zip(detections.class_id, detections.tracker_id):
|
| 153 |
if tid is None:
|
| 154 |
continue
|
|
|
|
| 158 |
else:
|
| 159 |
detections = last_detections
|
| 160 |
|
| 161 |
+
# Annotate frame
|
| 162 |
if len(detections) > 0:
|
| 163 |
+
tids = (detections.tracker_id
|
| 164 |
+
if detections.tracker_id is not None
|
| 165 |
+
else [None] * len(detections))
|
| 166 |
+
labels = []
|
| 167 |
+
for cid, tid, conf in zip(detections.class_id, tids, detections.confidence):
|
| 168 |
+
name = TARGET_CLASSES.get(int(cid), "obj")
|
| 169 |
+
display = DISPLAY_NAMES.get(name, name)
|
| 170 |
+
tid_str = f"#{tid} " if tid is not None else ""
|
| 171 |
+
labels.append(f"{tid_str}{display} {conf:.2f}")
|
| 172 |
+
|
| 173 |
frame = BOX_ANNOTATOR.annotate(frame, detections)
|
| 174 |
frame = LABEL_ANNOTATOR.annotate(frame, detections, labels)
|
| 175 |
|
|
|
|
| 177 |
frame = draw_counter_panel(frame, counts_now)
|
| 178 |
sink.write_frame(frame)
|
| 179 |
|
| 180 |
+
# ---------- Build summary outputs ----------
|
| 181 |
total = sum(len(ids) for ids in unique_ids.values())
|
| 182 |
if total == 0:
|
| 183 |
+
summary_md = ("### ℹ️ No target objects detected.\n"
|
| 184 |
+
"Try lowering the confidence threshold or the frame stride.")
|
| 185 |
else:
|
| 186 |
lines = [f"### ✅ Total unique objects detected: **{total}**", ""]
|
| 187 |
for name in TARGET_CLASSES.values():
|
| 188 |
n = len(unique_ids.get(name, set()))
|
| 189 |
if n > 0:
|
| 190 |
+
display = DISPLAY_NAMES.get(name, name).capitalize()
|
| 191 |
+
lines.append(f"- **{display}** — {n}")
|
| 192 |
summary_md = "\n".join(lines)
|
| 193 |
|
| 194 |
+
table = []
|
| 195 |
+
for name in TARGET_CLASSES.values():
|
| 196 |
+
n = len(unique_ids.get(name, set()))
|
| 197 |
+
if n > 0:
|
| 198 |
+
display = DISPLAY_NAMES.get(name, name).capitalize()
|
| 199 |
+
table.append([display, n])
|
| 200 |
if not table:
|
| 201 |
table = [["—", 0]]
|
| 202 |
|
|
|
|
| 216 |
footer {visibility: hidden;}
|
| 217 |
"""
|
| 218 |
|
| 219 |
+
with gr.Blocks(title="RF-DETR Object Counter") as demo:
|
|
|
|
| 220 |
|
| 221 |
with gr.Row(elem_id="title-row"):
|
| 222 |
gr.Markdown(
|
| 223 |
"""
|
| 224 |
+
# 🐄 RF-DETR Object Counter
|
| 225 |
+
Count **people, cars, trucks, and farm animals** in any video.
|
| 226 |
+
Powered by [RF-DETR Medium](https://github.com/roboflow/rf-detr) + ByteTrack —
|
| 227 |
each object is counted **only once** as it moves across frames.
|
| 228 |
"""
|
| 229 |
)
|
|
|
|
| 241 |
|
| 242 |
with gr.Accordion("⚙️ Advanced settings", open=False):
|
| 243 |
confidence = gr.Slider(
|
| 244 |
+
minimum=0.1, maximum=0.9, value=0.45, step=0.05,
|
| 245 |
label="Confidence threshold",
|
| 246 |
info="Higher = fewer but more certain detections.",
|
| 247 |
)
|
| 248 |
frame_stride = gr.Slider(
|
| 249 |
+
minimum=1, maximum=15, value=5, step=1,
|
| 250 |
+
label="Frame stride (CPU speed control)",
|
| 251 |
+
info="Process every Nth frame. On CPU, 5–8 is a good balance.",
|
| 252 |
)
|
| 253 |
|
| 254 |
submit_btn = gr.Button("🔍 Count Objects", variant="primary", size="lg")
|
|
|
|
| 276 |
|
| 277 |
gr.Markdown(
|
| 278 |
"""
|
| 279 |
+
**Detected categories:** person · car · truck · dog · horse / donkey · sheep / goat · cow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
"""
|
| 281 |
)
|
| 282 |
|
|
|
|
| 288 |
|
| 289 |
|
| 290 |
if __name__ == "__main__":
|
| 291 |
+
demo.queue(max_size=4).launch(
|
| 292 |
+
theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
|
| 293 |
+
css=CUSTOM_CSS,
|
| 294 |
+
)
|