lyimo commited on
Commit
35e66da
·
verified ·
1 Parent(s): 7176850

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +777 -406
app.py CHANGED
@@ -1,13 +1,14 @@
1
  """
2
- Bridge Traffic & Load Demo App
3
- Fast RF-DETR + ByteTrack vehicle counting for bridge videos.
4
  """
5
 
6
  import os
7
  import time
8
  import tempfile
 
 
9
  from functools import lru_cache
10
- from typing import Dict, List, Tuple
11
 
12
  import cv2
13
  import gradio as gr
@@ -19,42 +20,65 @@ import pandas as pd
19
  import supervision as sv
20
  import torch
21
 
22
- from rfdetr import RFDETRNano, RFDETRMedium
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  # ---------------------------------------------------------------------
26
- # Vehicle classes from COCO
 
27
  # ---------------------------------------------------------------------
28
- # COCO IDs used by RF-DETR:
29
- # 2 = car, 3 = motorcycle, 5 = bus, 7 = truck
30
- VEHICLE_CLASSES: Dict[int, str] = {
31
- 2: "car",
32
- 3: "motorcycle",
33
- 5: "bus",
34
- 7: "truck",
35
- }
36
 
37
- # Very rough demonstration weights in kg.
38
- # Adjust these for your local traffic profile.
39
- DEFAULT_WEIGHTS_KG: Dict[int, int] = {
40
- 2: 1500, # car / small vehicle
41
- 3: 250, # motorcycle
42
- 5: 12000, # bus
43
- 7: 18000, # truck / lorry
44
- }
45
 
46
- CLASS_COLORS_BGR: Dict[int, Tuple[int, int, int]] = {
47
- 2: (40, 190, 120), # car
48
- 3: (255, 170, 70), # motorcycle
49
- 5: (245, 120, 45), # bus
50
- 7: (220, 70, 180), # truck
51
- }
52
 
