arm-model / model /hybrid_system.py
pragadeeshv23's picture
Upload folder using huggingface_hub
5b86813 verified
#!/usr/bin/env python3
"""
Hybrid Road Anomaly Detection System
Based on: Nature Scientific Reports (Nov 2025) - YOLOv11 + CNN-BiGRU
Integrates:
1. YOLOv11 – real-time spatial detection (bounding boxes)
2. CNN-BiGRU – temporal severity prediction (Minor / Moderate / Severe / Critical)
Pipeline (per video frame):
Frame β†’ YOLOv11 detect β†’ crop anomaly regions β†’
maintain temporal buffer per tracked anomaly β†’
CNN-BiGRU severity prediction β†’ annotated output
Target performance:
mAP@0.5 = 96.92% | 105 FPS | 9.5 ms latency
"""
import os
import sys
import time
import json
import logging
from pathlib import Path
from collections import defaultdict, deque
from typing import List, Dict, Any, Optional, Tuple, Union
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from yolo_detection import YOLOv11Detector, CLASS_NAMES, CLASS_COLORS, SEVERITY_WEIGHTS
from cnn_bigru import CNNBiGRU, CNNSeverityClassifier, SEVERITY_LABELS, PATCH_SIZE
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)
logger = logging.getLogger("HybridSystem")
# ---------------------------------------------------------------------------
# Colour helpers
# ---------------------------------------------------------------------------
SEVERITY_COLORS = {
0: (0, 255, 0), # Green – Minor
1: (0, 255, 255), # Yellow – Moderate
2: (0, 165, 255), # Orange – Severe
3: (0, 0, 255), # Red – Critical
}
# ═══════════════════════════════════════════════════════════════════════════
# Simple IoU-based tracker (lightweight, no extra deps)
# ═══════════════════════════════════════════════════════════════════════════
class SimpleTracker:
"""
Associates detections across frames using IoU overlap.
Maintains a temporal buffer of cropped patches per track for BiGRU.
"""
def __init__(self, iou_threshold: float = 0.3, max_age: int = 5, buffer_len: int = 8):
self.iou_threshold = iou_threshold
self.max_age = max_age # frames before a track is dropped
self.buffer_len = buffer_len # temporal sequence length for BiGRU
self.tracks: Dict[int, Dict[str, Any]] = {}
self.next_id = 0
@staticmethod
def _iou(boxA: List[int], boxB: List[int]) -> float:
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
inter = max(0, xB - xA) * max(0, yB - yA)
areaA = max(1, (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
areaB = max(1, (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
return inter / (areaA + areaB - inter + 1e-6)
def update(
self,
detections: List[Dict[str, Any]],
frame: np.ndarray,
) -> Dict[int, Dict[str, Any]]:
"""
Match current detections to existing tracks.
Returns dict track_id β†’ {bbox, class_id, class_name, confidence,
severity, patches: deque}
"""
h, w = frame.shape[:2]
# Greedy assignment: for each detection find best matching track
used_tracks = set()
matched: List[Tuple[int, Dict]] = []
unmatched_dets: List[Dict] = []
for det in detections:
best_iou = 0.0
best_tid = -1
for tid, track in self.tracks.items():
if tid in used_tracks:
continue
iou_val = self._iou(det["bbox"], track["bbox"])
if iou_val > best_iou:
best_iou = iou_val
best_tid = tid
if best_iou >= self.iou_threshold and best_tid >= 0:
matched.append((best_tid, det))
used_tracks.add(best_tid)
else:
unmatched_dets.append(det)
# Update matched tracks
for tid, det in matched:
self.tracks[tid]["bbox"] = det["bbox"]
self.tracks[tid]["class_id"] = det["class_id"]
self.tracks[tid]["class_name"] = det["class_name"]
self.tracks[tid]["confidence"] = det["confidence"]
self.tracks[tid]["age"] = 0
# Crop and buffer patch
patch = self._crop(frame, det["bbox"], h, w)
if patch is not None:
self.tracks[tid]["patches"].append(patch)
# Create new tracks for unmatched detections
for det in unmatched_dets:
tid = self.next_id
self.next_id += 1
patch = self._crop(frame, det["bbox"], h, w)
patches: deque = deque(maxlen=self.buffer_len)
if patch is not None:
patches.append(patch)
self.tracks[tid] = {
**det,
"age": 0,
"patches": patches,
"severity_pred": None,
"severity_probs": None,
}
# Age out stale tracks
stale = []
for tid in self.tracks:
if tid not in used_tracks and tid not in {t for t, _ in matched}:
self.tracks[tid]["age"] += 1
if self.tracks[tid]["age"] > self.max_age:
stale.append(tid)
for tid in stale:
del self.tracks[tid]
return self.tracks
def _crop(self, frame: np.ndarray, bbox: List[int], h: int, w: int) -> Optional[np.ndarray]:
x1, y1, x2, y2 = bbox
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w, x2), min(h, y2)
crop = frame[y1:y2, x1:x2]
if crop.size == 0:
return None
crop = cv2.resize(crop, (PATCH_SIZE, PATCH_SIZE))
return crop
def get_sequences(self, min_len: int = 2) -> Dict[int, np.ndarray]:
"""
Return tracks that have at least `min_len` buffered patches,
formatted as (T, C, H, W) float32 tensors in [0,1].
"""
seqs: Dict[int, np.ndarray] = {}
for tid, track in self.tracks.items():
patches = list(track["patches"])
if len(patches) < min_len:
continue
# Convert BGR β†’ RGB, normalise
frames = []
for p in patches:
p_rgb = cv2.cvtColor(p, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
p_chw = np.transpose(p_rgb, (2, 0, 1)) # (C, H, W)
frames.append(p_chw)
seqs[tid] = np.stack(frames) # (T, C, H, W)
return seqs
# ═══════════════════════════════════════════════════════════════════════════
# Hybrid System
# ═══════════════════════════════════════════════════════════════════════════
class HybridRoadAnomalySystem:
"""
Complete YOLOv11 + CNN-BiGRU road anomaly detection & severity system.
"""
def __init__(
self,
yolo_weights: str = "best.pt",
bigru_weights: Optional[str] = None,
img_size: int = 416,
conf_threshold: float = 0.02,
iou_threshold: float = 0.45,
seq_len: int = 8,
device: Optional[Union[int, str]] = None,
use_attention: bool = False,
):
# Device
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = str(device)
# --- YOLOv11 spatial detector ---
self.yolo = YOLOv11Detector(
model_path=yolo_weights,
img_size=img_size,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
device=0 if self.device == "cuda" else "cpu",
)
# --- CNN-BiGRU severity model ---
self.seq_len = seq_len
self.use_attention = use_attention
self.bigru: Optional[CNNBiGRU] = None
if bigru_weights and Path(bigru_weights).exists():
self.bigru = CNNBiGRU(
in_channels=3,
hidden_size=128,
num_gru_layers=2,
num_severity_classes=len(SEVERITY_LABELS),
)
self.bigru.load_state_dict(
torch.load(bigru_weights, map_location=self.device)
)
self.bigru.to(self.device)
self.bigru.eval()
logger.info("CNN-BiGRU loaded from %s", bigru_weights)
else:
logger.warning(
"No BiGRU weights – severity will use heuristic from YOLO class."
)
# --- Tracker ---
self.tracker = SimpleTracker(
iou_threshold=0.3,
max_age=5,
buffer_len=seq_len,
)
logger.info("Hybrid system ready (YOLO=%s BiGRU=%s)",
yolo_weights, bigru_weights or "heuristic")
# ------------------------------------------------------------------
# Core processing: single frame
# ------------------------------------------------------------------
def process_frame(
self,
frame: np.ndarray,
) -> Tuple[List[Dict[str, Any]], np.ndarray]:
"""
Process one video frame through the full pipeline.
Returns:
enriched_detections: list of dicts with severity info added.
annotated_frame: BGR image with boxes + severity drawn.
"""
# 1. YOLO detection
detections = self.yolo.detect(frame)
# 2. Update tracker
tracks = self.tracker.update(detections, frame)
# 3. Severity prediction (BiGRU or heuristic)
enriched = self._predict_severity(tracks)
# 4. Annotate
annotated = self._annotate(frame, enriched)
return enriched, annotated
# ------------------------------------------------------------------
# Severity prediction
# ------------------------------------------------------------------
def _predict_severity(
self,
tracks: Dict[int, Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
For each track, run CNN-BiGRU if enough temporal context is
available; otherwise fall back to a heuristic mapping.
"""
results: List[Dict[str, Any]] = []
if self.bigru is not None:
# Batch all eligible sequences
seqs = self.tracker.get_sequences(min_len=2)
if seqs:
# Pad to seq_len and batch
batch_ids: List[int] = []
batch_tensors: List[torch.Tensor] = []
for tid, seq_np in seqs.items():
t = seq_np.shape[0]
if t < self.seq_len:
# Pad by repeating last frame
pad = np.tile(seq_np[-1:], (self.seq_len - t, 1, 1, 1))
seq_np = np.concatenate([seq_np, pad], axis=0)
elif t > self.seq_len:
indices = np.linspace(0, t - 1, self.seq_len, dtype=int)
seq_np = seq_np[indices]
batch_ids.append(tid)
batch_tensors.append(torch.from_numpy(seq_np))
batch = torch.stack(batch_tensors).to(self.device) # (N, T, C, H, W)
with torch.no_grad():
if self.use_attention:
logits, attn = self.bigru.forward_with_attention(batch)
else:
logits = self.bigru(batch)
probs = F.softmax(logits, dim=-1)
preds = probs.argmax(dim=-1)
for i, tid in enumerate(batch_ids):
if tid in tracks:
tracks[tid]["severity_pred"] = int(preds[i])
tracks[tid]["severity_probs"] = probs[i].cpu().numpy().tolist()
# Build output list
for tid, track in tracks.items():
severity_idx = track.get("severity_pred")
if severity_idx is None:
# Heuristic fallback based on class + confidence
severity_idx = self._heuristic_severity(
track["class_id"], track["confidence"]
)
entry = {
"track_id": tid,
"bbox": track["bbox"],
"class_id": track["class_id"],
"class_name": track["class_name"],
"confidence": track["confidence"],
"severity_idx": severity_idx,
"severity_label": SEVERITY_LABELS[severity_idx],
"severity_probs": track.get("severity_probs"),
}
results.append(entry)
return results
@staticmethod
def _heuristic_severity(class_id: int, confidence: float) -> int:
"""
Map anomaly class + detection confidence to a severity index
when the BiGRU model is unavailable.
"""
base = SEVERITY_WEIGHTS.get(class_id, 0.5)
score = base * (0.6 + 0.4 * confidence) # scale by confidence
if score >= 0.85:
return 3 # Critical
elif score >= 0.65:
return 2 # Severe
elif score >= 0.40:
return 1 # Moderate
return 0 # Minor
# ------------------------------------------------------------------
# Annotation drawing
# ------------------------------------------------------------------
def _annotate(
self,
frame: np.ndarray,
detections: List[Dict[str, Any]],
) -> np.ndarray:
"""Draw bounding boxes with anomaly class + severity label."""
img = frame.copy()
for det in detections:
x1, y1, x2, y2 = det["bbox"]
cls_id = det["class_id"]
sev_idx = det["severity_idx"]
# Box colour from severity
box_color = SEVERITY_COLORS.get(sev_idx, (255, 255, 255))
# Labels
cls_label = f"{det['class_name']} {det['confidence']:.2f}"
sev_label = f"Sev: {det['severity_label']}"
tid_label = f"ID:{det['track_id']}"
# Draw box
cv2.rectangle(img, (x1, y1), (x2, y2), box_color, 2)
# Class label (above box)
(tw, th), _ = cv2.getTextSize(cls_label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(img, (x1, y1 - th - 8), (x1 + tw + 4, y1), box_color, -1)
cv2.putText(img, cls_label, (x1 + 2, y1 - 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# Severity label (below box)
combined = f"{sev_label} {tid_label}"
(tw2, th2), _ = cv2.getTextSize(combined, cv2.FONT_HERSHEY_SIMPLEX, 0.45, 1)
cv2.rectangle(img, (x1, y2), (x1 + tw2 + 4, y2 + th2 + 8), box_color, -1)
cv2.putText(img, combined, (x1 + 2, y2 + th2 + 4),
cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 1)
return img
# ------------------------------------------------------------------
# Video processing
# ------------------------------------------------------------------
def process_video(
self,
video_source: Union[str, int] = 0,
output_path: Optional[str] = "hybrid_output.mp4",
show: bool = True,
save_json: bool = True,
) -> Dict[str, Any]:
"""
End-to-end video processing.
Returns a summary dict with per-frame statistics.
"""
cap = cv2.VideoCapture(video_source)
if not cap.isOpened():
logger.error("Cannot open: %s", video_source)
return {}
fps_in = int(cap.get(cv2.CAP_PROP_FPS)) or 30
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
writer = None
if output_path:
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(output_path, fourcc, fps_in, (w, h))
all_results: List[Dict[str, Any]] = []
frame_times: List[float] = []
frame_idx = 0
logger.info("Processing %dx%d @ %d fps (%d frames)", w, h, fps_in, total)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
t0 = time.perf_counter()
detections, annotated = self.process_frame(frame)
dt = time.perf_counter() - t0
frame_times.append(dt)
# FPS overlay
live_fps = 1.0 / dt if dt > 0 else 0
cv2.putText(annotated, f"FPS: {live_fps:.1f}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 2)
cv2.putText(annotated, f"Detections: {len(detections)}", (10, 65),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
if writer:
writer.write(annotated)
if show:
cv2.imshow("Hybrid Road Anomaly Detection", annotated)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
# Collect results
all_results.append({
"frame": frame_idx,
"latency_ms": dt * 1000,
"detections": [
{k: v for k, v in d.items() if k != "severity_probs"}
for d in detections
],
})
frame_idx += 1
if frame_idx % 200 == 0:
avg_fps = 1.0 / np.mean(frame_times[-200:])
logger.info(" frame %d/%d avg FPS=%.1f", frame_idx, total, avg_fps)
cap.release()
if writer:
writer.release()
logger.info("Saved β†’ %s", output_path)
cv2.destroyAllWindows()
# Summary
summary = {
"total_frames": frame_idx,
"avg_latency_ms": float(np.mean(frame_times) * 1000),
"avg_fps": float(1.0 / np.mean(frame_times)) if frame_times else 0,
"total_detections": sum(len(r["detections"]) for r in all_results),
}
if save_json:
json_path = Path(output_path or "hybrid_output").with_suffix(".json")
with open(json_path, "w") as f:
json.dump({"summary": summary, "frames": all_results}, f, indent=2)
logger.info("JSON results β†’ %s", json_path)
logger.info(
"Done – %d frames, avg %.1f ms/frame (%.1f FPS), %d total detections",
summary["total_frames"],
summary["avg_latency_ms"],
summary["avg_fps"],
summary["total_detections"],
)
return summary
# ------------------------------------------------------------------
# Single image (no temporal context β†’ heuristic severity)
# ------------------------------------------------------------------
def process_image(
self,
image_path: str,
output_path: Optional[str] = None,
show: bool = False,
) -> List[Dict[str, Any]]:
"""
Detect anomalies in a single image.
Severity is heuristic (no temporal data).
"""
frame = cv2.imread(image_path)
if frame is None:
logger.error("Cannot read image: %s", image_path)
return []
detections = self.yolo.detect(frame)
enriched = []
for det in detections:
sev = self._heuristic_severity(det["class_id"], det["confidence"])
enriched.append({
"track_id": -1,
"bbox": det["bbox"],
"class_id": det["class_id"],
"class_name": det["class_name"],
"confidence": det["confidence"],
"severity_idx": sev,
"severity_label": SEVERITY_LABELS[sev],
"severity_probs": None,
})
annotated = self._annotate(frame, enriched)
if output_path:
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(output_path, annotated)
logger.info("Saved β†’ %s", output_path)
if show:
cv2.imshow("Hybrid Detection", annotated)
cv2.waitKey(0)
cv2.destroyAllWindows()
return enriched
# ------------------------------------------------------------------
# Batch image directory
# ------------------------------------------------------------------
def process_directory(
self,
image_dir: str,
output_dir: str = "hybrid_results",
) -> Dict[str, Any]:
"""Process all images in a directory."""
img_dir = Path(image_dir)
out_dir = Path(output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
exts = {".jpg", ".jpeg", ".png", ".bmp"}
images = sorted(f for f in img_dir.iterdir() if f.suffix.lower() in exts)
all_dets = []
for i, img_path in enumerate(images, 1):
out_path = out_dir / f"result_{i:04d}.jpg"
dets = self.process_image(str(img_path), str(out_path))
all_dets.append({"image": str(img_path), "detections": dets})
logger.info(" [%d/%d] %s β†’ %d detections",
i, len(images), img_path.name, len(dets))
summary = {
"images_processed": len(images),
"total_detections": sum(len(d["detections"]) for d in all_dets),
}
# Save JSON
with open(out_dir / "results.json", "w") as f:
json.dump({"summary": summary, "images": all_dets}, f, indent=2, default=str)
logger.info("Batch complete – %d images, %d detections",
summary["images_processed"], summary["total_detections"])
return summary
# ------------------------------------------------------------------
def __repr__(self) -> str:
return (
f"HybridRoadAnomalySystem(yolo={self.yolo}, "
f"bigru={'loaded' if self.bigru else 'heuristic'}, "
f"device={self.device})"
)
# ═══════════════════════════════════════════════════════════════════════════
# CLI
# ═══════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Hybrid YOLOv11 + CNN-BiGRU Road Anomaly Detection",
)
parser.add_argument("--yolo", default="best.pt", help="YOLO weights (.pt / .engine)")
parser.add_argument("--bigru", default=None, help="BiGRU weights (.pth)")
parser.add_argument("--source", required=True,
help="Image / video / directory / 'camera'")
parser.add_argument("--output", "-o", default=None, help="Output path")
parser.add_argument("--conf", type=float, default=0.02)
parser.add_argument("--imgsz", type=int, default=416)
parser.add_argument("--no-show", action="store_true")
parser.add_argument("--attention", action="store_true",
help="Use attention-weighted BiGRU")
args = parser.parse_args()
system = HybridRoadAnomalySystem(
yolo_weights=args.yolo,
bigru_weights=args.bigru,
img_size=args.imgsz,
conf_threshold=args.conf,
use_attention=args.attention,
)
show = not args.no_show
src = args.source
if src.lower() == "camera":
system.process_video(video_source=0, show=show)
elif Path(src).is_dir():
system.process_directory(src, output_dir=args.output or "hybrid_results")
elif Path(src).suffix.lower() in (".mp4", ".avi", ".mov", ".mkv"):
system.process_video(video_source=src,
output_path=args.output or "hybrid_output.mp4",
show=show)
elif Path(src).suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp"):
system.process_image(src, output_path=args.output, show=show)
else:
logger.error("Unsupported source: %s", src)
sys.exit(1)