lyimo commited on
Commit
912ec81
·
verified ·
1 Parent(s): e4c58c8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -0
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RF-DETR Object Counter — Gradio app for Hugging Face Spaces.
3
+ Counts people, bicycles, cars, trucks, and animals in video using
4
+ RF-DETR Medium + ByteTrack (so each object is counted only once).
5
+ """
6
+
7
+ import os
8
+ import tempfile
9
+ from collections import defaultdict
10
+
11
+ import cv2
12
+ import gradio as gr
13
+ import numpy as np
14
+ import supervision as sv
15
+ from rfdetr import RFDETRMedium
16
+ from rfdetr.assets.coco_classes import COCO_CLASSES
17
+
18
+ # ---------------------------------------------------------------------------
19
+ # Target classes (COCO indices) — exactly what the user asked for
20
+ # ---------------------------------------------------------------------------
21
+ TARGET_CLASSES = {
22
+ 0: "person",
23
+ 1: "bicycle",
24
+ 2: "car",
25
+ 7: "truck",
26
+ # animals
27
+ 14: "bird",
28
+ 15: "cat",
29
+ 16: "dog",
30
+ 17: "horse",
31
+ 18: "sheep",
32
+ 19: "cow",
33
+ 20: "elephant",
34
+ 21: "bear",
35
+ 22: "zebra",
36
+ 23: "giraffe",
37
+ }
38
+ TARGET_IDS = list(TARGET_CLASSES.keys())
39
+
40
+ # Per-class colour palette (BGR) for the live overlay
41
+ CLASS_COLORS = {
42
+ "person": (66, 135, 245),
43
+ "bicycle": (245, 173, 66),
44
+ "car": (66, 245, 167),
45
+ "truck": (245, 66, 161),
46
+ "bird": (245, 230, 66),
47
+ "cat": (200, 120, 245),
48
+ "dog": (120, 245, 200),
49
+ "horse": (245, 120, 120),
50
+ "sheep": (220, 220, 220),
51
+ "cow": (140, 90, 60),
52
+ "elephant": (160, 160, 200),
53
+ "bear": (90, 60, 30),
54
+ "zebra": (40, 40, 40),
55
+ "giraffe": (220, 180, 90),
56
+ }
57
+
58
+ # Example video lives next to app.py
59
+ APP_DIR = os.path.dirname(os.path.abspath(__file__))
60
+ EXAMPLE_VIDEO = os.path.join(APP_DIR, "example.mp4")
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # Load model once at startup
64
+ # ---------------------------------------------------------------------------
65
+ print("Loading RF-DETR Medium…")
66
+ MODEL = RFDETRMedium()
67
+ try:
68
+ MODEL.optimize_for_inference() # speeds up subsequent predicts
69
+ print("Model optimized for inference.")
70
+ except Exception as e:
71
+ print(f"(Optimization skipped: {e})")
72
+ print("Model ready.")
73
+
74
+ # Annotators
75
+ BOX_ANNOTATOR = sv.BoxAnnotator(thickness=2)
76
+ LABEL_ANNOTATOR = sv.LabelAnnotator(text_scale=0.45, text_thickness=1, text_padding=3)
77
+
78
+
79
+ def draw_counter_panel(frame: np.ndarray, counts: dict) -> np.ndarray:
80
+ """Translucent counter panel in the top-left corner."""
81
+ active = [(name, n) for name, n in counts.items() if n > 0]
82
+ if not active:
83
+ active = [("No targets yet", 0)]
84
+
85
+ panel_w = 230
86
+ panel_h = 40 + 22 * len(active)
87
+ overlay = frame.copy()
88
+ cv2.rectangle(overlay, (12, 12), (12 + panel_w, 12 + panel_h), (20, 20, 20), -1)
89
+ frame = cv2.addWeighted(overlay, 0.65, frame, 0.35, 0)
90
+
91
+ cv2.putText(frame, "LIVE COUNTS", (24, 38),
92
+ cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2, cv2.LINE_AA)
93
+
94
+ y = 62
95
+ for name, n in active:
96
+ color = CLASS_COLORS.get(name, (200, 200, 200))
97
+ cv2.circle(frame, (28, y - 5), 5, color, -1)
98
+ cv2.putText(frame, f"{name}: {n}", (44, y),
99
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (240, 240, 240), 1, cv2.LINE_AA)
100
+ y += 22
101
+ return frame
102
+
103
+
104
+ def process_video(video_path, confidence, frame_stride, progress=gr.Progress(track_tqdm=True)):
105
+ if video_path is None:
106
+ return None, "⚠️ Please upload a video first.", []
107
+
108
+ video_info = sv.VideoInfo.from_video_path(video_path)
109
+ frame_gen = sv.get_video_frames_generator(video_path)
110
+ tracker = sv.ByteTrack(frame_rate=int(video_info.fps))
111
+
112
+ unique_ids = defaultdict(set) # class_name -> {tracker_id, ...}
113
+ last_detections = sv.Detections.empty()
114
+
115
+ out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
116
+
117
+ with sv.VideoSink(target_path=out_path, video_info=video_info) as sink:
118
+ for i, frame in enumerate(progress.tqdm(frame_gen, total=video_info.total_frames,
119
+ desc="Analyzing video")):
120
+ # Detect every Nth frame; reuse previous detections in-between to keep video smooth
121
+ if i % frame_stride == 0:
122
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
123
+ detections = MODEL.predict(rgb, threshold=confidence)
124
+
125
+ # Keep only the classes we care about
126
+ if len(detections) > 0:
127
+ mask = np.isin(detections.class_id, TARGET_IDS)
128
+ detections = detections[mask]
129
+
130
+ detections = tracker.update_with_detections(detections)
131
+ last_detections = detections
132
+
133
+ # Register unique IDs per class
134
+ for cid, tid in zip(detections.class_id, detections.tracker_id):
135
+ if tid is None:
136
+ continue
137
+ name = TARGET_CLASSES.get(int(cid))
138
+ if name:
139
+ unique_ids[name].add(int(tid))
140
+ else:
141
+ detections = last_detections
142
+
143
+ # Annotate
144
+ if len(detections) > 0:
145
+ labels = [
146
+ f"#{tid} {TARGET_CLASSES.get(int(cid), 'obj')} {conf:.2f}"
147
+ for cid, tid, conf in zip(
148
+ detections.class_id,
149
+ detections.tracker_id if detections.tracker_id is not None
150
+ else [None] * len(detections),
151
+ detections.confidence,
152
+ )
153
+ ]
154
+ frame = BOX_ANNOTATOR.annotate(frame, detections)
155
+ frame = LABEL_ANNOTATOR.annotate(frame, detections, labels)
156
+
157
+ counts_now = {name: len(ids) for name, ids in unique_ids.items()}
158
+ frame = draw_counter_panel(frame, counts_now)
159
+ sink.write_frame(frame)
160
+
161
+ # Build summary outputs
162
+ total = sum(len(ids) for ids in unique_ids.values())
163
+ if total == 0:
164
+ summary_md = "### ℹ️ No target objects detected.\nTry lowering the confidence threshold."
165
+ else:
166
+ lines = [f"### ✅ Total unique objects detected: **{total}**", ""]
167
+ for name in TARGET_CLASSES.values():
168
+ n = len(unique_ids.get(name, set()))
169
+ if n > 0:
170
+ lines.append(f"- **{name.capitalize()}** — {n}")
171
+ summary_md = "\n".join(lines)
172
+
173
+ table = [[name.capitalize(), len(unique_ids.get(name, set()))]
174
+ for name in TARGET_CLASSES.values()
175
+ if len(unique_ids.get(name, set())) > 0]
176
+ if not table:
177
+ table = [["—", 0]]
178
+
179
+ return out_path, summary_md, table
180
+
181
+
182
+ # ---------------------------------------------------------------------------
183
+ # UI
184
+ # ---------------------------------------------------------------------------
185
+ CUSTOM_CSS = """
186
+ .gradio-container {max-width: 1200px !important; margin: auto;}
187
+ #title-row {text-align: center; padding: 8px 0 0 0;}
188
+ #title-row h1 {font-weight: 700; letter-spacing: -0.5px; margin-bottom: 4px;}
189
+ #title-row p {color: #6b7280; margin-top: 0;}
190
+ .card {border: 1px solid #e5e7eb; border-radius: 14px; padding: 16px;
191
+ background: #ffffff;}
192
+ footer {visibility: hidden;}
193
+ """
194
+
195
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
196
+ css=CUSTOM_CSS, title="RF-DETR Object Counter") as demo:
197
+
198
+ with gr.Row(elem_id="title-row"):
199
+ gr.Markdown(
200
+ """
201
+ # 🚦 RF-DETR Object Counter
202
+ Count **people, bicycles, cars, trucks, and animals** in any video.
203
+ Powered by [RF-DETR Medium](https://github.com/roboflow/rf-detr) (Roboflow, ICLR 2026) and ByteTrack —
204
+ each object is counted **only once** as it moves across frames.
205
+ """
206
+ )
207
+
208
+ with gr.Row():
209
+ with gr.Column(scale=1):
210
+ with gr.Group(elem_classes="card"):
211
+ gr.Markdown("### 📥 Input")
212
+ video_input = gr.Video(
213
+ label="Upload a video",
214
+ sources=["upload"],
215
+ format="mp4",
216
+ height=320,
217
+ )
218
+
219
+ with gr.Accordion("⚙️ Advanced settings", open=False):
220
+ confidence = gr.Slider(
221
+ minimum=0.1, maximum=0.9, value=0.5, step=0.05,
222
+ label="Confidence threshold",
223
+ info="Higher = fewer but more certain detections.",
224
+ )
225
+ frame_stride = gr.Slider(
226
+ minimum=1, maximum=10, value=2, step=1,
227
+ label="Frame stride",
228
+ info="Process every Nth frame. Higher = faster, slightly less accurate.",
229
+ )
230
+
231
+ submit_btn = gr.Button("🔍 Count Objects", variant="primary", size="lg")
232
+
233
+ gr.Markdown("#### 🎬 Example video")
234
+ gr.Examples(
235
+ examples=[[EXAMPLE_VIDEO]],
236
+ inputs=video_input,
237
+ label=None,
238
+ examples_per_page=4,
239
+ )
240
+
241
+ with gr.Column(scale=1):
242
+ with gr.Group(elem_classes="card"):
243
+ gr.Markdown("### 📤 Annotated output")
244
+ video_output = gr.Video(label="Annotated video", height=320)
245
+ summary_output = gr.Markdown("Submit a video to see the results here.")
246
+ table_output = gr.Dataframe(
247
+ headers=["Class", "Unique count"],
248
+ datatype=["str", "number"],
249
+ label="Per-class totals",
250
+ interactive=False,
251
+ wrap=True,
252
+ )
253
+
254
+ gr.Markdown(
255
+ """
256
+ ---
257
+ **Detected categories:** person · bicycle · car · truck · bird · cat · dog · horse ·
258
+ sheep · cow · elephant · bear · zebra · giraffe
259
+
260
+ **Tip:** the first run loads the model (≈45–90 s for Medium). Subsequent runs are much faster.
261
+ Use *Frame stride* if processing is slow on CPU.
262
+ """
263
+ )
264
+
265
+ submit_btn.click(
266
+ fn=process_video,
267
+ inputs=[video_input, confidence, frame_stride],
268
+ outputs=[video_output, summary_output, table_output],
269
+ )
270
+
271
+
272
+ if __name__ == "__main__":
273
+ demo.queue(max_size=8).launch()