lyimo commited on
Commit
a4a341f
·
verified ·
1 Parent(s): 912ec81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -78
app.py CHANGED
@@ -1,7 +1,8 @@
1
  """
2
- RF-DETR Object Counter — Gradio app for Hugging Face Spaces.
3
- Counts people, bicycles, cars, trucks, and animals in video using
4
- RF-DETR Medium + ByteTrack (so each object is counted only once).
 
5
  """
6
 
7
  import os
@@ -12,47 +13,47 @@ import cv2
12
  import gradio as gr
13
  import numpy as np
14
  import supervision as sv
 
15
  from rfdetr import RFDETRMedium
16
- from rfdetr.assets.coco_classes import COCO_CLASSES
17
 
18
  # ---------------------------------------------------------------------------
19
- # Target classes (COCO indices) — exactly what the user asked for
20
  # ---------------------------------------------------------------------------
 
 
 
 
21
  TARGET_CLASSES = {
22
  0: "person",
23
- 1: "bicycle",
24
  2: "car",
25
  7: "truck",
26
- # animals
27
- 14: "bird",
28
- 15: "cat",
29
  16: "dog",
30
- 17: "horse",
31
- 18: "sheep",
32
  19: "cow",
33
- 20: "elephant",
34
- 21: "bear",
35
- 22: "zebra",
36
- 23: "giraffe",
37
  }
38
  TARGET_IDS = list(TARGET_CLASSES.keys())
39
 
40
- # Per-class colour palette (BGR) for the live overlay
 
 
 
 
 
 
 
 
 
 
 
41
  CLASS_COLORS = {
42
- "person": (66, 135, 245),
43
- "bicycle": (245, 173, 66),
44
- "car": (66, 245, 167),
45
- "truck": (245, 66, 161),
46
- "bird": (245, 230, 66),
47
- "cat": (200, 120, 245),
48
- "dog": (120, 245, 200),
49
- "horse": (245, 120, 120),
50
- "sheep": (220, 220, 220),
51
- "cow": (140, 90, 60),
52
- "elephant": (160, 160, 200),
53
- "bear": (90, 60, 30),
54
- "zebra": (40, 40, 40),
55
- "giraffe": (220, 180, 90),
56
  }
57
 
58
  # Example video lives next to app.py
@@ -60,15 +61,26 @@ APP_DIR = os.path.dirname(os.path.abspath(__file__))
60
  EXAMPLE_VIDEO = os.path.join(APP_DIR, "example.mp4")
61
 
62
  # ---------------------------------------------------------------------------
63
- # Load model once at startup
64
  # ---------------------------------------------------------------------------
65
- print("Loading RF-DETR Medium…")
66
- MODEL = RFDETRMedium()
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
- MODEL.optimize_for_inference() # speeds up subsequent predicts
69
- print("Model optimized for inference.")
70
- except Exception as e:
71
- print(f"(Optimization skipped: {e})")
72
  print("Model ready.")
73
 
74
  # Annotators
