lyimo commited on
Commit
7176850
·
verified ·
1 Parent(s): 6ebe736

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +831 -301
app.py CHANGED
@@ -1,384 +1,914 @@
1
  """
2
- RF-DETR Truck Counter counts trucks crossing a fixed horizontal line
3
- at the center of the video (either direction counts as one crossing).
4
  """
5
 
6
  import os
7
- import tempfile
8
  import time
9
- from collections import defaultdict
 
 
10
 
11
  import cv2
12
  import gradio as gr
 
 
 
13
  import numpy as np
 
14
  import supervision as sv
15
  import torch
16
- from rfdetr import RFDETRNano
17
-
18
- # ---------------------------------------------------------------------------
19
- # Target classes (COCO indices)
20
- # ---------------------------------------------------------------------------
21
- # We still DETECT all of these so users see them tracked on screen,
22
- # but only TRUCKS contribute to the line-crossing count.
23
- TARGET_CLASSES = {
24
- 0: "person",
25
- 2: "car",
26
- 7: "truck",
27
- 16: "dog",
28
- 17: "horse",
29
- 18: "sheep",
30
- 19: "cow",
31
  }
32
- TARGET_IDS = list(TARGET_CLASSES.keys())
33
- TRUCK_CLASS_ID = 7
34
-
35
- DISPLAY_NAMES = {
36
- "person": "person",
37
- "car": "car",
38
- "truck": "truck",
39
- "dog": "dog",
40
- "horse": "horse / donkey",
41
- "sheep": "sheep / goat",
42
- "cow": "cow",
43
  }
44
 
45
- CLASS_COLORS = { # BGR
46
- "person": (66, 135, 245),
47
- "car": (66, 245, 167),
48
- "truck": (245, 66, 161),
49
- "dog": (120, 245, 200),
50
- "horse": (245, 120, 120),
51
- "sheep": (220, 220, 220),
52
- "cow": (140, 90, 60),
53
  }
54
 
55
- APP_DIR = os.path.dirname(os.path.abspath(__file__))
56
- EXAMPLE_VIDEO = os.path.join(APP_DIR, "example.mp4")
 
 
57
 
58
- # ---------------------------------------------------------------------------
59
- # Load model — CPU-pinned
60
- # ---------------------------------------------------------------------------
61
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
62
- print(f"Loading RF-DETR Nano on {DEVICE}…")
63
- MODEL = RFDETRNano(device=DEVICE)
64
-
65
- if DEVICE == "cuda":
66
- try:
67
- MODEL.optimize_for_inference()
68
- print("Optimized for GPU inference.")
69
- except Exception as e:
70
- print(f"GPU optimization skipped: {e}")
71
 
72
  try:
73
  torch.set_num_threads(max(1, (os.cpu_count() or 2) - 1))
74
  except Exception:
75
  pass
76
 
77
- print("Model ready.")
78
-
79
 
80
- # ---------------------------------------------------------------------------
81
- # Draw a custom truck counter directly on the centerline
82
- # ---------------------------------------------------------------------------
83
- def draw_truck_label_on_line(frame: np.ndarray, line_y: int, total: int) -> np.ndarray:
84
- text = f"TRUCKS CROSSED: {total}"
85
- font = cv2.FONT_HERSHEY_SIMPLEX
86
- scale = max(0.6, frame.shape[1] / 1600)
87
- thickness = max(2, int(2 * scale))
88
 
89
- (tw, th), baseline = cv2.getTextSize(text, font, scale, thickness)
90
- pad_x, pad_y = 14, 10
91
- box_w = tw + 2 * pad_x
92
- box_h = th + 2 * pad_y + baseline
 
 
 
93
 
94
- cx = frame.shape[1] // 2
95
- x1 = cx - box_w // 2
96
- y1 = line_y - box_h // 2
97
- x2 = x1 + box_w
98
- y2 = y1 + box_h
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  overlay = frame.copy()
101
- cv2.rectangle(overlay, (x1, y1), (x2, y2), (245, 66, 161), -1) # truck-pink
102
- frame = cv2.addWeighted(overlay, 0.88, frame, 0.12, 0)
103
- cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 255, 255), 2)
104
-
105
- text_x = x1 + pad_x
106
- text_y = y1 + pad_y + th
107
- cv2.putText(frame, text, (text_x, text_y),
108
- font, scale, (255, 255, 255), thickness, cv2.LINE_AA)
 
 
 
 
 
 
 
 
109
  return frame
110
 
111
 
