Spaces:
Sleeping
Sleeping
| """ | |
| Fast Bridge Traffic + Livestock Load Demo | |
| """ | |
| import os | |
| import time | |
| import tempfile | |
| import warnings | |
| from pathlib import Path | |
| from functools import lru_cache | |
| from typing import Dict, List, Tuple, Optional | |
| import cv2 | |
| import gradio as gr | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import supervision as sv | |
| import torch | |
| # Optional engines | |
| try: | |
| from ultralytics import YOLO | |
| except Exception: | |
| YOLO = None | |
| try: | |
| from rfdetr import RFDETRMedium | |
| except Exception: | |
| RFDETRMedium = None | |
| # --------------------------------------------------------------------- | |
| # Quiet noisy dependency warning that is not controlled by this app. | |
| # The RF-DETR/transformers warning is internal to the dependency stack. | |
| # --------------------------------------------------------------------- | |
| warnings.filterwarnings("ignore", message=".*use_return_dict.*") | |
| warnings.filterwarnings("ignore", message=".*`use_return_dict` is deprecated.*") | |
| # --------------------------------------------------------------------- | |
| # App paths and default local video | |
| # --------------------------------------------------------------------- | |
| APP_DIR = Path(__file__).resolve().parent | |
| VIDEO_EXTENSIONS = [".mp4", ".mov", ".avi", ".mkv", ".webm"] | |
| PREFERRED_VIDEO_NAMES = [ | |
| "bridge.mp4", | |
| "traffic.mp4", | |
| "cars.mp4", | |
| "video.mp4", | |
| "input.mp4", | |
| "example.mp4", | |
| "sample.mp4", | |
| ] | |
| def find_default_video() -> Optional[str]: | |
| """Find a video sitting next to app.py.""" | |
| for name in PREFERRED_VIDEO_NAMES: | |
| candidate = APP_DIR / name | |
| if candidate.exists(): | |
| return str(candidate) | |
| for ext in VIDEO_EXTENSIONS: | |
| matches = sorted(APP_DIR.glob(f"*{ext}")) | |
| if matches: | |
| return str(matches[0]) | |
| return None | |
| DEFAULT_VIDEO = find_default_video() | |
| # --------------------------------------------------------------------- | |
| # Device and speed setup | |
| # --------------------------------------------------------------------- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| torch.set_num_threads(max(1, (os.cpu_count() or 2) - 1)) | |
| except Exception: | |
| pass | |
| if DEVICE == "cuda": | |
| try: | |
| torch.backends.cudnn.benchmark = True | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------- | |
| # Target classes and estimated weights | |
| # --------------------------------------------------------------------- | |
| # For YOLO COCO: | |
| # person=0, bicycle=1, car=2, motorcycle=3, bus=5, truck=7, | |
| # horse=17, sheep=18, cow=19. | |
| # | |
| # COCO does not have goat or donkey. We map: | |
| # sheep -> sheep/goat | |
| # horse -> horse/donkey | |
| TARGET_CANONICAL_NAMES = { | |
| "person", | |
| "bicycle", | |
| "car", | |
| "motorcycle", | |
| "bus", | |
| "truck", | |
| "cow", | |
| "sheep", | |
| "goat", | |
| "horse", | |
| "donkey", | |
| } | |
| DISPLAY_NAME = { | |
| "person": "person", | |
| "bicycle": "bicycle", | |
| "car": "car", | |
| "motorcycle": "motorcycle", | |
| "bus": "bus", | |
| "truck": "truck", | |
| "cow": "cow", | |
| "sheep": "sheep / goat", | |
| "goat": "goat", | |
| "horse": "horse / donkey", | |
| "donkey": "donkey", | |
| } | |
| # COCO class names for RF-DETR outputs. | |
| COCO_NAMES = { | |
| 0: "person", | |
| 1: "bicycle", | |
| 2: "car", | |
| 3: "motorcycle", | |
| 5: "bus", | |
| 7: "truck", | |
| 17: "horse", | |
| 18: "sheep", | |
| 19: "cow", | |
| } | |
| # Approximate demo weights in kg. | |
| # Adjust in the UI for your bridge/traffic context. | |
| DEFAULT_WEIGHTS_KG = { | |
| "person": 75, | |
| "bicycle": 120, # bicycle + rider approximation | |
| "motorcycle": 250, | |
| "car": 1500, | |
| "bus": 12000, | |
| "truck": 18000, | |
| "cow": 450, | |
| "sheep": 60, | |
| "goat": 45, | |
| "horse": 350, | |
| "donkey": 180, | |
| } | |
| COLOR_BY_NAME_BGR = { | |
| "person": (70, 160, 245), | |
| "bicycle": (240, 190, 80), | |
| "motorcycle": (255, 150, 80), | |
| "car": (60, 210, 130), | |
| "bus": (50, 130, 245), | |
| "truck": (220, 70, 180), | |
| "cow": (160, 120, 80), | |
| "sheep": (220, 220, 220), | |
| "goat": (210, 210, 230), | |
| "horse": (130, 90, 60), | |
| "donkey": (120, 110, 95), | |
| } | |
| # --------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------- | |
| def load_yolo_model(model_file: str): | |
| if YOLO is None: | |
| raise RuntimeError( | |
| "Ultralytics is not installed. Run: pip install ultralytics" | |
| ) | |
| local_candidate = APP_DIR / model_file | |
| model_path = str(local_candidate) if local_candidate.exists() else model_file | |
| print(f"Loading YOLO model: {model_path} on {DEVICE}") | |
| model = YOLO(model_path) | |
| try: | |
| model.to(DEVICE) | |
| except Exception: | |
| pass | |
| return model | |
| def load_rfdetr_medium(): | |
| if RFDETRMedium is None: | |
| raise RuntimeError( | |
| "RF-DETR is not installed. Run: pip install rfdetr" | |
| ) | |
| print(f"Loading RF-DETR Medium on {DEVICE}") | |
| try: | |
| model = RFDETRMedium(device=DEVICE) | |
| except TypeError: | |
| model = RFDETRMedium() | |
| # This directly addresses: | |
| # "Model is not optimized for inference. Latency may be higher..." | |
| try: | |
| model.optimize_for_inference() | |
| print("RF-DETR Medium optimized for inference.") | |
| except Exception as exc: | |
| print(f"RF-DETR optimize_for_inference skipped: {exc}") | |
| return model | |
| # --------------------------------------------------------------------- | |
| # Detection conversion | |
| # --------------------------------------------------------------------- | |
| def yolo_predict_to_supervision( | |
| model, | |
| frame_bgr: np.ndarray, | |
| confidence: float, | |
| imgsz: int, | |
| ) -> Tuple[sv.Detections, List[str]]: | |
| """ | |
| Run YOLO and return supervision Detections plus canonical class names. | |
| """ | |
| results = model.predict( | |
| source=frame_bgr, | |
| conf=float(confidence), | |
| imgsz=int(imgsz), | |
| device=0 if DEVICE == "cuda" else "cpu", | |
| verbose=False, | |
| )[0] | |
| if results.boxes is None or len(results.boxes) == 0: | |
| return sv.Detections.empty(), [] | |
| xyxy = results.boxes.xyxy.detach().cpu().numpy() | |
| conf = results.boxes.conf.detach().cpu().numpy() | |
| cls = results.boxes.cls.detach().cpu().numpy().astype(int) | |
| names = model.names if hasattr(model, "names") else {} | |
| canonical_names = [] | |
| keep = [] | |
| for i, class_id in enumerate(cls): | |
| name = str(names.get(int(class_id), class_id)).lower().strip() | |
| if name in TARGET_CANONICAL_NAMES: | |
| canonical_names.append(name) | |
| keep.append(i) | |
| elif name == "automobile": | |
| canonical_names.append("car") | |
| keep.append(i) | |
| elif name == "lorry": | |
| canonical_names.append("truck") | |
| keep.append(i) | |
| if not keep: | |
| return sv.Detections.empty(), [] | |
| keep = np.array(keep, dtype=int) | |
| detections = sv.Detections( | |
| xyxy=xyxy[keep], | |
| confidence=conf[keep], | |
| class_id=cls[keep], | |
| ) | |
| canonical_names = [canonical_names[j] for j in range(len(canonical_names))] | |
| return detections, canonical_names | |
| def rfdetr_predict_to_supervision( | |
| model, | |
| frame_bgr: np.ndarray, | |
| confidence: float, | |
| inference_width: int, | |
| ) -> Tuple[sv.Detections, List[str]]: | |
| """ | |
| Run RF-DETR Medium. Resize frame before inference for speed, then scale boxes back. | |
| """ | |
| h, w = frame_bgr.shape[:2] | |
| if inference_width > 0 and w > inference_width: | |
| scale = float(inference_width) / float(w) | |
| resized = cv2.resize( | |
| frame_bgr, | |
| (int(w * scale), int(h * scale)), | |
| interpolation=cv2.INTER_AREA, | |
| ) | |
| else: | |
| scale = 1.0 | |
| resized = frame_bgr | |
| rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) | |
| with torch.inference_mode(): | |
| detections = model.predict(rgb, threshold=float(confidence)) | |
| if len(detections) == 0: | |
| return detections, [] | |
| canonical_names = [] | |
| keep = [] | |
| for i, cid in enumerate(detections.class_id): | |
| cid = int(cid) | |
| name = COCO_NAMES.get(cid) | |
| if name in TARGET_CANONICAL_NAMES: | |
| keep.append(i) | |
| canonical_names.append(name) | |
| if not keep: | |
| return sv.Detections.empty(), [] | |
| keep = np.array(keep, dtype=int) | |
| detections = detections[keep] | |
| if scale != 1.0 and len(detections) > 0: | |
| detections.xyxy = detections.xyxy / scale | |
| return detections, canonical_names | |
| def predict_objects( | |
| engine: str, | |
| yolo_model_file: str, | |
| frame_bgr: np.ndarray, | |
| confidence: float, | |
| inference_width: int, | |
| ) -> Tuple[sv.Detections, List[str]]: | |
| if engine.startswith("YOLO"): | |
| model = load_yolo_model(yolo_model_file) | |
| return yolo_predict_to_supervision( | |
| model=model, | |
| frame_bgr=frame_bgr, | |
| confidence=confidence, | |
| imgsz=inference_width, | |
| ) | |
| model = load_rfdetr_medium() | |
| return rfdetr_predict_to_supervision( | |
| model=model, | |
| frame_bgr=frame_bgr, | |
| confidence=confidence, | |
| inference_width=inference_width, | |
| ) | |
| # --------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------- | |
| def side_of_line(y: float, line_y: int, dead_zone_px: int = 5) -> int: | |
| diff = y - line_y | |
| if abs(diff) <= dead_zone_px: | |
| return 0 | |
| return -1 if diff < 0 else 1 | |
| def detection_centres(detections: sv.Detections) -> np.ndarray: | |
| if len(detections) == 0: | |
| return np.empty((0, 2), dtype=float) | |
| xyxy = detections.xyxy | |
| return np.column_stack([ | |
| (xyxy[:, 0] + xyxy[:, 2]) / 2.0, | |
| (xyxy[:, 1] + xyxy[:, 3]) / 2.0, | |
| ]) | |
| def make_empty_plot() -> np.ndarray: | |
| img = np.ones((300, 620, 3), dtype=np.uint8) * 255 | |
| cv2.putText( | |
| img, | |
| "Bridge load index chart will appear here", | |
| (70, 155), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.75, | |
| (90, 90, 90), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| return img | |
| def render_load_plot(history: List[Dict]) -> np.ndarray: | |
| if not history: | |
| return make_empty_plot() | |
| df = pd.DataFrame(history) | |
| if len(df) > 600: | |
| df = df.iloc[np.linspace(0, len(df) - 1, 600).astype(int)] | |
| fig, ax = plt.subplots(figsize=(8.0, 3.5), dpi=100) | |
| ax.plot(df["time_s"], df["load_index_percent"], linewidth=2) | |
| ax.set_title("Estimated Bridge Load Index Over Time") | |
| ax.set_xlabel("Video time (seconds)") | |
| ax.set_ylabel("Load index (%)") | |
| ax.grid(True, alpha=0.25) | |
| ax.set_ylim(bottom=0) | |
| fig.tight_layout() | |
| fig.canvas.draw() | |
| rgba = np.asarray(fig.canvas.buffer_rgba()) | |
| rgb = cv2.cvtColor(rgba, cv2.COLOR_RGBA2RGB) | |
| plt.close(fig) | |
| return rgb | |
| def build_metrics_html( | |
| total_count: int, | |
| class_counts: Dict[str, int], | |
| cumulative_kg: float, | |
| live_load_kg: float, | |
| load_index_percent: float, | |
| frame_idx: int, | |
| total_frames: int, | |
| elapsed: float, | |
| proc_fps: float, | |
| engine: str, | |
| ) -> str: | |
| pct = (frame_idx / total_frames * 100.0) if total_frames else 0.0 | |
| tonnes = cumulative_kg / 1000.0 | |
| live_tonnes = live_load_kg / 1000.0 | |
| def c(name: str) -> int: | |
| return int(class_counts.get(name, 0)) | |
| return f""" | |
| <div style="font-family:Inter,system-ui,Arial;"> | |
| <div style="display:grid;grid-template-columns:1fr 1fr;gap:10px;margin-bottom:12px;"> | |
| <div style="padding:16px;border-radius:18px;background:linear-gradient(135deg,#1d4ed8,#312e81);color:white;"> | |
| <div style="font-size:11px;letter-spacing:1px;opacity:.86;">OBJECTS CROSSED</div> | |
| <div style="font-size:46px;font-weight:850;line-height:1;">{total_count}</div> | |
| </div> | |
| <div style="padding:16px;border-radius:18px;background:linear-gradient(135deg,#be185d,#7e22ce);color:white;"> | |
| <div style="font-size:11px;letter-spacing:1px;opacity:.86;">CUMULATIVE EST. MASS</div> | |
| <div style="font-size:36px;font-weight:850;line-height:1;">{tonnes:.1f} t</div> | |
| </div> | |
| </div> | |
| <div style="display:grid;grid-template-columns:1fr 1fr;gap:10px;margin-bottom:12px;"> | |
| <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:white;"> | |
| <div style="font-size:12px;color:#6b7280;">Live bridge load</div> | |
| <div style="font-size:28px;font-weight:800;color:#111827;">{live_tonnes:.1f} t</div> | |
| </div> | |
| <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:white;"> | |
| <div style="font-size:12px;color:#6b7280;">Load index</div> | |
| <div style="font-size:28px;font-weight:800;color:#111827;">{load_index_percent:.1f}%</div> | |
| </div> | |
| </div> | |
| <div style="padding:14px;border:1px solid #e5e7eb;border-radius:14px;background:#ffffff;margin-bottom:12px;"> | |
| <div style="font-size:12px;color:#6b7280;margin-bottom:8px;">Crossings by class</div> | |
| <div style="display:grid;grid-template-columns:1fr 1fr;gap:7px;font-size:13px;"> | |
| <div>🚶 People: <b>{c("person")}</b></div> | |
| <div>🚗 Cars: <b>{c("car")}</b></div> | |
| <div>🏍️ Motorcycles: <b>{c("motorcycle")}</b></div> | |
| <div>🚲 Bicycles: <b>{c("bicycle")}</b></div> | |
| <div>🚌 Buses: <b>{c("bus")}</b></div> | |
| <div>🚛 Trucks: <b>{c("truck")}</b></div> | |
| <div>🐄 Cows: <b>{c("cow")}</b></div> | |
| <div>🐑 Sheep/goats: <b>{c("sheep") + c("goat")}</b></div> | |
| <div>🐴 Horse/donkey: <b>{c("horse") + c("donkey")}</b></div> | |
| </div> | |
| </div> | |
| <div style="font-size:12px;color:#6b7280;margin-bottom:4px;display:flex;justify-content:space-between;"> | |
| <span>Frame {frame_idx} / {total_frames}</span> | |
| <span>{pct:.1f}% · {elapsed:.1f}s · {proc_fps:.1f} FPS · {DEVICE} · {engine}</span> | |
| </div> | |
| <div style="height:8px;background:#e5e7eb;border-radius:999px;overflow:hidden;"> | |
| <div style="height:100%;width:{pct:.2f}%;background:#4f46e5;"></div> | |
| </div> | |
| </div> | |
| """ | |
| def draw_dashboard( | |
| frame: np.ndarray, | |
| total_count: int, | |
| cumulative_kg: float, | |
| live_load_kg: float, | |
| load_index_percent: float, | |
| proc_fps: float, | |
| engine: str, | |
| ) -> np.ndarray: | |
| overlay = frame.copy() | |
| x1, y1, x2, y2 = 18, 18, 600, 164 | |
| cv2.rectangle(overlay, (x1, y1), (x2, y2), (18, 24, 38), -1) | |
| frame = cv2.addWeighted(overlay, 0.82, frame, 0.18, 0) | |
| cv2.putText( | |
| frame, | |
| "BRIDGE TRAFFIC + LIVESTOCK DEMO", | |
| (34, 48), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.72, | |
| (255, 255, 255), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| cv2.putText( | |
| frame, | |
| f"Crossed: {total_count} | Cumulative est. mass: {cumulative_kg/1000.0:.1f} t", | |
| (34, 82), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.58, | |
| (230, 240, 255), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| cv2.putText( | |
| frame, | |
| f"Live load: {live_load_kg/1000.0:.1f} t | Load index: {load_index_percent:.1f}%", | |
| (34, 114), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.58, | |
| (220, 245, 230), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| cv2.putText( | |
| frame, | |
| f"{proc_fps:.1f} processing FPS | {DEVICE} | {engine}", | |
| (34, 144), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.52, | |
| (230, 230, 255), | |
| 1, | |
| cv2.LINE_AA, | |
| ) | |
| return frame | |
| def annotate_frame( | |
| frame: np.ndarray, | |
| detections: sv.Detections, | |
| canonical_names: List[str], | |
| line_y: int, | |
| roi_top_y: int, | |
| roi_bottom_y: int, | |
| class_counts: Dict[str, int], | |
| total_count: int, | |
| cumulative_kg: float, | |
| live_load_kg: float, | |
| load_index_percent: float, | |
| proc_fps: float, | |
| engine: str, | |
| ) -> np.ndarray: | |
| h, w = frame.shape[:2] | |
| # Bridge deck ROI. | |
| overlay = frame.copy() | |
| cv2.rectangle(overlay, (0, roi_top_y), (w, roi_bottom_y), (90, 90, 90), -1) | |
| frame = cv2.addWeighted(overlay, 0.08, frame, 0.92, 0) | |
| # Counting line. | |
| cv2.line(frame, (0, line_y), (w, line_y), (40, 230, 255), 3) | |
| cv2.putText( | |
| frame, | |
| "COUNTING LINE", | |
| (24, max(28, line_y - 12)), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.60, | |
| (40, 230, 255), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| # ROI borders. | |
| cv2.line(frame, (0, roi_top_y), (w, roi_top_y), (170, 170, 170), 1) | |
| cv2.line(frame, (0, roi_bottom_y), (w, roi_bottom_y), (170, 170, 170), 1) | |
| if len(detections) > 0: | |
| tracker_ids = detections.tracker_id | |
| if tracker_ids is None: | |
| tracker_ids = [None] * len(detections) | |
| confidences = detections.confidence | |
| if confidences is None: | |
| confidences = [0.0] * len(detections) | |
| for i, (xyxy, conf, tid) in enumerate(zip(detections.xyxy, confidences, tracker_ids)): | |
| if i >= len(canonical_names): | |
| name = "object" | |
| else: | |
| name = canonical_names[i] | |
| x1, y1, x2, y2 = map(int, xyxy) | |
| color = COLOR_BY_NAME_BGR.get(name, (80, 220, 255)) | |
| display = DISPLAY_NAME.get(name, name) | |
| weight_t = DEFAULT_WEIGHTS_KG.get(name, 0) / 1000.0 | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) | |
| id_txt = f"#{int(tid)} " if tid is not None and int(tid) >= 0 else "" | |
| label = f"{id_txt}{display} {float(conf):.2f} ~{weight_t:.2f}t" | |
| (tw, th), base = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.52, 1) | |
| label_y1 = max(0, y1 - th - base - 8) | |
| cv2.rectangle(frame, (x1, label_y1), (x1 + tw + 10, y1), color, -1) | |
| cv2.putText( | |
| frame, | |
| label, | |
| (x1 + 5, y1 - 6), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.52, | |
| (255, 255, 255), | |
| 1, | |
| cv2.LINE_AA, | |
| ) | |
| frame = draw_dashboard( | |
| frame=frame, | |
| total_count=total_count, | |
| cumulative_kg=cumulative_kg, | |
| live_load_kg=live_load_kg, | |
| load_index_percent=load_index_percent, | |
| proc_fps=proc_fps, | |
| engine=engine, | |
| ) | |
| compact_items = [] | |
| for k in ["person", "car", "motorcycle", "bicycle", "bus", "truck", "cow", "sheep", "goat", "horse", "donkey"]: | |
| v = int(class_counts.get(k, 0)) | |
| if v > 0: | |
| compact_items.append(f"{DISPLAY_NAME.get(k, k)}: {v}") | |
| text = " | ".join(compact_items) if compact_items else "No crossings yet" | |
| cv2.putText(frame, text[:140], (22, h - 24), cv2.FONT_HERSHEY_SIMPLEX, 0.58, (255, 255, 255), 2, cv2.LINE_AA) | |
| return frame | |
| def final_summary_md( | |
| total_count: int, | |
| class_counts: Dict[str, int], | |
| cumulative_kg: float, | |
| peak_live_load_kg: float, | |
| peak_load_index: float, | |
| auto_video_used: str, | |
| ) -> str: | |
| rows = [] | |
| for name in ["person", "bicycle", "car", "motorcycle", "bus", "truck", "cow", "sheep", "goat", "horse", "donkey"]: | |
| count = int(class_counts.get(name, 0)) | |
| if count > 0: | |
| rows.append(f"| {DISPLAY_NAME.get(name, name)} | {count} |") | |
| if not rows: | |
| rows.append("| None | 0 |") | |
| video_line = f"\n**Default video used:** `{auto_video_used}`\n" if auto_video_used else "" | |
| return f""" | |
| ### Final summary | |
| {video_line} | |
| **Total crossings:** {total_count} | |
| | Class | Count | | |
| |---|---:| | |
| {chr(10).join(rows)} | |
| **Cumulative estimated mass:** {cumulative_kg/1000.0:.2f} tonnes | |
| **Peak estimated live load:** {peak_live_load_kg/1000.0:.2f} tonnes | |
| **Peak bridge load index:** {peak_load_index:.1f}% | |
| This is a demonstration traffic-load indicator. Real bridge stress needs axle loads, bridge geometry, material properties, span length, lane position and engineering calibration. | |
| """ | |
| # --------------------------------------------------------------------- | |
| # Main video processing generator | |
| # --------------------------------------------------------------------- | |
| def process_video( | |
| video_path, | |
| engine, | |
| yolo_model_file, | |
| confidence, | |
| frame_stride, | |
| inference_width, | |
| line_position_percent, | |
| roi_top_percent, | |
| roi_bottom_percent, | |
| reference_capacity_tonnes, | |
| person_weight_kg, | |
| bicycle_weight_kg, | |
| motorcycle_weight_kg, | |
| car_weight_t, | |
| bus_weight_t, | |
| truck_weight_t, | |
| cow_weight_kg, | |
| sheep_weight_kg, | |
| goat_weight_kg, | |
| horse_weight_kg, | |
| donkey_weight_kg, | |
| ): | |
| if video_path is None: | |
| yield ( | |
| None, | |
| build_metrics_html(0, {}, 0, 0, 0, 0, 0, 0, 0, str(engine)), | |
| make_empty_plot(), | |
| "No video found. Put an `.mp4` file in the same folder as `app.py`, or upload one.", | |
| None, | |
| None, | |
| ) | |
| return | |
| # Gradio can pass a dict in some versions. | |
| if isinstance(video_path, dict): | |
| video_path = video_path.get("path") or video_path.get("name") | |
| if not video_path or not os.path.exists(video_path): | |
| yield ( | |
| None, | |
| build_metrics_html(0, {}, 0, 0, 0, 0, 0, 0, 0, str(engine)), | |
| make_empty_plot(), | |
| f"Video not found: {video_path}", | |
| None, | |
| None, | |
| ) | |
| return | |
| DEFAULT_WEIGHTS_KG.update({ | |
| "person": int(person_weight_kg), | |
| "bicycle": int(bicycle_weight_kg), | |
| "motorcycle": int(motorcycle_weight_kg), | |
| "car": int(float(car_weight_t) * 1000), | |
| "bus": int(float(bus_weight_t) * 1000), | |
| "truck": int(float(truck_weight_t) * 1000), | |
| "cow": int(cow_weight_kg), | |
| "sheep": int(sheep_weight_kg), | |
| "goat": int(goat_weight_kg), | |
| "horse": int(horse_weight_kg), | |
| "donkey": int(donkey_weight_kg), | |
| }) | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise RuntimeError(f"Could not open video: {video_path}") | |
| fps = float(cap.get(cv2.CAP_PROP_FPS) or 25.0) | |
| if fps <= 1: | |
| fps = 25.0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) | |
| if width <= 0 or height <= 0: | |
| cap.release() | |
| raise RuntimeError("Could not read video dimensions.") | |
| line_y = int(height * float(line_position_percent) / 100.0) | |
| roi_top_y = int(height * float(roi_top_percent) / 100.0) | |
| roi_bottom_y = int(height * float(roi_bottom_percent) / 100.0) | |
| if roi_bottom_y <= roi_top_y: | |
| roi_top_y = int(height * 0.20) | |
| roi_bottom_y = int(height * 0.90) | |
| reference_capacity_kg = max(1.0, float(reference_capacity_tonnes) * 1000.0) | |
| yield ( | |
| None, | |
| build_metrics_html(0, {}, 0, 0, 0, 0, total_frames, 0, 0, str(engine)), | |
| make_empty_plot(), | |
| f"### Starting analysis on `{Path(video_path).name}`...", | |
| None, | |
| None, | |
| ) | |
| # Preload model before loop. | |
| if str(engine).startswith("YOLO"): | |
| _ = load_yolo_model(str(yolo_model_file)) | |
| else: | |
| _ = load_rfdetr_medium() | |
| tracker = sv.ByteTrack(frame_rate=int(round(fps))) | |
| out_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name | |
| out_csv_path = tempfile.NamedTemporaryFile(suffix=".csv", delete=False).name | |
| writer = cv2.VideoWriter( | |
| out_video_path, | |
| cv2.VideoWriter_fourcc(*"mp4v"), | |
| fps, | |
| (width, height), | |
| ) | |
| last_detections = sv.Detections.empty() | |
| last_names: List[str] = [] | |
| last_side_by_id: Dict[int, int] = {} | |
| counted_ids = set() | |
| track_name_by_id: Dict[int, str] = {} | |
| class_counts = {name: 0 for name in TARGET_CANONICAL_NAMES} | |
| total_count = 0 | |
| cumulative_kg = 0.0 | |
| history: List[Dict] = [] | |
| events: List[Dict] = [] | |
| peak_live_load_kg = 0.0 | |
| peak_load_index = 0.0 | |
| start_wall = time.time() | |
| last_yield_wall = 0.0 | |
| last_plot_wall = 0.0 | |
| latest_plot = make_empty_plot() | |
| processed = 0 | |
| frame_idx = 0 | |
| final_frame_rgb = None | |
| while True: | |
| ok, frame = cap.read() | |
| if not ok: | |
| break | |
| if frame_idx % int(frame_stride) == 0: | |
| detections, names = predict_objects( | |
| engine=str(engine), | |
| yolo_model_file=str(yolo_model_file), | |
| frame_bgr=frame, | |
| confidence=float(confidence), | |
| inference_width=int(inference_width), | |
| ) | |
| detections = tracker.update_with_detections(detections) | |
| # Preserve name alignment after tracker update. | |
| # ByteTrack keeps detections order, so this is usually aligned. | |
| if len(names) != len(detections): | |
| names = names[:len(detections)] | |
| if len(names) < len(detections): | |
| names += ["object"] * (len(detections) - len(names)) | |
| last_detections = detections | |
| last_names = names | |
| else: | |
| detections = last_detections | |
| names = last_names | |
| centres = detection_centres(detections) | |
| live_load_kg = 0.0 | |
| if len(detections) > 0 and detections.tracker_id is not None: | |
| for i, (centre, tid) in enumerate(zip(centres, detections.tracker_id)): | |
| if tid is None or int(tid) < 0: | |
| continue | |
| tid = int(tid) | |
| name = names[i] if i < len(names) else track_name_by_id.get(tid, "object") | |
| if name == "object": | |
| continue | |
| track_name_by_id[tid] = name | |
| cy = float(centre[1]) | |
| # Live load only for objects currently inside bridge deck ROI. | |
| if roi_top_y <= cy <= roi_bottom_y: | |
| live_load_kg += float(DEFAULT_WEIGHTS_KG.get(name, 0)) | |
| current_side = side_of_line(cy, line_y) | |
| previous_side = last_side_by_id.get(tid) | |
| if current_side != 0: | |
| if previous_side is not None and previous_side != 0 and previous_side != current_side: | |
| if tid not in counted_ids: | |
| counted_ids.add(tid) | |
| total_count += 1 | |
| class_counts[name] = int(class_counts.get(name, 0)) + 1 | |
| weight_kg = float(DEFAULT_WEIGHTS_KG.get(name, 0)) | |
| cumulative_kg += weight_kg | |
| direction = "down" if previous_side < current_side else "up" | |
| events.append({ | |
| "video_time_s": frame_idx / fps, | |
| "frame": frame_idx, | |
| "tracker_id": tid, | |
| "object_type": name, | |
| "display_type": DISPLAY_NAME.get(name, name), | |
| "direction": direction, | |
| "estimated_weight_kg": weight_kg, | |
| "cumulative_estimated_mass_kg": cumulative_kg, | |
| }) | |
| last_side_by_id[tid] = current_side | |
| load_index_percent = (live_load_kg / reference_capacity_kg) * 100.0 | |
| peak_live_load_kg = max(peak_live_load_kg, live_load_kg) | |
| peak_load_index = max(peak_load_index, load_index_percent) | |
| elapsed = time.time() - start_wall | |
| processed += 1 | |
| proc_fps = processed / max(elapsed, 1e-6) | |
| history.append({ | |
| "time_s": frame_idx / fps, | |
| "frame": frame_idx, | |
| "total_crossings": total_count, | |
| "people_crossed": class_counts.get("person", 0), | |
| "bicycles_crossed": class_counts.get("bicycle", 0), | |
| "cars_crossed": class_counts.get("car", 0), | |
| "motorcycles_crossed": class_counts.get("motorcycle", 0), | |
| "buses_crossed": class_counts.get("bus", 0), | |
| "trucks_crossed": class_counts.get("truck", 0), | |
| "cows_crossed": class_counts.get("cow", 0), | |
| "sheep_goats_crossed": class_counts.get("sheep", 0) + class_counts.get("goat", 0), | |
| "horse_donkey_crossed": class_counts.get("horse", 0) + class_counts.get("donkey", 0), | |
| "live_load_kg": live_load_kg, | |
| "live_load_tonnes": live_load_kg / 1000.0, | |
| "load_index_percent": load_index_percent, | |
| "cumulative_estimated_mass_kg": cumulative_kg, | |
| "cumulative_estimated_mass_tonnes": cumulative_kg / 1000.0, | |
| }) | |
| annotated = annotate_frame( | |
| frame=frame, | |
| detections=detections, | |
| canonical_names=names, | |
| line_y=line_y, | |
| roi_top_y=roi_top_y, | |
| roi_bottom_y=roi_bottom_y, | |
| class_counts=class_counts, | |
| total_count=total_count, | |
| cumulative_kg=cumulative_kg, | |
| live_load_kg=live_load_kg, | |
| load_index_percent=load_index_percent, | |
| proc_fps=proc_fps, | |
| engine=str(engine), | |
| ) | |
| writer.write(annotated) | |
| final_frame_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) | |
| now = time.time() | |
| if now - last_plot_wall >= 1.0: | |
| latest_plot = render_load_plot(history) | |
| last_plot_wall = now | |
| if now - last_yield_wall >= 0.35: | |
| last_yield_wall = now | |
| yield ( | |
| final_frame_rgb, | |
| build_metrics_html( | |
| total_count=total_count, | |
| class_counts=class_counts, | |
| cumulative_kg=cumulative_kg, | |
| live_load_kg=live_load_kg, | |
| load_index_percent=load_index_percent, | |
| frame_idx=frame_idx + 1, | |
| total_frames=total_frames, | |
| elapsed=elapsed, | |
| proc_fps=proc_fps, | |
| engine=str(engine), | |
| ), | |
| latest_plot, | |
| "### Live analysis running...", | |
| None, | |
| None, | |
| ) | |
| frame_idx += 1 | |
| cap.release() | |
| writer.release() | |
| history_df = pd.DataFrame(history) | |
| events_df = pd.DataFrame(events) | |
| if not events_df.empty: | |
| # Save both frame-level history and crossing events in one CSV-like file | |
| # by writing two separate CSV sections. | |
| with open(out_csv_path, "w", encoding="utf-8") as f: | |
| f.write("# FRAME_LEVEL_LOAD_INDEX\n") | |
| history_df.to_csv(f, index=False) | |
| f.write("\n# CROSSING_EVENTS\n") | |
| events_df.to_csv(f, index=False) | |
| else: | |
| history_df.to_csv(out_csv_path, index=False) | |
| elapsed = time.time() - start_wall | |
| proc_fps = processed / max(elapsed, 1e-6) | |
| final_plot = render_load_plot(history) | |
| yield ( | |
| final_frame_rgb, | |
| build_metrics_html( | |
| total_count=total_count, | |
| class_counts=class_counts, | |
| cumulative_kg=cumulative_kg, | |
| live_load_kg=0, | |
| load_index_percent=0, | |
| frame_idx=total_frames if total_frames else frame_idx, | |
| total_frames=total_frames if total_frames else frame_idx, | |
| elapsed=elapsed, | |
| proc_fps=proc_fps, | |
| engine=str(engine), | |
| ), | |
| final_plot, | |
| final_summary_md( | |
| total_count=total_count, | |
| class_counts=class_counts, | |
| cumulative_kg=cumulative_kg, | |
| peak_live_load_kg=peak_live_load_kg, | |
| peak_load_index=peak_load_index, | |
| auto_video_used=video_path if str(video_path).startswith(str(APP_DIR)) else "", | |
| ), | |
| out_video_path, | |
| out_csv_path, | |
| ) | |
| # --------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------- | |
| CUSTOM_CSS = """ | |
| .gradio-container { | |
| max-width: 1360px !important; | |
| margin: auto !important; | |
| } | |
| #hero { | |
| text-align: center; | |
| padding: 16px 8px 6px 8px; | |
| } | |
| #hero h1 { | |
| font-weight: 850; | |
| letter-spacing: -0.8px; | |
| margin-bottom: 2px; | |
| } | |
| #hero p { | |
| color: #64748b; | |
| font-size: 16px; | |
| margin-top: 0; | |
| } | |
| .panel { | |
| border: 1px solid #e5e7eb; | |
| border-radius: 18px; | |
| padding: 16px; | |
| background: #ffffff; | |
| box-shadow: 0 8px 24px rgba(15, 23, 42, 0.045); | |
| } | |
| #live-frame img, #load-plot img { | |
| border-radius: 14px; | |
| } | |
| footer { | |
| visibility: hidden; | |
| } | |
| """ | |
| with gr.Blocks( | |
| title="Fast Bridge Traffic + Livestock Load Demo", | |
| theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"), | |
| css=CUSTOM_CSS, | |
| ) as demo: | |
| with gr.Row(elem_id="hero"): | |
| gr.Markdown( | |
| """ | |
| # 🌉 Fast Bridge Traffic + Livestock Load Demo | |
| YOLO-small / RF-DETR Medium detection, ByteTrack tracking, line-crossing counts, | |
| estimated object weights, and live bridge load-index over time. | |
| """ | |
| ) | |
| if DEFAULT_VIDEO: | |
| gr.Markdown(f"✅ Found default video next to `app.py`: `{Path(DEFAULT_VIDEO).name}`. The app will auto-start inference when opened.") | |
| else: | |
| gr.Markdown("⚠️ No local video found next to `app.py`. Upload a video or place `bridge.mp4`, `traffic.mp4`, `input.mp4`, or any `.mp4` in the same folder.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="panel"): | |
| gr.Markdown("### 1) Video") | |
| video_input = gr.Video( | |
| label="Video input", | |
| sources=["upload"], | |
| value=DEFAULT_VIDEO, | |
| format="mp4", | |
| height=260, | |
| ) | |
| start_btn = gr.Button("▶ Start / rerun analysis", variant="primary", size="lg") | |
| gr.Markdown("### 2) Inference engine") | |
| engine = gr.Radio( | |
| choices=[ | |
| "YOLO small - fastest recommended", | |
| "RF-DETR Medium - slower but strong", | |
| ], | |
| value="YOLO small - fastest recommended", | |
| label="Engine", | |
| ) | |
| yolo_model_file = gr.Textbox( | |
| value="yolo11s.pt", | |
| label="YOLO model file/name", | |
| info="Use yolo11s.pt for small. Put your custom .pt in the same folder as app.py and type its filename here.", | |
| ) | |
| confidence = gr.Slider( | |
| minimum=0.10, | |
| maximum=0.90, | |
| value=0.35, | |
| step=0.05, | |
| label="Confidence threshold", | |
| ) | |
| frame_stride = gr.Slider( | |
| minimum=1, | |
| maximum=12, | |
| value=3, | |
| step=1, | |
| label="Frame stride", | |
| info="Detect every Nth frame. 3-5 is much faster than every frame.", | |
| ) | |
| inference_width = gr.Slider( | |
| minimum=384, | |
| maximum=1280, | |
| value=640, | |
| step=64, | |
| label="Inference image size / width", | |
| info="Lower is faster. Try 512 or 640 for fast demos.", | |
| ) | |
| with gr.Accordion("Bridge settings", open=False): | |
| line_position_percent = gr.Slider( | |
| minimum=10, | |
| maximum=90, | |
| value=55, | |
| step=1, | |
| label="Counting line vertical position (%)", | |
| ) | |
| roi_top_percent = gr.Slider( | |
| minimum=0, | |
| maximum=90, | |
| value=20, | |
| step=1, | |
| label="Bridge deck ROI top (%)", | |
| ) | |
| roi_bottom_percent = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| value=90, | |
| step=1, | |
| label="Bridge deck ROI bottom (%)", | |
| ) | |
| reference_capacity_tonnes = gr.Slider( | |
| minimum=1, | |
| maximum=250, | |
| value=40, | |
| step=1, | |
| label="Reference live-load capacity for demo index (tonnes)", | |
| ) | |
| with gr.Accordion("Estimated weights", open=False): | |
| person_weight_kg = gr.Number(value=75, label="Person weight estimate (kg)") | |
| bicycle_weight_kg = gr.Number(value=120, label="Bicycle + rider estimate (kg)") | |
| motorcycle_weight_kg = gr.Number(value=250, label="Motorcycle estimate (kg)") | |
| car_weight_t = gr.Number(value=1.5, label="Car estimate (tonnes)") | |
| bus_weight_t = gr.Number(value=12.0, label="Bus estimate (tonnes)") | |
| truck_weight_t = gr.Number(value=18.0, label="Truck estimate (tonnes)") | |
| cow_weight_kg = gr.Number(value=450, label="Cow estimate (kg)") | |
| sheep_weight_kg = gr.Number(value=60, label="Sheep estimate (kg)") | |
| goat_weight_kg = gr.Number(value=45, label="Goat estimate (kg)") | |
| horse_weight_kg = gr.Number(value=350, label="Horse estimate (kg)") | |
| donkey_weight_kg = gr.Number(value=180, label="Donkey estimate (kg)") | |
| gr.Markdown( | |
| """ | |
| **Fast demo settings:** YOLO small, confidence 0.30-0.40, | |
| frame stride 3-5, image size 512-640. | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| with gr.Group(elem_classes="panel"): | |
| gr.Markdown("### Live annotated video") | |
| live_frame = gr.Image( | |
| show_label=False, | |
| elem_id="live-frame", | |
| height=500, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="panel"): | |
| gr.Markdown("### Live metrics") | |
| metrics_html = gr.HTML( | |
| value=build_metrics_html( | |
| total_count=0, | |
| class_counts={}, | |
| cumulative_kg=0, | |
| live_load_kg=0, | |
| load_index_percent=0, | |
| frame_idx=0, | |
| total_frames=0, | |
| elapsed=0, | |
| proc_fps=0, | |
| engine="not started", | |
| ) | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="panel"): | |
| gr.Markdown("### Load index over time") | |
| load_plot = gr.Image( | |
| show_label=False, | |
| elem_id="load-plot", | |
| height=300, | |
| value=make_empty_plot(), | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="panel"): | |
| gr.Markdown("### Final annotated video") | |
| video_output = gr.Video(label="Replay / download annotated video", height=270) | |
| with gr.Column(scale=1): | |
| with gr.Group(elem_classes="panel"): | |
| gr.Markdown("### Final summary") | |
| summary_output = gr.Markdown("The summary will appear after analysis.") | |
| csv_output = gr.File(label="Download CSV") | |
| inputs = [ | |
| video_input, | |
| engine, | |
| yolo_model_file, | |
| confidence, | |
| frame_stride, | |
| inference_width, | |
| line_position_percent, | |
| roi_top_percent, | |
| roi_bottom_percent, | |
| reference_capacity_tonnes, | |
| person_weight_kg, | |
| bicycle_weight_kg, | |
| motorcycle_weight_kg, | |
| car_weight_t, | |
| bus_weight_t, | |
| truck_weight_t, | |
| cow_weight_kg, | |
| sheep_weight_kg, | |
| goat_weight_kg, | |
| horse_weight_kg, | |
| donkey_weight_kg, | |
| ] | |
| outputs = [ | |
| live_frame, | |
| metrics_html, | |
| load_plot, | |
| summary_output, | |
| video_output, | |
| csv_output, | |
| ] | |
| start_btn.click( | |
| fn=process_video, | |
| inputs=inputs, | |
| outputs=outputs, | |
| ) | |
| # Auto-start when a local video exists beside app.py. | |
| if DEFAULT_VIDEO: | |
| demo.load( | |
| fn=process_video, | |
| inputs=inputs, | |
| outputs=outputs, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=2).launch() | |