# ============================================================ # 🚦 Stage 3 — Wrong Direction Detection (Stable + Confidence + Hysteresis + Filter) # ============================================================ import os, cv2, json, tempfile, numpy as np, gradio as gr from ultralytics import YOLO from filterpy.kalman import KalmanFilter from scipy.optimize import linear_sum_assignment # ------------------------------------------------------------ # 🧠 Safe-load fix for PyTorch 2.6 # ------------------------------------------------------------ import torch, ultralytics.nn.tasks as ultralytics_tasks torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel]) MODEL_PATH = "yolov8n.pt" model = YOLO(MODEL_PATH) VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck # ============================================================ # 🧩 Kalman-based Tracker # ============================================================ class Track: def __init__(self, bbox, tid): self.id = tid self.kf = KalmanFilter(dim_x=4, dim_z=2) self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]]) self.kf.H = np.array([[1,0,0,0],[0,1,0,0]]) self.kf.P *= 10 self.kf.R *= 1 self.kf.x[:2] = np.array(bbox[:2]).reshape(2,1) self.history = [] self.frames_seen = 0 self.status = "OK" self.status_history = [] self.confidence = 1.0 self.ema_sim = 1.0 def update(self, bbox): self.kf.predict() self.kf.update(np.array(bbox[:2])) x, y = self.kf.x[:2].reshape(-1) self.history.append([x, y]) if len(self.history) > 30: self.history.pop(0) self.frames_seen += 1 return [x, y] def stable_status(self, new_status, new_conf, window=10, agree_ratio=0.6): """Debounce flicker using recent window consensus.""" self.status_history.append(new_status) if len(self.status_history) > window: self.status_history.pop(0) if self.status_history.count(new_status) >= int(agree_ratio * len(self.status_history)): self.status = new_status self.confidence = new_conf return self.status, self.confidence # ============================================================ # ⚙️ Utility Functions # ============================================================ def compute_cosine_similarity(v1, v2): v1 = v1 / (np.linalg.norm(v1) + 1e-6) v2 = v2 / (np.linalg.norm(v2) + 1e-6) return np.dot(v1, v2) def smooth_direction(points, window=5): """Compute smoothed motion vector using last N points""" if len(points) < window + 1: return None diffs = np.diff(points[-window:], axis=0) avg_vec = np.mean(diffs, axis=0) if np.linalg.norm(avg_vec) < 1: return None return avg_vec # ============================================================ # 🧭 Wrong-Direction Detection Core # ============================================================ def process_video(video_file, stage2_json, show_only_wrong=False, conf_threshold=0.0): data = json.load(open(stage2_json)) lane_flows = np.array(data.get("flow_centers", [[1,0]])) drive_zone = np.array(data.get("drive_zone", [])) entry_zones = [np.array(z) for z in data.get("entry_zones", [])] cap = cv2.VideoCapture(video_file) fps = int(cap.get(cv2.CAP_PROP_FPS)) w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) tracks, next_id = {}, 0 DELAY_FRAMES = 8 MIN_FLOW_SPEED = 1.2 HYST_OK = 0.55 HYST_WRONG = 0.45 ALPHA = 0.6 # exponential smoothing weight while True: ret, frame = cap.read() if not ret: break results = model(frame)[0] dets = [] for box in results.boxes: cls = int(box.cls[0]) if cls in VEHICLE_CLASSES: x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() cx, cy = (x1 + x2) / 2, (y1 + y2) / 2 dets.append([cx, cy]) dets = np.array(dets) # --- Tracker update --- assigned = set() if len(dets) > 0 and len(tracks) > 0: existing = np.array([t.kf.x[:2].reshape(-1) for t in tracks.values()]) dists = np.linalg.norm(existing[:, None, :] - dets[None, :, :], axis=2) row_idx, col_idx = linear_sum_assignment(dists) for r, c in zip(row_idx, col_idx): if dists[r, c] < 50: tid = list(tracks.keys())[r] tracks[tid].update(dets[c]) assigned.add(c) for i, d in enumerate(dets): if i not in assigned: tracks[next_id] = Track(d, next_id) next_id += 1 # --- Draw & classify --- for tid, trk in list(tracks.items()): pos = trk.update(trk.kf.x[:2].reshape(-1)) pts = np.array(trk.history) if len(pts) > 1: for i in range(1, len(pts)): cv2.line(frame, tuple(np.int32(pts[i-1])), tuple(np.int32(pts[i])), (0, 0, 255), 1) motion = smooth_direction(pts) if motion is None: continue if np.linalg.norm(motion) < MIN_FLOW_SPEED: continue sims = [compute_cosine_similarity(motion, f) for f in lane_flows] best_sim = max(sims) if trk.frames_seen > DELAY_FRAMES: # Exponential moving average trk.ema_sim = ALPHA * best_sim + (1 - ALPHA) * getattr(trk, "ema_sim", best_sim) # Hysteresis classification if trk.ema_sim >= HYST_OK: new_status = "OK" elif trk.ema_sim <= HYST_WRONG: new_status = "WRONG" else: new_status = trk.status # hold previous label trk.stable_status(new_status, new_conf=trk.ema_sim, window=10, agree_ratio=0.6) # --- Filter by UI controls --- show_label = True if trk.confidence < conf_threshold: show_label = False if show_only_wrong and trk.status != "WRONG": show_label = False if show_label: color = (0, 0, 255) if trk.status == "WRONG" else (0, 255, 0) label = f"ID:{tid} {trk.status} ({trk.confidence:.2f})" cv2.putText(frame, label, tuple(np.int32(pos)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) out.write(frame) cap.release() out.release() return out_path # ============================================================ # 🎛️ Gradio Interface # ============================================================ description = """ ### 🚦 Stage 3 — Wrong Direction Detection (Stable + Confidence + Filter) - ✅ Cosine similarity with exponential smoothing - ✅ Hysteresis (OK≥0.55 / WRONG≤0.45) for stability - ✅ 10-frame consensus voting (flicker-free) - ✅ Confidence-based label filtering - ✅ “Show Only Wrong” toggle """ demo = gr.Interface( fn=process_video, inputs=[ gr.File(label="Input Video"), gr.File(label="Stage 2 Flow JSON"), gr.Checkbox(label="Show ONLY Wrong Labels Overlay", value=False), gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Confidence Level Filter (Show ≥ this value)") ], outputs=gr.Video(label="Output Video"), title="🚗 Stage 3 – Stable Wrong-Direction Detection (with Confidence Filter)", description=description ) if __name__ == "__main__": demo.launch()