nishanth-saka's picture
revert (#19)
2315f61 verified
# ============================================================
# 🚦 Stage 3 — Wrong Direction Detection (Stable + Confidence + Hysteresis + Filter)
# ============================================================
import os, cv2, json, tempfile, numpy as np, gradio as gr
from ultralytics import YOLO
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment
# ------------------------------------------------------------
# 🧠 Safe-load fix for PyTorch 2.6
# ------------------------------------------------------------
import torch, ultralytics.nn.tasks as ultralytics_tasks
torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel])
MODEL_PATH = "yolov8n.pt"
model = YOLO(MODEL_PATH)
VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
# ============================================================
# 🧩 Kalman-based Tracker
# ============================================================
class Track:
def __init__(self, bbox, tid):
self.id = tid
self.kf = KalmanFilter(dim_x=4, dim_z=2)
self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
self.kf.P *= 10
self.kf.R *= 1
self.kf.x[:2] = np.array(bbox[:2]).reshape(2,1)
self.history = []
self.frames_seen = 0
self.status = "OK"
self.status_history = []
self.confidence = 1.0
self.ema_sim = 1.0
def update(self, bbox):
self.kf.predict()
self.kf.update(np.array(bbox[:2]))
x, y = self.kf.x[:2].reshape(-1)
self.history.append([x, y])
if len(self.history) > 30:
self.history.pop(0)
self.frames_seen += 1
return [x, y]
def stable_status(self, new_status, new_conf, window=10, agree_ratio=0.6):
"""Debounce flicker using recent window consensus."""
self.status_history.append(new_status)
if len(self.status_history) > window:
self.status_history.pop(0)
if self.status_history.count(new_status) >= int(agree_ratio * len(self.status_history)):
self.status = new_status
self.confidence = new_conf
return self.status, self.confidence
# ============================================================
# ⚙️ Utility Functions
# ============================================================
def compute_cosine_similarity(v1, v2):
v1 = v1 / (np.linalg.norm(v1) + 1e-6)
v2 = v2 / (np.linalg.norm(v2) + 1e-6)
return np.dot(v1, v2)
def smooth_direction(points, window=5):
"""Compute smoothed motion vector using last N points"""
if len(points) < window + 1:
return None
diffs = np.diff(points[-window:], axis=0)
avg_vec = np.mean(diffs, axis=0)
if np.linalg.norm(avg_vec) < 1:
return None
return avg_vec
# ============================================================
# 🧭 Wrong-Direction Detection Core
# ============================================================
def process_video(video_file, stage2_json, show_only_wrong=False, conf_threshold=0.0):
data = json.load(open(stage2_json))
lane_flows = np.array(data.get("flow_centers", [[1,0]]))
drive_zone = np.array(data.get("drive_zone", []))
entry_zones = [np.array(z) for z in data.get("entry_zones", [])]
cap = cv2.VideoCapture(video_file)
fps = int(cap.get(cv2.CAP_PROP_FPS))
w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
tracks, next_id = {}, 0
DELAY_FRAMES = 8
MIN_FLOW_SPEED = 1.2
HYST_OK = 0.55
HYST_WRONG = 0.45
ALPHA = 0.6 # exponential smoothing weight
while True:
ret, frame = cap.read()
if not ret:
break
results = model(frame)[0]
dets = []
for box in results.boxes:
cls = int(box.cls[0])
if cls in VEHICLE_CLASSES:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
dets.append([cx, cy])
dets = np.array(dets)
# --- Tracker update ---
assigned = set()
if len(dets) > 0 and len(tracks) > 0:
existing = np.array([t.kf.x[:2].reshape(-1) for t in tracks.values()])
dists = np.linalg.norm(existing[:, None, :] - dets[None, :, :], axis=2)
row_idx, col_idx = linear_sum_assignment(dists)
for r, c in zip(row_idx, col_idx):
if dists[r, c] < 50:
tid = list(tracks.keys())[r]
tracks[tid].update(dets[c])
assigned.add(c)
for i, d in enumerate(dets):
if i not in assigned:
tracks[next_id] = Track(d, next_id)
next_id += 1
# --- Draw & classify ---
for tid, trk in list(tracks.items()):
pos = trk.update(trk.kf.x[:2].reshape(-1))
pts = np.array(trk.history)
if len(pts) > 1:
for i in range(1, len(pts)):
cv2.line(frame, tuple(np.int32(pts[i-1])), tuple(np.int32(pts[i])), (0, 0, 255), 1)
motion = smooth_direction(pts)
if motion is None:
continue
if np.linalg.norm(motion) < MIN_FLOW_SPEED:
continue
sims = [compute_cosine_similarity(motion, f) for f in lane_flows]
best_sim = max(sims)
if trk.frames_seen > DELAY_FRAMES:
# Exponential moving average
trk.ema_sim = ALPHA * best_sim + (1 - ALPHA) * getattr(trk, "ema_sim", best_sim)
# Hysteresis classification
if trk.ema_sim >= HYST_OK:
new_status = "OK"
elif trk.ema_sim <= HYST_WRONG:
new_status = "WRONG"
else:
new_status = trk.status # hold previous label
trk.stable_status(new_status, new_conf=trk.ema_sim, window=10, agree_ratio=0.6)
# --- Filter by UI controls ---
show_label = True
if trk.confidence < conf_threshold:
show_label = False
if show_only_wrong and trk.status != "WRONG":
show_label = False
if show_label:
color = (0, 0, 255) if trk.status == "WRONG" else (0, 255, 0)
label = f"ID:{tid} {trk.status} ({trk.confidence:.2f})"
cv2.putText(frame, label, tuple(np.int32(pos)),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
out.write(frame)
cap.release()
out.release()
return out_path
# ============================================================
# 🎛️ Gradio Interface
# ============================================================
description = """
### 🚦 Stage 3 — Wrong Direction Detection (Stable + Confidence + Filter)
- ✅ Cosine similarity with exponential smoothing
- ✅ Hysteresis (OK≥0.55 / WRONG≤0.45) for stability
- ✅ 10-frame consensus voting (flicker-free)
- ✅ Confidence-based label filtering
- ✅ “Show Only Wrong” toggle
"""
demo = gr.Interface(
fn=process_video,
inputs=[
gr.File(label="Input Video"),
gr.File(label="Stage 2 Flow JSON"),
gr.Checkbox(label="Show ONLY Wrong Labels Overlay", value=False),
gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Confidence Level Filter (Show ≥ this value)")
],
outputs=gr.Video(label="Output Video"),
title="🚗 Stage 3 – Stable Wrong-Direction Detection (with Confidence Filter)",
description=description
)
if __name__ == "__main__":
demo.launch()