Car_Counting / app.py
lyimo's picture
Update app.py
35e66da verified
"""
Fast Bridge Traffic + Livestock Load Demo
"""
import os
import time
import tempfile
import warnings
from pathlib import Path
from functools import lru_cache
from typing import Dict, List, Tuple, Optional
import cv2
import gradio as gr
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import supervision as sv
import torch
# Optional engines
try:
from ultralytics import YOLO
except Exception:
YOLO = None
try:
from rfdetr import RFDETRMedium
except Exception:
RFDETRMedium = None
# ---------------------------------------------------------------------
# Quiet noisy dependency warning that is not controlled by this app.
# The RF-DETR/transformers warning is internal to the dependency stack.
# ---------------------------------------------------------------------
warnings.filterwarnings("ignore", message=".*use_return_dict.*")
warnings.filterwarnings("ignore", message=".*`use_return_dict` is deprecated.*")
# ---------------------------------------------------------------------
# App paths and default local video
# ---------------------------------------------------------------------
APP_DIR = Path(__file__).resolve().parent
VIDEO_EXTENSIONS = [".mp4", ".mov", ".avi", ".mkv", ".webm"]
PREFERRED_VIDEO_NAMES = [
"bridge.mp4",
"traffic.mp4",
"cars.mp4",
"video.mp4",
"input.mp4",
"example.mp4",
"sample.mp4",
]
def find_default_video() -> Optional[str]:
"""Find a video sitting next to app.py."""
for name in PREFERRED_VIDEO_NAMES:
candidate = APP_DIR / name
if candidate.exists():
return str(candidate)
for ext in VIDEO_EXTENSIONS:
matches = sorted(APP_DIR.glob(f"*{ext}"))
if matches:
return str(matches[0])
return None
DEFAULT_VIDEO = find_default_video()
# ---------------------------------------------------------------------
# Device and speed setup
# ---------------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
try:
torch.set_num_threads(max(1, (os.cpu_count() or 2) - 1))
except Exception:
pass
if DEVICE == "cuda":
try:
torch.backends.cudnn.benchmark = True
except Exception:
pass
# ---------------------------------------------------------------------
# Target classes and estimated weights
# ---------------------------------------------------------------------
# For YOLO COCO:
# person=0, bicycle=1, car=2, motorcycle=3, bus=5, truck=7,
# horse=17, sheep=18, cow=19.
#
# COCO does not have goat or donkey. We map:
# sheep -> sheep/goat
# horse -> horse/donkey
TARGET_CANONICAL_NAMES = {
"person",
"bicycle",
"car",
"motorcycle",
"bus",
"truck",
"cow",
"sheep",
"goat",
"horse",
"donkey",
}
DISPLAY_NAME = {
"person": "person",
"bicycle": "bicycle",
"car": "car",
"motorcycle": "motorcycle",
"bus": "bus",
"truck": "truck",
"cow": "cow",
"sheep": "sheep / goat",
"goat": "goat",
"horse": "horse / donkey",
"donkey": "donkey",
}
# COCO class names for RF-DETR outputs.
COCO_NAMES = {
0: "person",
1: "bicycle",
2: "car",
3: "motorcycle",
5: "bus",
7: "truck",
17: "horse",
18: "sheep",
19: "cow",
}
# Approximate demo weights in kg.
# Adjust in the UI for your bridge/traffic context.
DEFAULT_WEIGHTS_KG = {
"person": 75,
"bicycle": 120, # bicycle + rider approximation
"motorcycle": 250,
"car": 1500,
"bus": 12000,
"truck": 18000,
"cow": 450,
"sheep": 60,
"goat": 45,
"horse": 350,
"donkey": 180,
}
COLOR_BY_NAME_BGR = {
"person": (70, 160, 245),
"bicycle": (240, 190, 80),
"motorcycle": (255, 150, 80),
"car": (60, 210, 130),
"bus": (50, 130, 245),
"truck": (220, 70, 180),
"cow": (160, 120, 80),
"sheep": (220, 220, 220),
"goat": (210, 210, 230),
"horse": (130, 90, 60),
"donkey": (120, 110, 95),
}
# ---------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------
@lru_cache(maxsize=4)
def load_yolo_model(model_file: str):
if YOLO is None:
raise RuntimeError(
"Ultralytics is not installed. Run: pip install ultralytics"
)
local_candidate = APP_DIR / model_file
model_path = str(local_candidate) if local_candidate.exists() else model_file
print(f"Loading YOLO model: {model_path} on {DEVICE}")
model = YOLO(model_path)
try:
model.to(DEVICE)
except Exception:
pass
return model
@lru_cache(maxsize=1)
def load_rfdetr_medium():
if RFDETRMedium is None:
raise RuntimeError(
"RF-DETR is not installed. Run: pip install rfdetr"
)
print(f"Loading RF-DETR Medium on {DEVICE}")
try:
model = RFDETRMedium(device=DEVICE)
except TypeError:
model = RFDETRMedium()
# This directly addresses:
# "Model is not optimized for inference. Latency may be higher..."
try:
model.optimize_for_inference()
print("RF-DETR Medium optimized for inference.")
except Exception as exc:
print(f"RF-DETR optimize_for_inference skipped: {exc}")
return model
# ---------------------------------------------------------------------
# Detection conversion
# ---------------------------------------------------------------------
def yolo_predict_to_supervision(
model,
frame_bgr: np.ndarray,
confidence: float,
imgsz: int,
) -> Tuple[sv.Detections, List[str]]:
"""
Run YOLO and return supervision Detections plus canonical class names.
"""
results = model.predict(
source=frame_bgr,
conf=float(confidence),
imgsz=int(imgsz),
device=0 if DEVICE == "cuda" else "cpu",
verbose=False,
)[0]
if results.boxes is None or len(results.boxes) == 0:
return sv.Detections.empty(), []
xyxy = results.boxes.xyxy.detach().cpu().numpy()
conf = results.boxes.conf.detach().cpu().numpy()
cls = results.boxes.cls.detach().cpu().numpy().astype(int)
names = model.names if hasattr(model, "names") else {}
canonical_names = []
keep = []
for i, class_id in enumerate(cls):
name = str(names.get(int(class_id), class_id)).lower().strip()
if name in TARGET_CANONICAL_NAMES:
canonical_names.append(name)
keep.append(i)
elif name == "automobile":
canonical_names.append("car")
keep.append(i)
elif name == "lorry":
canonical_names.append("truck")
keep.append(i)
if not keep:
return sv.Detections.empty(), []
keep = np.array(keep, dtype=int)
detections = sv.Detections(
xyxy=xyxy[keep],
confidence=conf[keep],
class_id=cls[keep],
)
canonical_names = [canonical_names[j] for j in range(len(canonical_names))]
return detections, canonical_names
def rfdetr_predict_to_supervision(
model,
frame_bgr: np.ndarray,
confidence: float,
inference_width: int,
) -> Tuple[sv.Detections, List[str]]:
"""
Run RF-DETR Medium. Resize frame before inference for speed, then scale boxes back.
"""
h, w = frame_bgr.shape[:2]
if inference_width > 0 and w > inference_width:
scale = float(inference_width) / float(w)
resized = cv2.resize(
frame_bgr,
(int(w * scale), int(h * scale)),
interpolation=cv2.INTER_AREA,
)
else:
scale = 1.0
resized = frame_bgr
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
with torch.inference_mode():
detections = model.predict(rgb, threshold=float(confidence))
if len(detections) == 0:
return detections, []
canonical_names = []
keep = []
for i, cid in enumerate(detections.class_id):
cid = int(cid)
name = COCO_NAMES.get(cid)
if name in TARGET_CANONICAL_NAMES:
keep.append(i)
canonical_names.append(name)
if not keep:
return sv.Detections.empty(), []
keep = np.array(keep, dtype=int)
detections = detections[keep]
if scale != 1.0 and len(detections) > 0:
detections.xyxy = detections.xyxy / scale
return detections, canonical_names
def predict_objects(
engine: str,
yolo_model_file: str,
frame_bgr: np.ndarray,
confidence: float,
inference_width: int,
) -> Tuple[sv.Detections, List[str]]:
if engine.startswith("YOLO"):
model = load_yolo_model(yolo_model_file)
return yolo_predict_to_supervision(
model=model,
frame_bgr=frame_bgr,
confidence=confidence,
imgsz=inference_width,
)
model = load_rfdetr_medium()
return rfdetr_predict_to_supervision(
model=model,
frame_bgr=frame_bgr,
confidence=confidence,
inference_width=inference_width,
)
# ---------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------
def side_of_line(y: float, line_y: int, dead_zone_px: int = 5) -> int:
diff = y - line_y
if abs(diff) <= dead_zone_px:
return 0
return -1 if diff < 0 else 1
def detection_centres(detections: sv.Detections) -> np.ndarray:
if len(detections) == 0:
return np.empty((0, 2), dtype=float)
xyxy = detections.xyxy
return np.column_stack([
(xyxy[:, 0] + xyxy[:, 2]) / 2.0,
(xyxy[:, 1] + xyxy[:, 3]) / 2.0,
])
def make_empty_plot() -> np.ndarray:
img = np.ones((300, 620, 3), dtype=np.uint8) * 255
cv2.putText(
img,
"Bridge load index chart will appear here",
(70, 155),
cv2.FONT_HERSHEY_SIMPLEX,
0.75,
(90, 90, 90),
2,
cv2.LINE_AA,
)
return img
def render_load_plot(history: List[Dict]) -> np.ndarray:
if not history:
return make_empty_plot()
df = pd.DataFrame(history)
if len(df) > 600:
df = df.iloc[np.linspace(0, len(df) - 1, 600).astype(int)]
fig, ax = plt.subplots(figsize=(8.0, 3.5), dpi=100)
ax.plot(df["time_s"], df["load_index_percent"], linewidth=2)
ax.set_title("Estimated Bridge Load Index Over Time")
ax.set_xlabel("Video time (seconds)")
ax.set_ylabel("Load index (%)")
ax.grid(True, alpha=0.25)
ax.set_ylim(bottom=0)
fig.tight_layout()
fig.canvas.draw()
rgba = np.asarray(fig.canvas.buffer_rgba())
rgb = cv2.cvtColor(rgba, cv2.COLOR_RGBA2RGB)
plt.close(fig)
return rgb
def build_metrics_html(
total_count: int,
class_counts: Dict[str, int],
cumulative_kg: float,
live_load_kg: float,
load_index_percent: float,
frame_idx: int,
total_frames: int,
elapsed: float,
proc_fps: float,
engine: str,
) -> str:
pct = (frame_idx / total_frames * 100.0) if total_frames else 0.0
tonnes = cumulative_kg / 1000.0
live_tonnes = live_load_kg / 1000.0
def c(name: str) -> int:
return int(class_counts.get(name, 0))
return f"""
<div style="font-family:Inter,system-ui,Arial;">
<div style="display:grid;grid-template-columns:1fr 1fr;gap:10px;margin-bottom:12px;">
<div style="padding:16px;border-radius:18px;background:linear-gradient(135deg,#1d4ed8,#312e81);color:white;">
<div style="font-size:11px;letter-spacing:1px;opacity:.86;">OBJECTS CROSSED</div>
<div style="font-size:46px;font-weight:850;line-height:1;">{total_count}</div>
</div>
<div style="padding:16px;border-radius:18px;background:linear-gradient(135deg,#be185d,#7e22ce);color:white;">
<div style="font-size:11px;letter-spacing:1px;opacity:.86;">CUMULATIVE EST. MASS</div>
<div style="font-size:36px;font-weight:850;line-height:1;">{tonnes:.1f} t</div>
</div>
</div>
<div style="display:grid;grid-template-columns:1fr 1fr;gap:10px;margin-bottom:12px;">
<div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:white;">
<div style="font-size:12px;color:#6b7280;">Live bridge load</div>
<div style="font-size:28px;font-weight:800;color:#111827;">{live_tonnes:.1f} t</div>
</div>
<div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:white;">
<div style="font-size:12px;color:#6b7280;">Load index</div>
<div style="font-size:28px;font-weight:800;color:#111827;">{load_index_percent:.1f}%</div>
</div>
</div>
<div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:#ffffff;margin-bottom:12px;">
<div style="font-size:12px;color:#6b7280;margin-bottom:8px;">Crossings by class</div>
<div style="display:grid;grid-template-columns:1fr 1fr;gap:7px;font-size:13px;">
<div>🚶 People: <b>{c("person")}</b></div>
<div>🚗 Cars: <b>{c("car")}</b></div>
<div>🏍️ Motorcycles: <b>{c("motorcycle")}</b></div>
<div>🚲 Bicycles: <b>{c("bicycle")}</b></div>
<div>🚌 Buses: <b>{c("bus")}</b></div>
<div>🚛 Trucks: <b>{c("truck")}</b></div>
<div>🐄 Cows: <b>{c("cow")}</b></div>
<div>🐑 Sheep/goats: <b>{c("sheep") + c("goat")}</b></div>
<div>🐴 Horse/donkey: <b>{c("horse") + c("donkey")}</b></div>
</div>
</div>
<div style="font-size:12px;color:#6b7280;margin-bottom:4px;display:flex;justify-content:space-between;">
<span>Frame {frame_idx} / {total_frames}</span>
<span>{pct:.1f}% · {elapsed:.1f}s · {proc_fps:.1f} FPS · {DEVICE} · {engine}</span>
</div>
<div style="height:8px;background:#e5e7eb;border-radius:999px;overflow:hidden;">
<div style="height:100%;width:{pct:.2f}%;background:#4f46e5;"></div>
</div>
</div>
"""
def draw_dashboard(
frame: np.ndarray,
total_count: int,
cumulative_kg: float,
live_load_kg: float,
load_index_percent: float,
proc_fps: float,
engine: str,
) -> np.ndarray:
overlay = frame.copy()
x1, y1, x2, y2 = 18, 18, 600, 164
cv2.rectangle(overlay, (x1, y1), (x2, y2), (18, 24, 38), -1)
frame = cv2.addWeighted(overlay, 0.82, frame, 0.18, 0)
cv2.putText(
frame,
"BRIDGE TRAFFIC + LIVESTOCK DEMO",
(34, 48),
cv2.FONT_HERSHEY_SIMPLEX,
0.72,
(255, 255, 255),
2,
cv2.LINE_AA,
)
cv2.putText(
frame,
f"Crossed: {total_count} | Cumulative est. mass: {cumulative_kg/1000.0:.1f} t",
(34, 82),
cv2.FONT_HERSHEY_SIMPLEX,
0.58,
(230, 240, 255),
2,
cv2.LINE_AA,
)
cv2.putText(
frame,
f"Live load: {live_load_kg/1000.0:.1f} t | Load index: {load_index_percent:.1f}%",
(34, 114),
cv2.FONT_HERSHEY_SIMPLEX,
0.58,
(220, 245, 230),
2,
cv2.LINE_AA,
)
cv2.putText(
frame,
f"{proc_fps:.1f} processing FPS | {DEVICE} | {engine}",
(34, 144),
cv2.FONT_HERSHEY_SIMPLEX,
0.52,
(230, 230, 255),
1,
cv2.LINE_AA,
)
return frame
def annotate_frame(
frame: np.ndarray,
detections: sv.Detections,
canonical_names: List[str],
line_y: int,
roi_top_y: int,
roi_bottom_y: int,
class_counts: Dict[str, int],
total_count: int,
cumulative_kg: float,
live_load_kg: float,
load_index_percent: float,
proc_fps: float,
engine: str,
) -> np.ndarray:
h, w = frame.shape[:2]
# Bridge deck ROI.
overlay = frame.copy()
cv2.rectangle(overlay, (0, roi_top_y), (w, roi_bottom_y), (90, 90, 90), -1)
frame = cv2.addWeighted(overlay, 0.08, frame, 0.92, 0)
# Counting line.
cv2.line(frame, (0, line_y), (w, line_y), (40, 230, 255), 3)
cv2.putText(
frame,
"COUNTING LINE",
(24, max(28, line_y - 12)),
cv2.FONT_HERSHEY_SIMPLEX,
0.60,
(40, 230, 255),
2,
cv2.LINE_AA,
)
# ROI borders.
cv2.line(frame, (0, roi_top_y), (w, roi_top_y), (170, 170, 170), 1)
cv2.line(frame, (0, roi_bottom_y), (w, roi_bottom_y), (170, 170, 170), 1)
if len(detections) > 0:
tracker_ids = detections.tracker_id
if tracker_ids is None:
tracker_ids = [None] * len(detections)
confidences = detections.confidence
if confidences is None:
confidences = [0.0] * len(detections)
for i, (xyxy, conf, tid) in enumerate(zip(detections.xyxy, confidences, tracker_ids)):
if i >= len(canonical_names):
name = "object"
else:
name = canonical_names[i]
x1, y1, x2, y2 = map(int, xyxy)
color = COLOR_BY_NAME_BGR.get(name, (80, 220, 255))
display = DISPLAY_NAME.get(name, name)
weight_t = DEFAULT_WEIGHTS_KG.get(name, 0) / 1000.0
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
id_txt = f"#{int(tid)} " if tid is not None and int(tid) >= 0 else ""
label = f"{id_txt}{display} {float(conf):.2f} ~{weight_t:.2f}t"
(tw, th), base = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.52, 1)
label_y1 = max(0, y1 - th - base - 8)
cv2.rectangle(frame, (x1, label_y1), (x1 + tw + 10, y1), color, -1)
cv2.putText(
frame,
label,
(x1 + 5, y1 - 6),
cv2.FONT_HERSHEY_SIMPLEX,
0.52,
(255, 255, 255),
1,
cv2.LINE_AA,
)
frame = draw_dashboard(
frame=frame,
total_count=total_count,
cumulative_kg=cumulative_kg,
live_load_kg=live_load_kg,
load_index_percent=load_index_percent,
proc_fps=proc_fps,
engine=engine,
)
compact_items = []
for k in ["person", "car", "motorcycle", "bicycle", "bus", "truck", "cow", "sheep", "goat", "horse", "donkey"]:
v = int(class_counts.get(k, 0))
if v > 0:
compact_items.append(f"{DISPLAY_NAME.get(k, k)}: {v}")
text = " | ".join(compact_items) if compact_items else "No crossings yet"
cv2.putText(frame, text[:140], (22, h - 24), cv2.FONT_HERSHEY_SIMPLEX, 0.58, (255, 255, 255), 2, cv2.LINE_AA)
return frame
def final_summary_md(
total_count: int,
class_counts: Dict[str, int],
cumulative_kg: float,
peak_live_load_kg: float,
peak_load_index: float,
auto_video_used: str,
) -> str:
rows = []
for name in ["person", "bicycle", "car", "motorcycle", "bus", "truck", "cow", "sheep", "goat", "horse", "donkey"]:
count = int(class_counts.get(name, 0))
if count > 0:
rows.append(f"| {DISPLAY_NAME.get(name, name)} | {count} |")
if not rows:
rows.append("| None | 0 |")
video_line = f"\n**Default video used:** `{auto_video_used}`\n" if auto_video_used else ""
return f"""
### Final summary
{video_line}
**Total crossings:** {total_count}
| Class | Count |
|---|---:|
{chr(10).join(rows)}
**Cumulative estimated mass:** {cumulative_kg/1000.0:.2f} tonnes
**Peak estimated live load:** {peak_live_load_kg/1000.0:.2f} tonnes
**Peak bridge load index:** {peak_load_index:.1f}%
This is a demonstration traffic-load indicator. Real bridge stress needs axle loads, bridge geometry, material properties, span length, lane position and engineering calibration.
"""
# ---------------------------------------------------------------------
# Main video processing generator
# ---------------------------------------------------------------------
def process_video(
video_path,
engine,
yolo_model_file,
confidence,
frame_stride,
inference_width,
line_position_percent,
roi_top_percent,
roi_bottom_percent,
reference_capacity_tonnes,
person_weight_kg,
bicycle_weight_kg,
motorcycle_weight_kg,
car_weight_t,
bus_weight_t,
truck_weight_t,
cow_weight_kg,
sheep_weight_kg,
goat_weight_kg,
horse_weight_kg,
donkey_weight_kg,
):
if video_path is None:
yield (
None,
build_metrics_html(0, {}, 0, 0, 0, 0, 0, 0, 0, str(engine)),
make_empty_plot(),
"No video found. Put an `.mp4` file in the same folder as `app.py`, or upload one.",
None,
None,
)
return
# Gradio can pass a dict in some versions.
if isinstance(video_path, dict):
video_path = video_path.get("path") or video_path.get("name")
if not video_path or not os.path.exists(video_path):
yield (
None,
build_metrics_html(0, {}, 0, 0, 0, 0, 0, 0, 0, str(engine)),
make_empty_plot(),
f"Video not found: {video_path}",
None,
None,
)
return
DEFAULT_WEIGHTS_KG.update({
"person": int(person_weight_kg),
"bicycle": int(bicycle_weight_kg),
"motorcycle": int(motorcycle_weight_kg),
"car": int(float(car_weight_t) * 1000),
"bus": int(float(bus_weight_t) * 1000),
"truck": int(float(truck_weight_t) * 1000),
"cow": int(cow_weight_kg),
"sheep": int(sheep_weight_kg),
"goat": int(goat_weight_kg),
"horse": int(horse_weight_kg),
"donkey": int(donkey_weight_kg),
})
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"Could not open video: {video_path}")
fps = float(cap.get(cv2.CAP_PROP_FPS) or 25.0)
if fps <= 1:
fps = 25.0
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0)
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0)
if width <= 0 or height <= 0:
cap.release()
raise RuntimeError("Could not read video dimensions.")
line_y = int(height * float(line_position_percent) / 100.0)
roi_top_y = int(height * float(roi_top_percent) / 100.0)
roi_bottom_y = int(height * float(roi_bottom_percent) / 100.0)
if roi_bottom_y <= roi_top_y:
roi_top_y = int(height * 0.20)
roi_bottom_y = int(height * 0.90)
reference_capacity_kg = max(1.0, float(reference_capacity_tonnes) * 1000.0)
yield (
None,
build_metrics_html(0, {}, 0, 0, 0, 0, total_frames, 0, 0, str(engine)),
make_empty_plot(),
f"### Starting analysis on `{Path(video_path).name}`...",
None,
None,
)
# Preload model before loop.
if str(engine).startswith("YOLO"):
_ = load_yolo_model(str(yolo_model_file))
else:
_ = load_rfdetr_medium()
tracker = sv.ByteTrack(frame_rate=int(round(fps)))
out_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
out_csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name
writer = cv2.VideoWriter(
out_video_path,
cv2.VideoWriter_fourcc(*"mp4v"),
fps,
(width, height),
)
last_detections = sv.Detections.empty()
last_names: List[str] = []
last_side_by_id: Dict[int, int] = {}
counted_ids = set()
track_name_by_id: Dict[int, str] = {}
class_counts = {name: 0 for name in TARGET_CANONICAL_NAMES}
total_count = 0
cumulative_kg = 0.0
history: List[Dict] = []
events: List[Dict] = []
peak_live_load_kg = 0.0
peak_load_index = 0.0
start_wall = time.time()
last_yield_wall = 0.0
last_plot_wall = 0.0
latest_plot = make_empty_plot()
processed = 0
frame_idx = 0
final_frame_rgb = None
while True:
ok, frame = cap.read()
if not ok:
break
if frame_idx % int(frame_stride) == 0:
detections, names = predict_objects(
engine=str(engine),
yolo_model_file=str(yolo_model_file),
frame_bgr=frame,
confidence=float(confidence),
inference_width=int(inference_width),
)
detections = tracker.update_with_detections(detections)
# Preserve name alignment after tracker update.
# ByteTrack keeps detections order, so this is usually aligned.
if len(names) != len(detections):
names = names[:len(detections)]
if len(names) < len(detections):
names += ["object"] * (len(detections) - len(names))
last_detections = detections
last_names = names
else:
detections = last_detections
names = last_names
centres = detection_centres(detections)
live_load_kg = 0.0
if len(detections) > 0 and detections.tracker_id is not None:
for i, (centre, tid) in enumerate(zip(centres, detections.tracker_id)):
if tid is None or int(tid) < 0:
continue
tid = int(tid)
name = names[i] if i < len(names) else track_name_by_id.get(tid, "object")
if name == "object":
continue
track_name_by_id[tid] = name
cy = float(centre[1])
# Live load only for objects currently inside bridge deck ROI.
if roi_top_y <= cy <= roi_bottom_y:
live_load_kg += float(DEFAULT_WEIGHTS_KG.get(name, 0))
current_side = side_of_line(cy, line_y)
previous_side = last_side_by_id.get(tid)
if current_side != 0:
if previous_side is not None and previous_side != 0 and previous_side != current_side:
if tid not in counted_ids:
counted_ids.add(tid)
total_count += 1
class_counts[name] = int(class_counts.get(name, 0)) + 1
weight_kg = float(DEFAULT_WEIGHTS_KG.get(name, 0))
cumulative_kg += weight_kg
direction = "down" if previous_side < current_side else "up"
events.append({
"video_time_s": frame_idx / fps,
"frame": frame_idx,
"tracker_id": tid,
"object_type": name,
"display_type": DISPLAY_NAME.get(name, name),
"direction": direction,
"estimated_weight_kg": weight_kg,
"cumulative_estimated_mass_kg": cumulative_kg,
})
last_side_by_id[tid] = current_side
load_index_percent = (live_load_kg / reference_capacity_kg) * 100.0
peak_live_load_kg = max(peak_live_load_kg, live_load_kg)
peak_load_index = max(peak_load_index, load_index_percent)
elapsed = time.time() - start_wall
processed += 1
proc_fps = processed / max(elapsed, 1e-6)
history.append({
"time_s": frame_idx / fps,
"frame": frame_idx,
"total_crossings": total_count,
"people_crossed": class_counts.get("person", 0),
"bicycles_crossed": class_counts.get("bicycle", 0),
"cars_crossed": class_counts.get("car", 0),
"motorcycles_crossed": class_counts.get("motorcycle", 0),
"buses_crossed": class_counts.get("bus", 0),
"trucks_crossed": class_counts.get("truck", 0),
"cows_crossed": class_counts.get("cow", 0),
"sheep_goats_crossed": class_counts.get("sheep", 0) + class_counts.get("goat", 0),
"horse_donkey_crossed": class_counts.get("horse", 0) + class_counts.get("donkey", 0),
"live_load_kg": live_load_kg,
"live_load_tonnes": live_load_kg / 1000.0,
"load_index_percent": load_index_percent,
"cumulative_estimated_mass_kg": cumulative_kg,
"cumulative_estimated_mass_tonnes": cumulative_kg / 1000.0,
})
annotated = annotate_frame(
frame=frame,
detections=detections,
canonical_names=names,
line_y=line_y,
roi_top_y=roi_top_y,
roi_bottom_y=roi_bottom_y,
class_counts=class_counts,
total_count=total_count,
cumulative_kg=cumulative_kg,
live_load_kg=live_load_kg,
load_index_percent=load_index_percent,
proc_fps=proc_fps,
engine=str(engine),
)
writer.write(annotated)
final_frame_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
now = time.time()
if now - last_plot_wall >= 1.0:
latest_plot = render_load_plot(history)
last_plot_wall = now
if now - last_yield_wall >= 0.35:
last_yield_wall = now
yield (
final_frame_rgb,
build_metrics_html(
total_count=total_count,
class_counts=class_counts,
cumulative_kg=cumulative_kg,
live_load_kg=live_load_kg,
load_index_percent=load_index_percent,
frame_idx=frame_idx + 1,
total_frames=total_frames,
elapsed=elapsed,
proc_fps=proc_fps,
engine=str(engine),
),
latest_plot,
"### Live analysis running...",
None,
None,
)
frame_idx += 1
cap.release()
writer.release()
history_df = pd.DataFrame(history)
events_df = pd.DataFrame(events)
if not events_df.empty:
# Save both frame-level history and crossing events in one CSV-like file
# by writing two separate CSV sections.
with open(out_csv_path, "w", encoding="utf-8") as f:
f.write("# FRAME_LEVEL_LOAD_INDEX\n")
history_df.to_csv(f, index=False)
f.write("\n# CROSSING_EVENTS\n")
events_df.to_csv(f, index=False)
else:
history_df.to_csv(out_csv_path, index=False)
elapsed = time.time() - start_wall
proc_fps = processed / max(elapsed, 1e-6)
final_plot = render_load_plot(history)
yield (
final_frame_rgb,
build_metrics_html(
total_count=total_count,
class_counts=class_counts,
cumulative_kg=cumulative_kg,
live_load_kg=0,
load_index_percent=0,
frame_idx=total_frames if total_frames else frame_idx,
total_frames=total_frames if total_frames else frame_idx,
elapsed=elapsed,
proc_fps=proc_fps,
engine=str(engine),
),
final_plot,
final_summary_md(
total_count=total_count,
class_counts=class_counts,
cumulative_kg=cumulative_kg,
peak_live_load_kg=peak_live_load_kg,
peak_load_index=peak_load_index,
auto_video_used=video_path if str(video_path).startswith(str(APP_DIR)) else "",
),
out_video_path,
out_csv_path,
)
# ---------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------
CUSTOM_CSS = """
.gradio-container {
max-width: 1360px !important;
margin: auto !important;
}
#hero {
text-align: center;
padding: 16px 8px 6px 8px;
}
#hero h1 {
font-weight: 850;
letter-spacing: -0.8px;
margin-bottom: 2px;
}
#hero p {
color: #64748b;
font-size: 16px;
margin-top: 0;
}
.panel {
border: 1px solid #e5e7eb;
border-radius: 18px;
padding: 16px;
background: #ffffff;
box-shadow: 0 8px 24px rgba(15, 23, 42, 0.045);
}
#live-frame img, #load-plot img {
border-radius: 14px;
}
footer {
visibility: hidden;
}
"""
with gr.Blocks(
title="Fast Bridge Traffic + Livestock Load Demo",
theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
css=CUSTOM_CSS,
) as demo:
with gr.Row(elem_id="hero"):
gr.Markdown(
"""
# 🌉 Fast Bridge Traffic + Livestock Load Demo
YOLO-small / RF-DETR Medium detection, ByteTrack tracking, line-crossing counts,
estimated object weights, and live bridge load-index over time.
"""
)
if DEFAULT_VIDEO:
gr.Markdown(f"✅ Found default video next to `app.py`: `{Path(DEFAULT_VIDEO).name}`. The app will auto-start inference when opened.")
else:
gr.Markdown("⚠️ No local video found next to `app.py`. Upload a video or place `bridge.mp4`, `traffic.mp4`, `input.mp4`, or any `.mp4` in the same folder.")
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="panel"):
gr.Markdown("### 1) Video")
video_input = gr.Video(
label="Video input",
sources=["upload"],
value=DEFAULT_VIDEO,
format="mp4",
height=260,
)
start_btn = gr.Button("▶ Start / rerun analysis", variant="primary", size="lg")
gr.Markdown("### 2) Inference engine")
engine = gr.Radio(
choices=[
"YOLO small - fastest recommended",
"RF-DETR Medium - slower but strong",
],
value="YOLO small - fastest recommended",
label="Engine",
)
yolo_model_file = gr.Textbox(
value="yolo11s.pt",
label="YOLO model file/name",
info="Use yolo11s.pt for small. Put your custom .pt in the same folder as app.py and type its filename here.",
)
confidence = gr.Slider(
minimum=0.10,
maximum=0.90,
value=0.35,
step=0.05,
label="Confidence threshold",
)
frame_stride = gr.Slider(
minimum=1,
maximum=12,
value=3,
step=1,
label="Frame stride",
info="Detect every Nth frame. 3-5 is much faster than every frame.",
)
inference_width = gr.Slider(
minimum=384,
maximum=1280,
value=640,
step=64,
label="Inference image size / width",
info="Lower is faster. Try 512 or 640 for fast demos.",
)
with gr.Accordion("Bridge settings", open=False):
line_position_percent = gr.Slider(
minimum=10,
maximum=90,
value=55,
step=1,
label="Counting line vertical position (%)",
)
roi_top_percent = gr.Slider(
minimum=0,
maximum=90,
value=20,
step=1,
label="Bridge deck ROI top (%)",
)
roi_bottom_percent = gr.Slider(
minimum=10,
maximum=100,
value=90,
step=1,
label="Bridge deck ROI bottom (%)",
)
reference_capacity_tonnes = gr.Slider(
minimum=1,
maximum=250,
value=40,
step=1,
label="Reference live-load capacity for demo index (tonnes)",
)
with gr.Accordion("Estimated weights", open=False):
person_weight_kg = gr.Number(value=75, label="Person weight estimate (kg)")
bicycle_weight_kg = gr.Number(value=120, label="Bicycle + rider estimate (kg)")
motorcycle_weight_kg = gr.Number(value=250, label="Motorcycle estimate (kg)")
car_weight_t = gr.Number(value=1.5, label="Car estimate (tonnes)")
bus_weight_t = gr.Number(value=12.0, label="Bus estimate (tonnes)")
truck_weight_t = gr.Number(value=18.0, label="Truck estimate (tonnes)")
cow_weight_kg = gr.Number(value=450, label="Cow estimate (kg)")
sheep_weight_kg = gr.Number(value=60, label="Sheep estimate (kg)")
goat_weight_kg = gr.Number(value=45, label="Goat estimate (kg)")
horse_weight_kg = gr.Number(value=350, label="Horse estimate (kg)")
donkey_weight_kg = gr.Number(value=180, label="Donkey estimate (kg)")
gr.Markdown(
"""
**Fast demo settings:** YOLO small, confidence 0.30-0.40,
frame stride 3-5, image size 512-640.
"""
)
with gr.Column(scale=2):
with gr.Group(elem_classes="panel"):
gr.Markdown("### Live annotated video")
live_frame = gr.Image(
show_label=False,
elem_id="live-frame",
height=500,
)
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="panel"):
gr.Markdown("### Live metrics")
metrics_html = gr.HTML(
value=build_metrics_html(
total_count=0,
class_counts={},
cumulative_kg=0,
live_load_kg=0,
load_index_percent=0,
frame_idx=0,
total_frames=0,
elapsed=0,
proc_fps=0,
engine="not started",
)
)
with gr.Column(scale=1):
with gr.Group(elem_classes="panel"):
gr.Markdown("### Load index over time")
load_plot = gr.Image(
show_label=False,
elem_id="load-plot",
height=300,
value=make_empty_plot(),
)
with gr.Row():
with gr.Column(scale=1):
with gr.Group(elem_classes="panel"):
gr.Markdown("### Final annotated video")
video_output = gr.Video(label="Replay / download annotated video", height=270)
with gr.Column(scale=1):
with gr.Group(elem_classes="panel"):
gr.Markdown("### Final summary")
summary_output = gr.Markdown("The summary will appear after analysis.")
csv_output = gr.File(label="Download CSV")
inputs = [
video_input,
engine,
yolo_model_file,
confidence,
frame_stride,
inference_width,
line_position_percent,
roi_top_percent,
roi_bottom_percent,
reference_capacity_tonnes,
person_weight_kg,
bicycle_weight_kg,
motorcycle_weight_kg,
car_weight_t,
bus_weight_t,
truck_weight_t,
cow_weight_kg,
sheep_weight_kg,
goat_weight_kg,
horse_weight_kg,
donkey_weight_kg,
]
outputs = [
live_frame,
metrics_html,
load_plot,
summary_output,
video_output,
csv_output,
]
start_btn.click(
fn=process_video,
inputs=inputs,
outputs=outputs,
)
# Auto-start when a local video exists beside app.py.
if DEFAULT_VIDEO:
demo.load(
fn=process_video,
inputs=inputs,
outputs=outputs,
)
if __name__ == "__main__":
demo.queue(max_size=2).launch()