lyimo commited on
Commit
9915b68
·
verified ·
1 Parent(s): 51683d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -100
app.py CHANGED
@@ -1,12 +1,12 @@
1
  """
2
- RF-DETR Object Counter — CPU-optimized Gradio app for Hugging Face Spaces.
3
- Counts people, cars, trucks, and farm animals (cow, sheep/goat, horse/donkey,
4
- dog) in video using RF-DETR Medium + ByteTrack so each object is counted
5
- only once across the whole video.
6
  """
7
 
8
  import os
9
  import tempfile
 
10
  from collections import defaultdict
11
 
12
  import cv2
@@ -14,15 +14,12 @@ import gradio as gr
14
  import numpy as np
15
  import supervision as sv
16
  import torch
17
- from rfdetr import RFDETRMedium
18
 
19
  # ---------------------------------------------------------------------------
20
  # Target classes (COCO indices)
21
  # ---------------------------------------------------------------------------
22
- # Note: COCO does NOT contain "goat" or "donkey". We approximate:
23
- # goat ~ sheep (closest 4-legged ruminant in COCO)
24
- # donkey ~ horse (closest equid in COCO)
25
- # Counts for these will be roughly right; labels will say sheep/horse.
26
  TARGET_CLASSES = {
27
  0: "person",
28
  2: "car",
@@ -34,7 +31,6 @@ TARGET_CLASSES = {
34
  }
35
  TARGET_IDS = list(TARGET_CLASSES.keys())
36
 
37
- # Friendly UI labels
38
  DISPLAY_NAMES = {
39
  "person": "person",
40
  "car": "car",
@@ -45,8 +41,7 @@ DISPLAY_NAMES = {
45
  "cow": "cow",
46
  }
47
 
48
- # Per-class colours for the live overlay panel (BGR)
49
- CLASS_COLORS = {
50
  "person": (66, 135, 245),
51
  "car": (66, 245, 167),
52
  "truck": (245, 66, 161),
@@ -56,18 +51,16 @@ CLASS_COLORS = {
56
  "cow": (140, 90, 60),
57
  }
58
 
59
- # Example video lives next to app.py
60
  APP_DIR = os.path.dirname(os.path.abspath(__file__))
61
  EXAMPLE_VIDEO = os.path.join(APP_DIR, "example.mp4")
62
 
63
  # ---------------------------------------------------------------------------
64
- # Load model — pinned to CPU
65
  # ---------------------------------------------------------------------------
66
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
67
- print(f"Loading RF-DETR Medium on {DEVICE}…")
68
- MODEL = RFDETRMedium(device=DEVICE)
69
 
70
- # optimize_for_inference is GPU-only (TensorRT-style ops). Skip on CPU.
71
  if DEVICE == "cuda":
72
  try:
73
  MODEL.optimize_for_inference()
@@ -75,7 +68,6 @@ if DEVICE == "cuda":
75
  except Exception as e:
76
  print(f"GPU optimization skipped: {e}")
77
 
78
- # Use a few threads for torch CPU inference; tune to your Space's vCPU count
79
  try:
80
  torch.set_num_threads(max(1, (os.cpu_count() or 2) - 1))
81
  except Exception:
@@ -83,64 +75,147 @@ except Exception:
83
 
84
  print("Model ready.")
85
 
86
- # Annotators
87
  BOX_ANNOTATOR = sv.BoxAnnotator(thickness=2)
88
  LABEL_ANNOTATOR = sv.LabelAnnotator(text_scale=0.45, text_thickness=1, text_padding=3)
89
 
90
 
91
- def draw_counter_panel(frame: np.ndarray, counts: dict) -> np.ndarray:
92
- """Translucent live-count panel in the top-left corner of the frame."""
 
 
 
 
 
93
  active = [(name, n) for name, n in counts.items() if n > 0]
94
- if not active:
95
- active = [("No targets yet", 0)]
96
 
97
- panel_w = 280
98
- panel_h = 40 + 22 * len(active)
 
 
99
  overlay = frame.copy()
100
- cv2.rectangle(overlay, (12, 12), (12 + panel_w, 12 + panel_h), (20, 20, 20), -1)
 
101
  frame = cv2.addWeighted(overlay, 0.65, frame, 0.35, 0)
102
 
103
- cv2.putText(frame, "LIVE COUNTS", (24, 38),
104
- cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2, cv2.LINE_AA)
 
 
 
105
 
106
- y = 62
107
- for name, n in active:
108
- color = CLASS_COLORS.get(name, (200, 200, 200))
109
- cv2.circle(frame, (28, y - 5), 5, color, -1)
110
- display = DISPLAY_NAMES.get(name, name)
111
- cv2.putText(frame, f"{display}: {n}", (44, y),
112
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (240, 240, 240), 1, cv2.LINE_AA)
113
- y += 22
 
 
 
 
 
114
  return frame
115
 
116
 
117
- def process_video(video_path, confidence, frame_stride,
118
- progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  if video_path is None:
120
- return None, "⚠️ Please upload a video first.", []
 
 
 
 
 
121
 
122
  video_info = sv.VideoInfo.from_video_path(video_path)
123
  frame_gen = sv.get_video_frames_generator(video_path)
124
  tracker = sv.ByteTrack(frame_rate=int(video_info.fps))
125
 
126
- unique_ids = defaultdict(set) # class_name -> {tracker_id, ...}
127
  last_detections = sv.Detections.empty()
128
 
129
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  with sv.VideoSink(target_path=out_path, video_info=video_info) as sink:
132
- for i, frame in enumerate(progress.tqdm(
133
- frame_gen,
134
- total=video_info.total_frames,
135
- desc="Analyzing video")):
136
 
137
- # Detect every Nth frame; reuse previous detections in-between
138
- # so the output video stays smooth even with high stride.
139
  if i % frame_stride == 0:
140
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
141
  detections = MODEL.predict(rgb, threshold=confidence)
142
 
143
- # Keep only the classes we care about
144
  if len(detections) > 0:
145
  mask = np.isin(detections.class_id, TARGET_IDS)
146
  detections = detections[mask]
@@ -148,7 +223,6 @@ def process_video(video_path, confidence, frame_stride,
148
  detections = tracker.update_with_detections(detections)
149
  last_detections = detections
150
 
151
- # Register unique tracker IDs per class
152
  for cid, tid in zip(detections.class_id, detections.tracker_id):
153
  if tid is None:
154
  continue
@@ -158,13 +232,14 @@ def process_video(video_path, confidence, frame_stride,
158
  else:
159
  detections = last_detections
160
 
161
- # Annotate frame
162
  if len(detections) > 0:
163
  tids = (detections.tracker_id
164
  if detections.tracker_id is not None
165
  else [None] * len(detections))
166
  labels = []
167
- for cid, tid, conf in zip(detections.class_id, tids, detections.confidence):
 
168
  name = TARGET_CLASSES.get(int(cid), "obj")
169
  display = DISPLAY_NAMES.get(name, name)
170
  tid_str = f"#{tid} " if tid is not None else ""
@@ -174,61 +249,61 @@ def process_video(video_path, confidence, frame_stride,
174
  frame = LABEL_ANNOTATOR.annotate(frame, detections, labels)
175
 
176
  counts_now = {name: len(ids) for name, ids in unique_ids.items()}
177
- frame = draw_counter_panel(frame, counts_now)
178
- sink.write_frame(frame)
179
-
180
- # ---------- Build summary outputs ----------
181
- total = sum(len(ids) for ids in unique_ids.values())
182
- if total == 0:
183
- summary_md = ("### ℹ️ No target objects detected.\n"
184
- "Try lowering the confidence threshold or the frame stride.")
185
- else:
186
- lines = [f"### ✅ Total unique objects detected: **{total}**", ""]
187
- for name in TARGET_CLASSES.values():
188
- n = len(unique_ids.get(name, set()))
189
- if n > 0:
190
- display = DISPLAY_NAMES.get(name, name).capitalize()
191
- lines.append(f"- **{display}** — {n}")
192
- summary_md = "\n".join(lines)
193
-
194
- table = []
195
- for name in TARGET_CLASSES.values():
196
- n = len(unique_ids.get(name, set()))
197
- if n > 0:
198
- display = DISPLAY_NAMES.get(name, name).capitalize()
199
- table.append([display, n])
200
- if not table:
201
- table = [["—", 0]]
202
 
203
- return out_path, summary_md, table
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
 
206
  # ---------------------------------------------------------------------------
207
  # UI
208
  # ---------------------------------------------------------------------------
209
  CUSTOM_CSS = """
210
- .gradio-container {max-width: 1200px !important; margin: auto;}
211
  #title-row {text-align: center; padding: 8px 0 0 0;}
212
  #title-row h1 {font-weight: 700; letter-spacing: -0.5px; margin-bottom: 4px;}
213
  #title-row p {color: #6b7280; margin-top: 0;}
214
  .card {border: 1px solid #e5e7eb; border-radius: 14px; padding: 16px;
215
  background: #ffffff;}
 
216
  footer {visibility: hidden;}
217
  """
218
 
219
- with gr.Blocks(title="RF-DETR Object Counter") as demo:
220
 
221
  with gr.Row(elem_id="title-row"):
222
  gr.Markdown(
223
  """
224
- # 🐄 RF-DETR Object Counter
225
- Count **people, cars, trucks, and farm animals** in any video.
226
- Powered by [RF-DETR Medium](https://github.com/roboflow/rf-detr) + ByteTrack —
227
- each object is counted **only once** as it moves across frames.
 
228
  """
229
  )
230
 
231
  with gr.Row():
 
232
  with gr.Column(scale=1):
233
  with gr.Group(elem_classes="card"):
234
  gr.Markdown("### 📥 Input")
@@ -236,54 +311,70 @@ with gr.Blocks(title="RF-DETR Object Counter") as demo:
236
  label="Upload a video",
237
  sources=["upload"],
238
  format="mp4",
239
- height=320,
240
  )
241
 
242
  with gr.Accordion("⚙️ Advanced settings", open=False):
243
  confidence = gr.Slider(
244
  minimum=0.1, maximum=0.9, value=0.45, step=0.05,
245
  label="Confidence threshold",
246
- info="Higher = fewer but more certain detections.",
247
  )
248
  frame_stride = gr.Slider(
249
- minimum=1, maximum=15, value=5, step=1,
250
- label="Frame stride (CPU speed control)",
251
- info="Process every Nth frame. On CPU, 5–8 is a good balance.",
252
  )
253
 
254
- submit_btn = gr.Button("🔍 Count Objects", variant="primary", size="lg")
 
255
 
256
  gr.Markdown("#### 🎬 Example video")
257
  gr.Examples(
258
  examples=[[EXAMPLE_VIDEO]],
259
  inputs=video_input,
260
- label=None,
261
  examples_per_page=4,
262
  )
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  with gr.Column(scale=1):
265
  with gr.Group(elem_classes="card"):
266
- gr.Markdown("### 📤 Annotated output")
267
- video_output = gr.Video(label="Annotated video", height=320)
268
- summary_output = gr.Markdown("Submit a video to see the results here.")
269
  table_output = gr.Dataframe(
270
  headers=["Class", "Unique count"],
271
  datatype=["str", "number"],
272
- label="Per-class totals",
273
  interactive=False,
274
  wrap=True,
275
  )
276
 
277
- gr.Markdown(
278
- """
279
- **Detected categories:** person · car · truck · dog · horse / donkey · sheep / goat · cow
280
- """
281
- )
282
-
283
  submit_btn.click(
284
  fn=process_video,
285
  inputs=[video_input, confidence, frame_stride],
286
- outputs=[video_output, summary_output, table_output],
 
287
  )
288
 
289
 
 
1
  """
2
+ RF-DETR Object Counter — live-streaming Gradio app for Hugging Face Spaces.
3
+ Annotated frames stream into the UI in real time while counts update as
4
+ the model processes the video.
 
5
  """
6
 
7
  import os
8
  import tempfile
9
+ import time
10
  from collections import defaultdict
11
 
12
  import cv2
 
14
  import numpy as np
15
  import supervision as sv
16
  import torch
17
+ from rfdetr import RFDETRNano
18
 
19
  # ---------------------------------------------------------------------------
20
  # Target classes (COCO indices)
21
  # ---------------------------------------------------------------------------
22
+ # COCO has no "goat" or "donkey" closest proxies: sheep ≈ goat, horse ≈ donkey
 
 
 
23
  TARGET_CLASSES = {
24
  0: "person",
25
  2: "car",
 
31
  }
32
  TARGET_IDS = list(TARGET_CLASSES.keys())
33
 
 
34
  DISPLAY_NAMES = {
35
  "person": "person",
36
  "car": "car",
 
41
  "cow": "cow",
42
  }
43
 
44
+ CLASS_COLORS = { # BGR
 
45
  "person": (66, 135, 245),
46
  "car": (66, 245, 167),
47
  "truck": (245, 66, 161),
 
51
  "cow": (140, 90, 60),
52
  }
53
 
 
54
  APP_DIR = os.path.dirname(os.path.abspath(__file__))
55
  EXAMPLE_VIDEO = os.path.join(APP_DIR, "example.mp4")
56
 
57
  # ---------------------------------------------------------------------------
58
+ # Load model — CPU-pinned
59
  # ---------------------------------------------------------------------------
60
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
61
+ print(f"Loading RF-DETR Nano on {DEVICE}…")
62
+ MODEL = RFDETRNano(device=DEVICE)
63
 
 
64
  if DEVICE == "cuda":
65
  try:
66
  MODEL.optimize_for_inference()
 
68
  except Exception as e:
69
  print(f"GPU optimization skipped: {e}")
70
 
 
71
  try:
72
  torch.set_num_threads(max(1, (os.cpu_count() or 2) - 1))
73
  except Exception:
 
75
 
76
  print("Model ready.")
77
 
 
78
  BOX_ANNOTATOR = sv.BoxAnnotator(thickness=2)
79
  LABEL_ANNOTATOR = sv.LabelAnnotator(text_scale=0.45, text_thickness=1, text_padding=3)
80
 
81
 
82
+ # ---------------------------------------------------------------------------
83
+ # Drawing helpers
84
+ # ---------------------------------------------------------------------------
85
+ def draw_counter_panel(frame: np.ndarray, counts: dict,
86
+ frame_idx: int, total_frames: int,
87
+ fps_proc: float) -> np.ndarray:
88
+ """Translucent live-info panel in the top-left corner."""
89
  active = [(name, n) for name, n in counts.items() if n > 0]
 
 
90
 
91
+ rows = max(1, len(active)) + 1
92
+ panel_w = 320
93
+ panel_h = 28 + 22 * rows
94
+
95
  overlay = frame.copy()
96
+ cv2.rectangle(overlay, (12, 12), (12 + panel_w, 12 + panel_h),
97
+ (20, 20, 20), -1)
98
  frame = cv2.addWeighted(overlay, 0.65, frame, 0.35, 0)
99
 
100
+ cv2.putText(frame, " LIVE", (24, 36),
101
+ cv2.FONT_HERSHEY_SIMPLEX, 0.55, (60, 220, 60), 2, cv2.LINE_AA)
102
+ cv2.putText(frame, f"frame {frame_idx}/{total_frames} · {fps_proc:.1f} fps",
103
+ (90, 36),
104
+ cv2.FONT_HERSHEY_SIMPLEX, 0.45, (200, 200, 200), 1, cv2.LINE_AA)
105
 
106
+ y = 60
107
+ if not active:
108
+ cv2.putText(frame, "scanning…", (28, y),
109
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1, cv2.LINE_AA)
110
+ else:
111
+ for name, n in active:
112
+ color = CLASS_COLORS.get(name, (200, 200, 200))
113
+ cv2.circle(frame, (28, y - 5), 5, color, -1)
114
+ display = DISPLAY_NAMES.get(name, name)
115
+ cv2.putText(frame, f"{display}: {n}", (44, y),
116
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5,
117
+ (240, 240, 240), 1, cv2.LINE_AA)
118
+ y += 22
119
  return frame
120
 
121
 
122
+ def build_counts_html(unique_ids: dict, frame_idx: int, total: int,
123
+ elapsed: float) -> str:
124
+ """Side-panel live counts as HTML cards."""
125
+ pct = (frame_idx / total * 100) if total else 0
126
+ cards = []
127
+ for name in TARGET_CLASSES.values():
128
+ n = len(unique_ids.get(name, set()))
129
+ display = DISPLAY_NAMES.get(name, name)
130
+ r, g, b = CLASS_COLORS.get(name, (200, 200, 200))[::-1] # BGR->RGB
131
+ opacity = "1.0" if n > 0 else "0.35"
132
+ cards.append(
133
+ f'<div style="display:flex;justify-content:space-between;'
134
+ f'align-items:center;padding:8px 12px;margin:4px 0;'
135
+ f'border-radius:8px;background:rgba({r},{g},{b},0.10);'
136
+ f'border-left:4px solid rgb({r},{g},{b});opacity:{opacity};">'
137
+ f'<span style="font-weight:500;color:#111;">{display}</span>'
138
+ f'<span style="font-size:18px;font-weight:700;color:rgb({r},{g},{b});">{n}</span>'
139
+ f'</div>'
140
+ )
141
+ progress = (
142
+ f'<div style="margin:8px 0 14px 0;">'
143
+ f'<div style="display:flex;justify-content:space-between;font-size:12px;'
144
+ f'color:#6b7280;margin-bottom:4px;"><span>frame {frame_idx} / {total}</span>'
145
+ f'<span>{pct:.1f}% · {elapsed:.1f}s</span></div>'
146
+ f'<div style="height:6px;background:#e5e7eb;border-radius:3px;overflow:hidden;">'
147
+ f'<div style="height:100%;width:{pct}%;background:#6366f1;'
148
+ f'transition:width 0.2s;"></div></div></div>'
149
+ )
150
+ return progress + "".join(cards)
151
+
152
+
153
+ def build_summary_md(unique_ids: dict) -> str:
154
+ total = sum(len(ids) for ids in unique_ids.values())
155
+ if total == 0:
156
+ return ("### ℹ️ No target objects detected.\n"
157
+ "Try lowering the confidence threshold or the frame stride.")
158
+ lines = [f"### ✅ Total unique objects detected: **{total}**", ""]
159
+ for name in TARGET_CLASSES.values():
160
+ n = len(unique_ids.get(name, set()))
161
+ if n > 0:
162
+ lines.append(f"- **{DISPLAY_NAMES.get(name, name).capitalize()}** — {n}")
163
+ return "\n".join(lines)
164
+
165
+
166
+ def build_table(unique_ids: dict):
167
+ rows = []
168
+ for name in TARGET_CLASSES.values():
169
+ n = len(unique_ids.get(name, set()))
170
+ if n > 0:
171
+ rows.append([DISPLAY_NAMES.get(name, name).capitalize(), n])
172
+ return rows if rows else [["—", 0]]
173
+
174
+
175
+ # ---------------------------------------------------------------------------
176
+ # Main streaming generator
177
+ # ---------------------------------------------------------------------------
178
+ def process_video(video_path, confidence, frame_stride):
179
+ """
180
+ Generator: streams the annotated frame + live counts every iteration,
181
+ then yields the saved video and final table on completion.
182
+ """
183
  if video_path is None:
184
+ yield (None,
185
+ '<div style="padding:12px;color:#b91c1c;">⚠️ Please upload a video first.</div>',
186
+ "Submit a video to start.",
187
+ None,
188
+ [])
189
+ return
190
 
191
  video_info = sv.VideoInfo.from_video_path(video_path)
192
  frame_gen = sv.get_video_frames_generator(video_path)
193
  tracker = sv.ByteTrack(frame_rate=int(video_info.fps))
194
 
195
+ unique_ids = defaultdict(set)
196
  last_detections = sv.Detections.empty()
197
 
198
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
199
+ total = video_info.total_frames or 0
200
+ start_time = time.time()
201
+ last_yield = 0.0
202
+
203
+ yield (None,
204
+ build_counts_html(unique_ids, 0, total, 0.0),
205
+ "### 🎬 Starting analysis…",
206
+ None,
207
+ [])
208
+
209
+ last_rgb = None
210
 
211
  with sv.VideoSink(target_path=out_path, video_info=video_info) as sink:
212
+ for i, frame in enumerate(frame_gen):
 
 
 
213
 
214
+ # ---- Detection (every Nth frame) ----
 
215
  if i % frame_stride == 0:
216
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
217
  detections = MODEL.predict(rgb, threshold=confidence)
218
 
 
219
  if len(detections) > 0:
220
  mask = np.isin(detections.class_id, TARGET_IDS)
221
  detections = detections[mask]
 
223
  detections = tracker.update_with_detections(detections)
224
  last_detections = detections
225
 
 
226
  for cid, tid in zip(detections.class_id, detections.tracker_id):
227
  if tid is None:
228
  continue
 
232
  else:
233
  detections = last_detections
234
 
235
+ # ---- Annotate ----
236
  if len(detections) > 0:
237
  tids = (detections.tracker_id
238
  if detections.tracker_id is not None
239
  else [None] * len(detections))
240
  labels = []
241
+ for cid, tid, conf in zip(detections.class_id, tids,
242
+ detections.confidence):
243
  name = TARGET_CLASSES.get(int(cid), "obj")
244
  display = DISPLAY_NAMES.get(name, name)
245
  tid_str = f"#{tid} " if tid is not None else ""
 
249
  frame = LABEL_ANNOTATOR.annotate(frame, detections, labels)
250
 
251
  counts_now = {name: len(ids) for name, ids in unique_ids.items()}
252
+ elapsed = time.time() - start_time
253
+ fps_proc = (i + 1) / elapsed if elapsed > 0 else 0
254
+ frame = draw_counter_panel(frame, counts_now, i + 1, total, fps_proc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ sink.write_frame(frame)
257
+ last_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
258
+
259
+ # ---- Yield to UI (throttled to ~5 updates/sec) ----
260
+ now = time.time()
261
+ if now - last_yield > 0.20 or i == total - 1:
262
+ last_yield = now
263
+ yield (last_rgb,
264
+ build_counts_html(unique_ids, i + 1, total, elapsed),
265
+ "### 🔴 Live analysis in progress…",
266
+ None,
267
+ [])
268
+
269
+ # ---- Final yield: include saved video + summary table ----
270
+ elapsed = time.time() - start_time
271
+ yield (last_rgb,
272
+ build_counts_html(unique_ids, total, total, elapsed),
273
+ build_summary_md(unique_ids),
274
+ out_path,
275
+ build_table(unique_ids))
276
 
277
 
278
  # ---------------------------------------------------------------------------
279
  # UI
280
  # ---------------------------------------------------------------------------
281
  CUSTOM_CSS = """
282
+ .gradio-container {max-width: 1240px !important; margin: auto;}
283
  #title-row {text-align: center; padding: 8px 0 0 0;}
284
  #title-row h1 {font-weight: 700; letter-spacing: -0.5px; margin-bottom: 4px;}
285
  #title-row p {color: #6b7280; margin-top: 0;}
286
  .card {border: 1px solid #e5e7eb; border-radius: 14px; padding: 16px;
287
  background: #ffffff;}
288
+ #live-frame img {border-radius: 10px;}
289
  footer {visibility: hidden;}
290
  """
291
 
292
+ with gr.Blocks(title="RF-DETR Live Object Counter") as demo:
293
 
294
  with gr.Row(elem_id="title-row"):
295
  gr.Markdown(
296
  """
297
+ # 🐄 RF-DETR Live Object Counter
298
+ Watch detections appear **frame by frame** as the model processes your video
299
+ counts update in real time. Powered by
300
+ [RF-DETR Nano](https://github.com/roboflow/rf-detr) + ByteTrack
301
+ (each object counted only once).
302
  """
303
  )
304
 
305
  with gr.Row():
306
+ # ---------- Left: input ----------
307
  with gr.Column(scale=1):
308
  with gr.Group(elem_classes="card"):
309
  gr.Markdown("### 📥 Input")
 
311
  label="Upload a video",
312
  sources=["upload"],
313
  format="mp4",
314
+ height=260,
315
  )
316
 
317
  with gr.Accordion("⚙️ Advanced settings", open=False):
318
  confidence = gr.Slider(
319
  minimum=0.1, maximum=0.9, value=0.45, step=0.05,
320
  label="Confidence threshold",
 
321
  )
322
  frame_stride = gr.Slider(
323
+ minimum=1, maximum=15, value=3, step=1,
324
+ label="Frame stride (CPU speed)",
325
+ info="Detect every Nth frame. Higher = faster.",
326
  )
327
 
328
+ submit_btn = gr.Button("▶️ Start Live Analysis",
329
+ variant="primary", size="lg")
330
 
331
  gr.Markdown("#### 🎬 Example video")
332
  gr.Examples(
333
  examples=[[EXAMPLE_VIDEO]],
334
  inputs=video_input,
 
335
  examples_per_page=4,
336
  )
337
 
338
+ # ---------- Right: live view ----------
339
+ with gr.Column(scale=2):
340
+ with gr.Group(elem_classes="card"):
341
+ with gr.Row():
342
+ with gr.Column(scale=3):
343
+ gr.Markdown("### 🔴 Live View")
344
+ live_frame = gr.Image(
345
+ label=None,
346
+ show_label=False,
347
+ elem_id="live-frame",
348
+ height=420,
349
+ )
350
+ with gr.Column(scale=1, min_width=220):
351
+ gr.Markdown("### 📊 Live Counts")
352
+ live_counts = gr.HTML(
353
+ value=build_counts_html(defaultdict(set), 0, 0, 0)
354
+ )
355
+
356
+ # ---------- Bottom: final results ----------
357
+ with gr.Row():
358
+ with gr.Column(scale=1):
359
+ with gr.Group(elem_classes="card"):
360
+ gr.Markdown("### 📤 Final annotated video")
361
+ video_output = gr.Video(label="Download / replay", height=260)
362
  with gr.Column(scale=1):
363
  with gr.Group(elem_classes="card"):
364
+ gr.Markdown("### 📈 Final totals")
365
+ summary_output = gr.Markdown("Run an analysis to see results.")
 
366
  table_output = gr.Dataframe(
367
  headers=["Class", "Unique count"],
368
  datatype=["str", "number"],
 
369
  interactive=False,
370
  wrap=True,
371
  )
372
 
 
 
 
 
 
 
373
  submit_btn.click(
374
  fn=process_video,
375
  inputs=[video_input, confidence, frame_stride],
376
+ outputs=[live_frame, live_counts, summary_output,
377
+ video_output, table_output],
378
  )
379
 
380