@@ -77,12 +89,12 @@ LABEL_ANNOTATOR = sv.LabelAnnotator(text_scale=0.45, text_thickness=1, text_padd
77
 
78
 
79
  def draw_counter_panel(frame: np.ndarray, counts: dict) -> np.ndarray:
80
- """Translucent counter panel in the top-left corner."""
81
  active = [(name, n) for name, n in counts.items() if n > 0]
82
  if not active:
83
  active = [("No targets yet", 0)]
84
 
85
- panel_w = 230
86
  panel_h = 40 + 22 * len(active)
87
  overlay = frame.copy()
88
  cv2.rectangle(overlay, (12, 12), (12 + panel_w, 12 + panel_h), (20, 20, 20), -1)
@@ -95,13 +107,15 @@ def draw_counter_panel(frame: np.ndarray, counts: dict) -> np.ndarray:
95
  for name, n in active:
96
  color = CLASS_COLORS.get(name, (200, 200, 200))
97
  cv2.circle(frame, (28, y - 5), 5, color, -1)
98
- cv2.putText(frame, f"{name}: {n}", (44, y),
 
99
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (240, 240, 240), 1, cv2.LINE_AA)
100
  y += 22
101
  return frame
102
 
103
 
104
- def process_video(video_path, confidence, frame_stride, progress=gr.Progress(track_tqdm=True)):
 
105
  if video_path is None:
106
  return None, "⚠️ Please upload a video first.", []
107
 
@@ -109,15 +123,19 @@ def process_video(video_path, confidence, frame_stride, progress=gr.Progress(tra
109
  frame_gen = sv.get_video_frames_generator(video_path)
110
  tracker = sv.ByteTrack(frame_rate=int(video_info.fps))
111
 
112
- unique_ids = defaultdict(set) # class_name -> {tracker_id, ...}
113
  last_detections = sv.Detections.empty()
114
 
115
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
116
 
117
  with sv.VideoSink(target_path=out_path, video_info=video_info) as sink:
118
- for i, frame in enumerate(progress.tqdm(frame_gen, total=video_info.total_frames,
119
- desc="Analyzing video")):
120
- # Detect every Nth frame; reuse previous detections in-between to keep video smooth
 
 
 
 
121
  if i % frame_stride == 0:
122
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
123
  detections = MODEL.predict(rgb, threshold=confidence)
@@ -130,7 +148,7 @@ def process_video(video_path, confidence, frame_stride, progress=gr.Progress(tra
130
  detections = tracker.update_with_detections(detections)
131
  last_detections = detections
132
 
133
- # Register unique IDs per class
134
  for cid, tid in zip(detections.class_id, detections.tracker_id):
135
  if tid is None:
136
  continue
@@ -140,17 +158,18 @@ def process_video(video_path, confidence, frame_stride, progress=gr.Progress(tra
140
  else:
141
  detections = last_detections
142
 
143
- # Annotate
144
  if len(detections) > 0:
145
- labels = [
146
- f"#{tid} {TARGET_CLASSES.get(int(cid), 'obj')} {conf:.2f}"
147
- for cid, tid, conf in zip(
148
- detections.class_id,
149
- detections.tracker_id if detections.tracker_id is not None
150
- else [None] * len(detections),
151
- detections.confidence,
152
- )
153
- ]
 
154
  frame = BOX_ANNOTATOR.annotate(frame, detections)
155
  frame = LABEL_ANNOTATOR.annotate(frame, detections, labels)
156
 
@@ -158,21 +177,26 @@ def process_video(video_path, confidence, frame_stride, progress=gr.Progress(tra
158
  frame = draw_counter_panel(frame, counts_now)
159
  sink.write_frame(frame)
160
 
161
- # Build summary outputs
162
  total = sum(len(ids) for ids in unique_ids.values())
163
  if total == 0:
164
- summary_md = "### ℹ️ No target objects detected.\nTry lowering the confidence threshold."
 
165
  else:
166
  lines = [f"### ✅ Total unique objects detected: **{total}**", ""]
167
  for name in TARGET_CLASSES.values():
168
  n = len(unique_ids.get(name, set()))
169
  if n > 0:
170
- lines.append(f"- **{name.capitalize()}** — {n}")
 
171
  summary_md = "\n".join(lines)
172
 
173
- table = [[name.capitalize(), len(unique_ids.get(name, set()))]
174
- for name in TARGET_CLASSES.values()
175
- if len(unique_ids.get(name, set())) > 0]
 
 
 
176
  if not table:
177
  table = [["—", 0]]
178
 
@@ -192,15 +216,14 @@ CUSTOM_CSS = """
192
  footer {visibility: hidden;}
193
  """
194
 
195
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
196
- css=CUSTOM_CSS, title="RF-DETR Object Counter") as demo:
197
 
198
  with gr.Row(elem_id="title-row"):
199
  gr.Markdown(
200
  """
201
- # 🚦 RF-DETR Object Counter
202
- Count **people, bicycles, cars, trucks, and animals** in any video.
203
- Powered by [RF-DETR Medium](https://github.com/roboflow/rf-detr) (Roboflow, ICLR 2026) and ByteTrack —
204
  each object is counted **only once** as it moves across frames.
205
  """
206
  )
@@ -218,14 +241,14 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate")
218
 
219
  with gr.Accordion("⚙️ Advanced settings", open=False):
220
  confidence = gr.Slider(
221
- minimum=0.1, maximum=0.9, value=0.5, step=0.05,
222
  label="Confidence threshold",
223
  info="Higher = fewer but more certain detections.",
224
  )
225
  frame_stride = gr.Slider(
226
- minimum=1, maximum=10, value=2, step=1,
227
- label="Frame stride",
228
- info="Process every Nth frame. Higher = faster, slightly less accurate.",
229
  )
230
 
231
  submit_btn = gr.Button("🔍 Count Objects", variant="primary", size="lg")
@@ -253,12 +276,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate")
253
 
254
  gr.Markdown(
255
  """
256
- ---
257
- **Detected categories:** person · bicycle · car · truck · bird · cat · dog · horse ·
258
- sheep · cow · elephant · bear · zebra · giraffe
259
-
260
- **Tip:** the first run loads the model (≈45–90 s for Medium). Subsequent runs are much faster.
261
- Use *Frame stride* if processing is slow on CPU.
262
  """
263
  )
264
 
@@ -270,4 +288,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate")
270
 
271
 
272
  if __name__ == "__main__":
273
- demo.queue(max_size=8).launch()
 
 
 
 
1
  """
2
+ RF-DETR Object Counter — CPU-optimized Gradio app for Hugging Face Spaces.
3
+ Counts people, cars, trucks, and farm animals (cow, sheep/goat, horse/donkey,
4
+ dog) in video using RF-DETR Medium + ByteTrack so each object is counted
5
+ only once across the whole video.
6
  """
7
 
8
  import os
 
13
  import gradio as gr
14
  import numpy as np
15
  import supervision as sv
16
+ import torch
17
  from rfdetr import RFDETRMedium
 
18
 
19
  # ---------------------------------------------------------------------------
20
+ # Target classes (COCO indices)
21
  # ---------------------------------------------------------------------------
22
+ # Note: COCO does NOT contain "goat" or "donkey". We approximate:
23
+ # goat ~ sheep (closest 4-legged ruminant in COCO)
24
+ # donkey ~ horse (closest equid in COCO)
25
+ # Counts for these will be roughly right; labels will say sheep/horse.
26
  TARGET_CLASSES = {
27
  0: "person",
 
28
  2: "car",
29
  7: "truck",
 
 
 
30
  16: "dog",
31
+ 17: "horse", # also catches donkeys
32
+ 18: "sheep", # also catches goats
33
  19: "cow",
 
 
 
 
34
  }
35
  TARGET_IDS = list(TARGET_CLASSES.keys())
36
 
37
+ # Friendly UI labels
38
+ DISPLAY_NAMES = {
39
+ "person": "person",
40
+ "car": "car",
41
+ "truck": "truck",
42
+ "dog": "dog",
43
+ "horse": "horse / donkey",
44
+ "sheep": "sheep / goat",
45
+ "cow": "cow",
46
+ }
47
+
48
+ # Per-class colours for the live overlay panel (BGR)
49
  CLASS_COLORS = {
50
+ "person": (66, 135, 245),
51
+ "car": (66, 245, 167),
52
+ "truck": (245, 66, 161),
53
+ "dog": (120, 245, 200),
54
+ "horse": (245, 120, 120),
55
+ "sheep": (220, 220, 220),
56
+ "cow": (140, 90, 60),
 
 
 
 
 
 
 
57
  }
58
 
59
  # Example video lives next to app.py
 
61
  EXAMPLE_VIDEO = os.path.join(APP_DIR, "example.mp4")
62
 
63
  # ---------------------------------------------------------------------------
64
+ # Load model pinned to CPU
65
  # ---------------------------------------------------------------------------
66
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
67
+ print(f"Loading RF-DETR Medium on {DEVICE}…")
68
+ MODEL = RFDETRMedium(device=DEVICE)
69
+
70
+ # optimize_for_inference is GPU-only (TensorRT-style ops). Skip on CPU.
71
+ if DEVICE == "cuda":
72
+ try:
73
+ MODEL.optimize_for_inference()
74
+ print("Optimized for GPU inference.")
75
+ except Exception as e:
76
+ print(f"GPU optimization skipped: {e}")
77
+
78
+ # Use a few threads for torch CPU inference; tune to your Space's vCPU count
79
  try:
80
+ torch.set_num_threads(max(1, (os.cpu_count() or 2) - 1))
81
+ except Exception:
82
+ pass
83
+
84
  print("Model ready.")
85
 
86
  # Annotators
 
89
 
90
 
91
  def draw_counter_panel(frame: np.ndarray, counts: dict) -> np.ndarray:
92
+ """Translucent live-count panel in the top-left corner of the frame."""
93
  active = [(name, n) for name, n in counts.items() if n > 0]
94
  if not active:
95
  active = [("No targets yet", 0)]
96
 
97
+ panel_w = 280
98
  panel_h = 40 + 22 * len(active)
99
  overlay = frame.copy()
100
  cv2.rectangle(overlay, (12, 12), (12 + panel_w, 12 + panel_h), (20, 20, 20), -1)
 
107
  for name, n in active:
108
  color = CLASS_COLORS.get(name, (200, 200, 200))
109
  cv2.circle(frame, (28, y - 5), 5, color, -1)
110
+ display = DISPLAY_NAMES.get(name, name)
111
+ cv2.putText(frame, f"{display}: {n}", (44, y),
112
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (240, 240, 240), 1, cv2.LINE_AA)
113
  y += 22
114
  return frame
115
 
116
 
117
+ def process_video(video_path, confidence, frame_stride,
118
+ progress=gr.Progress(track_tqdm=True)):
119
  if video_path is None:
120
  return None, "⚠️ Please upload a video first.", []
121
 
 
123
  frame_gen = sv.get_video_frames_generator(video_path)
124
  tracker = sv.ByteTrack(frame_rate=int(video_info.fps))
125
 
126
+ unique_ids = defaultdict(set) # class_name -> {tracker_id, ...}
127
  last_detections = sv.Detections.empty()
128
 
129
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
130
 
131
  with sv.VideoSink(target_path=out_path, video_info=video_info) as sink:
132
+ for i, frame in enumerate(progress.tqdm(
133
+ frame_gen,
134
+ total=video_info.total_frames,
135
+ desc="Analyzing video")):
136
+
137
+ # Detect every Nth frame; reuse previous detections in-between
138
+ # so the output video stays smooth even with high stride.
139
  if i % frame_stride == 0:
140
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
141
  detections = MODEL.predict(rgb, threshold=confidence)
 
148
  detections = tracker.update_with_detections(detections)
149
  last_detections = detections
150
 
151
+ # Register unique tracker IDs per class
152
  for cid, tid in zip(detections.class_id, detections.tracker_id):
153
  if tid is None:
154
  continue
 
158
  else:
159
  detections = last_detections
160
 
161
+ # Annotate frame
162
  if len(detections) > 0:
163
+ tids = (detections.tracker_id
164
+ if detections.tracker_id is not None
165
+ else [None] * len(detections))
166
+ labels = []
167
+ for cid, tid, conf in zip(detections.class_id, tids, detections.confidence):
168
+ name = TARGET_CLASSES.get(int(cid), "obj")
169
+ display = DISPLAY_NAMES.get(name, name)
170
+ tid_str = f"#{tid} " if tid is not None else ""
171
+ labels.append(f"{tid_str}{display} {conf:.2f}")
172
+
173
  frame = BOX_ANNOTATOR.annotate(frame, detections)
174
  frame = LABEL_ANNOTATOR.annotate(frame, detections, labels)
175
 
 
177
  frame = draw_counter_panel(frame, counts_now)
178
  sink.write_frame(frame)
179
 
180
+ # ---------- Build summary outputs ----------
181
  total = sum(len(ids) for ids in unique_ids.values())
182
  if total == 0:
183
+ summary_md = ("### ℹ️ No target objects detected.\n"
184
+ "Try lowering the confidence threshold or the frame stride.")
185
  else:
186
  lines = [f"### ✅ Total unique objects detected: **{total}**", ""]
187
  for name in TARGET_CLASSES.values():
188
  n = len(unique_ids.get(name, set()))
189
  if n > 0:
190
+ display = DISPLAY_NAMES.get(name, name).capitalize()
191
+ lines.append(f"- **{display}** — {n}")
192
  summary_md = "\n".join(lines)
193
 
194
+ table = []
195
+ for name in TARGET_CLASSES.values():
196
+ n = len(unique_ids.get(name, set()))
197
+ if n > 0:
198
+ display = DISPLAY_NAMES.get(name, name).capitalize()
199
+ table.append([display, n])
200
  if not table:
201
  table = [["—", 0]]
202
 
 
216
  footer {visibility: hidden;}
217
  """
218
 
219
+ with gr.Blocks(title="RF-DETR Object Counter") as demo:
 
220
 
221
  with gr.Row(elem_id="title-row"):
222
  gr.Markdown(
223
  """
224
+ # 🐄 RF-DETR Object Counter
225
+ Count **people, cars, trucks, and farm animals** in any video.
226
+ Powered by [RF-DETR Medium](https://github.com/roboflow/rf-detr) + ByteTrack —
227
  each object is counted **only once** as it moves across frames.
228
  """
229
  )
 
241
 
242
  with gr.Accordion("⚙️ Advanced settings", open=False):
243
  confidence = gr.Slider(
244
+ minimum=0.1, maximum=0.9, value=0.45, step=0.05,
245
  label="Confidence threshold",
246
  info="Higher = fewer but more certain detections.",
247
  )
248
  frame_stride = gr.Slider(
249
+ minimum=1, maximum=15, value=5, step=1,
250
+ label="Frame stride (CPU speed control)",
251
+ info="Process every Nth frame. On CPU, 5–8 is a good balance.",
252
  )
253
 
254
  submit_btn = gr.Button("🔍 Count Objects", variant="primary", size="lg")
 
276
 
277
  gr.Markdown(
278
  """
279
+ **Detected categories:** person · car · truck · dog · horse / donkey · sheep / goat · cow
 
 
 
 
 
280
  """
281
  )
282
 
 
288
 
289
 
290
  if __name__ == "__main__":
291
+ demo.queue(max_size=4).launch(
292
+ theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
293
+ css=CUSTOM_CSS,
294
+ )