|
|
import os, cv2, json, tempfile, zipfile, numpy as np, gradio as gr |
|
|
from ultralytics import YOLO |
|
|
from filterpy.kalman import KalmanFilter |
|
|
from scipy.optimize import linear_sum_assignment |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch, ultralytics.nn.tasks as ultralytics_tasks |
|
|
torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "yolov8n.pt" |
|
|
model = YOLO(MODEL_PATH) |
|
|
VEHICLE_CLASSES = [2, 3, 5, 7] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Track: |
|
|
def __init__(self, bbox, tid): |
|
|
self.id = tid |
|
|
self.kf = KalmanFilter(dim_x=4, dim_z=2) |
|
|
self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]]) |
|
|
self.kf.H = np.array([[1,0,0,0],[0,1,0,0]]) |
|
|
self.kf.P *= 1000.0 |
|
|
self.kf.R *= 10.0 |
|
|
self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1) |
|
|
self.trace = [] |
|
|
|
|
|
def centroid(self, b): |
|
|
x1, y1, x2, y2 = b |
|
|
return [(x1+x2)/2, (y1+y2)/2] |
|
|
|
|
|
def predict(self): |
|
|
self.kf.predict() |
|
|
return self.kf.x[:2].reshape(2) |
|
|
|
|
|
def update(self, b): |
|
|
z = np.array(self.centroid(b)).reshape(2,1) |
|
|
self.kf.update(z) |
|
|
cx, cy = self.kf.x[:2].reshape(2) |
|
|
self.trace.append((float(cx), float(cy))) |
|
|
return (cx, cy) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def analyze_direction(trace, centers): |
|
|
if len(trace) < 3: |
|
|
return "NA", 1.0 |
|
|
v = np.array(trace[-1]) - np.array(trace[-3]) |
|
|
if np.linalg.norm(v) < 1e-6: |
|
|
return "NA", 1.0 |
|
|
v = v / np.linalg.norm(v) |
|
|
sims = np.dot(centers, v) |
|
|
max_sim = np.max(sims) |
|
|
if max_sim < 0: |
|
|
return "WRONG", float(max_sim) |
|
|
return "OK", float(max_sim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_flow_centers(flow_json): |
|
|
data = json.load(open(flow_json)) |
|
|
centers = np.array(data["flow_centers"]) |
|
|
centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6) |
|
|
return centers |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_video(video_path, flow_json, show_only_wrong=False): |
|
|
centers = load_flow_centers(flow_json) |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 25 |
|
|
w, h = int(cap.get(3)), int(cap.get(4)) |
|
|
|
|
|
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name |
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
|
out = cv2.VideoWriter(out_path, fourcc, fps, (w, h)) |
|
|
|
|
|
tracks, next_id, log = [], 0, [] |
|
|
|
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
results = model(frame, verbose=False)[0] |
|
|
detections = [] |
|
|
for box in results.boxes: |
|
|
if int(box.cls) in VEHICLE_CLASSES and box.conf > 0.3: |
|
|
detections.append(box.xyxy[0].cpu().numpy()) |
|
|
|
|
|
|
|
|
predicted = [t.predict() for t in tracks] |
|
|
predicted = np.array(predicted) if len(predicted) > 0 else np.empty((0,2)) |
|
|
|
|
|
|
|
|
assigned = set() |
|
|
if len(predicted) > 0 and len(detections) > 0: |
|
|
cost = np.zeros((len(predicted), len(detections))) |
|
|
for i, p in enumerate(predicted): |
|
|
for j, d in enumerate(detections): |
|
|
cx, cy = ((d[0]+d[2])/2, (d[1]+d[3])/2) |
|
|
cost[i,j] = np.linalg.norm(p - np.array([cx,cy])) |
|
|
r, c = linear_sum_assignment(cost) |
|
|
for i, j in zip(r, c): |
|
|
if cost[i,j] < 80: |
|
|
assigned.add(j) |
|
|
tracks[i].update(detections[j]) |
|
|
|
|
|
|
|
|
for j, d in enumerate(detections): |
|
|
if j not in assigned: |
|
|
t = Track(d, next_id) |
|
|
next_id += 1 |
|
|
t.update(d) |
|
|
tracks.append(t) |
|
|
|
|
|
|
|
|
for trk in tracks: |
|
|
if len(trk.trace) < 3: |
|
|
continue |
|
|
status, sim = analyze_direction(trk.trace, centers) |
|
|
|
|
|
|
|
|
if show_only_wrong and status != "WRONG": |
|
|
continue |
|
|
|
|
|
x, y = map(int, trk.trace[-1]) |
|
|
color = (0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (200,200,200)) |
|
|
cv2.circle(frame,(x,y),4,color,-1) |
|
|
cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10), |
|
|
cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1) |
|
|
for i in range(1,len(trk.trace)): |
|
|
cv2.line(frame, |
|
|
(int(trk.trace[i-1][0]),int(trk.trace[i-1][1])), |
|
|
(int(trk.trace[i][0]),int(trk.trace[i][1])), |
|
|
color,1) |
|
|
|
|
|
|
|
|
if len(trk.trace) > 5 and not any(entry["id"] == trk.id for entry in log): |
|
|
log.append({"id": trk.id, "status": status, "cos_sim": round(sim,3)}) |
|
|
|
|
|
out.write(frame) |
|
|
|
|
|
cap.release() |
|
|
out.release() |
|
|
|
|
|
|
|
|
unique_ids = {entry["id"] for entry in log} |
|
|
summary = {"vehicles_analyzed": len(unique_ids)} |
|
|
|
|
|
|
|
|
zip_path = tempfile.NamedTemporaryFile(suffix=".zip", delete=False).name |
|
|
with zipfile.ZipFile(zip_path, "w") as zf: |
|
|
zf.write(out_path, arcname="violation_output.mp4") |
|
|
zf.writestr("per_vehicle_log.json", json.dumps(log, indent=2)) |
|
|
zf.writestr("summary.json", json.dumps(summary, indent=2)) |
|
|
|
|
|
return out_path, log, summary, zip_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_app(video, flow_file, show_only_wrong): |
|
|
vid, log_json, summary, zip_file = process_video(video, flow_file, show_only_wrong) |
|
|
return vid, log_json, summary, zip_file |
|
|
|
|
|
|
|
|
description_text = """ |
|
|
### 🚦 Wrong-Direction Detection (Stage 3) |
|
|
Upload your traffic video and the **flow_stats.json** from Stage 2. |
|
|
You can toggle whether to display all detections or only WRONG-direction vehicles. |
|
|
""" |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=run_app, |
|
|
inputs=[ |
|
|
gr.Video(label="Upload Traffic Video (.mp4)"), |
|
|
gr.File(label="Upload flow_stats.json (Stage 2 Output)"), |
|
|
gr.Checkbox(label="Show Only Wrong Labels", value=False) |
|
|
], |
|
|
outputs=[ |
|
|
gr.Video(label="Violation Output Video"), |
|
|
gr.JSON(label="Per-Vehicle Log"), |
|
|
gr.JSON(label="Summary"), |
|
|
gr.File(label="⬇️ Download All Outputs (ZIP)") |
|
|
], |
|
|
title="🚗 Wrong-Direction Detection – Stage 3 (Toggle + ZIP)", |
|
|
description=description_text, |
|
|
examples=None, |
|
|
) |
|
|
|
|
|
|
|
|
demo.flagging_mode = "never" |
|
|
demo.cache_examples = False |
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False) |
|
|
|