lyimo commited on
Commit
6ebe736
·
verified ·
1 Parent(s): 9915b68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -141
app.py CHANGED
@@ -1,7 +1,6 @@
1
  """
2
- RF-DETR Object Counter — live-streaming Gradio app for Hugging Face Spaces.
3
- Annotated frames stream into the UI in real time while counts update as
4
- the model processes the video.
5
  """
6
 
7
  import os
@@ -19,17 +18,19 @@ from rfdetr import RFDETRNano
19
  # ---------------------------------------------------------------------------
20
  # Target classes (COCO indices)
21
  # ---------------------------------------------------------------------------
22
- # COCO has no "goat" or "donkey" closest proxies: sheep goat, horse ≈ donkey
 
23
  TARGET_CLASSES = {
24
  0: "person",
25
  2: "car",
26
  7: "truck",
27
  16: "dog",
28
- 17: "horse", # also catches donkeys
29
- 18: "sheep", # also catches goats
30
  19: "cow",
31
  }
32
  TARGET_IDS = list(TARGET_CLASSES.keys())
 
33
 
34
  DISPLAY_NAMES = {
35
  "person": "person",
@@ -75,138 +76,138 @@ except Exception:
75
 
76
  print("Model ready.")
77
 
78
- BOX_ANNOTATOR = sv.BoxAnnotator(thickness=2)
79
- LABEL_ANNOTATOR = sv.LabelAnnotator(text_scale=0.45, text_thickness=1, text_padding=3)
80
-
81
 
82
  # ---------------------------------------------------------------------------
83
- # Drawing helpers
84
  # ---------------------------------------------------------------------------
85
- def draw_counter_panel(frame: np.ndarray, counts: dict,
86
- frame_idx: int, total_frames: int,
87
- fps_proc: float) -> np.ndarray:
88
- """Translucent live-info panel in the top-left corner."""
89
- active = [(name, n) for name, n in counts.items() if n > 0]
90
-
91
- rows = max(1, len(active)) + 1
92
- panel_w = 320
93
- panel_h = 28 + 22 * rows
 
 
 
 
 
 
 
94
 
95
  overlay = frame.copy()
96
- cv2.rectangle(overlay, (12, 12), (12 + panel_w, 12 + panel_h),
97
- (20, 20, 20), -1)
98
- frame = cv2.addWeighted(overlay, 0.65, frame, 0.35, 0)
99
-
100
- cv2.putText(frame, "● LIVE", (24, 36),
101
- cv2.FONT_HERSHEY_SIMPLEX, 0.55, (60, 220, 60), 2, cv2.LINE_AA)
102
- cv2.putText(frame, f"frame {frame_idx}/{total_frames} · {fps_proc:.1f} fps",
103
- (90, 36),
104
- cv2.FONT_HERSHEY_SIMPLEX, 0.45, (200, 200, 200), 1, cv2.LINE_AA)
105
-
106
- y = 60
107
- if not active:
108
- cv2.putText(frame, "scanning…", (28, y),
109
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (180, 180, 180), 1, cv2.LINE_AA)
110
- else:
111
- for name, n in active:
112
- color = CLASS_COLORS.get(name, (200, 200, 200))
113
- cv2.circle(frame, (28, y - 5), 5, color, -1)
114
- display = DISPLAY_NAMES.get(name, name)
115
- cv2.putText(frame, f"{display}: {n}", (44, y),
116
- cv2.FONT_HERSHEY_SIMPLEX, 0.5,
117
- (240, 240, 240), 1, cv2.LINE_AA)
118
- y += 22
119
  return frame
120
 
121
 
122
- def build_counts_html(unique_ids: dict, frame_idx: int, total: int,
123
- elapsed: float) -> str:
124
- """Side-panel live counts as HTML cards."""
 
 
125
  pct = (frame_idx / total * 100) if total else 0
126
- cards = []
127
- for name in TARGET_CLASSES.values():
128
- n = len(unique_ids.get(name, set()))
129
- display = DISPLAY_NAMES.get(name, name)
130
- r, g, b = CLASS_COLORS.get(name, (200, 200, 200))[::-1] # BGR->RGB
131
- opacity = "1.0" if n > 0 else "0.35"
132
- cards.append(
133
- f'<div style="display:flex;justify-content:space-between;'
134
- f'align-items:center;padding:8px 12px;margin:4px 0;'
135
- f'border-radius:8px;background:rgba({r},{g},{b},0.10);'
136
- f'border-left:4px solid rgb({r},{g},{b});opacity:{opacity};">'
137
- f'<span style="font-weight:500;color:#111;">{display}</span>'
138
- f'<span style="font-size:18px;font-weight:700;color:rgb({r},{g},{b});">{n}</span>'
139
- f'</div>'
140
- )
141
- progress = (
142
- f'<div style="margin:8px 0 14px 0;">'
143
- f'<div style="display:flex;justify-content:space-between;font-size:12px;'
144
- f'color:#6b7280;margin-bottom:4px;"><span>frame {frame_idx} / {total}</span>'
145
- f'<span>{pct:.1f}% · {elapsed:.1f}s</span></div>'
146
- f'<div style="height:6px;background:#e5e7eb;border-radius:3px;overflow:hidden;">'
147
- f'<div style="height:100%;width:{pct}%;background:#6366f1;'
148
- f'transition:width 0.2s;"></div></div></div>'
149
  )
150
- return progress + "".join(cards)
151
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- def build_summary_md(unique_ids: dict) -> str:
154
- total = sum(len(ids) for ids in unique_ids.values())
155
- if total == 0:
156
- return ("### ℹ️ No target objects detected.\n"
157
- "Try lowering the confidence threshold or the frame stride.")
158
- lines = [f"### ✅ Total unique objects detected: **{total}**", ""]
159
- for name in TARGET_CLASSES.values():
160
- n = len(unique_ids.get(name, set()))
161
- if n > 0:
162
- lines.append(f"- **{DISPLAY_NAMES.get(name, name).capitalize()}** — {n}")
163
- return "\n".join(lines)
164
 
165
 
166
- def build_table(unique_ids: dict):
167
- rows = []
168
- for name in TARGET_CLASSES.values():
169
- n = len(unique_ids.get(name, set()))
170
- if n > 0:
171
- rows.append([DISPLAY_NAMES.get(name, name).capitalize(), n])
172
- return rows if rows else [["—", 0]]
173
 
174
 
175
  # ---------------------------------------------------------------------------
176
  # Main streaming generator
177
  # ---------------------------------------------------------------------------
178
  def process_video(video_path, confidence, frame_stride):
179
- """
180
- Generator: streams the annotated frame + live counts every iteration,
181
- then yields the saved video and final table on completion.
182
- """
183
  if video_path is None:
184
  yield (None,
185
  '<div style="padding:12px;color:#b91c1c;">⚠️ Please upload a video first.</div>',
186
  "Submit a video to start.",
187
- None,
188
- [])
189
  return
190
 
191
  video_info = sv.VideoInfo.from_video_path(video_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  frame_gen = sv.get_video_frames_generator(video_path)
193
  tracker = sv.ByteTrack(frame_rate=int(video_info.fps))
194
 
195
- unique_ids = defaultdict(set)
196
  last_detections = sv.Detections.empty()
197
-
198
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
199
  total = video_info.total_frames or 0
200
  start_time = time.time()
201
  last_yield = 0.0
 
202
 
203
  yield (None,
204
- build_counts_html(unique_ids, 0, total, 0.0),
205
  "### 🎬 Starting analysis…",
206
- None,
207
- [])
208
-
209
- last_rgb = None
210
 
211
  with sv.VideoSink(target_path=out_path, video_info=video_info) as sink:
212
  for i, frame in enumerate(frame_gen):
@@ -223,16 +224,16 @@ def process_video(video_path, confidence, frame_stride):
223
  detections = tracker.update_with_detections(detections)
224
  last_detections = detections
225
 
226
- for cid, tid in zip(detections.class_id, detections.tracker_id):
227
- if tid is None:
228
- continue
229
- name = TARGET_CLASSES.get(int(cid))
230
- if name:
231
- unique_ids[name].add(int(tid))
232
  else:
233
  detections = last_detections
234
 
235
- # ---- Annotate ----
236
  if len(detections) > 0:
237
  tids = (detections.tracker_id
238
  if detections.tracker_id is not None
@@ -245,34 +246,40 @@ def process_video(video_path, confidence, frame_stride):
245
  tid_str = f"#{tid} " if tid is not None else ""
246
  labels.append(f"{tid_str}{display} {conf:.2f}")
247
 
248
- frame = BOX_ANNOTATOR.annotate(frame, detections)
249
- frame = LABEL_ANNOTATOR.annotate(frame, detections, labels)
 
 
250
 
251
- counts_now = {name: len(ids) for name, ids in unique_ids.items()}
252
- elapsed = time.time() - start_time
253
- fps_proc = (i + 1) / elapsed if elapsed > 0 else 0
254
- frame = draw_counter_panel(frame, counts_now, i + 1, total, fps_proc)
 
 
255
 
256
  sink.write_frame(frame)
257
  last_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
258
 
259
- # ---- Yield to UI (throttled to ~5 updates/sec) ----
260
  now = time.time()
261
  if now - last_yield > 0.20 or i == total - 1:
262
  last_yield = now
 
263
  yield (last_rgb,
264
- build_counts_html(unique_ids, i + 1, total, elapsed),
 
265
  "### 🔴 Live analysis in progress…",
266
- None,
267
- [])
268
 
269
- # ---- Final yield: include saved video + summary table ----
270
  elapsed = time.time() - start_time
 
 
 
271
  yield (last_rgb,
272
- build_counts_html(unique_ids, total, total, elapsed),
273
- build_summary_md(unique_ids),
274
- out_path,
275
- build_table(unique_ids))
276
 
277
 
278
  # ---------------------------------------------------------------------------
@@ -289,16 +296,16 @@ CUSTOM_CSS = """
289
  footer {visibility: hidden;}
290
  """
291
 
292
- with gr.Blocks(title="RF-DETR Live Object Counter") as demo:
293
 
294
  with gr.Row(elem_id="title-row"):
295
  gr.Markdown(
296
  """
297
- # 🐄 RF-DETR Live Object Counter
298
- Watch detections appear **frame by frame** as the model processes your video ��
299
- counts update in real time. Powered by
300
- [RF-DETR Nano](https://github.com/roboflow/rf-detr) + ByteTrack
301
- (each object counted only once).
302
  """
303
  )
304
 
@@ -342,15 +349,14 @@ with gr.Blocks(title="RF-DETR Live Object Counter") as demo:
342
  with gr.Column(scale=3):
343
  gr.Markdown("### 🔴 Live View")
344
  live_frame = gr.Image(
345
- label=None,
346
  show_label=False,
347
  elem_id="live-frame",
348
- height=420,
349
  )
350
- with gr.Column(scale=1, min_width=220):
351
  gr.Markdown("### 📊 Live Counts")
352
  live_counts = gr.HTML(
353
- value=build_counts_html(defaultdict(set), 0, 0, 0)
354
  )
355
 
356
  # ---------- Bottom: final results ----------
@@ -361,20 +367,13 @@ with gr.Blocks(title="RF-DETR Live Object Counter") as demo:
361
  video_output = gr.Video(label="Download / replay", height=260)
362
  with gr.Column(scale=1):
363
  with gr.Group(elem_classes="card"):
364
- gr.Markdown("### 📈 Final totals")
365
  summary_output = gr.Markdown("Run an analysis to see results.")
366
- table_output = gr.Dataframe(
367
- headers=["Class", "Unique count"],
368
- datatype=["str", "number"],
369
- interactive=False,
370
- wrap=True,
371
- )
372
 
373
  submit_btn.click(
374
  fn=process_video,
375
  inputs=[video_input, confidence, frame_stride],
376
- outputs=[live_frame, live_counts, summary_output,
377
- video_output, table_output],
378
  )
379
 
380
 
 
1
  """
2
+ RF-DETR Truck Counter — counts trucks crossing a fixed horizontal line
3
+ at the center of the video (either direction counts as one crossing).
 
4
  """
5
 
6
  import os
 
18
  # ---------------------------------------------------------------------------
19
  # Target classes (COCO indices)
20
  # ---------------------------------------------------------------------------
21
+ # We still DETECT all of these so users see them tracked on screen,
22
+ # but only TRUCKS contribute to the line-crossing count.
23
  TARGET_CLASSES = {
24
  0: "person",
25
  2: "car",
26
  7: "truck",
27
  16: "dog",
28
+ 17: "horse",
29
+ 18: "sheep",
30
  19: "cow",
31
  }
32
  TARGET_IDS = list(TARGET_CLASSES.keys())
33
+ TRUCK_CLASS_ID = 7
34
 
35
  DISPLAY_NAMES = {
36
  "person": "person",
 
76
 
77
  print("Model ready.")
78
 
 
 
 
79
 
80
  # ---------------------------------------------------------------------------
81
+ # Draw a custom truck counter directly on the centerline
82
  # ---------------------------------------------------------------------------
83
+ def draw_truck_label_on_line(frame: np.ndarray, line_y: int, total: int) -> np.ndarray:
84
+ text = f"TRUCKS CROSSED: {total}"
85
+ font = cv2.FONT_HERSHEY_SIMPLEX
86
+ scale = max(0.6, frame.shape[1] / 1600)
87
+ thickness = max(2, int(2 * scale))
88
+
89
+ (tw, th), baseline = cv2.getTextSize(text, font, scale, thickness)
90
+ pad_x, pad_y = 14, 10
91
+ box_w = tw + 2 * pad_x
92
+ box_h = th + 2 * pad_y + baseline
93
+
94
+ cx = frame.shape[1] // 2
95
+ x1 = cx - box_w // 2
96
+ y1 = line_y - box_h // 2
97
+ x2 = x1 + box_w
98
+ y2 = y1 + box_h
99
 
100
  overlay = frame.copy()
101
+ cv2.rectangle(overlay, (x1, y1), (x2, y2), (245, 66, 161), -1) # truck-pink
102
+ frame = cv2.addWeighted(overlay, 0.88, frame, 0.12, 0)
103
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 255, 255), 2)
104
+
105
+ text_x = x1 + pad_x
106
+ text_y = y1 + pad_y + th
107
+ cv2.putText(frame, text, (text_x, text_y),
108
+ font, scale, (255, 255, 255), thickness, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  return frame
110
 
111
 
112
+ # ---------------------------------------------------------------------------
113
+ # HTML side panel
114
+ # ---------------------------------------------------------------------------
115
+ def build_counts_html(truck_total: int, truck_in: int, truck_out: int,
116
+ frame_idx: int, total: int, elapsed: float) -> str:
117
  pct = (frame_idx / total * 100) if total else 0
118
+
119
+ hero = (
120
+ '<div style="text-align:center;padding:22px 14px;margin-bottom:14px;'
121
+ 'background:linear-gradient(135deg,#f5318b,#c2185b);color:white;'
122
+ 'border-radius:14px;box-shadow:0 4px 12px rgba(245,49,139,0.35);">'
123
+ '<div style="font-size:32px;line-height:1;">🚛</div>'
124
+ '<div style="font-size:11px;opacity:0.9;letter-spacing:1.5px;'
125
+ 'margin-top:6px;">TRUCKS CROSSED</div>'
126
+ f'<div style="font-size:54px;font-weight:800;line-height:1.0;'
127
+ f'margin-top:4px;">{truck_total}</div>'
128
+ '<div style="display:flex;justify-content:center;gap:14px;'
129
+ 'margin-top:10px;font-size:12px;opacity:0.95;">'
130
+ f'<span>↓ {truck_in} down</span>'
131
+ f'<span>↑ {truck_out} up</span>'
132
+ '</div></div>'
 
 
 
 
 
 
 
 
133
  )
 
134
 
135
+ progress = (
136
+ '<div style="margin:8px 0 4px 0;">'
137
+ '<div style="display:flex;justify-content:space-between;font-size:11px;'
138
+ 'color:#6b7280;margin-bottom:4px;">'
139
+ f'<span>frame {frame_idx} / {total}</span>'
140
+ f'<span>{pct:.1f}% · {elapsed:.1f}s</span>'
141
+ '</div>'
142
+ '<div style="height:6px;background:#e5e7eb;border-radius:3px;overflow:hidden;">'
143
+ f'<div style="height:100%;width:{pct}%;background:#6366f1;transition:width 0.2s;"></div>'
144
+ '</div></div>'
145
+ )
146
 
147
+ return hero + progress
 
 
 
 
 
 
 
 
 
 
148
 
149
 
150
+ def build_summary_md(truck_total: int, truck_in: int, truck_out: int) -> str:
151
+ if truck_total == 0:
152
+ return ("### ℹ️ No trucks crossed the center line.\n"
153
+ "Try a lower confidence threshold or a smaller frame stride.")
154
+ return (f"### 🚛 Total trucks crossed: **{truck_total}**\n\n"
155
+ f"- Going down: {truck_in}\n"
156
+ f"- Going up: {truck_out}")
157
 
158
 
159
  # ---------------------------------------------------------------------------
160
  # Main streaming generator
161
  # ---------------------------------------------------------------------------
162
  def process_video(video_path, confidence, frame_stride):
 
 
 
 
163
  if video_path is None:
164
  yield (None,
165
  '<div style="padding:12px;color:#b91c1c;">⚠️ Please upload a video first.</div>',
166
  "Submit a video to start.",
167
+ None)
 
168
  return
169
 
170
  video_info = sv.VideoInfo.from_video_path(video_path)
171
+ width, height = video_info.width, video_info.height
172
+
173
+ # ---- Line zone fixed at vertical center ----
174
+ line_y = height // 2
175
+ line_zone = sv.LineZone(
176
+ start=sv.Point(0, line_y),
177
+ end=sv.Point(width, line_y),
178
+ )
179
+
180
+ # ---- Annotators (sized to frame) ----
181
+ scale = max(0.5, width / 1280)
182
+ box_ann = sv.BoxAnnotator(thickness=max(2, int(2 * scale)))
183
+ label_ann = sv.LabelAnnotator(
184
+ text_scale=0.5 * scale,
185
+ text_thickness=max(1, int(1 * scale)),
186
+ text_padding=4,
187
+ )
188
+ trace_ann = sv.TraceAnnotator(thickness=max(2, int(2 * scale)), trace_length=40)
189
+ # We draw the line ourselves so we control the label completely
190
+ line_ann = sv.LineZoneAnnotator(
191
+ thickness=max(2, int(3 * scale)),
192
+ text_thickness=1, text_scale=0.01, # effectively hide default text
193
+ display_in_count=False,
194
+ display_out_count=False,
195
+ )
196
+
197
  frame_gen = sv.get_video_frames_generator(video_path)
198
  tracker = sv.ByteTrack(frame_rate=int(video_info.fps))
199
 
 
200
  last_detections = sv.Detections.empty()
 
201
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
202
  total = video_info.total_frames or 0
203
  start_time = time.time()
204
  last_yield = 0.0
205
+ last_rgb = None
206
 
207
  yield (None,
208
+ build_counts_html(0, 0, 0, 0, total, 0.0),
209
  "### 🎬 Starting analysis…",
210
+ None)
 
 
 
211
 
212
  with sv.VideoSink(target_path=out_path, video_info=video_info) as sink:
213
  for i, frame in enumerate(frame_gen):
 
224
  detections = tracker.update_with_detections(detections)
225
  last_detections = detections
226
 
227
+ # ---- Only trucks feed the line zone ----
228
+ if len(detections) > 0:
229
+ truck_mask = detections.class_id == TRUCK_CLASS_ID
230
+ truck_detections = detections[truck_mask]
231
+ if len(truck_detections) > 0:
232
+ line_zone.trigger(truck_detections)
233
  else:
234
  detections = last_detections
235
 
236
+ # ---- Annotate everything detected (visual richness) ----
237
  if len(detections) > 0:
238
  tids = (detections.tracker_id
239
  if detections.tracker_id is not None
 
246
  tid_str = f"#{tid} " if tid is not None else ""
247
  labels.append(f"{tid_str}{display} {conf:.2f}")
248
 
249
+ frame = trace_ann.annotate(scene=frame, detections=detections)
250
+ frame = box_ann.annotate(scene=frame, detections=detections)
251
+ frame = label_ann.annotate(scene=frame, detections=detections,
252
+ labels=labels)
253
 
254
+ # ---- Draw line + custom truck counter on the line ----
255
+ frame = line_ann.annotate(frame=frame, line_counter=line_zone)
256
+ truck_in = int(line_zone.in_count)
257
+ truck_out = int(line_zone.out_count)
258
+ truck_total = truck_in + truck_out
259
+ frame = draw_truck_label_on_line(frame, line_y, truck_total)
260
 
261
  sink.write_frame(frame)
262
  last_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
263
 
264
+ # ---- Yield to UI (throttled ~5/sec) ----
265
  now = time.time()
266
  if now - last_yield > 0.20 or i == total - 1:
267
  last_yield = now
268
+ elapsed = time.time() - start_time
269
  yield (last_rgb,
270
+ build_counts_html(truck_total, truck_in, truck_out,
271
+ i + 1, total, elapsed),
272
  "### 🔴 Live analysis in progress…",
273
+ None)
 
274
 
 
275
  elapsed = time.time() - start_time
276
+ truck_in = int(line_zone.in_count)
277
+ truck_out = int(line_zone.out_count)
278
+ truck_total = truck_in + truck_out
279
  yield (last_rgb,
280
+ build_counts_html(truck_total, truck_in, truck_out, total, total, elapsed),
281
+ build_summary_md(truck_total, truck_in, truck_out),
282
+ out_path)
 
283
 
284
 
285
  # ---------------------------------------------------------------------------
 
296
  footer {visibility: hidden;}
297
  """
298
 
299
+ with gr.Blocks(title="RF-DETR Truck Counter") as demo:
300
 
301
  with gr.Row(elem_id="title-row"):
302
  gr.Markdown(
303
  """
304
+ # 🚛 RF-DETR Truck Counter
305
+ Counts trucks crossing a horizontal line at the **center of the video**
306
+ in either direction. Powered by
307
+ [RF-DETR Nano](https://github.com/roboflow/rf-detr) + ByteTrack +
308
+ `sv.LineZone`.
309
  """
310
  )
311
 
 
349
  with gr.Column(scale=3):
350
  gr.Markdown("### 🔴 Live View")
351
  live_frame = gr.Image(
 
352
  show_label=False,
353
  elem_id="live-frame",
354
+ height=440,
355
  )
356
+ with gr.Column(scale=1, min_width=240):
357
  gr.Markdown("### 📊 Live Counts")
358
  live_counts = gr.HTML(
359
+ value=build_counts_html(0, 0, 0, 0, 0, 0)
360
  )
361
 
362
  # ---------- Bottom: final results ----------
 
367
  video_output = gr.Video(label="Download / replay", height=260)
368
  with gr.Column(scale=1):
369
  with gr.Group(elem_classes="card"):
370
+ gr.Markdown("### 📈 Final summary")
371
  summary_output = gr.Markdown("Run an analysis to see results.")
 
 
 
 
 
 
372
 
373
  submit_btn.click(
374
  fn=process_video,
375
  inputs=[video_input, confidence, frame_stride],
376
+ outputs=[live_frame, live_counts, summary_output, video_output],
 
377
  )
378
 
379