nishanth-saka's picture
Stage 3 Wrong-Direction Detection (#1)
9453af9 verified
raw
history blame
6.17 kB
import gradio as gr
import cv2, os, numpy as np, json, tempfile, time
from ultralytics import YOLO
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment
# ------------------------------------------------------------
# βš™οΈ Load YOLO and flow centers
# ------------------------------------------------------------
MODEL_PATH = "yolov8n.pt"
model = YOLO(MODEL_PATH)
VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorbike, bus, truck
def load_flow_centers(flow_json):
data = json.load(open(flow_json))
centers = np.array(data["flow_centers"])
centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
return centers
# ------------------------------------------------------------
# 🧩 Simple Kalman Tracker
# ------------------------------------------------------------
class Track:
def __init__(self, bbox, tid):
self.id = tid
self.kf = KalmanFilter(dim_x=4, dim_z=2)
self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
self.kf.P *= 1000.0
self.kf.R *= 10.0
self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
self.trace = []
def centroid(self,b):
x1,y1,x2,y2=b
return [(x1+x2)/2,(y1+y2)/2]
def predict(self): self.kf.predict(); return self.kf.x[:2].reshape(2)
def update(self,b):
z=np.array(self.centroid(b)).reshape(2,1)
self.kf.update(z)
cx,cy=self.kf.x[:2].reshape(2)
self.trace.append((float(cx),float(cy)))
return (cx,cy)
# ------------------------------------------------------------
# 🚦 Wrong-Direction Analyzer
# ------------------------------------------------------------
def analyze_direction(trace, centers):
if len(trace)<3: return "NA",1.0
v = np.array(trace[-1]) - np.array(trace[-3]) # motion vector
if np.linalg.norm(v)<1e-6: return "NA",1.0
v = v / np.linalg.norm(v)
sims = np.dot(centers, v)
max_sim = np.max(sims)
if max_sim < 0: return "WRONG", float(max_sim)
return "OK", float(max_sim)
# ------------------------------------------------------------
# πŸŽ₯ Process Video
# ------------------------------------------------------------
def process_video(video_path, flow_json):
centers = load_flow_centers(flow_json)
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 25
w,h = int(cap.get(3)), int(cap.get(4))
tmp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
out = cv2.VideoWriter(tmp_out.name, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w,h))
tracks, next_id = [], 0
log = []
while True:
ret, frame = cap.read()
if not ret: break
results = model(frame, verbose=False)[0]
detections=[]
for box in results.boxes:
if int(box.cls) in VEHICLE_CLASSES and box.conf>0.3:
detections.append(box.xyxy[0].cpu().numpy())
# predict existing tracks
predicted = [trk.predict() for trk in tracks]
predicted = np.array(predicted) if len(predicted)>0 else np.empty((0,2))
assigned=set()
if len(predicted)>0 and len(detections)>0:
cost=np.zeros((len(predicted),len(detections)))
for i,p in enumerate(predicted):
for j,d in enumerate(detections):
cx,cy=((d[0]+d[2])/2,(d[1]+d[3])/2)
cost[i,j]=np.linalg.norm(p-np.array([cx,cy]))
r,c=linear_sum_assignment(cost)
for i,j in zip(r,c):
if cost[i,j]<80:
assigned.add(j)
tracks[i].update(detections[j])
# new tracks
for j,d in enumerate(detections):
if j not in assigned:
trk=Track(d,next_id); next_id+=1
trk.update(d)
tracks.append(trk)
# draw
for trk in tracks:
if len(trk.trace)<3: continue
status, sim = analyze_direction(trk.trace, centers)
x,y=map(int,trk.trace[-1])
color=(0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (255,255,255))
cv2.circle(frame,(x,y),4,color,-1)
cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
for i in range(1,len(trk.trace)):
cv2.line(frame,(int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
(int(trk.trace[i][0]),int(trk.trace[i][1])),color,1)
log.append({"id":trk.id,"status":status,"cos_sim":round(sim,3)})
out.write(frame)
cap.release(); out.release()
log_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
with open(log_path,"w") as f: json.dump(log,f,indent=2)
return tmp_out.name, log_path
# ------------------------------------------------------------
# πŸ–₯️ Gradio Interface
# ------------------------------------------------------------
def run_app(video, flow_file):
out_path, log_path = process_video(video, flow_file)
log_data = json.load(open(log_path))
summary = {"vehicles_analyzed": len(log_data)}
return out_path, log_data, summary
description_text = """
### 🚦 Wrong-Direction Detection (Stage 3)
Uploads your traffic video and the **flow_stats.json** from Stage 2.
Outputs an annotated video with βœ… OK / 🚫 WRONG labels per vehicle, plus a JSON log.
"""
example_vid = "10.mp4" if os.path.exists("10.mp4") else None
example_flow = "flow_stats.json" if os.path.exists("flow_stats.json") else None
demo = gr.Interface(
fn=run_app,
inputs=[
gr.Video(label="Upload Traffic Video (.mp4)"),
gr.File(label="Upload flow_stats.json (Stage 2 Output)")
],
outputs=[
gr.Video(label="Violation Output Video"),
gr.JSON(label="Per-Vehicle Log"),
gr.JSON(label="Summary")
],
title="πŸš— Wrong-Direction Detection – Stage 3",
description=description_text,
examples=[[example_vid, example_flow]] if example_vid and example_flow else None,
)
if __name__ == "__main__":
demo.launch()