112
- # ---------------------------------------------------------------------------
113
- # HTML side panel
114
- # ---------------------------------------------------------------------------
115
- def build_counts_html(truck_total: int, truck_in: int, truck_out: int,
116
- frame_idx: int, total: int, elapsed: float) -> str:
117
- pct = (frame_idx / total * 100) if total else 0
118
-
119
- hero = (
120
- '<div style="text-align:center;padding:22px 14px;margin-bottom:14px;'
121
- 'background:linear-gradient(135deg,#f5318b,#c2185b);color:white;'
122
- 'border-radius:14px;box-shadow:0 4px 12px rgba(245,49,139,0.35);">'
123
- '<div style="font-size:32px;line-height:1;">🚛</div>'
124
- '<div style="font-size:11px;opacity:0.9;letter-spacing:1.5px;'
125
- 'margin-top:6px;">TRUCKS CROSSED</div>'
126
- f'<div style="font-size:54px;font-weight:800;line-height:1.0;'
127
- f'margin-top:4px;">{truck_total}</div>'
128
- '<div style="display:flex;justify-content:center;gap:14px;'
129
- 'margin-top:10px;font-size:12px;opacity:0.95;">'
130
- f'<span>↓ {truck_in} down</span>'
131
- f'<span>↑ {truck_out} up</span>'
132
- '</div></div>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  )
134
 
135
- progress = (
136
- '<div style="margin:8px 0 4px 0;">'
137
- '<div style="display:flex;justify-content:space-between;font-size:11px;'
138
- 'color:#6b7280;margin-bottom:4px;">'
139
- f'<span>frame {frame_idx} / {total}</span>'
140
- f'<span>{pct:.1f}% · {elapsed:.1f}s</span>'
141
- '</div>'
142
- '<div style="height:6px;background:#e5e7eb;border-radius:3px;overflow:hidden;">'
143
- f'<div style="height:100%;width:{pct}%;background:#6366f1;transition:width 0.2s;"></div>'
144
- '</div></div>'
145
- )
146
 
147
- return hero + progress
148
 
149
 
150
- def build_summary_md(truck_total: int, truck_in: int, truck_out: int) -> str:
151
- if truck_total == 0:
152
- return ("### ℹ️ No trucks crossed the center line.\n"
153
- "Try a lower confidence threshold or a smaller frame stride.")
154
- return (f"### 🚛 Total trucks crossed: **{truck_total}**\n\n"
155
- f"- ↓ Going down: {truck_in}\n"
156
- f"- ↑ Going up: {truck_out}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
 
159
- # ---------------------------------------------------------------------------
160
- # Main streaming generator
161
- # ---------------------------------------------------------------------------
162
- def process_video(video_path, confidence, frame_stride):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  if video_path is None:
164
- yield (None,
165
- '<div style="padding:12px;color:#b91c1c;">⚠️ Please upload a video first.</div>',
166
- "Submit a video to start.",
167
- None)
 
 
 
 
168
  return
169
 
170
- video_info = sv.VideoInfo.from_video_path(video_path)
171
- width, height = video_info.width, video_info.height
172
-
173
- # ---- Line zone fixed at vertical center ----
174
- line_y = height // 2
175
- line_zone = sv.LineZone(
176
- start=sv.Point(0, line_y),
177
- end=sv.Point(width, line_y),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  )
179
 
180
- # ---- Annotators (sized to frame) ----
181
- scale = max(0.5, width / 1280)
182
- box_ann = sv.BoxAnnotator(thickness=max(2, int(2 * scale)))
183
- label_ann = sv.LabelAnnotator(
184
- text_scale=0.5 * scale,
185
- text_thickness=max(1, int(1 * scale)),
186
- text_padding=4,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
- trace_ann = sv.TraceAnnotator(thickness=max(2, int(2 * scale)), trace_length=40)
189
- # We draw the line ourselves so we control the label completely
190
- line_ann = sv.LineZoneAnnotator(
191
- thickness=max(2, int(3 * scale)),
192
- text_thickness=1, text_scale=0.01, # effectively hide default text
193
- display_in_count=False,
194
- display_out_count=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  )
196
 
197
- frame_gen = sv.get_video_frames_generator(video_path)
198
- tracker = sv.ByteTrack(frame_rate=int(video_info.fps))
199
 
200
- last_detections = sv.Detections.empty()
201
- out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
202
- total = video_info.total_frames or 0
203
- start_time = time.time()
204
- last_yield = 0.0
205
- last_rgb = None
206
-
207
- yield (None,
208
- build_counts_html(0, 0, 0, 0, total, 0.0),
209
- "### 🎬 Starting analysis…",
210
- None)
211
-
212
- with sv.VideoSink(target_path=out_path, video_info=video_info) as sink:
213
- for i, frame in enumerate(frame_gen):
214
-
215
- # ---- Detection (every Nth frame) ----
216
- if i % frame_stride == 0:
217
- rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
218
- detections = MODEL.predict(rgb, threshold=confidence)
219
-
220
- if len(detections) > 0:
221
- mask = np.isin(detections.class_id, TARGET_IDS)
222
- detections = detections[mask]
223
-
224
- detections = tracker.update_with_detections(detections)
225
- last_detections = detections
226
-
227
- # ---- Only trucks feed the line zone ----
228
- if len(detections) > 0:
229
- truck_mask = detections.class_id == TRUCK_CLASS_ID
230
- truck_detections = detections[truck_mask]
231
- if len(truck_detections) > 0:
232
- line_zone.trigger(truck_detections)
233
- else:
234
- detections = last_detections
235
-
236
- # ---- Annotate everything detected (visual richness) ----
237
- if len(detections) > 0:
238
- tids = (detections.tracker_id
239
- if detections.tracker_id is not None
240
- else [None] * len(detections))
241
- labels = []
242
- for cid, tid, conf in zip(detections.class_id, tids,
243
- detections.confidence):
244
- name = TARGET_CLASSES.get(int(cid), "obj")
245
- display = DISPLAY_NAMES.get(name, name)
246
- tid_str = f"#{tid} " if tid is not None else ""
247
- labels.append(f"{tid_str}{display} {conf:.2f}")
248
-
249
- frame = trace_ann.annotate(scene=frame, detections=detections)
250
- frame = box_ann.annotate(scene=frame, detections=detections)
251
- frame = label_ann.annotate(scene=frame, detections=detections,
252
- labels=labels)
253
-
254
- # ---- Draw line + custom truck counter on the line ----
255
- frame = line_ann.annotate(frame=frame, line_counter=line_zone)
256
- truck_in = int(line_zone.in_count)
257
- truck_out = int(line_zone.out_count)
258
- truck_total = truck_in + truck_out
259
- frame = draw_truck_label_on_line(frame, line_y, truck_total)
260
-
261
- sink.write_frame(frame)
262
- last_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
263
-
264
- # ---- Yield to UI (throttled ~5/sec) ----
265
- now = time.time()
266
- if now - last_yield > 0.20 or i == total - 1:
267
- last_yield = now
268
- elapsed = time.time() - start_time
269
- yield (last_rgb,
270
- build_counts_html(truck_total, truck_in, truck_out,
271
- i + 1, total, elapsed),
272
- "### 🔴 Live analysis in progress…",
273
- None)
274
-
275
- elapsed = time.time() - start_time
276
- truck_in = int(line_zone.in_count)
277
- truck_out = int(line_zone.out_count)
278
- truck_total = truck_in + truck_out
279
- yield (last_rgb,
280
- build_counts_html(truck_total, truck_in, truck_out, total, total, elapsed),
281
- build_summary_md(truck_total, truck_in, truck_out),
282
- out_path)
283
-
284
-
285
- # ---------------------------------------------------------------------------
286
- # UI
287
- # ---------------------------------------------------------------------------
288
  CUSTOM_CSS = """
289
- .gradio-container {max-width: 1240px !important; margin: auto;}
290
- #title-row {text-align: center; padding: 8px 0 0 0;}
291
- #title-row h1 {font-weight: 700; letter-spacing: -0.5px; margin-bottom: 4px;}
292
- #title-row p {color: #6b7280; margin-top: 0;}
293
- .card {border: 1px solid #e5e7eb; border-radius: 14px; padding: 16px;
294
- background: #ffffff;}
295
- #live-frame img {border-radius: 10px;}
296
- footer {visibility: hidden;}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  """
298
 
299
- with gr.Blocks(title="RF-DETR Truck Counter") as demo:
 
 
 
 
300
 
301
- with gr.Row(elem_id="title-row"):
302
  gr.Markdown(
303
  """
304
- # 🚛 RF-DETR Truck Counter
305
- Counts trucks crossing a horizontal line at the **center of the video** —
306
- in either direction. Powered by
307
- [RF-DETR Nano](https://github.com/roboflow/rf-detr) + ByteTrack +
308
- `sv.LineZone`.
309
  """
310
  )
311
 
312
  with gr.Row():
313
- # ---------- Left: input ----------
314
  with gr.Column(scale=1):
315
- with gr.Group(elem_classes="card"):
316
- gr.Markdown("### 📥 Input")
317
  video_input = gr.Video(
318
- label="Upload a video",
319
  sources=["upload"],
320
  format="mp4",
321
  height=260,
322
  )
323
 
324
- with gr.Accordion("⚙️ Advanced settings", open=False):
325
- confidence = gr.Slider(
326
- minimum=0.1, maximum=0.9, value=0.45, step=0.05,
327
- label="Confidence threshold",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  )
329
- frame_stride = gr.Slider(
330
- minimum=1, maximum=15, value=3, step=1,
331
- label="Frame stride (CPU speed)",
332
- info="Detect every Nth frame. Higher = faster.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  )
334
 
335
- submit_btn = gr.Button("▶️ Start Live Analysis",
336
- variant="primary", size="lg")
337
-
338
- gr.Markdown("#### 🎬 Example video")
339
- gr.Examples(
340
- examples=[[EXAMPLE_VIDEO]],
341
- inputs=video_input,
342
- examples_per_page=4,
 
 
 
343
  )
344
 
345
- # ---------- Right: live view ----------
346
  with gr.Column(scale=2):
347
- with gr.Group(elem_classes="card"):
348
- with gr.Row():
349
- with gr.Column(scale=3):
350
- gr.Markdown("### 🔴 Live View")
351
- live_frame = gr.Image(
352
- show_label=False,
353
- elem_id="live-frame",
354
- height=440,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  )
356
- with gr.Column(scale=1, min_width=240):
357
- gr.Markdown("### 📊 Live Counts")
358
- live_counts = gr.HTML(
359
- value=build_counts_html(0, 0, 0, 0, 0, 0)
 
 
 
 
 
360
  )
361
 
362
- # ---------- Bottom: final results ----------
363
  with gr.Row():
364
  with gr.Column(scale=1):
365
- with gr.Group(elem_classes="card"):
366
- gr.Markdown("### 📤 Final annotated video")
367
- video_output = gr.Video(label="Download / replay", height=260)
368
  with gr.Column(scale=1):
369
- with gr.Group(elem_classes="card"):
370
- gr.Markdown("### 📈 Final summary")
371
- summary_output = gr.Markdown("Run an analysis to see results.")
 
372
 
373
- submit_btn.click(
374
  fn=process_video,
375
- inputs=[video_input, confidence, frame_stride],
376
- outputs=[live_frame, live_counts, summary_output, video_output],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  )
378
 
379
 
380
  if __name__ == "__main__":
381
- demo.queue(max_size=4).launch(
382
- theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
383
- css=CUSTOM_CSS,
384
- )
 
1
  """
2
+ Bridge Traffic & Load Demo App
3
+ Fast RF-DETR + ByteTrack vehicle counting for bridge videos.
4
  """
5
 
6
  import os
 
7
  import time
8
+ import tempfile
9
+ from functools import lru_cache
10
+ from typing import Dict, List, Tuple
11
 
12
  import cv2
13
  import gradio as gr
14
+ import matplotlib
15
+ matplotlib.use("Agg")
16
+ import matplotlib.pyplot as plt
17
  import numpy as np
18
+ import pandas as pd
19
  import supervision as sv
20
  import torch
21
+
22
+ from rfdetr import RFDETRNano, RFDETRMedium
23
+
24
+
25
+ # ---------------------------------------------------------------------
26
+ # Vehicle classes from COCO
27
+ # ---------------------------------------------------------------------
28
+ # COCO IDs used by RF-DETR:
29
+ # 2 = car, 3 = motorcycle, 5 = bus, 7 = truck
30
+ VEHICLE_CLASSES: Dict[int, str] = {
31
+ 2: "car",
32
+ 3: "motorcycle",
33
+ 5: "bus",
34
+ 7: "truck",
 
35
  }
36
+
37
+ # Very rough demonstration weights in kg.
38
+ # Adjust these for your local traffic profile.
39
+ DEFAULT_WEIGHTS_KG: Dict[int, int] = {
40
+ 2: 1500, # car / small vehicle
41
+ 3: 250, # motorcycle
42
+ 5: 12000, # bus
43
+ 7: 18000, # truck / lorry
 
 
 
44
  }
45
 
46
+ CLASS_COLORS_BGR: Dict[int, Tuple[int, int, int]] = {
47
+ 2: (40, 190, 120), # car
48
+ 3: (255, 170, 70), # motorcycle
49
+ 5: (245, 120, 45), # bus
50
+ 7: (220, 70, 180), # truck
 
 
 
51
  }
52
 
53
+ MODEL_OPTIONS = {
54
+ "Nano - fastest": RFDETRNano,
55
+ "Medium - more accurate, slower": RFDETRMedium,
56
+ }
57
 
 
 
 
58
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
59
 
60
  try:
61
  torch.set_num_threads(max(1, (os.cpu_count() or 2) - 1))
62
  except Exception:
63
  pass
64
 
65
+ if DEVICE == "cuda":
66
+ torch.backends.cudnn.benchmark = True
67
 
 
 
 
 
 
 
 
 
68
 
69
+ # ---------------------------------------------------------------------
70
+ # Model loading
71
+ # ---------------------------------------------------------------------
72
+ @lru_cache(maxsize=2)
73
+ def load_model(model_name: str):
74
+ """Load RF-DETR once and reuse it across runs."""
75
+ model_cls = MODEL_OPTIONS[model_name]
76
 
77
+ print(f"Loading {model_name} on {DEVICE}...")
 
 
 
 
78
 
79
+ try:
80
+ model = model_cls(device=DEVICE)
81
+ except TypeError:
82
+ # Fallback for older RF-DETR builds
83
+ model = model_cls()
84
+
85
+ if DEVICE == "cuda":
86
+ try:
87
+ model.optimize_for_inference()
88
+ print("RF-DETR optimized for inference.")
89
+ except Exception as exc:
90
+ print(f"optimize_for_inference skipped: {exc}")
91
+
92
+ print("Model ready.")
93
+ return model
94
+
95
+
96
+ # ---------------------------------------------------------------------
97
+ # Detection helper
98
+ # ---------------------------------------------------------------------
99
+ def predict_vehicles(
100
+ model,
101
+ frame_bgr: np.ndarray,
102
+ confidence: float,
103
+ inference_width: int,
104
+ ) -> sv.Detections:
105
+ """
106
+ Resize frame before inference for speed, then scale boxes back to
107
+ original video coordinates.
108
+ """
109
+ h, w = frame_bgr.shape[:2]
110
+ inference_width = int(inference_width)
111
+
112
+ if inference_width > 0 and w > inference_width:
113
+ scale = inference_width / float(w)
114
+ resized_w = inference_width
115
+ resized_h = int(h * scale)
116
+ model_frame = cv2.resize(frame_bgr, (resized_w, resized_h), interpolation=cv2.INTER_AREA)
117
+ else:
118
+ scale = 1.0
119
+ model_frame = frame_bgr
120
+
121
+ frame_rgb = cv2.cvtColor(model_frame, cv2.COLOR_BGR2RGB)
122
+
123
+ with torch.inference_mode():
124
+ detections = model.predict(frame_rgb, threshold=float(confidence))
125
+
126
+ if len(detections) == 0:
127
+ return detections
128
+
129
+ # Keep only vehicle classes.
130
+ mask = np.isin(detections.class_id, list(VEHICLE_CLASSES.keys()))
131
+ detections = detections[mask]
132
+
133
+ if len(detections) == 0:
134
+ return detections
135
+
136
+ # Scale boxes back to original frame size.
137
+ if scale != 1.0:
138
+ detections.xyxy = detections.xyxy / scale
139
+
140
+ return detections
141
+
142
+
143
+ # ---------------------------------------------------------------------
144
+ # Counting and load helpers
145
+ # ---------------------------------------------------------------------
146
+ def side_of_line(y: float, line_y: int, dead_zone_px: int = 4) -> int:
147
+ """
148
+ Returns -1 above the line, +1 below the line, 0 inside a small dead zone.
149
+ The dead zone prevents jitter around the line from causing false crossings.
150
+ """
151
+ diff = y - line_y
152
+ if abs(diff) <= dead_zone_px:
153
+ return 0
154
+ return -1 if diff < 0 else 1
155
+
156
+
157
+ def detection_centres(detections: sv.Detections) -> np.ndarray:
158
+ if len(detections) == 0:
159
+ return np.empty((0, 2), dtype=float)
160
+ xyxy = detections.xyxy
161
+ cx = (xyxy[:, 0] + xyxy[:, 2]) / 2.0
162
+ cy = (xyxy[:, 1] + xyxy[:, 3]) / 2.0
163
+ return np.column_stack([cx, cy])
164
+
165
+
166
+ def get_class_weight_kg(class_id: int, weights: Dict[int, int]) -> int:
167
+ return int(weights.get(int(class_id), 0))
168
+
169
+
170
+ def draw_header_panel(
171
+ frame: np.ndarray,
172
+ total_count: int,
173
+ cumulative_kg: float,
174
+ live_load_kg: float,
175
+ load_index_percent: float,
176
+ fps_text: str,
177
+ ) -> np.ndarray:
178
+ """Draw a clean dashboard panel at the top-left of the frame."""
179
  overlay = frame.copy()
180
+ x1, y1, x2, y2 = 18, 18, 520, 158
181
+ cv2.rectangle(overlay, (x1, y1), (x2, y2), (20, 24, 36), -1)
182
+ frame = cv2.addWeighted(overlay, 0.82, frame, 0.18, 0)
183
+
184
+ cv2.putText(frame, "BRIDGE TRAFFIC LOAD DEMO", (34, 46),
185
+ cv2.FONT_HERSHEY_SIMPLEX, 0.72, (255, 255, 255), 2, cv2.LINE_AA)
186
+
187
+ cv2.putText(frame, f"Vehicles crossed: {total_count}", (34, 78),
188
+ cv2.FONT_HERSHEY_SIMPLEX, 0.62, (230, 240, 255), 2, cv2.LINE_AA)
189
+
190
+ cv2.putText(frame, f"Cumulative estimated mass: {cumulative_kg / 1000.0:.1f} tonnes", (34, 106),
191
+ cv2.FONT_HERSHEY_SIMPLEX, 0.58, (220, 240, 230), 2, cv2.LINE_AA)
192
+
193
+ cv2.putText(frame, f"Live load: {live_load_kg / 1000.0:.1f} t | Load index: {load_index_percent:.1f}% | {fps_text}", (34, 134),
194
+ cv2.FONT_HERSHEY_SIMPLEX, 0.52, (230, 230, 255), 1, cv2.LINE_AA)
195
+
196
  return frame
197
 
198
 
199
+ def annotate_frame(
200
+ frame: np.ndarray,
201
+ detections: sv.Detections,
202
+ line_y: int,
203
+ roi_top_y: int,
204
+ roi_bottom_y: int,
205
+ class_counts: Dict[str, int],
206
+ total_count: int,
207
+ cumulative_kg: float,
208
+ live_load_kg: float,
209
+ load_index_percent: float,
210
+ fps_text: str,
211
+ ) -> np.ndarray:
212
+ """Draw ROI, counting line, boxes, labels and dashboard."""
213
+ h, w = frame.shape[:2]
214
+
215
+ # Bridge deck ROI overlay
216
+ overlay = frame.copy()
217
+ cv2.rectangle(overlay, (0, roi_top_y), (w, roi_bottom_y), (80, 80, 80), -1)
218
+ frame = cv2.addWeighted(overlay, 0.08, frame, 0.92, 0)
219
+
220
+ # Counting line
221
+ cv2.line(frame, (0, line_y), (w, line_y), (40, 230, 255), 3)
222
+ cv2.putText(frame, "COUNTING LINE", (24, max(28, line_y - 12)),
223
+ cv2.FONT_HERSHEY_SIMPLEX, 0.58, (40, 230, 255), 2, cv2.LINE_AA)
224
+
225
+ # ROI borders
226
+ cv2.line(frame, (0, roi_top_y), (w, roi_top_y), (170, 170, 170), 1)
227
+ cv2.line(frame, (0, roi_bottom_y), (w, roi_bottom_y), (170, 170, 170), 1)
228
+
229
+ if len(detections) > 0:
230
+ tracker_ids = detections.tracker_id
231
+ if tracker_ids is None:
232
+ tracker_ids = [None] * len(detections)
233
+
234
+ confidences = detections.confidence
235
+ if confidences is None:
236
+ confidences = [0.0] * len(detections)
237
+
238
+ for xyxy, class_id, conf, track_id in zip(
239
+ detections.xyxy,
240
+ detections.class_id,
241
+ confidences,
242
+ tracker_ids,
243
+ ):
244
+ class_id = int(class_id)
245
+ x1, y1, x2, y2 = map(int, xyxy)
246
+ name = VEHICLE_CLASSES.get(class_id, "vehicle")
247
+ color = CLASS_COLORS_BGR.get(class_id, (80, 220, 255))
248
+ weight_t = DEFAULT_WEIGHTS_KG.get(class_id, 0) / 1000.0
249
+
250
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
251
+
252
+ id_text = f"#{int(track_id)} " if track_id is not None and int(track_id) >= 0 else ""
253
+ label = f"{id_text}{name} {float(conf):.2f} ~{weight_t:.1f}t"
254
+ (tw, th), base = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.52, 1)
255
+
256
+ label_y1 = max(0, y1 - th - base - 8)
257
+ cv2.rectangle(frame, (x1, label_y1), (x1 + tw + 10, y1), color, -1)
258
+ cv2.putText(frame, label, (x1 + 5, y1 - 6),
259
+ cv2.FONT_HERSHEY_SIMPLEX, 0.52, (255, 255, 255), 1, cv2.LINE_AA)
260
+
261
+ frame = draw_header_panel(
262
+ frame=frame,
263
+ total_count=total_count,
264
+ cumulative_kg=cumulative_kg,
265
+ live_load_kg=live_load_kg,
266
+ load_index_percent=load_index_percent,
267
+ fps_text=fps_text,
268
  )
269
 
270
+ # Compact class counts at bottom
271
+ items = [f"{k}: {v}" for k, v in class_counts.items() if v > 0]
272
+ count_text = " | ".join(items) if items else "No crossings yet"
273
+ cv2.putText(frame, count_text, (22, h - 24),
274
+ cv2.FONT_HERSHEY_SIMPLEX, 0.62, (255, 255, 255), 2, cv2.LINE_AA)
 
 
 
 
 
 
275
 
276
+ return frame
277
 
278
 
279
+ def build_metrics_html(
280
+ total_count: int,
281
+ class_counts: Dict[str, int],
282
+ cumulative_kg: float,
283
+ live_load_kg: float,
284
+ load_index_percent: float,
285
+ frame_idx: int,
286
+ total_frames: int,
287
+ elapsed: float,
288
+ device: str,
289
+ ) -> str:
290
+ pct = (frame_idx / total_frames * 100.0) if total_frames else 0.0
291
+ tonnes = cumulative_kg / 1000.0
292
+ live_tonnes = live_load_kg / 1000.0
293
+
294
+ car = class_counts.get("car", 0)
295
+ motorcycle = class_counts.get("motorcycle", 0)
296
+ bus = class_counts.get("bus", 0)
297
+ truck = class_counts.get("truck", 0)
298
+
299
+ return f"""
300
+ <div style="font-family:Inter,system-ui,Arial;">
301
+ <div style="display:grid;grid-template-columns:1fr 1fr;gap:10px;margin-bottom:12px;">
302
+ <div style="padding:16px;border-radius:16px;background:linear-gradient(135deg,#1d4ed8,#312e81);color:white;">
303
+ <div style="font-size:11px;letter-spacing:1px;opacity:.85;">VEHICLES CROSSED</div>
304
+ <div style="font-size:46px;font-weight:800;line-height:1;">{total_count}</div>
305
+ </div>
306
+ <div style="padding:16px;border-radius:16px;background:linear-gradient(135deg,#be185d,#7e22ce);color:white;">
307
+ <div style="font-size:11px;letter-spacing:1px;opacity:.85;">EST. CUMULATIVE MASS</div>
308
+ <div style="font-size:36px;font-weight:800;line-height:1;">{tonnes:.1f} t</div>
309
+ </div>
310
+ </div>
311
+
312
+ <div style="display:grid;grid-template-columns:1fr 1fr;gap:10px;margin-bottom:12px;">
313
+ <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:white;">
314
+ <div style="font-size:12px;color:#6b7280;">Live bridge load</div>
315
+ <div style="font-size:28px;font-weight:750;color:#111827;">{live_tonnes:.1f} t</div>
316
+ </div>
317
+ <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:white;">
318
+ <div style="font-size:12px;color:#6b7280;">Load index</div>
319
+ <div style="font-size:28px;font-weight:750;color:#111827;">{load_index_percent:.1f}%</div>
320
+ </div>
321
+ </div>
322
+
323
+ <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:#ffffff;margin-bottom:12px;">
324
+ <div style="font-size:12px;color:#6b7280;margin-bottom:8px;">Crossings by class</div>
325
+ <div style="display:grid;grid-template-columns:1fr 1fr;gap:8px;font-size:14px;">
326
+ <div>🚗 Cars: <b>{car}</b></div>
327
+ <div>🏍️ Motorcycles: <b>{motorcycle}</b></div>
328
+ <div>🚌 Buses: <b>{bus}</b></div>
329
+ <div>🚛 Trucks: <b>{truck}</b></div>
330
+ </div>
331
+ </div>
332
+
333
+ <div style="font-size:12px;color:#6b7280;margin-bottom:4px;display:flex;justify-content:space-between;">
334
+ <span>Frame {frame_idx} / {total_frames}</span>
335
+ <span>{pct:.1f}% · {elapsed:.1f}s · {device}</span>
336
+ </div>
337
+ <div style="height:8px;background:#e5e7eb;border-radius:99px;overflow:hidden;">
338
+ <div style="height:100%;width:{pct:.2f}%;background:#4f46e5;"></div>
339
+ </div>
340
+ </div>
341
+ """
342
+
343
+
344
+ def render_load_plot(history: List[Dict]) -> np.ndarray:
345
+ """Render load-index chart as an RGB image for Gradio."""
346
+ if not history:
347
+ img = np.ones((320, 600, 3), dtype=np.uint8) * 255
348
+ cv2.putText(img, "Load index chart will appear here", (60, 165),
349
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (80, 80, 80), 2, cv2.LINE_AA)
350
+ return img
351
+
352
+ df = pd.DataFrame(history)
353
+ # Plot only a manageable number of points for speed.
354
+ if len(df) > 500:
355
+ df = df.iloc[np.linspace(0, len(df) - 1, 500).astype(int)]
356
+
357
+ fig, ax = plt.subplots(figsize=(8.0, 3.8), dpi=100)
358
+ ax.plot(df["time_s"], df["load_index_percent"], linewidth=2)
359
+ ax.set_title("Estimated Bridge Load Index Over Time")
360
+ ax.set_xlabel("Video time (seconds)")
361
+ ax.set_ylabel("Load index (%)")
362
+ ax.grid(True, alpha=0.25)
363
+ ax.set_ylim(bottom=0)
364
+ fig.tight_layout()
365
+
366
+ fig.canvas.draw()
367
+ rgba = np.asarray(fig.canvas.buffer_rgba())
368
+ rgb = cv2.cvtColor(rgba, cv2.COLOR_RGBA2RGB)
369
+ plt.close(fig)
370
+ return rgb
371
+
372
+
373
+ def build_final_summary(
374
+ total_count: int,
375
+ class_counts: Dict[str, int],
376
+ cumulative_kg: float,
377
+ peak_live_load_kg: float,
378
+ peak_load_index: float,
379
+ csv_path: str,
380
+ ) -> str:
381
+ tonnes = cumulative_kg / 1000.0
382
+ peak_tonnes = peak_live_load_kg / 1000.0
383
+
384
+ return f"""
385
+ ### Final bridge traffic summary
386
+
387
+ **Vehicles crossed:** {total_count}
388
+
389
+ | Vehicle class | Count |
390
+ |---|---:|
391
+ | Cars | {class_counts.get("car", 0)} |
392
+ | Motorcycles | {class_counts.get("motorcycle", 0)} |
393
+ | Buses | {class_counts.get("bus", 0)} |
394
+ | Trucks | {class_counts.get("truck", 0)} |
395
+
396
+ **Cumulative estimated mass:** {tonnes:.2f} tonnes
397
+ **Peak estimated live load:** {peak_tonnes:.2f} tonnes
398
+ **Peak load index:** {peak_load_index:.1f}%
399
+
400
+ The CSV output contains the estimated load-index time series for later plotting or analysis.
401
+
402
+ > Note: This is a demonstration traffic-load indicator, not a certified structural stress calculation.
403
+ """
404
 
405
 
406
+ # ---------------------------------------------------------------------
407
+ # Main processing generator
408
+ # ---------------------------------------------------------------------
409
+ def process_video(
410
+ video_path,
411
+ model_name,
412
+ confidence,
413
+ frame_stride,
414
+ inference_width,
415
+ line_position_percent,
416
+ roi_top_percent,
417
+ roi_bottom_percent,
418
+ reference_capacity_tonnes,
419
+ car_weight_t,
420
+ motorcycle_weight_t,
421
+ bus_weight_t,
422
+ truck_weight_t,
423
+ ):
424
  if video_path is None:
425
+ yield (
426
+ None,
427
+ build_metrics_html(0, {"car": 0, "motorcycle": 0, "bus": 0, "truck": 0}, 0, 0, 0, 0, 0, 0, DEVICE),
428
+ render_load_plot([]),
429
+ "Upload a video to start analysis.",
430
+ None,
431
+ None,
432
+ )
433
  return
434
 
435
+ # Update demo weights from UI.
436
+ weights_kg = {
437
+ 2: int(float(car_weight_t) * 1000),
438
+ 3: int(float(motorcycle_weight_t) * 1000),
439
+ 5: int(float(bus_weight_t) * 1000),
440
+ 7: int(float(truck_weight_t) * 1000),
441
+ }
442
+ # Keep global-like drawing labels consistent for this run.
443
+ DEFAULT_WEIGHTS_KG.update(weights_kg)
444
+
445
+ cap = cv2.VideoCapture(video_path)
446
+ if not cap.isOpened():
447
+ raise RuntimeError(f"Could not open video: {video_path}")
448
+
449
+ fps = cap.get(cv2.CAP_PROP_FPS)
450
+ if fps is None or fps <= 1:
451
+ fps = 25.0
452
+
453
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
454
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
455
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
456
+
457
+ if width <= 0 or height <= 0:
458
+ cap.release()
459
+ raise RuntimeError("Could not read video dimensions.")
460
+
461
+ line_y = int(height * float(line_position_percent) / 100.0)
462
+ roi_top_y = int(height * float(roi_top_percent) / 100.0)
463
+ roi_bottom_y = int(height * float(roi_bottom_percent) / 100.0)
464
+
465
+ if roi_bottom_y <= roi_top_y:
466
+ roi_top_y = int(height * 0.25)
467
+ roi_bottom_y = int(height * 0.90)
468
+
469
+ reference_capacity_kg = max(1.0, float(reference_capacity_tonnes) * 1000.0)
470
+
471
+ yield (
472
+ None,
473
+ build_metrics_html(0, {"car": 0, "motorcycle": 0, "bus": 0, "truck": 0}, 0, 0, 0, 0, total_frames, 0, DEVICE),
474
+ render_load_plot([]),
475
+ "### Loading RF-DETR model and starting analysis...",
476
+ None,
477
+ None,
478
  )
479
 
480
+ model = load_model(str(model_name))
481
+ tracker = sv.ByteTrack(frame_rate=int(round(fps)))
482
+
483
+ # Output files
484
+ out_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
485
+ out_csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name
486
+
487
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
488
+ writer = cv2.VideoWriter(out_video_path, fourcc, fps, (width, height))
489
+
490
+ # State
491
+ last_detections = sv.Detections.empty()
492
+ last_side_by_id: Dict[int, int] = {}
493
+ counted_ids = set()
494
+ track_class: Dict[int, int] = {}
495
+
496
+ class_counts = {"car": 0, "motorcycle": 0, "bus": 0, "truck": 0}
497
+ total_count = 0
498
+ cumulative_kg = 0.0
499
+
500
+ history: List[Dict] = []
501
+ event_rows: List[Dict] = []
502
+
503
+ start_wall = time.time()
504
+ last_yield_wall = 0.0
505
+ last_plot = render_load_plot([])
506
+ processed_frames = 0
507
+
508
+ peak_live_load_kg = 0.0
509
+ peak_load_index = 0.0
510
+
511
+ frame_idx = 0
512
+
513
+ while True:
514
+ ok, frame = cap.read()
515
+ if not ok:
516
+ break
517
+
518
+ detect_this_frame = (frame_idx % int(frame_stride) == 0)
519
+
520
+ if detect_this_frame:
521
+ detections = predict_vehicles(
522
+ model=model,
523
+ frame_bgr=frame,
524
+ confidence=float(confidence),
525
+ inference_width=int(inference_width),
526
+ )
527
+ detections = tracker.update_with_detections(detections)
528
+ last_detections = detections
529
+ else:
530
+ detections = last_detections
531
+
532
+ # Update per-track class and line crossing only when we have tracked detections.
533
+ centres = detection_centres(detections)
534
+
535
+ live_load_kg = 0.0
536
+ active_track_ids = set()
537
+
538
+ if len(detections) > 0 and detections.tracker_id is not None:
539
+ for det_i, (centre, class_id, track_id) in enumerate(
540
+ zip(centres, detections.class_id, detections.tracker_id)
541
+ ):
542
+ if track_id is None or int(track_id) < 0:
543
+ continue
544
+
545
+ tid = int(track_id)
546
+ cid = int(class_id)
547
+ cy = float(centre[1])
548
+
549
+ track_class[tid] = cid
550
+ active_track_ids.add(tid)
551
+
552
+ # Live bridge-deck load, only if the object is inside the deck ROI.
553
+ if roi_top_y <= cy <= roi_bottom_y:
554
+ live_load_kg += get_class_weight_kg(cid, weights_kg)
555
+
556
+ current_side = side_of_line(cy, line_y)
557
+ previous_side = last_side_by_id.get(tid)
558
+
559
+ if current_side != 0:
560
+ if previous_side is not None and previous_side != 0:
561
+ crossed = previous_side != current_side
562
+ if crossed and tid not in counted_ids:
563
+ vehicle_name = VEHICLE_CLASSES.get(cid, "vehicle")
564
+ vehicle_weight = get_class_weight_kg(cid, weights_kg)
565
+ direction = "down" if previous_side < current_side else "up"
566
+
567
+ counted_ids.add(tid)
568
+ total_count += 1
569
+ class_counts[vehicle_name] = class_counts.get(vehicle_name, 0) + 1
570
+ cumulative_kg += vehicle_weight
571
+
572
+ event_rows.append({
573
+ "video_time_s": frame_idx / fps,
574
+ "frame": frame_idx,
575
+ "tracker_id": tid,
576
+ "vehicle_type": vehicle_name,
577
+ "direction": direction,
578
+ "estimated_vehicle_weight_kg": vehicle_weight,
579
+ "cumulative_estimated_mass_kg": cumulative_kg,
580
+ })
581
+
582
+ last_side_by_id[tid] = current_side
583
+
584
+ load_index_percent = (live_load_kg / reference_capacity_kg) * 100.0
585
+ peak_live_load_kg = max(peak_live_load_kg, live_load_kg)
586
+ peak_load_index = max(peak_load_index, load_index_percent)
587
+
588
+ history.append({
589
+ "video_time_s": frame_idx / fps,
590
+ "time_s": frame_idx / fps,
591
+ "frame": frame_idx,
592
+ "vehicles_crossed_total": total_count,
593
+ "cars_crossed": class_counts.get("car", 0),
594
+ "motorcycles_crossed": class_counts.get("motorcycle", 0),
595
+ "buses_crossed": class_counts.get("bus", 0),
596
+ "trucks_crossed": class_counts.get("truck", 0),
597
+ "live_load_kg": live_load_kg,
598
+ "live_load_tonnes": live_load_kg / 1000.0,
599
+ "load_index_percent": load_index_percent,
600
+ "cumulative_estimated_mass_kg": cumulative_kg,
601
+ "cumulative_estimated_mass_tonnes": cumulative_kg / 1000.0,
602
+ })
603
+
604
+ elapsed_wall = time.time() - start_wall
605
+ processed_frames += 1
606
+ current_processing_fps = processed_frames / max(elapsed_wall, 1e-6)
607
+ fps_text = f"{current_processing_fps:.1f} proc FPS"
608
+
609
+ annotated = annotate_frame(
610
+ frame=frame,
611
+ detections=detections,
612
+ line_y=line_y,
613
+ roi_top_y=roi_top_y,
614
+ roi_bottom_y=roi_bottom_y,
615
+ class_counts=class_counts,
616
+ total_count=total_count,
617
+ cumulative_kg=cumulative_kg,
618
+ live_load_kg=live_load_kg,
619
+ load_index_percent=load_index_percent,
620
+ fps_text=fps_text,
621
+ )
622
+ writer.write(annotated)
623
+
624
+ now = time.time()
625
+ if now - last_yield_wall >= 0.35:
626
+ last_yield_wall = now
627
+ # Refresh the chart less often than the frame display.
628
+ last_plot = render_load_plot(history)
629
+ rgb_frame = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
630
+ yield (
631
+ rgb_frame,
632
+ build_metrics_html(
633
+ total_count=total_count,
634
+ class_counts=class_counts,
635
+ cumulative_kg=cumulative_kg,
636
+ live_load_kg=live_load_kg,
637
+ load_index_percent=load_index_percent,
638
+ frame_idx=frame_idx + 1,
639
+ total_frames=total_frames,
640
+ elapsed=elapsed_wall,
641
+ device=DEVICE,
642
+ ),
643
+ last_plot,
644
+ "### Live analysis running...",
645
+ None,
646
+ None,
647
+ )
648
+
649
+ frame_idx += 1
650
+
651
+ cap.release()
652
+ writer.release()
653
+
654
+ # Save CSV time series. Add event-level detail as separate columns where possible.
655
+ history_df = pd.DataFrame(history)
656
+ history_df.to_csv(out_csv_path, index=False)
657
+
658
+ final_plot = render_load_plot(history)
659
+ final_summary = build_final_summary(
660
+ total_count=total_count,
661
+ class_counts=class_counts,
662
+ cumulative_kg=cumulative_kg,
663
+ peak_live_load_kg=peak_live_load_kg,
664
+ peak_load_index=peak_load_index,
665
+ csv_path=out_csv_path,
666
  )
667
+
668
+ final_frame = None
669
+ if history:
670
+ # Try to show the last annotated frame from the output video.
671
+ cap2 = cv2.VideoCapture(out_video_path)
672
+ if cap2.isOpened():
673
+ cap2.set(cv2.CAP_PROP_POS_FRAMES, max(0, frame_idx - 1))
674
+ ok, last = cap2.read()
675
+ if ok:
676
+ final_frame = cv2.cvtColor(last, cv2.COLOR_BGR2RGB)
677
+ cap2.release()
678
+
679
+ yield (
680
+ final_frame,
681
+ build_metrics_html(
682
+ total_count=total_count,
683
+ class_counts=class_counts,
684
+ cumulative_kg=cumulative_kg,
685
+ live_load_kg=0,
686
+ load_index_percent=0,
687
+ frame_idx=total_frames if total_frames else frame_idx,
688
+ total_frames=total_frames if total_frames else frame_idx,
689
+ elapsed=time.time() - start_wall,
690
+ device=DEVICE,
691
+ ),
692
+ final_plot,
693
+ final_summary,
694
+ out_video_path,
695
+ out_csv_path,
696
  )
697
 
 
 
698
 
699
+ # ---------------------------------------------------------------------
700
+ # Gradio UI
701
+ # ---------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
  CUSTOM_CSS = """
703
+ .gradio-container {
704
+ max-width: 1320px !important;
705
+ margin: auto !important;
706
+ }
707
+ #hero {
708
+ text-align: center;
709
+ padding: 18px 8px 8px 8px;
710
+ }
711
+ #hero h1 {
712
+ font-weight: 850;
713
+ letter-spacing: -0.6px;
714
+ margin-bottom: 2px;
715
+ }
716
+ #hero p {
717
+ color: #64748b;
718
+ font-size: 16px;
719
+ margin-top: 0;
720
+ }
721
+ .panel {
722
+ border: 1px solid #e5e7eb;
723
+ border-radius: 18px;
724
+ padding: 16px;
725
+ background: #ffffff;
726
+ box-shadow: 0 8px 24px rgba(15, 23, 42, 0.04);
727
+ }
728
+ #live-frame img, #load-plot img {
729
+ border-radius: 14px;
730
+ }
731
+ footer {
732
+ visibility: hidden;
733
+ }
734
  """
735
 
736
+ with gr.Blocks(
737
+ title="Bridge Traffic Load Demo",
738
+ theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
739
+ css=CUSTOM_CSS,
740
+ ) as demo:
741
 
742
+ with gr.Row(elem_id="hero"):
743
  gr.Markdown(
744
  """
745
+ # 🌉 Bridge Traffic Load Demo
746
+ Fast RF-DETR vehicle detection, ByteTrack tracking, line-crossing counts,
747
+ estimated cumulative vehicle mass, and live bridge load-index over time.
 
 
748
  """
749
  )
750
 
751
  with gr.Row():
 
752
  with gr.Column(scale=1):
753
+ with gr.Group(elem_classes="panel"):
754
+ gr.Markdown("### 1) Upload video")
755
  video_input = gr.Video(
756
+ label="Bridge traffic video",
757
  sources=["upload"],
758
  format="mp4",
759
  height=260,
760
  )
761
 
762
+ start_btn = gr.Button(" Start analysis", variant="primary", size="lg")
763
+
764
+ gr.Markdown("### 2) Speed settings")
765
+ model_name = gr.Radio(
766
+ choices=list(MODEL_OPTIONS.keys()),
767
+ value="Nano - fastest",
768
+ label="RF-DETR model",
769
+ )
770
+ confidence = gr.Slider(
771
+ minimum=0.10,
772
+ maximum=0.90,
773
+ value=0.40,
774
+ step=0.05,
775
+ label="Confidence threshold",
776
+ )
777
+ frame_stride = gr.Slider(
778
+ minimum=1,
779
+ maximum=10,
780
+ value=3,
781
+ step=1,
782
+ label="Frame stride",
783
+ info="Detect every Nth frame. 1 is most accurate. 3-5 is much faster.",
784
+ )
785
+ inference_width = gr.Slider(
786
+ minimum=384,
787
+ maximum=1280,
788
+ value=640,
789
+ step=64,
790
+ label="Inference width",
791
+ info="Lower is faster. Try 512 or 640 for CPU demos.",
792
+ )
793
+
794
+ with gr.Accordion("Bridge settings", open=False):
795
+ line_position_percent = gr.Slider(
796
+ minimum=10,
797
+ maximum=90,
798
+ value=55,
799
+ step=1,
800
+ label="Counting line vertical position (%)",
801
  )
802
+ roi_top_percent = gr.Slider(
803
+ minimum=0,
804
+ maximum=90,
805
+ value=20,
806
+ step=1,
807
+ label="Bridge deck ROI top (%)",
808
+ )
809
+ roi_bottom_percent = gr.Slider(
810
+ minimum=10,
811
+ maximum=100,
812
+ value=90,
813
+ step=1,
814
+ label="Bridge deck ROI bottom (%)",
815
+ )
816
+ reference_capacity_tonnes = gr.Slider(
817
+ minimum=5,
818
+ maximum=200,
819
+ value=40,
820
+ step=5,
821
+ label="Reference live-load capacity for demo index (tonnes)",
822
  )
823
 
824
+ with gr.Accordion("Estimated class weights", open=False):
825
+ car_weight_t = gr.Number(value=1.5, label="Car weight estimate (tonnes)")
826
+ motorcycle_weight_t = gr.Number(value=0.25, label="Motorcycle weight estimate (tonnes)")
827
+ bus_weight_t = gr.Number(value=12.0, label="Bus weight estimate (tonnes)")
828
+ truck_weight_t = gr.Number(value=18.0, label="Truck weight estimate (tonnes)")
829
+
830
+ gr.Markdown(
831
+ """
832
+ **For speed:** use **Nano**, inference width **512-640**, and frame stride **3-5**.
833
+ Use **Medium** only when you need better detection and have a GPU.
834
+ """
835
  )
836
 
 
837
  with gr.Column(scale=2):
838
+ with gr.Group(elem_classes="panel"):
839
+ gr.Markdown("### Live annotated video")
840
+ live_frame = gr.Image(
841
+ show_label=False,
842
+ elem_id="live-frame",
843
+ height=470,
844
+ )
845
+
846
+ with gr.Row():
847
+ with gr.Column(scale=1):
848
+ with gr.Group(elem_classes="panel"):
849
+ gr.Markdown("### Live metrics")
850
+ metrics_html = gr.HTML(
851
+ value=build_metrics_html(
852
+ total_count=0,
853
+ class_counts={"car": 0, "motorcycle": 0, "bus": 0, "truck": 0},
854
+ cumulative_kg=0,
855
+ live_load_kg=0,
856
+ load_index_percent=0,
857
+ frame_idx=0,
858
+ total_frames=0,
859
+ elapsed=0,
860
+ device=DEVICE,
861
+ )
862
  )
863
+
864
+ with gr.Column(scale=1):
865
+ with gr.Group(elem_classes="panel"):
866
+ gr.Markdown("### Load index over time")
867
+ load_plot = gr.Image(
868
+ show_label=False,
869
+ elem_id="load-plot",
870
+ height=310,
871
+ value=render_load_plot([]),
872
  )
873
 
 
874
  with gr.Row():
875
  with gr.Column(scale=1):
876
+ with gr.Group(elem_classes="panel"):
877
+ gr.Markdown("### Final annotated video")
878
+ video_output = gr.Video(label="Replay / download annotated video", height=270)
879
  with gr.Column(scale=1):
880
+ with gr.Group(elem_classes="panel"):
881
+ gr.Markdown("### Final summary")
882
+ summary_output = gr.Markdown("Run an analysis to see the final summary.")
883
+ csv_output = gr.File(label="Download load-index CSV")
884
 
885
+ start_btn.click(
886
  fn=process_video,
887
+ inputs=[
888
+ video_input,
889
+ model_name,
890
+ confidence,
891
+ frame_stride,
892
+ inference_width,
893
+ line_position_percent,
894
+ roi_top_percent,
895
+ roi_bottom_percent,
896
+ reference_capacity_tonnes,
897
+ car_weight_t,
898
+ motorcycle_weight_t,
899
+ bus_weight_t,
900
+ truck_weight_t,
901
+ ],
902
+ outputs=[
903
+ live_frame,
904
+ metrics_html,
905
+ load_plot,
906
+ summary_output,
907
+ video_output,
908
+ csv_output,
909
+ ],
910
  )
911
 
912
 
913
  if __name__ == "__main__":
914
+ demo.queue(max_size=3).launch()