nishanth-saka's picture
✅ Angle-Aware ✅ Temporal Smoothing
f066f31 verified
raw
history blame
9.93 kB
import os, cv2, json, tempfile, zipfile, numpy as np, gradio as gr
from ultralytics import YOLO
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment
# ------------------------------------------------------------
# 🔧 Safe-load fix for PyTorch 2.6
# ------------------------------------------------------------
import torch, ultralytics.nn.tasks as ultralytics_tasks
torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel])
# ------------------------------------------------------------
# ⚙️ YOLO setup
# ------------------------------------------------------------
MODEL_PATH = "yolov8n.pt"
model = YOLO(MODEL_PATH)
VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
# ------------------------------------------------------------
# 🧭 Utility: rotate vector by angle
# ------------------------------------------------------------
def rotate_vec(v, theta_deg):
t = np.deg2rad(theta_deg)
R = np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]])
return R @ v
# ------------------------------------------------------------
# 🧩 Kalman tracker with temporal smoothing + entry gate flag
# ------------------------------------------------------------
class Track:
def __init__(self, bbox, tid):
self.id = tid
self.kf = KalmanFilter(dim_x=4, dim_z=2)
self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
self.kf.P *= 1000.0
self.kf.R *= 10.0
self.kf.x[:2,0] = np.array(self.centroid(bbox), dtype=float)
self.trace = []
self.status_hist = []
self.entry_flag = False
self.active = True
self.missed_frames = 0
def centroid(self, b):
x1, y1, x2, y2 = b
return [(x1+x2)/2, (y1+y2)/2]
def predict(self):
self.kf.predict()
self.missed_frames += 1
return self.kf.x[:2].reshape(2)
def update(self, b):
z = np.array(self.centroid(b)).reshape(2,1)
self.kf.update(z)
cx, cy = self.kf.x[:2].reshape(2)
self.trace.append((float(cx), float(cy)))
self.missed_frames = 0
return (cx, cy)
# ------------------------------------------------------------
# 🧮 Direction analyzer (angle + temporal aware)
# ------------------------------------------------------------
def analyze_direction(trace, centers, road_angle_deg, hist):
if len(trace) < 3:
return "NA", 1.0
# motion vector
v = np.array(trace[-1]) - np.array(trace[-3])
if np.linalg.norm(v) < 1e-6:
return "NA", 1.0
# rotate to road reference
v = rotate_vec(v / np.linalg.norm(v), -road_angle_deg)
# cosine similarity vs dominant centers
sims = np.dot(centers, v)
max_sim = np.max(sims)
# temporal averaging
hist.append(max_sim)
if len(hist) > 5:
hist.pop(0)
avg_sim = np.mean(hist)
if avg_sim < -0.2:
return "WRONG", float(avg_sim)
elif avg_sim > 0.2:
return "OK", float(avg_sim)
else:
return "NA", float(avg_sim)
# ------------------------------------------------------------
# 🗺️ Load Stage-2 flow stats (centers, angle, zones)
# ------------------------------------------------------------
def load_flow_stats(flow_json):
data = json.load(open(flow_json))
centers = np.array(data["flow_centers"])
centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
road_angle_deg = float(data.get("road_angle_deg", 0.0))
drive_zone = data.get("drive_zone", None)
entry_zones = data.get("entry_zones", [])
return centers, road_angle_deg, drive_zone, entry_zones
# ------------------------------------------------------------
# 🧾 Zone tests
# ------------------------------------------------------------
def inside_zone(pt, zone):
if zone is None: return True
return cv2.pointPolygonTest(np.array(zone, np.int32), pt, False) >= 0
def inside_any(pt, zones):
return any(cv2.pointPolygonTest(np.array(z, np.int32), pt, False) >= 0 for z in zones)
# ------------------------------------------------------------
# 🎥 Process video (angle + temporal + zone + entry-gating)
# ------------------------------------------------------------
def process_video(video_path, flow_json, show_only_wrong=False):
centers, road_angle, drive_zone, entry_zones = load_flow_stats(flow_json)
cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
w, h = int(cap.get(3)), int(cap.get(4))
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
tracks, next_id, log = [], 0, []
while True:
ret, frame = cap.read()
if not ret: break
results = model(frame, verbose=False)[0]
detections = []
for box in results.boxes:
if int(box.cls) in VEHICLE_CLASSES and float(box.conf) > 0.3:
detections.append(box.xyxy[0].cpu().numpy())
# predict existing
predicted = [t.predict() for t in tracks if t.active]
predicted = np.array(predicted) if predicted else np.empty((0,2))
# assign detections
assigned = set()
if len(predicted) > 0 and len(detections) > 0:
cost = np.zeros((len(predicted), len(detections)))
for i, p in enumerate(predicted):
for j, d in enumerate(detections):
cx, cy = ((d[0]+d[2])/2, (d[1]+d[3])/2)
cost[i,j] = np.linalg.norm(p - np.array([cx,cy]))
r, c = linear_sum_assignment(cost)
for i, j in zip(r, c):
if cost[i,j] < 80:
assigned.add(j)
tracks[i].update(detections[j])
# new tracks
for j, d in enumerate(detections):
if j not in assigned:
t = Track(d, next_id)
next_id += 1
t.update(d)
first_pt = tuple(map(int, t.trace[-1]))
# entry gating: mark if starts inside forbidden zone
if inside_any(first_pt, entry_zones):
t.entry_flag = True
tracks.append(t)
# clean up stale tracks
for t in tracks:
if t.missed_frames > 15:
t.active = False
# draw + log
for trk in tracks:
if not trk.active or len(trk.trace) < 3:
continue
x, y = map(int, trk.trace[-1])
if not inside_zone((x, y), drive_zone):
continue # skip outside drive zone
status, sim = analyze_direction(trk.trace, centers, road_angle, trk.status_hist)
if trk.entry_flag:
status = "WRONG_ENTRY"
if show_only_wrong and status not in ["WRONG", "WRONG_ENTRY"]:
continue
color = (0,255,0) if status=="OK" else \
(0,0,255) if status.startswith("WRONG") else (200,200,200)
cv2.circle(frame, (x,y), 4, color, -1)
cv2.putText(frame, f"ID:{trk.id} {status}", (x-20,y-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
for i in range(1, len(trk.trace)):
cv2.line(frame,
(int(trk.trace[i-1][0]), int(trk.trace[i-1][1])),
(int(trk.trace[i][0]), int(trk.trace[i][1])),
color, 1)
if len(trk.trace) > 5 and not any(e["id"]==trk.id for e in log):
log.append({
"id": trk.id,
"status": status,
"cos_sim": round(sim,3),
"entry_flag": trk.entry_flag
})
out.write(frame)
cap.release()
out.release()
# summary
unique_ids = {e["id"] for e in log}
summary = {
"vehicles_analyzed": len(unique_ids),
"wrong_count": sum(1 for e in log if e["status"].startswith("WRONG")),
"road_angle_deg": road_angle
}
# zip outputs
zip_path = tempfile.NamedTemporaryFile(suffix=".zip", delete=False).name
with zipfile.ZipFile(zip_path, "w") as zf:
zf.write(out_path, arcname="violation_output.mp4")
zf.writestr("per_vehicle_log.json", json.dumps(log, indent=2))
zf.writestr("summary.json", json.dumps(summary, indent=2))
return out_path, log, summary, zip_path
# ------------------------------------------------------------
# 🖥️ Gradio interface
# ------------------------------------------------------------
def run_app(video, flow_file, show_only_wrong):
vid, log_json, summary, zip_file = process_video(video, flow_file, show_only_wrong)
return vid, log_json, summary, zip_file
description_text = """
### 🚦 Wrong-Direction Detection (Stage 3 — Angle + Temporal + Zone + Entry-Aware)
Upload your traffic video and the **flow_stats.json** from Stage 2.
Stage 3 will respect the learned road angle, driving zones, and entry gates.
"""
demo = gr.Interface(
fn=run_app,
inputs=[
gr.Video(label="Upload Traffic Video (.mp4)"),
gr.File(label="Upload flow_stats.json (Stage 2 Output)"),
gr.Checkbox(label="Show Only Wrong Labels", value=False)
],
outputs=[
gr.Video(label="Violation Output Video"),
gr.JSON(label="Per-Vehicle Log"),
gr.JSON(label="Summary"),
gr.File(label="⬇️ Download All Outputs (ZIP)")
],
title="🚗 Wrong-Direction Detection – Stage 3 (Angle + Temporal + Zone + Entry)",
description=description_text,
)
demo.flagging_mode = "never"
demo.cache_examples = False
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)