53
- MODEL_OPTIONS = {
54
- "Nano - fastest": RFDETRNano,
55
- "Medium - more accurate, slower": RFDETRMedium,
56
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
59
 
60
  try:
@@ -63,91 +87,281 @@ except Exception:
63
  pass
64
 
65
  if DEVICE == "cuda":
66
- torch.backends.cudnn.benchmark = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  # ---------------------------------------------------------------------
70
  # Model loading
71
  # ---------------------------------------------------------------------
72
- @lru_cache(maxsize=2)
73
- def load_model(model_name: str):
74
- """Load RF-DETR once and reuse it across runs."""
75
- model_cls = MODEL_OPTIONS[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- print(f"Loading {model_name} on {DEVICE}...")
78
 
79
  try:
80
- model = model_cls(device=DEVICE)
81
  except TypeError:
82
- # Fallback for older RF-DETR builds
83
- model = model_cls()
84
 
85
- if DEVICE == "cuda":
86
- try:
87
- model.optimize_for_inference()
88
- print("RF-DETR optimized for inference.")
89
- except Exception as exc:
90
- print(f"optimize_for_inference skipped: {exc}")
 
91
 
92
- print("Model ready.")
93
  return model
94
 
95
 
96
  # ---------------------------------------------------------------------
97
- # Detection helper
98
  # ---------------------------------------------------------------------
99
- def predict_vehicles(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  model,
101
  frame_bgr: np.ndarray,
102
  confidence: float,
103
  inference_width: int,
104
- ) -> sv.Detections:
105
  """
106
- Resize frame before inference for speed, then scale boxes back to
107
- original video coordinates.
108
  """
109
  h, w = frame_bgr.shape[:2]
110
- inference_width = int(inference_width)
111
 
112
  if inference_width > 0 and w > inference_width:
113
- scale = inference_width / float(w)
114
- resized_w = inference_width
115
- resized_h = int(h * scale)
116
- model_frame = cv2.resize(frame_bgr, (resized_w, resized_h), interpolation=cv2.INTER_AREA)
 
 
117
  else:
118
  scale = 1.0
119
- model_frame = frame_bgr
120
 
121
- frame_rgb = cv2.cvtColor(model_frame, cv2.COLOR_BGR2RGB)
122
 
123
  with torch.inference_mode():
124
- detections = model.predict(frame_rgb, threshold=float(confidence))
125
 
126
  if len(detections) == 0:
127
- return detections
128
 
129
- # Keep only vehicle classes.
130
- mask = np.isin(detections.class_id, list(VEHICLE_CLASSES.keys()))
131
- detections = detections[mask]
132
 
133
- if len(detections) == 0:
134
- return detections
 
 
 
 
135
 
136
- # Scale boxes back to original frame size.
137
- if scale != 1.0:
 
 
 
 
 
138
  detections.xyxy = detections.xyxy / scale
139
 
140
- return detections
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
 
143
  # ---------------------------------------------------------------------
144
- # Counting and load helpers
145
  # ---------------------------------------------------------------------
146
- def side_of_line(y: float, line_y: int, dead_zone_px: int = 4) -> int:
147
- """
148
- Returns -1 above the line, +1 below the line, 0 inside a small dead zone.
149
- The dead zone prevents jitter around the line from causing false crossings.
150
- """
151
  diff = y - line_y
152
  if abs(diff) <= dead_zone_px:
153
  return 0
@@ -158,47 +372,181 @@ def detection_centres(detections: sv.Detections) -> np.ndarray:
158
  if len(detections) == 0:
159
  return np.empty((0, 2), dtype=float)
160
  xyxy = detections.xyxy
161
- cx = (xyxy[:, 0] + xyxy[:, 2]) / 2.0
162
- cy = (xyxy[:, 1] + xyxy[:, 3]) / 2.0
163
- return np.column_stack([cx, cy])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
 
166
- def get_class_weight_kg(class_id: int, weights: Dict[int, int]) -> int:
167
- return int(weights.get(int(class_id), 0))
 
168
 
 
 
 
169
 
170
- def draw_header_panel(
171
- frame: np.ndarray,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  total_count: int,
 
173
  cumulative_kg: float,
174
  live_load_kg: float,
175
  load_index_percent: float,
176
- fps_text: str,
177
- ) -> np.ndarray:
178
- """Draw a clean dashboard panel at the top-left of the frame."""
179
- overlay = frame.copy()
180
- x1, y1, x2, y2 = 18, 18, 520, 158
181
- cv2.rectangle(overlay, (x1, y1), (x2, y2), (20, 24, 36), -1)
182
- frame = cv2.addWeighted(overlay, 0.82, frame, 0.18, 0)
 
 
 
 
 
183
 
184
- cv2.putText(frame, "BRIDGE TRAFFIC LOAD DEMO", (34, 46),
185
- cv2.FONT_HERSHEY_SIMPLEX, 0.72, (255, 255, 255), 2, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- cv2.putText(frame, f"Vehicles crossed: {total_count}", (34, 78),
188
- cv2.FONT_HERSHEY_SIMPLEX, 0.62, (230, 240, 255), 2, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- cv2.putText(frame, f"Cumulative estimated mass: {cumulative_kg / 1000.0:.1f} tonnes", (34, 106),
191
- cv2.FONT_HERSHEY_SIMPLEX, 0.58, (220, 240, 230), 2, cv2.LINE_AA)
 
 
 
 
 
 
 
192
 
193
- cv2.putText(frame, f"Live load: {live_load_kg / 1000.0:.1f} t | Load index: {load_index_percent:.1f}% | {fps_text}", (34, 134),
194
- cv2.FONT_HERSHEY_SIMPLEX, 0.52, (230, 230, 255), 1, cv2.LINE_AA)
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  return frame
197
 
198
 
199
  def annotate_frame(
200
  frame: np.ndarray,
201
  detections: sv.Detections,
 
202
  line_y: int,
203
  roi_top_y: int,
204
  roi_bottom_y: int,
@@ -207,22 +555,30 @@ def annotate_frame(
207
  cumulative_kg: float,
208
  live_load_kg: float,
209
  load_index_percent: float,
210
- fps_text: str,
 
211
  ) -> np.ndarray:
212
- """Draw ROI, counting line, boxes, labels and dashboard."""
213
  h, w = frame.shape[:2]
214
 
215
- # Bridge deck ROI overlay
216
  overlay = frame.copy()
217
- cv2.rectangle(overlay, (0, roi_top_y), (w, roi_bottom_y), (80, 80, 80), -1)
218
  frame = cv2.addWeighted(overlay, 0.08, frame, 0.92, 0)
219
 
220
- # Counting line
221
  cv2.line(frame, (0, line_y), (w, line_y), (40, 230, 255), 3)
222
- cv2.putText(frame, "COUNTING LINE", (24, max(28, line_y - 12)),
223
- cv2.FONT_HERSHEY_SIMPLEX, 0.58, (40, 230, 255), 2, cv2.LINE_AA)
 
 
 
 
 
 
 
 
224
 
225
- # ROI borders
226
  cv2.line(frame, (0, roi_top_y), (w, roi_top_y), (170, 170, 170), 1)
227
  cv2.line(frame, (0, roi_bottom_y), (w, roi_bottom_y), (170, 170, 170), 1)
228
 
@@ -235,180 +591,100 @@ def annotate_frame(
235
  if confidences is None:
236
  confidences = [0.0] * len(detections)
237
 
238
- for xyxy, class_id, conf, track_id in zip(
239
- detections.xyxy,
240
- detections.class_id,
241
- confidences,
242
- tracker_ids,
243
- ):
244
- class_id = int(class_id)
245
  x1, y1, x2, y2 = map(int, xyxy)
246
- name = VEHICLE_CLASSES.get(class_id, "vehicle")
247
- color = CLASS_COLORS_BGR.get(class_id, (80, 220, 255))
248
- weight_t = DEFAULT_WEIGHTS_KG.get(class_id, 0) / 1000.0
249
 
250
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
251
 
252
- id_text = f"#{int(track_id)} " if track_id is not None and int(track_id) >= 0 else ""
253
- label = f"{id_text}{name} {float(conf):.2f} ~{weight_t:.1f}t"
254
  (tw, th), base = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.52, 1)
255
-
256
  label_y1 = max(0, y1 - th - base - 8)
257
  cv2.rectangle(frame, (x1, label_y1), (x1 + tw + 10, y1), color, -1)
258
- cv2.putText(frame, label, (x1 + 5, y1 - 6),
259
- cv2.FONT_HERSHEY_SIMPLEX, 0.52, (255, 255, 255), 1, cv2.LINE_AA)
 
 
 
 
 
 
 
 
260
 
261
- frame = draw_header_panel(
262
  frame=frame,
263
  total_count=total_count,
264
  cumulative_kg=cumulative_kg,
265
  live_load_kg=live_load_kg,
266
  load_index_percent=load_index_percent,
267
- fps_text=fps_text,
 
268
  )
269
 
270
- # Compact class counts at bottom
271
- items = [f"{k}: {v}" for k, v in class_counts.items() if v > 0]
272
- count_text = " | ".join(items) if items else "No crossings yet"
273
- cv2.putText(frame, count_text, (22, h - 24),
274
- cv2.FONT_HERSHEY_SIMPLEX, 0.62, (255, 255, 255), 2, cv2.LINE_AA)
275
-
276
- return frame
277
-
278
-
279
- def build_metrics_html(
280
- total_count: int,
281
- class_counts: Dict[str, int],
282
- cumulative_kg: float,
283
- live_load_kg: float,
284
- load_index_percent: float,
285
- frame_idx: int,
286
- total_frames: int,
287
- elapsed: float,
288
- device: str,
289
- ) -> str:
290
- pct = (frame_idx / total_frames * 100.0) if total_frames else 0.0
291
- tonnes = cumulative_kg / 1000.0
292
- live_tonnes = live_load_kg / 1000.0
293
-
294
- car = class_counts.get("car", 0)
295
- motorcycle = class_counts.get("motorcycle", 0)
296
- bus = class_counts.get("bus", 0)
297
- truck = class_counts.get("truck", 0)
298
-
299
- return f"""
300
- <div style="font-family:Inter,system-ui,Arial;">
301
- <div style="display:grid;grid-template-columns:1fr 1fr;gap:10px;margin-bottom:12px;">
302
- <div style="padding:16px;border-radius:16px;background:linear-gradient(135deg,#1d4ed8,#312e81);color:white;">
303
- <div style="font-size:11px;letter-spacing:1px;opacity:.85;">VEHICLES CROSSED</div>
304
- <div style="font-size:46px;font-weight:800;line-height:1;">{total_count}</div>
305
- </div>
306
- <div style="padding:16px;border-radius:16px;background:linear-gradient(135deg,#be185d,#7e22ce);color:white;">
307
- <div style="font-size:11px;letter-spacing:1px;opacity:.85;">EST. CUMULATIVE MASS</div>
308
- <div style="font-size:36px;font-weight:800;line-height:1;">{tonnes:.1f} t</div>
309
- </div>
310
- </div>
311
-
312
- <div style="display:grid;grid-template-columns:1fr 1fr;gap:10px;margin-bottom:12px;">
313
- <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:white;">
314
- <div style="font-size:12px;color:#6b7280;">Live bridge load</div>
315
- <div style="font-size:28px;font-weight:750;color:#111827;">{live_tonnes:.1f} t</div>
316
- </div>
317
- <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:white;">
318
- <div style="font-size:12px;color:#6b7280;">Load index</div>
319
- <div style="font-size:28px;font-weight:750;color:#111827;">{load_index_percent:.1f}%</div>
320
- </div>
321
- </div>
322
-
323
- <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:#ffffff;margin-bottom:12px;">
324
- <div style="font-size:12px;color:#6b7280;margin-bottom:8px;">Crossings by class</div>
325
- <div style="display:grid;grid-template-columns:1fr 1fr;gap:8px;font-size:14px;">
326
- <div>🚗 Cars: <b>{car}</b></div>
327
- <div>🏍️ Motorcycles: <b>{motorcycle}</b></div>
328
- <div>🚌 Buses: <b>{bus}</b></div>
329
- <div>🚛 Trucks: <b>{truck}</b></div>
330
- </div>
331
- </div>
332
-
333
- <div style="font-size:12px;color:#6b7280;margin-bottom:4px;display:flex;justify-content:space-between;">
334
- <span>Frame {frame_idx} / {total_frames}</span>
335
- <span>{pct:.1f}% · {elapsed:.1f}s · {device}</span>
336
- </div>
337
- <div style="height:8px;background:#e5e7eb;border-radius:99px;overflow:hidden;">
338
- <div style="height:100%;width:{pct:.2f}%;background:#4f46e5;"></div>
339
- </div>
340
- </div>
341
- """
342
-
343
-
344
- def render_load_plot(history: List[Dict]) -> np.ndarray:
345
- """Render load-index chart as an RGB image for Gradio."""
346
- if not history:
347
- img = np.ones((320, 600, 3), dtype=np.uint8) * 255
348
- cv2.putText(img, "Load index chart will appear here", (60, 165),
349
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (80, 80, 80), 2, cv2.LINE_AA)
350
- return img
351
-
352
- df = pd.DataFrame(history)
353
- # Plot only a manageable number of points for speed.
354
- if len(df) > 500:
355
- df = df.iloc[np.linspace(0, len(df) - 1, 500).astype(int)]
356
 
357
- fig, ax = plt.subplots(figsize=(8.0, 3.8), dpi=100)
358
- ax.plot(df["time_s"], df["load_index_percent"], linewidth=2)
359
- ax.set_title("Estimated Bridge Load Index Over Time")
360
- ax.set_xlabel("Video time (seconds)")
361
- ax.set_ylabel("Load index (%)")
362
- ax.grid(True, alpha=0.25)
363
- ax.set_ylim(bottom=0)
364
- fig.tight_layout()
365
 
366
- fig.canvas.draw()
367
- rgba = np.asarray(fig.canvas.buffer_rgba())
368
- rgb = cv2.cvtColor(rgba, cv2.COLOR_RGBA2RGB)
369
- plt.close(fig)
370
- return rgb
371
 
372
 
373
- def build_final_summary(
374
  total_count: int,
375
  class_counts: Dict[str, int],
376
  cumulative_kg: float,
377
  peak_live_load_kg: float,
378
  peak_load_index: float,
379
- csv_path: str,
380
  ) -> str:
381
- tonnes = cumulative_kg / 1000.0
382
- peak_tonnes = peak_live_load_kg / 1000.0
 
 
 
383
 
384
- return f"""
385
- ### Final bridge traffic summary
386
 
387
- **Vehicles crossed:** {total_count}
388
 
389
- | Vehicle class | Count |
390
- |---|---:|
391
- | Cars | {class_counts.get("car", 0)} |
392
- | Motorcycles | {class_counts.get("motorcycle", 0)} |
393
- | Buses | {class_counts.get("bus", 0)} |
394
- | Trucks | {class_counts.get("truck", 0)} |
395
 
396
- **Cumulative estimated mass:** {tonnes:.2f} tonnes
397
- **Peak estimated live load:** {peak_tonnes:.2f} tonnes
398
- **Peak load index:** {peak_load_index:.1f}%
399
 
400
- The CSV output contains the estimated load-index time series for later plotting or analysis.
 
 
401
 
402
- > Note: This is a demonstration traffic-load indicator, not a certified structural stress calculation.
403
  """
404
 
405
 
406
  # ---------------------------------------------------------------------
407
- # Main processing generator
408
  # ---------------------------------------------------------------------
409
  def process_video(
410
  video_path,
411
- model_name,
 
412
  confidence,
413
  frame_stride,
414
  inference_width,
@@ -416,38 +692,64 @@ def process_video(
416
  roi_top_percent,
417
  roi_bottom_percent,
418
  reference_capacity_tonnes,
 
 
 
419
  car_weight_t,
420
- motorcycle_weight_t,
421
  bus_weight_t,
422
  truck_weight_t,
 
 
 
 
 
423
  ):
424
  if video_path is None:
425
  yield (
426
  None,
427
- build_metrics_html(0, {"car": 0, "motorcycle": 0, "bus": 0, "truck": 0}, 0, 0, 0, 0, 0, 0, DEVICE),
428
- render_load_plot([]),
429
- "Upload a video to start analysis.",
430
  None,
431
  None,
432
  )
433
  return
434
 
435
- # Update demo weights from UI.
436
- weights_kg = {
437
- 2: int(float(car_weight_t) * 1000),
438
- 3: int(float(motorcycle_weight_t) * 1000),
439
- 5: int(float(bus_weight_t) * 1000),
440
- 7: int(float(truck_weight_t) * 1000),
441
- }
442
- # Keep global-like drawing labels consistent for this run.
443
- DEFAULT_WEIGHTS_KG.update(weights_kg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  cap = cv2.VideoCapture(video_path)
446
  if not cap.isOpened():
447
  raise RuntimeError(f"Could not open video: {video_path}")
448
 
449
- fps = cap.get(cv2.CAP_PROP_FPS)
450
- if fps is None or fps <= 1:
451
  fps = 25.0
452
 
453
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
@@ -463,119 +765,134 @@ def process_video(
463
  roi_bottom_y = int(height * float(roi_bottom_percent) / 100.0)
464
 
465
  if roi_bottom_y <= roi_top_y:
466
- roi_top_y = int(height * 0.25)
467
  roi_bottom_y = int(height * 0.90)
468
 
469
  reference_capacity_kg = max(1.0, float(reference_capacity_tonnes) * 1000.0)
470
 
471
  yield (
472
  None,
473
- build_metrics_html(0, {"car": 0, "motorcycle": 0, "bus": 0, "truck": 0}, 0, 0, 0, 0, total_frames, 0, DEVICE),
474
- render_load_plot([]),
475
- "### Loading RF-DETR model and starting analysis...",
476
  None,
477
  None,
478
  )
479
 
480
- model = load_model(str(model_name))
 
 
 
 
 
481
  tracker = sv.ByteTrack(frame_rate=int(round(fps)))
482
 
483
- # Output files
484
  out_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
485
  out_csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name
486
 
487
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
488
- writer = cv2.VideoWriter(out_video_path, fourcc, fps, (width, height))
 
 
 
 
489
 
490
- # State
491
  last_detections = sv.Detections.empty()
 
 
492
  last_side_by_id: Dict[int, int] = {}
493
  counted_ids = set()
494
- track_class: Dict[int, int] = {}
495
 
496
- class_counts = {"car": 0, "motorcycle": 0, "bus": 0, "truck": 0}
497
  total_count = 0
498
  cumulative_kg = 0.0
499
 
500
  history: List[Dict] = []
501
- event_rows: List[Dict] = []
502
-
503
- start_wall = time.time()
504
- last_yield_wall = 0.0
505
- last_plot = render_load_plot([])
506
- processed_frames = 0
507
 
508
  peak_live_load_kg = 0.0
509
  peak_load_index = 0.0
510
 
 
 
 
 
 
511
  frame_idx = 0
 
512
 
513
  while True:
514
  ok, frame = cap.read()
515
  if not ok:
516
  break
517
 
518
- detect_this_frame = (frame_idx % int(frame_stride) == 0)
519
-
520
- if detect_this_frame:
521
- detections = predict_vehicles(
522
- model=model,
523
  frame_bgr=frame,
524
  confidence=float(confidence),
525
  inference_width=int(inference_width),
526
  )
527
  detections = tracker.update_with_detections(detections)
 
 
 
 
 
 
 
 
528
  last_detections = detections
 
529
  else:
530
  detections = last_detections
 
531
 
532
- # Update per-track class and line crossing only when we have tracked detections.
533
  centres = detection_centres(detections)
534
 
535
  live_load_kg = 0.0
536
- active_track_ids = set()
537
 
538
  if len(detections) > 0 and detections.tracker_id is not None:
539
- for det_i, (centre, class_id, track_id) in enumerate(
540
- zip(centres, detections.class_id, detections.tracker_id)
541
- ):
542
- if track_id is None or int(track_id) < 0:
543
  continue
544
 
545
- tid = int(track_id)
546
- cid = int(class_id)
547
- cy = float(centre[1])
 
548
 
549
- track_class[tid] = cid
550
- active_track_ids.add(tid)
551
 
552
- # Live bridge-deck load, only if the object is inside the deck ROI.
 
 
553
  if roi_top_y <= cy <= roi_bottom_y:
554
- live_load_kg += get_class_weight_kg(cid, weights_kg)
555
 
556
  current_side = side_of_line(cy, line_y)
557
  previous_side = last_side_by_id.get(tid)
558
 
559
  if current_side != 0:
560
- if previous_side is not None and previous_side != 0:
561
- crossed = previous_side != current_side
562
- if crossed and tid not in counted_ids:
563
- vehicle_name = VEHICLE_CLASSES.get(cid, "vehicle")
564
- vehicle_weight = get_class_weight_kg(cid, weights_kg)
565
- direction = "down" if previous_side < current_side else "up"
566
-
567
  counted_ids.add(tid)
568
  total_count += 1
569
- class_counts[vehicle_name] = class_counts.get(vehicle_name, 0) + 1
570
- cumulative_kg += vehicle_weight
 
571
 
572
- event_rows.append({
 
573
  "video_time_s": frame_idx / fps,
574
  "frame": frame_idx,
575
  "tracker_id": tid,
576
- "vehicle_type": vehicle_name,
 
577
  "direction": direction,
578
- "estimated_vehicle_weight_kg": vehicle_weight,
579
  "cumulative_estimated_mass_kg": cumulative_kg,
580
  })
581
 
@@ -585,15 +902,23 @@ def process_video(
585
  peak_live_load_kg = max(peak_live_load_kg, live_load_kg)
586
  peak_load_index = max(peak_load_index, load_index_percent)
587
 
 
 
 
 
588
  history.append({
589
- "video_time_s": frame_idx / fps,
590
  "time_s": frame_idx / fps,
591
  "frame": frame_idx,
592
- "vehicles_crossed_total": total_count,
 
 
593
  "cars_crossed": class_counts.get("car", 0),
594
  "motorcycles_crossed": class_counts.get("motorcycle", 0),
595
  "buses_crossed": class_counts.get("bus", 0),
596
  "trucks_crossed": class_counts.get("truck", 0),
 
 
 
597
  "live_load_kg": live_load_kg,
598
  "live_load_tonnes": live_load_kg / 1000.0,
599
  "load_index_percent": load_index_percent,
@@ -601,14 +926,10 @@ def process_video(
601
  "cumulative_estimated_mass_tonnes": cumulative_kg / 1000.0,
602
  })
603
 
604
- elapsed_wall = time.time() - start_wall
605
- processed_frames += 1
606
- current_processing_fps = processed_frames / max(elapsed_wall, 1e-6)
607
- fps_text = f"{current_processing_fps:.1f} proc FPS"
608
-
609
  annotated = annotate_frame(
610
  frame=frame,
611
  detections=detections,
 
612
  line_y=line_y,
613
  roi_top_y=roi_top_y,
614
  roi_bottom_y=roi_bottom_y,
@@ -617,18 +938,22 @@ def process_video(
617
  cumulative_kg=cumulative_kg,
618
  live_load_kg=live_load_kg,
619
  load_index_percent=load_index_percent,
620
- fps_text=fps_text,
 
621
  )
 
622
  writer.write(annotated)
 
623
 
624
  now = time.time()
 
 
 
 
625
  if now - last_yield_wall >= 0.35:
626
  last_yield_wall = now
627
- # Refresh the chart less often than the frame display.
628
- last_plot = render_load_plot(history)
629
- rgb_frame = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
630
  yield (
631
- rgb_frame,
632
  build_metrics_html(
633
  total_count=total_count,
634
  class_counts=class_counts,
@@ -637,10 +962,11 @@ def process_video(
637
  load_index_percent=load_index_percent,
638
  frame_idx=frame_idx + 1,
639
  total_frames=total_frames,
640
- elapsed=elapsed_wall,
641
- device=DEVICE,
 
642
  ),
643
- last_plot,
644
  "### Live analysis running...",
645
  None,
646
  None,
@@ -651,33 +977,26 @@ def process_video(
651
  cap.release()
652
  writer.release()
653
 
654
- # Save CSV time series. Add event-level detail as separate columns where possible.
655
  history_df = pd.DataFrame(history)
656
- history_df.to_csv(out_csv_path, index=False)
 
 
 
 
 
 
 
 
 
 
 
657
 
 
 
658
  final_plot = render_load_plot(history)
659
- final_summary = build_final_summary(
660
- total_count=total_count,
661
- class_counts=class_counts,
662
- cumulative_kg=cumulative_kg,
663
- peak_live_load_kg=peak_live_load_kg,
664
- peak_load_index=peak_load_index,
665
- csv_path=out_csv_path,
666
- )
667
-
668
- final_frame = None
669
- if history:
670
- # Try to show the last annotated frame from the output video.
671
- cap2 = cv2.VideoCapture(out_video_path)
672
- if cap2.isOpened():
673
- cap2.set(cv2.CAP_PROP_POS_FRAMES, max(0, frame_idx - 1))
674
- ok, last = cap2.read()
675
- if ok:
676
- final_frame = cv2.cvtColor(last, cv2.COLOR_BGR2RGB)
677
- cap2.release()
678
 
679
  yield (
680
- final_frame,
681
  build_metrics_html(
682
  total_count=total_count,
683
  class_counts=class_counts,
@@ -686,31 +1005,39 @@ def process_video(
686
  load_index_percent=0,
687
  frame_idx=total_frames if total_frames else frame_idx,
688
  total_frames=total_frames if total_frames else frame_idx,
689
- elapsed=time.time() - start_wall,
690
- device=DEVICE,
 
691
  ),
692
  final_plot,
693
- final_summary,
 
 
 
 
 
 
 
694
  out_video_path,
695
  out_csv_path,
696
  )
697
 
698
 
699
  # ---------------------------------------------------------------------
700
- # Gradio UI
701
  # ---------------------------------------------------------------------
702
  CUSTOM_CSS = """
703
  .gradio-container {
704
- max-width: 1320px !important;
705
  margin: auto !important;
706
  }
707
  #hero {
708
  text-align: center;
709
- padding: 18px 8px 8px 8px;
710
  }
711
  #hero h1 {
712
  font-weight: 850;
713
- letter-spacing: -0.6px;
714
  margin-bottom: 2px;
715
  }
716
  #hero p {
@@ -723,7 +1050,7 @@ CUSTOM_CSS = """
723
  border-radius: 18px;
724
  padding: 16px;
725
  background: #ffffff;
726
- box-shadow: 0 8px 24px rgba(15, 23, 42, 0.04);
727
  }
728
  #live-frame img, #load-plot img {
729
  border-radius: 14px;
@@ -734,7 +1061,7 @@ footer {
734
  """
735
 
736
  with gr.Blocks(
737
- title="Bridge Traffic Load Demo",
738
  theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
739
  css=CUSTOM_CSS,
740
  ) as demo:
@@ -742,53 +1069,68 @@ with gr.Blocks(
742
  with gr.Row(elem_id="hero"):
743
  gr.Markdown(
744
  """
745
- # 🌉 Bridge Traffic Load Demo
746
- Fast RF-DETR vehicle detection, ByteTrack tracking, line-crossing counts,
747
- estimated cumulative vehicle mass, and live bridge load-index over time.
748
  """
749
  )
750
 
 
 
 
 
 
751
  with gr.Row():
752
  with gr.Column(scale=1):
753
  with gr.Group(elem_classes="panel"):
754
- gr.Markdown("### 1) Upload video")
755
  video_input = gr.Video(
756
- label="Bridge traffic video",
757
  sources=["upload"],
 
758
  format="mp4",
759
  height=260,
760
  )
761
 
762
- start_btn = gr.Button("▶ Start analysis", variant="primary", size="lg")
763
 
764
- gr.Markdown("### 2) Speed settings")
765
- model_name = gr.Radio(
766
- choices=list(MODEL_OPTIONS.keys()),
767
- value="Nano - fastest",
768
- label="RF-DETR model",
 
 
 
769
  )
 
 
 
 
 
 
770
  confidence = gr.Slider(
771
  minimum=0.10,
772
  maximum=0.90,
773
- value=0.40,
774
  step=0.05,
775
  label="Confidence threshold",
776
  )
777
  frame_stride = gr.Slider(
778
  minimum=1,
779
- maximum=10,
780
  value=3,
781
  step=1,
782
  label="Frame stride",
783
- info="Detect every Nth frame. 1 is most accurate. 3-5 is much faster.",
784
  )
785
  inference_width = gr.Slider(
786
  minimum=384,
787
  maximum=1280,
788
  value=640,
789
  step=64,
790
- label="Inference width",
791
- info="Lower is faster. Try 512 or 640 for CPU demos.",
792
  )
793
 
794
  with gr.Accordion("Bridge settings", open=False):
@@ -814,23 +1156,30 @@ with gr.Blocks(
814
  label="Bridge deck ROI bottom (%)",
815
  )
816
  reference_capacity_tonnes = gr.Slider(
817
- minimum=5,
818
- maximum=200,
819
  value=40,
820
- step=5,
821
  label="Reference live-load capacity for demo index (tonnes)",
822
  )
823
 
824
- with gr.Accordion("Estimated class weights", open=False):
825
- car_weight_t = gr.Number(value=1.5, label="Car weight estimate (tonnes)")
826
- motorcycle_weight_t = gr.Number(value=0.25, label="Motorcycle weight estimate (tonnes)")
827
- bus_weight_t = gr.Number(value=12.0, label="Bus weight estimate (tonnes)")
828
- truck_weight_t = gr.Number(value=18.0, label="Truck weight estimate (tonnes)")
 
 
 
 
 
 
 
829
 
830
  gr.Markdown(
831
  """
832
- **For speed:** use **Nano**, inference width **512-640**, and frame stride **3-5**.
833
- Use **Medium** only when you need better detection and have a GPU.
834
  """
835
  )
836
 
@@ -840,7 +1189,7 @@ with gr.Blocks(
840
  live_frame = gr.Image(
841
  show_label=False,
842
  elem_id="live-frame",
843
- height=470,
844
  )
845
 
846
  with gr.Row():
@@ -850,14 +1199,15 @@ with gr.Blocks(
850
  metrics_html = gr.HTML(
851
  value=build_metrics_html(
852
  total_count=0,
853
- class_counts={"car": 0, "motorcycle": 0, "bus": 0, "truck": 0},
854
  cumulative_kg=0,
855
  live_load_kg=0,
856
  load_index_percent=0,
857
  frame_idx=0,
858
  total_frames=0,
859
  elapsed=0,
860
- device=DEVICE,
 
861
  )
862
  )
863
 
@@ -867,8 +1217,8 @@ with gr.Blocks(
867
  load_plot = gr.Image(
868
  show_label=False,
869
  elem_id="load-plot",
870
- height=310,
871
- value=render_load_plot([]),
872
  )
873
 
874
  with gr.Row():
@@ -876,39 +1226,60 @@ with gr.Blocks(
876
  with gr.Group(elem_classes="panel"):
877
  gr.Markdown("### Final annotated video")
878
  video_output = gr.Video(label="Replay / download annotated video", height=270)
 
879
  with gr.Column(scale=1):
880
  with gr.Group(elem_classes="panel"):
881
  gr.Markdown("### Final summary")
882
- summary_output = gr.Markdown("Run an analysis to see the final summary.")
883
- csv_output = gr.File(label="Download load-index CSV")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
884
 
885
  start_btn.click(
886
  fn=process_video,
887
- inputs=[
888
- video_input,
889
- model_name,
890
- confidence,
891
- frame_stride,
892
- inference_width,
893
- line_position_percent,
894
- roi_top_percent,
895
- roi_bottom_percent,
896
- reference_capacity_tonnes,
897
- car_weight_t,
898
- motorcycle_weight_t,
899
- bus_weight_t,
900
- truck_weight_t,
901
- ],
902
- outputs=[
903
- live_frame,
904
- metrics_html,
905
- load_plot,
906
- summary_output,
907
- video_output,
908
- csv_output,
909
- ],
910
  )
911
 
 
 
 
 
 
 
 
 
912
 
913
  if __name__ == "__main__":
914
- demo.queue(max_size=3).launch()
 
1
  """
2
+ Fast Bridge Traffic + Livestock Load Demo
 
3
  """
4
 
5
  import os
6
  import time
7
  import tempfile
8
+ import warnings
9
+ from pathlib import Path
10
  from functools import lru_cache
11
+ from typing import Dict, List, Tuple, Optional
12
 
13
  import cv2
14
  import gradio as gr
 
20
  import supervision as sv
21
  import torch
22
 
23
+ # Optional engines
24
+ try:
25
+ from ultralytics import YOLO
26
+ except Exception:
27
+ YOLO = None
28
+
29
+ try:
30
+ from rfdetr import RFDETRMedium
31
+ except Exception:
32
+ RFDETRMedium = None
33
 
34
 
35
  # ---------------------------------------------------------------------
36
+ # Quiet noisy dependency warning that is not controlled by this app.
37
+ # The RF-DETR/transformers warning is internal to the dependency stack.
38
  # ---------------------------------------------------------------------
39
+ warnings.filterwarnings("ignore", message=".*use_return_dict.*")
40
+ warnings.filterwarnings("ignore", message=".*`use_return_dict` is deprecated.*")
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
42
 
43
+ # ---------------------------------------------------------------------
44
+ # App paths and default local video
45
+ # ---------------------------------------------------------------------
46
+ APP_DIR = Path(__file__).resolve().parent
 
 
47
 
48
+ VIDEO_EXTENSIONS = [".mp4", ".mov", ".avi", ".mkv", ".webm"]
49
+
50
+ PREFERRED_VIDEO_NAMES = [
51
+ "bridge.mp4",
52
+ "traffic.mp4",
53
+ "cars.mp4",
54
+ "video.mp4",
55
+ "input.mp4",
56
+ "example.mp4",
57
+ "sample.mp4",
58
+ ]
59
+
60
+
61
+ def find_default_video() -> Optional[str]:
62
+ """Find a video sitting next to app.py."""
63
+ for name in PREFERRED_VIDEO_NAMES:
64
+ candidate = APP_DIR / name
65
+ if candidate.exists():
66
+ return str(candidate)
67
 
68
+ for ext in VIDEO_EXTENSIONS:
69
+ matches = sorted(APP_DIR.glob(f"*{ext}"))
70
+ if matches:
71
+ return str(matches[0])
72
+
73
+ return None
74
+
75
+
76
+ DEFAULT_VIDEO = find_default_video()
77
+
78
+
79
+ # ---------------------------------------------------------------------
80
+ # Device and speed setup
81
+ # ---------------------------------------------------------------------
82
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
83
 
84
  try:
 
87
  pass
88
 
89
  if DEVICE == "cuda":
90
+ try:
91
+ torch.backends.cudnn.benchmark = True
92
+ except Exception:
93
+ pass
94
+
95
+
96
+ # ---------------------------------------------------------------------
97
+ # Target classes and estimated weights
98
+ # ---------------------------------------------------------------------
99
+ # For YOLO COCO:
100
+ # person=0, bicycle=1, car=2, motorcycle=3, bus=5, truck=7,
101
+ # horse=17, sheep=18, cow=19.
102
+ #
103
+ # COCO does not have goat or donkey. We map:
104
+ # sheep -> sheep/goat
105
+ # horse -> horse/donkey
106
+ TARGET_CANONICAL_NAMES = {
107
+ "person",
108
+ "bicycle",
109
+ "car",
110
+ "motorcycle",
111
+ "bus",
112
+ "truck",
113
+ "cow",
114
+ "sheep",
115
+ "goat",
116
+ "horse",
117
+ "donkey",
118
+ }
119
+
120
+ DISPLAY_NAME = {
121
+ "person": "person",
122
+ "bicycle": "bicycle",
123
+ "car": "car",
124
+ "motorcycle": "motorcycle",
125
+ "bus": "bus",
126
+ "truck": "truck",
127
+ "cow": "cow",
128
+ "sheep": "sheep / goat",
129
+ "goat": "goat",
130
+ "horse": "horse / donkey",
131
+ "donkey": "donkey",
132
+ }
133
+
134
+ # COCO class names for RF-DETR outputs.
135
+ COCO_NAMES = {
136
+ 0: "person",
137
+ 1: "bicycle",
138
+ 2: "car",
139
+ 3: "motorcycle",
140
+ 5: "bus",
141
+ 7: "truck",
142
+ 17: "horse",
143
+ 18: "sheep",
144
+ 19: "cow",
145
+ }
146
+
147
+ # Approximate demo weights in kg.
148
+ # Adjust in the UI for your bridge/traffic context.
149
+ DEFAULT_WEIGHTS_KG = {
150
+ "person": 75,
151
+ "bicycle": 120, # bicycle + rider approximation
152
+ "motorcycle": 250,
153
+ "car": 1500,
154
+ "bus": 12000,
155
+ "truck": 18000,
156
+ "cow": 450,
157
+ "sheep": 60,
158
+ "goat": 45,
159
+ "horse": 350,
160
+ "donkey": 180,
161
+ }
162
+
163
+ COLOR_BY_NAME_BGR = {
164
+ "person": (70, 160, 245),
165
+ "bicycle": (240, 190, 80),
166
+ "motorcycle": (255, 150, 80),
167
+ "car": (60, 210, 130),
168
+ "bus": (50, 130, 245),
169
+ "truck": (220, 70, 180),
170
+ "cow": (160, 120, 80),
171
+ "sheep": (220, 220, 220),
172
+ "goat": (210, 210, 230),
173
+ "horse": (130, 90, 60),
174
+ "donkey": (120, 110, 95),
175
+ }
176
 
177
 
178
  # ---------------------------------------------------------------------
179
  # Model loading
180
  # ---------------------------------------------------------------------
181
+ @lru_cache(maxsize=4)
182
+ def load_yolo_model(model_file: str):
183
+ if YOLO is None:
184
+ raise RuntimeError(
185
+ "Ultralytics is not installed. Run: pip install ultralytics"
186
+ )
187
+
188
+ local_candidate = APP_DIR / model_file
189
+ model_path = str(local_candidate) if local_candidate.exists() else model_file
190
+
191
+ print(f"Loading YOLO model: {model_path} on {DEVICE}")
192
+ model = YOLO(model_path)
193
+
194
+ try:
195
+ model.to(DEVICE)
196
+ except Exception:
197
+ pass
198
+
199
+ return model
200
+
201
+
202
+ @lru_cache(maxsize=1)
203
+ def load_rfdetr_medium():
204
+ if RFDETRMedium is None:
205
+ raise RuntimeError(
206
+ "RF-DETR is not installed. Run: pip install rfdetr"
207
+ )
208
 
209
+ print(f"Loading RF-DETR Medium on {DEVICE}")
210
 
211
  try:
212
+ model = RFDETRMedium(device=DEVICE)
213
  except TypeError:
214
+ model = RFDETRMedium()
 
215
 
216
+ # This directly addresses:
217
+ # "Model is not optimized for inference. Latency may be higher..."
218
+ try:
219
+ model.optimize_for_inference()
220
+ print("RF-DETR Medium optimized for inference.")
221
+ except Exception as exc:
222
+ print(f"RF-DETR optimize_for_inference skipped: {exc}")
223
 
 
224
  return model
225
 
226
 
227
  # ---------------------------------------------------------------------
228
+ # Detection conversion
229
  # ---------------------------------------------------------------------
230
+ def yolo_predict_to_supervision(
231
+ model,
232
+ frame_bgr: np.ndarray,
233
+ confidence: float,
234
+ imgsz: int,
235
+ ) -> Tuple[sv.Detections, List[str]]:
236
+ """
237
+ Run YOLO and return supervision Detections plus canonical class names.
238
+ """
239
+ results = model.predict(
240
+ source=frame_bgr,
241
+ conf=float(confidence),
242
+ imgsz=int(imgsz),
243
+ device=0 if DEVICE == "cuda" else "cpu",
244
+ verbose=False,
245
+ )[0]
246
+
247
+ if results.boxes is None or len(results.boxes) == 0:
248
+ return sv.Detections.empty(), []
249
+
250
+ xyxy = results.boxes.xyxy.detach().cpu().numpy()
251
+ conf = results.boxes.conf.detach().cpu().numpy()
252
+ cls = results.boxes.cls.detach().cpu().numpy().astype(int)
253
+
254
+ names = model.names if hasattr(model, "names") else {}
255
+ canonical_names = []
256
+ keep = []
257
+
258
+ for i, class_id in enumerate(cls):
259
+ name = str(names.get(int(class_id), class_id)).lower().strip()
260
+ if name in TARGET_CANONICAL_NAMES:
261
+ canonical_names.append(name)
262
+ keep.append(i)
263
+ elif name == "automobile":
264
+ canonical_names.append("car")
265
+ keep.append(i)
266
+ elif name == "lorry":
267
+ canonical_names.append("truck")
268
+ keep.append(i)
269
+
270
+ if not keep:
271
+ return sv.Detections.empty(), []
272
+
273
+ keep = np.array(keep, dtype=int)
274
+ detections = sv.Detections(
275
+ xyxy=xyxy[keep],
276
+ confidence=conf[keep],
277
+ class_id=cls[keep],
278
+ )
279
+ canonical_names = [canonical_names[j] for j in range(len(canonical_names))]
280
+
281
+ return detections, canonical_names
282
+
283
+
284
+ def rfdetr_predict_to_supervision(
285
  model,
286
  frame_bgr: np.ndarray,
287
  confidence: float,
288
  inference_width: int,
289
+ ) -> Tuple[sv.Detections, List[str]]:
290
  """
291
+ Run RF-DETR Medium. Resize frame before inference for speed, then scale boxes back.
 
292
  """
293
  h, w = frame_bgr.shape[:2]
 
294
 
295
  if inference_width > 0 and w > inference_width:
296
+ scale = float(inference_width) / float(w)
297
+ resized = cv2.resize(
298
+ frame_bgr,
299
+ (int(w * scale), int(h * scale)),
300
+ interpolation=cv2.INTER_AREA,
301
+ )
302
  else:
303
  scale = 1.0
304
+ resized = frame_bgr
305
 
306
+ rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
307
 
308
  with torch.inference_mode():
309
+ detections = model.predict(rgb, threshold=float(confidence))
310
 
311
  if len(detections) == 0:
312
+ return detections, []
313
 
314
+ canonical_names = []
315
+ keep = []
 
316
 
317
+ for i, cid in enumerate(detections.class_id):
318
+ cid = int(cid)
319
+ name = COCO_NAMES.get(cid)
320
+ if name in TARGET_CANONICAL_NAMES:
321
+ keep.append(i)
322
+ canonical_names.append(name)
323
 
324
+ if not keep:
325
+ return sv.Detections.empty(), []
326
+
327
+ keep = np.array(keep, dtype=int)
328
+ detections = detections[keep]
329
+
330
+ if scale != 1.0 and len(detections) > 0:
331
  detections.xyxy = detections.xyxy / scale
332
 
333
+ return detections, canonical_names
334
+
335
+
336
+ def predict_objects(
337
+ engine: str,
338
+ yolo_model_file: str,
339
+ frame_bgr: np.ndarray,
340
+ confidence: float,
341
+ inference_width: int,
342
+ ) -> Tuple[sv.Detections, List[str]]:
343
+ if engine.startswith("YOLO"):
344
+ model = load_yolo_model(yolo_model_file)
345
+ return yolo_predict_to_supervision(
346
+ model=model,
347
+ frame_bgr=frame_bgr,
348
+ confidence=confidence,
349
+ imgsz=inference_width,
350
+ )
351
+
352
+ model = load_rfdetr_medium()
353
+ return rfdetr_predict_to_supervision(
354
+ model=model,
355
+ frame_bgr=frame_bgr,
356
+ confidence=confidence,
357
+ inference_width=inference_width,
358
+ )
359
 
360
 
361
  # ---------------------------------------------------------------------
362
+ # Helpers
363
  # ---------------------------------------------------------------------
364
+ def side_of_line(y: float, line_y: int, dead_zone_px: int = 5) -> int:
 
 
 
 
365
  diff = y - line_y
366
  if abs(diff) <= dead_zone_px:
367
  return 0
 
372
  if len(detections) == 0:
373
  return np.empty((0, 2), dtype=float)
374
  xyxy = detections.xyxy
375
+ return np.column_stack([
376
+ (xyxy[:, 0] + xyxy[:, 2]) / 2.0,
377
+ (xyxy[:, 1] + xyxy[:, 3]) / 2.0,
378
+ ])
379
+
380
+
381
+ def make_empty_plot() -> np.ndarray:
382
+ img = np.ones((300, 620, 3), dtype=np.uint8) * 255
383
+ cv2.putText(
384
+ img,
385
+ "Bridge load index chart will appear here",
386
+ (70, 155),
387
+ cv2.FONT_HERSHEY_SIMPLEX,
388
+ 0.75,
389
+ (90, 90, 90),
390
+ 2,
391
+ cv2.LINE_AA,
392
+ )
393
+ return img
394
 
395
 
396
+ def render_load_plot(history: List[Dict]) -> np.ndarray:
397
+ if not history:
398
+ return make_empty_plot()
399
 
400
+ df = pd.DataFrame(history)
401
+ if len(df) > 600:
402
+ df = df.iloc[np.linspace(0, len(df) - 1, 600).astype(int)]
403
 
404
+ fig, ax = plt.subplots(figsize=(8.0, 3.5), dpi=100)
405
+ ax.plot(df["time_s"], df["load_index_percent"], linewidth=2)
406
+ ax.set_title("Estimated Bridge Load Index Over Time")
407
+ ax.set_xlabel("Video time (seconds)")
408
+ ax.set_ylabel("Load index (%)")
409
+ ax.grid(True, alpha=0.25)
410
+ ax.set_ylim(bottom=0)
411
+ fig.tight_layout()
412
+
413
+ fig.canvas.draw()
414
+ rgba = np.asarray(fig.canvas.buffer_rgba())
415
+ rgb = cv2.cvtColor(rgba, cv2.COLOR_RGBA2RGB)
416
+ plt.close(fig)
417
+ return rgb
418
+
419
+
420
+ def build_metrics_html(
421
  total_count: int,
422
+ class_counts: Dict[str, int],
423
  cumulative_kg: float,
424
  live_load_kg: float,
425
  load_index_percent: float,
426
+ frame_idx: int,
427
+ total_frames: int,
428
+ elapsed: float,
429
+ proc_fps: float,
430
+ engine: str,
431
+ ) -> str:
432
+ pct = (frame_idx / total_frames * 100.0) if total_frames else 0.0
433
+ tonnes = cumulative_kg / 1000.0
434
+ live_tonnes = live_load_kg / 1000.0
435
+
436
+ def c(name: str) -> int:
437
+ return int(class_counts.get(name, 0))
438
 
439
+ return f"""
440
+ <div style="font-family:Inter,system-ui,Arial;">
441
+ <div style="display:grid;grid-template-columns:1fr 1fr;gap:10px;margin-bottom:12px;">
442
+ <div style="padding:16px;border-radius:18px;background:linear-gradient(135deg,#1d4ed8,#312e81);color:white;">
443
+ <div style="font-size:11px;letter-spacing:1px;opacity:.86;">OBJECTS CROSSED</div>
444
+ <div style="font-size:46px;font-weight:850;line-height:1;">{total_count}</div>
445
+ </div>
446
+ <div style="padding:16px;border-radius:18px;background:linear-gradient(135deg,#be185d,#7e22ce);color:white;">
447
+ <div style="font-size:11px;letter-spacing:1px;opacity:.86;">CUMULATIVE EST. MASS</div>
448
+ <div style="font-size:36px;font-weight:850;line-height:1;">{tonnes:.1f} t</div>
449
+ </div>
450
+ </div>
451
+
452
+ <div style="display:grid;grid-template-columns:1fr 1fr;gap:10px;margin-bottom:12px;">
453
+ <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:white;">
454
+ <div style="font-size:12px;color:#6b7280;">Live bridge load</div>
455
+ <div style="font-size:28px;font-weight:800;color:#111827;">{live_tonnes:.1f} t</div>
456
+ </div>
457
+ <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:white;">
458
+ <div style="font-size:12px;color:#6b7280;">Load index</div>
459
+ <div style="font-size:28px;font-weight:800;color:#111827;">{load_index_percent:.1f}%</div>
460
+ </div>
461
+ </div>
462
 
463
+ <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:#ffffff;margin-bottom:12px;">
464
+ <div style="font-size:12px;color:#6b7280;margin-bottom:8px;">Crossings by class</div>
465
+ <div style="display:grid;grid-template-columns:1fr 1fr;gap:7px;font-size:13px;">
466
+ <div>🚶 People: <b>{c("person")}</b></div>
467
+ <div>🚗 Cars: <b>{c("car")}</b></div>
468
+ <div>🏍️ Motorcycles: <b>{c("motorcycle")}</b></div>
469
+ <div>🚲 Bicycles: <b>{c("bicycle")}</b></div>
470
+ <div>🚌 Buses: <b>{c("bus")}</b></div>
471
+ <div>🚛 Trucks: <b>{c("truck")}</b></div>
472
+ <div>🐄 Cows: <b>{c("cow")}</b></div>
473
+ <div>🐑 Sheep/goats: <b>{c("sheep") + c("goat")}</b></div>
474
+ <div>🐴 Horse/donkey: <b>{c("horse") + c("donkey")}</b></div>
475
+ </div>
476
+ </div>
477
 
478
+ <div style="font-size:12px;color:#6b7280;margin-bottom:4px;display:flex;justify-content:space-between;">
479
+ <span>Frame {frame_idx} / {total_frames}</span>
480
+ <span>{pct:.1f}% · {elapsed:.1f}s · {proc_fps:.1f} FPS · {DEVICE} · {engine}</span>
481
+ </div>
482
+ <div style="height:8px;background:#e5e7eb;border-radius:999px;overflow:hidden;">
483
+ <div style="height:100%;width:{pct:.2f}%;background:#4f46e5;"></div>
484
+ </div>
485
+ </div>
486
+ """
487
 
 
 
488
 
489
+ def draw_dashboard(
490
+ frame: np.ndarray,
491
+ total_count: int,
492
+ cumulative_kg: float,
493
+ live_load_kg: float,
494
+ load_index_percent: float,
495
+ proc_fps: float,
496
+ engine: str,
497
+ ) -> np.ndarray:
498
+ overlay = frame.copy()
499
+ x1, y1, x2, y2 = 18, 18, 600, 164
500
+ cv2.rectangle(overlay, (x1, y1), (x2, y2), (18, 24, 38), -1)
501
+ frame = cv2.addWeighted(overlay, 0.82, frame, 0.18, 0)
502
+
503
+ cv2.putText(
504
+ frame,
505
+ "BRIDGE TRAFFIC + LIVESTOCK DEMO",
506
+ (34, 48),
507
+ cv2.FONT_HERSHEY_SIMPLEX,
508
+ 0.72,
509
+ (255, 255, 255),
510
+ 2,
511
+ cv2.LINE_AA,
512
+ )
513
+ cv2.putText(
514
+ frame,
515
+ f"Crossed: {total_count} | Cumulative est. mass: {cumulative_kg/1000.0:.1f} t",
516
+ (34, 82),
517
+ cv2.FONT_HERSHEY_SIMPLEX,
518
+ 0.58,
519
+ (230, 240, 255),
520
+ 2,
521
+ cv2.LINE_AA,
522
+ )
523
+ cv2.putText(
524
+ frame,
525
+ f"Live load: {live_load_kg/1000.0:.1f} t | Load index: {load_index_percent:.1f}%",
526
+ (34, 114),
527
+ cv2.FONT_HERSHEY_SIMPLEX,
528
+ 0.58,
529
+ (220, 245, 230),
530
+ 2,
531
+ cv2.LINE_AA,
532
+ )
533
+ cv2.putText(
534
+ frame,
535
+ f"{proc_fps:.1f} processing FPS | {DEVICE} | {engine}",
536
+ (34, 144),
537
+ cv2.FONT_HERSHEY_SIMPLEX,
538
+ 0.52,
539
+ (230, 230, 255),
540
+ 1,
541
+ cv2.LINE_AA,
542
+ )
543
  return frame
544
 
545
 
546
  def annotate_frame(
547
  frame: np.ndarray,
548
  detections: sv.Detections,
549
+ canonical_names: List[str],
550
  line_y: int,
551
  roi_top_y: int,
552
  roi_bottom_y: int,
 
555
  cumulative_kg: float,
556
  live_load_kg: float,
557
  load_index_percent: float,
558
+ proc_fps: float,
559
+ engine: str,
560
  ) -> np.ndarray:
 
561
  h, w = frame.shape[:2]
562
 
563
+ # Bridge deck ROI.
564
  overlay = frame.copy()
565
+ cv2.rectangle(overlay, (0, roi_top_y), (w, roi_bottom_y), (90, 90, 90), -1)
566
  frame = cv2.addWeighted(overlay, 0.08, frame, 0.92, 0)
567
 
568
+ # Counting line.
569
  cv2.line(frame, (0, line_y), (w, line_y), (40, 230, 255), 3)
570
+ cv2.putText(
571
+ frame,
572
+ "COUNTING LINE",
573
+ (24, max(28, line_y - 12)),
574
+ cv2.FONT_HERSHEY_SIMPLEX,
575
+ 0.60,
576
+ (40, 230, 255),
577
+ 2,
578
+ cv2.LINE_AA,
579
+ )
580
 
581
+ # ROI borders.
582
  cv2.line(frame, (0, roi_top_y), (w, roi_top_y), (170, 170, 170), 1)
583
  cv2.line(frame, (0, roi_bottom_y), (w, roi_bottom_y), (170, 170, 170), 1)
584
 
 
591
  if confidences is None:
592
  confidences = [0.0] * len(detections)
593
 
594
+ for i, (xyxy, conf, tid) in enumerate(zip(detections.xyxy, confidences, tracker_ids)):
595
+ if i >= len(canonical_names):
596
+ name = "object"
597
+ else:
598
+ name = canonical_names[i]
599
+
 
600
  x1, y1, x2, y2 = map(int, xyxy)
601
+ color = COLOR_BY_NAME_BGR.get(name, (80, 220, 255))
602
+ display = DISPLAY_NAME.get(name, name)
603
+ weight_t = DEFAULT_WEIGHTS_KG.get(name, 0) / 1000.0
604
 
605
  cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
606
 
607
+ id_txt = f"#{int(tid)} " if tid is not None and int(tid) >= 0 else ""
608
+ label = f"{id_txt}{display} {float(conf):.2f} ~{weight_t:.2f}t"
609
  (tw, th), base = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.52, 1)
 
610
  label_y1 = max(0, y1 - th - base - 8)
611
  cv2.rectangle(frame, (x1, label_y1), (x1 + tw + 10, y1), color, -1)
612
+ cv2.putText(
613
+ frame,
614
+ label,
615
+ (x1 + 5, y1 - 6),
616
+ cv2.FONT_HERSHEY_SIMPLEX,
617
+ 0.52,
618
+ (255, 255, 255),
619
+ 1,
620
+ cv2.LINE_AA,
621
+ )
622
 
623
+ frame = draw_dashboard(
624
  frame=frame,
625
  total_count=total_count,
626
  cumulative_kg=cumulative_kg,
627
  live_load_kg=live_load_kg,
628
  load_index_percent=load_index_percent,
629
+ proc_fps=proc_fps,
630
+ engine=engine,
631
  )
632
 
633
+ compact_items = []
634
+ for k in ["person", "car", "motorcycle", "bicycle", "bus", "truck", "cow", "sheep", "goat", "horse", "donkey"]:
635
+ v = int(class_counts.get(k, 0))
636
+ if v > 0:
637
+ compact_items.append(f"{DISPLAY_NAME.get(k, k)}: {v}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
638
 
639
+ text = " | ".join(compact_items) if compact_items else "No crossings yet"
640
+ cv2.putText(frame, text[:140], (22, h - 24), cv2.FONT_HERSHEY_SIMPLEX, 0.58, (255, 255, 255), 2, cv2.LINE_AA)
 
 
 
 
 
 
641
 
642
+ return frame
 
 
 
 
643
 
644
 
645
+ def final_summary_md(
646
  total_count: int,
647
  class_counts: Dict[str, int],
648
  cumulative_kg: float,
649
  peak_live_load_kg: float,
650
  peak_load_index: float,
651
+ auto_video_used: str,
652
  ) -> str:
653
+ rows = []
654
+ for name in ["person", "bicycle", "car", "motorcycle", "bus", "truck", "cow", "sheep", "goat", "horse", "donkey"]:
655
+ count = int(class_counts.get(name, 0))
656
+ if count > 0:
657
+ rows.append(f"| {DISPLAY_NAME.get(name, name)} | {count} |")
658
 
659
+ if not rows:
660
+ rows.append("| None | 0 |")
661
 
662
+ video_line = f"\n**Default video used:** `{auto_video_used}`\n" if auto_video_used else ""
663
 
664
+ return f"""
665
+ ### Final summary
666
+ {video_line}
667
+ **Total crossings:** {total_count}
 
 
668
 
669
+ | Class | Count |
670
+ |---|---:|
671
+ {chr(10).join(rows)}
672
 
673
+ **Cumulative estimated mass:** {cumulative_kg/1000.0:.2f} tonnes
674
+ **Peak estimated live load:** {peak_live_load_kg/1000.0:.2f} tonnes
675
+ **Peak bridge load index:** {peak_load_index:.1f}%
676
 
677
+ This is a demonstration traffic-load indicator. Real bridge stress needs axle loads, bridge geometry, material properties, span length, lane position and engineering calibration.
678
  """
679
 
680
 
681
  # ---------------------------------------------------------------------
682
+ # Main video processing generator
683
  # ---------------------------------------------------------------------
684
  def process_video(
685
  video_path,
686
+ engine,
687
+ yolo_model_file,
688
  confidence,
689
  frame_stride,
690
  inference_width,
 
692
  roi_top_percent,
693
  roi_bottom_percent,
694
  reference_capacity_tonnes,
695
+ person_weight_kg,
696
+ bicycle_weight_kg,
697
+ motorcycle_weight_kg,
698
  car_weight_t,
 
699
  bus_weight_t,
700
  truck_weight_t,
701
+ cow_weight_kg,
702
+ sheep_weight_kg,
703
+ goat_weight_kg,
704
+ horse_weight_kg,
705
+ donkey_weight_kg,
706
  ):
707
  if video_path is None:
708
  yield (
709
  None,
710
+ build_metrics_html(0, {}, 0, 0, 0, 0, 0, 0, 0, str(engine)),
711
+ make_empty_plot(),
712
+ "No video found. Put an `.mp4` file in the same folder as `app.py`, or upload one.",
713
  None,
714
  None,
715
  )
716
  return
717
 
718
+ # Gradio can pass a dict in some versions.
719
+ if isinstance(video_path, dict):
720
+ video_path = video_path.get("path") or video_path.get("name")
721
+
722
+ if not video_path or not os.path.exists(video_path):
723
+ yield (
724
+ None,
725
+ build_metrics_html(0, {}, 0, 0, 0, 0, 0, 0, 0, str(engine)),
726
+ make_empty_plot(),
727
+ f"Video not found: {video_path}",
728
+ None,
729
+ None,
730
+ )
731
+ return
732
+
733
+ DEFAULT_WEIGHTS_KG.update({
734
+ "person": int(person_weight_kg),
735
+ "bicycle": int(bicycle_weight_kg),
736
+ "motorcycle": int(motorcycle_weight_kg),
737
+ "car": int(float(car_weight_t) * 1000),
738
+ "bus": int(float(bus_weight_t) * 1000),
739
+ "truck": int(float(truck_weight_t) * 1000),
740
+ "cow": int(cow_weight_kg),
741
+ "sheep": int(sheep_weight_kg),
742
+ "goat": int(goat_weight_kg),
743
+ "horse": int(horse_weight_kg),
744
+ "donkey": int(donkey_weight_kg),
745
+ })
746
 
747
  cap = cv2.VideoCapture(video_path)
748
  if not cap.isOpened():
749
  raise RuntimeError(f"Could not open video: {video_path}")
750
 
751
+ fps = float(cap.get(cv2.CAP_PROP_FPS) or 25.0)
752
+ if fps <= 1:
753
  fps = 25.0
754
 
755
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
 
765
  roi_bottom_y = int(height * float(roi_bottom_percent) / 100.0)
766
 
767
  if roi_bottom_y <= roi_top_y:
768
+ roi_top_y = int(height * 0.20)
769
  roi_bottom_y = int(height * 0.90)
770
 
771
  reference_capacity_kg = max(1.0, float(reference_capacity_tonnes) * 1000.0)
772
 
773
  yield (
774
  None,
775
+ build_metrics_html(0, {}, 0, 0, 0, 0, total_frames, 0, 0, str(engine)),
776
+ make_empty_plot(),
777
+ f"### Starting analysis on `{Path(video_path).name}`...",
778
  None,
779
  None,
780
  )
781
 
782
+ # Preload model before loop.
783
+ if str(engine).startswith("YOLO"):
784
+ _ = load_yolo_model(str(yolo_model_file))
785
+ else:
786
+ _ = load_rfdetr_medium()
787
+
788
  tracker = sv.ByteTrack(frame_rate=int(round(fps)))
789
 
 
790
  out_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
791
  out_csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name
792
 
793
+ writer = cv2.VideoWriter(
794
+ out_video_path,
795
+ cv2.VideoWriter_fourcc(*"mp4v"),
796
+ fps,
797
+ (width, height),
798
+ )
799
 
 
800
  last_detections = sv.Detections.empty()
801
+ last_names: List[str] = []
802
+
803
  last_side_by_id: Dict[int, int] = {}
804
  counted_ids = set()
805
+ track_name_by_id: Dict[int, str] = {}
806
 
807
+ class_counts = {name: 0 for name in TARGET_CANONICAL_NAMES}
808
  total_count = 0
809
  cumulative_kg = 0.0
810
 
811
  history: List[Dict] = []
812
+ events: List[Dict] = []
 
 
 
 
 
813
 
814
  peak_live_load_kg = 0.0
815
  peak_load_index = 0.0
816
 
817
+ start_wall = time.time()
818
+ last_yield_wall = 0.0
819
+ last_plot_wall = 0.0
820
+ latest_plot = make_empty_plot()
821
+ processed = 0
822
  frame_idx = 0
823
+ final_frame_rgb = None
824
 
825
  while True:
826
  ok, frame = cap.read()
827
  if not ok:
828
  break
829
 
830
+ if frame_idx % int(frame_stride) == 0:
831
+ detections, names = predict_objects(
832
+ engine=str(engine),
833
+ yolo_model_file=str(yolo_model_file),
 
834
  frame_bgr=frame,
835
  confidence=float(confidence),
836
  inference_width=int(inference_width),
837
  )
838
  detections = tracker.update_with_detections(detections)
839
+
840
+ # Preserve name alignment after tracker update.
841
+ # ByteTrack keeps detections order, so this is usually aligned.
842
+ if len(names) != len(detections):
843
+ names = names[:len(detections)]
844
+ if len(names) < len(detections):
845
+ names += ["object"] * (len(detections) - len(names))
846
+
847
  last_detections = detections
848
+ last_names = names
849
  else:
850
  detections = last_detections
851
+ names = last_names
852
 
 
853
  centres = detection_centres(detections)
854
 
855
  live_load_kg = 0.0
 
856
 
857
  if len(detections) > 0 and detections.tracker_id is not None:
858
+ for i, (centre, tid) in enumerate(zip(centres, detections.tracker_id)):
859
+ if tid is None or int(tid) < 0:
 
 
860
  continue
861
 
862
+ tid = int(tid)
863
+ name = names[i] if i < len(names) else track_name_by_id.get(tid, "object")
864
+ if name == "object":
865
+ continue
866
 
867
+ track_name_by_id[tid] = name
 
868
 
869
+ cy = float(centre[1])
870
+
871
+ # Live load only for objects currently inside bridge deck ROI.
872
  if roi_top_y <= cy <= roi_bottom_y:
873
+ live_load_kg += float(DEFAULT_WEIGHTS_KG.get(name, 0))
874
 
875
  current_side = side_of_line(cy, line_y)
876
  previous_side = last_side_by_id.get(tid)
877
 
878
  if current_side != 0:
879
+ if previous_side is not None and previous_side != 0 and previous_side != current_side:
880
+ if tid not in counted_ids:
 
 
 
 
 
881
  counted_ids.add(tid)
882
  total_count += 1
883
+ class_counts[name] = int(class_counts.get(name, 0)) + 1
884
+ weight_kg = float(DEFAULT_WEIGHTS_KG.get(name, 0))
885
+ cumulative_kg += weight_kg
886
 
887
+ direction = "down" if previous_side < current_side else "up"
888
+ events.append({
889
  "video_time_s": frame_idx / fps,
890
  "frame": frame_idx,
891
  "tracker_id": tid,
892
+ "object_type": name,
893
+ "display_type": DISPLAY_NAME.get(name, name),
894
  "direction": direction,
895
+ "estimated_weight_kg": weight_kg,
896
  "cumulative_estimated_mass_kg": cumulative_kg,
897
  })
898
 
 
902
  peak_live_load_kg = max(peak_live_load_kg, live_load_kg)
903
  peak_load_index = max(peak_load_index, load_index_percent)
904
 
905
+ elapsed = time.time() - start_wall
906
+ processed += 1
907
+ proc_fps = processed / max(elapsed, 1e-6)
908
+
909
  history.append({
 
910
  "time_s": frame_idx / fps,
911
  "frame": frame_idx,
912
+ "total_crossings": total_count,
913
+ "people_crossed": class_counts.get("person", 0),
914
+ "bicycles_crossed": class_counts.get("bicycle", 0),
915
  "cars_crossed": class_counts.get("car", 0),
916
  "motorcycles_crossed": class_counts.get("motorcycle", 0),
917
  "buses_crossed": class_counts.get("bus", 0),
918
  "trucks_crossed": class_counts.get("truck", 0),
919
+ "cows_crossed": class_counts.get("cow", 0),
920
+ "sheep_goats_crossed": class_counts.get("sheep", 0) + class_counts.get("goat", 0),
921
+ "horse_donkey_crossed": class_counts.get("horse", 0) + class_counts.get("donkey", 0),
922
  "live_load_kg": live_load_kg,
923
  "live_load_tonnes": live_load_kg / 1000.0,
924
  "load_index_percent": load_index_percent,
 
926
  "cumulative_estimated_mass_tonnes": cumulative_kg / 1000.0,
927
  })
928
 
 
 
 
 
 
929
  annotated = annotate_frame(
930
  frame=frame,
931
  detections=detections,
932
+ canonical_names=names,
933
  line_y=line_y,
934
  roi_top_y=roi_top_y,
935
  roi_bottom_y=roi_bottom_y,
 
938
  cumulative_kg=cumulative_kg,
939
  live_load_kg=live_load_kg,
940
  load_index_percent=load_index_percent,
941
+ proc_fps=proc_fps,
942
+ engine=str(engine),
943
  )
944
+
945
  writer.write(annotated)
946
+ final_frame_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
947
 
948
  now = time.time()
949
+ if now - last_plot_wall >= 1.0:
950
+ latest_plot = render_load_plot(history)
951
+ last_plot_wall = now
952
+
953
  if now - last_yield_wall >= 0.35:
954
  last_yield_wall = now
 
 
 
955
  yield (
956
+ final_frame_rgb,
957
  build_metrics_html(
958
  total_count=total_count,
959
  class_counts=class_counts,
 
962
  load_index_percent=load_index_percent,
963
  frame_idx=frame_idx + 1,
964
  total_frames=total_frames,
965
+ elapsed=elapsed,
966
+ proc_fps=proc_fps,
967
+ engine=str(engine),
968
  ),
969
+ latest_plot,
970
  "### Live analysis running...",
971
  None,
972
  None,
 
977
  cap.release()
978
  writer.release()
979
 
 
980
  history_df = pd.DataFrame(history)
981
+ events_df = pd.DataFrame(events)
982
+
983
+ if not events_df.empty:
984
+ # Save both frame-level history and crossing events in one CSV-like file
985
+ # by writing two separate CSV sections.
986
+ with open(out_csv_path, "w", encoding="utf-8") as f:
987
+ f.write("# FRAME_LEVEL_LOAD_INDEX\n")
988
+ history_df.to_csv(f, index=False)
989
+ f.write("\n# CROSSING_EVENTS\n")
990
+ events_df.to_csv(f, index=False)
991
+ else:
992
+ history_df.to_csv(out_csv_path, index=False)
993
 
994
+ elapsed = time.time() - start_wall
995
+ proc_fps = processed / max(elapsed, 1e-6)
996
  final_plot = render_load_plot(history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
997
 
998
  yield (
999
+ final_frame_rgb,
1000
  build_metrics_html(
1001
  total_count=total_count,
1002
  class_counts=class_counts,
 
1005
  load_index_percent=0,
1006
  frame_idx=total_frames if total_frames else frame_idx,
1007
  total_frames=total_frames if total_frames else frame_idx,
1008
+ elapsed=elapsed,
1009
+ proc_fps=proc_fps,
1010
+ engine=str(engine),
1011
  ),
1012
  final_plot,
1013
+ final_summary_md(
1014
+ total_count=total_count,
1015
+ class_counts=class_counts,
1016
+ cumulative_kg=cumulative_kg,
1017
+ peak_live_load_kg=peak_live_load_kg,
1018
+ peak_load_index=peak_load_index,
1019
+ auto_video_used=video_path if str(video_path).startswith(str(APP_DIR)) else "",
1020
+ ),
1021
  out_video_path,
1022
  out_csv_path,
1023
  )
1024
 
1025
 
1026
  # ---------------------------------------------------------------------
1027
+ # UI
1028
  # ---------------------------------------------------------------------
1029
  CUSTOM_CSS = """
1030
  .gradio-container {
1031
+ max-width: 1360px !important;
1032
  margin: auto !important;
1033
  }
1034
  #hero {
1035
  text-align: center;
1036
+ padding: 16px 8px 6px 8px;
1037
  }
1038
  #hero h1 {
1039
  font-weight: 850;
1040
+ letter-spacing: -0.8px;
1041
  margin-bottom: 2px;
1042
  }
1043
  #hero p {
 
1050
  border-radius: 18px;
1051
  padding: 16px;
1052
  background: #ffffff;
1053
+ box-shadow: 0 8px 24px rgba(15, 23, 42, 0.045);
1054
  }
1055
  #live-frame img, #load-plot img {
1056
  border-radius: 14px;
 
1061
  """
1062
 
1063
  with gr.Blocks(
1064
+ title="Fast Bridge Traffic + Livestock Load Demo",
1065
  theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
1066
  css=CUSTOM_CSS,
1067
  ) as demo:
 
1069
  with gr.Row(elem_id="hero"):
1070
  gr.Markdown(
1071
  """
1072
+ # 🌉 Fast Bridge Traffic + Livestock Load Demo
1073
+ YOLO-small / RF-DETR Medium detection, ByteTrack tracking, line-crossing counts,
1074
+ estimated object weights, and live bridge load-index over time.
1075
  """
1076
  )
1077
 
1078
+ if DEFAULT_VIDEO:
1079
+ gr.Markdown(f"✅ Found default video next to `app.py`: `{Path(DEFAULT_VIDEO).name}`. The app will auto-start inference when opened.")
1080
+ else:
1081
+ gr.Markdown("⚠️ No local video found next to `app.py`. Upload a video or place `bridge.mp4`, `traffic.mp4`, `input.mp4`, or any `.mp4` in the same folder.")
1082
+
1083
  with gr.Row():
1084
  with gr.Column(scale=1):
1085
  with gr.Group(elem_classes="panel"):
1086
+ gr.Markdown("### 1) Video")
1087
  video_input = gr.Video(
1088
+ label="Video input",
1089
  sources=["upload"],
1090
+ value=DEFAULT_VIDEO,
1091
  format="mp4",
1092
  height=260,
1093
  )
1094
 
1095
+ start_btn = gr.Button("▶ Start / rerun analysis", variant="primary", size="lg")
1096
 
1097
+ gr.Markdown("### 2) Inference engine")
1098
+ engine = gr.Radio(
1099
+ choices=[
1100
+ "YOLO small - fastest recommended",
1101
+ "RF-DETR Medium - slower but strong",
1102
+ ],
1103
+ value="YOLO small - fastest recommended",
1104
+ label="Engine",
1105
  )
1106
+ yolo_model_file = gr.Textbox(
1107
+ value="yolo11s.pt",
1108
+ label="YOLO model file/name",
1109
+ info="Use yolo11s.pt for small. Put your custom .pt in the same folder as app.py and type its filename here.",
1110
+ )
1111
+
1112
  confidence = gr.Slider(
1113
  minimum=0.10,
1114
  maximum=0.90,
1115
+ value=0.35,
1116
  step=0.05,
1117
  label="Confidence threshold",
1118
  )
1119
  frame_stride = gr.Slider(
1120
  minimum=1,
1121
+ maximum=12,
1122
  value=3,
1123
  step=1,
1124
  label="Frame stride",
1125
+ info="Detect every Nth frame. 3-5 is much faster than every frame.",
1126
  )
1127
  inference_width = gr.Slider(
1128
  minimum=384,
1129
  maximum=1280,
1130
  value=640,
1131
  step=64,
1132
+ label="Inference image size / width",
1133
+ info="Lower is faster. Try 512 or 640 for fast demos.",
1134
  )
1135
 
1136
  with gr.Accordion("Bridge settings", open=False):
 
1156
  label="Bridge deck ROI bottom (%)",
1157
  )
1158
  reference_capacity_tonnes = gr.Slider(
1159
+ minimum=1,
1160
+ maximum=250,
1161
  value=40,
1162
+ step=1,
1163
  label="Reference live-load capacity for demo index (tonnes)",
1164
  )
1165
 
1166
+ with gr.Accordion("Estimated weights", open=False):
1167
+ person_weight_kg = gr.Number(value=75, label="Person weight estimate (kg)")
1168
+ bicycle_weight_kg = gr.Number(value=120, label="Bicycle + rider estimate (kg)")
1169
+ motorcycle_weight_kg = gr.Number(value=250, label="Motorcycle estimate (kg)")
1170
+ car_weight_t = gr.Number(value=1.5, label="Car estimate (tonnes)")
1171
+ bus_weight_t = gr.Number(value=12.0, label="Bus estimate (tonnes)")
1172
+ truck_weight_t = gr.Number(value=18.0, label="Truck estimate (tonnes)")
1173
+ cow_weight_kg = gr.Number(value=450, label="Cow estimate (kg)")
1174
+ sheep_weight_kg = gr.Number(value=60, label="Sheep estimate (kg)")
1175
+ goat_weight_kg = gr.Number(value=45, label="Goat estimate (kg)")
1176
+ horse_weight_kg = gr.Number(value=350, label="Horse estimate (kg)")
1177
+ donkey_weight_kg = gr.Number(value=180, label="Donkey estimate (kg)")
1178
 
1179
  gr.Markdown(
1180
  """
1181
+ **Fast demo settings:** YOLO small, confidence 0.30-0.40,
1182
+ frame stride 3-5, image size 512-640.
1183
  """
1184
  )
1185
 
 
1189
  live_frame = gr.Image(
1190
  show_label=False,
1191
  elem_id="live-frame",
1192
+ height=500,
1193
  )
1194
 
1195
  with gr.Row():
 
1199
  metrics_html = gr.HTML(
1200
  value=build_metrics_html(
1201
  total_count=0,
1202
+ class_counts={},
1203
  cumulative_kg=0,
1204
  live_load_kg=0,
1205
  load_index_percent=0,
1206
  frame_idx=0,
1207
  total_frames=0,
1208
  elapsed=0,
1209
+ proc_fps=0,
1210
+ engine="not started",
1211
  )
1212
  )
1213
 
 
1217
  load_plot = gr.Image(
1218
  show_label=False,
1219
  elem_id="load-plot",
1220
+ height=300,
1221
+ value=make_empty_plot(),
1222
  )
1223
 
1224
  with gr.Row():
 
1226
  with gr.Group(elem_classes="panel"):
1227
  gr.Markdown("### Final annotated video")
1228
  video_output = gr.Video(label="Replay / download annotated video", height=270)
1229
+
1230
  with gr.Column(scale=1):
1231
  with gr.Group(elem_classes="panel"):
1232
  gr.Markdown("### Final summary")
1233
+ summary_output = gr.Markdown("The summary will appear after analysis.")
1234
+ csv_output = gr.File(label="Download CSV")
1235
+
1236
+ inputs = [
1237
+ video_input,
1238
+ engine,
1239
+ yolo_model_file,
1240
+ confidence,
1241
+ frame_stride,
1242
+ inference_width,
1243
+ line_position_percent,
1244
+ roi_top_percent,
1245
+ roi_bottom_percent,
1246
+ reference_capacity_tonnes,
1247
+ person_weight_kg,
1248
+ bicycle_weight_kg,
1249
+ motorcycle_weight_kg,
1250
+ car_weight_t,
1251
+ bus_weight_t,
1252
+ truck_weight_t,
1253
+ cow_weight_kg,
1254
+ sheep_weight_kg,
1255
+ goat_weight_kg,
1256
+ horse_weight_kg,
1257
+ donkey_weight_kg,
1258
+ ]
1259
+
1260
+ outputs = [
1261
+ live_frame,
1262
+ metrics_html,
1263
+ load_plot,
1264
+ summary_output,
1265
+ video_output,
1266
+ csv_output,
1267
+ ]
1268
 
1269
  start_btn.click(
1270
  fn=process_video,
1271
+ inputs=inputs,
1272
+ outputs=outputs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1273
  )
1274
 
1275
+ # Auto-start when a local video exists beside app.py.
1276
+ if DEFAULT_VIDEO:
1277
+ demo.load(
1278
+ fn=process_video,
1279
+ inputs=inputs,
1280
+ outputs=outputs,
1281
+ )
1282
+
1283
 
1284
  if __name__ == "__main__":
1285
+ demo.queue(max_size=2).launch()