nishanth-saka's picture
update
6eb9596 verified
raw
history blame
6.56 kB
import os, cv2, json, time, tempfile, numpy as np, gradio as gr
from ultralytics import YOLO
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment
# ------------------------------------------------------------
# ๐Ÿ”ง Safe-load fix for PyTorch 2.6
# ------------------------------------------------------------
import torch
import ultralytics.nn.tasks as ultralytics_tasks
torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel])
# ------------------------------------------------------------
# โš™๏ธ Model + constants
# ------------------------------------------------------------
MODEL_PATH = "yolov8n.pt"
model = YOLO(MODEL_PATH)
VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorbike, bus, truck
# ------------------------------------------------------------
# ๐Ÿงฉ Kalman Tracker
# ------------------------------------------------------------
class Track:
def __init__(self, bbox, tid):
self.id = tid
self.kf = KalmanFilter(dim_x=4, dim_z=2)
self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
self.kf.P *= 1000.0
self.kf.R *= 10.0
self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
self.trace = []
def centroid(self, b):
x1,y1,x2,y2=b
return [(x1+x2)/2,(y1+y2)/2]
def predict(self): self.kf.predict(); return self.kf.x[:2].reshape(2)
def update(self,b):
z=np.array(self.centroid(b)).reshape(2,1)
self.kf.update(z)
cx,cy=self.kf.x[:2].reshape(2)
self.trace.append((float(cx),float(cy)))
return (cx,cy)
# ------------------------------------------------------------
# ๐Ÿงฎ Direction check
# ------------------------------------------------------------
def analyze_direction(trace, centers):
if len(trace) < 3:
return "NA", 1.0
v = np.array(trace[-1]) - np.array(trace[-3])
if np.linalg.norm(v) < 1e-6:
return "NA", 1.0
v = v / np.linalg.norm(v)
sims = np.dot(centers, v)
max_sim = np.max(sims)
if max_sim < 0:
return "WRONG", float(max_sim)
return "OK", float(max_sim)
# ------------------------------------------------------------
# ๐Ÿšฆ Main Processing
# ------------------------------------------------------------
def load_flow_centers(flow_json):
data = json.load(open(flow_json))
centers = np.array(data["flow_centers"])
centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
return centers
def process_video(video_path, flow_json):
centers = load_flow_centers(flow_json)
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 25
w,h = int(cap.get(3)), int(cap.get(4))
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(out_path, fourcc, fps, (w,h))
tracks, next_id, log = [], 0, []
while True:
ret, frame = cap.read()
if not ret:
break
results = model(frame, verbose=False)[0]
detections=[]
for box in results.boxes:
if int(box.cls) in VEHICLE_CLASSES and box.conf>0.3:
detections.append(box.xyxy[0].cpu().numpy())
# Predict existing
predicted=[t.predict() for t in tracks]
predicted=np.array(predicted) if len(predicted)>0 else np.empty((0,2))
# Assign detections
assigned=set()
if len(predicted)>0 and len(detections)>0:
cost=np.zeros((len(predicted),len(detections)))
for i,p in enumerate(predicted):
for j,d in enumerate(detections):
cx,cy=((d[0]+d[2])/2,(d[1]+d[3])/2)
cost[i,j]=np.linalg.norm(p-np.array([cx,cy]))
r,c=linear_sum_assignment(cost)
for i,j in zip(r,c):
if cost[i,j]<80:
assigned.add(j)
tracks[i].update(detections[j])
# New tracks
for j,d in enumerate(detections):
if j not in assigned:
t=Track(d,next_id); next_id+=1
t.update(d); tracks.append(t)
# Draw results
for trk in tracks:
if len(trk.trace)<3: continue
status, sim = analyze_direction(trk.trace, centers)
x,y=map(int,trk.trace[-1])
color=(0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (200,200,200))
cv2.circle(frame,(x,y),4,color,-1)
cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),
cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
for i in range(1,len(trk.trace)):
cv2.line(frame,(int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
(int(trk.trace[i][0]),int(trk.trace[i][1])),color,1)
log.append({"id":trk.id,"status":status,"cos_sim":round(sim,3)})
out.write(frame)
cap.release(); out.release()
log_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
with open(log_path,"w") as f: json.dump(log,f,indent=2)
return out_path, log_path
# ------------------------------------------------------------
# ๐Ÿ–ฅ๏ธ Gradio UI
# ------------------------------------------------------------
def run_app(video, flow_file):
out_path, log_path = process_video(video, flow_file)
log_data = json.load(open(log_path))
summary = {"vehicles_analyzed": len(log_data)}
return out_path, log_data, summary
description_text = """
### ๐Ÿšฆ Wrong-Direction Detection (Stage 3)
Uploads your traffic video and the **flow_stats.json** from Stage 2.
Outputs an annotated video with โœ… OK / ๐Ÿšซ WRONG labels and a JSON log.
"""
example_vid = "10.mp4" if os.path.exists("10.mp4") else None
example_flow = "flow_stats.json" if os.path.exists("flow_stats.json") else None
demo = gr.Interface(
fn=run_app,
inputs=[gr.Video(label="Upload Traffic Video (.mp4)"),
gr.File(label="Upload flow_stats.json (Stage 2 Output)")],
outputs=[gr.Video(label="Violation Output Video"),
gr.JSON(label="Per-Vehicle Log"),
gr.JSON(label="Summary")],
title="๐Ÿš— Wrong-Direction Detection โ€“ Stage 3",
description=description_text,
examples=None, # disables example caching
)
if __name__ == "__main__":
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)