File size: 6,556 Bytes
6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 9453af9 6eb9596 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | import os, cv2, json, time, tempfile, numpy as np, gradio as gr
from ultralytics import YOLO
from filterpy.kalman import KalmanFilter
from scipy.optimize import linear_sum_assignment
# ------------------------------------------------------------
# ๐ง Safe-load fix for PyTorch 2.6
# ------------------------------------------------------------
import torch
import ultralytics.nn.tasks as ultralytics_tasks
torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel])
# ------------------------------------------------------------
# โ๏ธ Model + constants
# ------------------------------------------------------------
MODEL_PATH = "yolov8n.pt"
model = YOLO(MODEL_PATH)
VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorbike, bus, truck
# ------------------------------------------------------------
# ๐งฉ Kalman Tracker
# ------------------------------------------------------------
class Track:
def __init__(self, bbox, tid):
self.id = tid
self.kf = KalmanFilter(dim_x=4, dim_z=2)
self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
self.kf.P *= 1000.0
self.kf.R *= 10.0
self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
self.trace = []
def centroid(self, b):
x1,y1,x2,y2=b
return [(x1+x2)/2,(y1+y2)/2]
def predict(self): self.kf.predict(); return self.kf.x[:2].reshape(2)
def update(self,b):
z=np.array(self.centroid(b)).reshape(2,1)
self.kf.update(z)
cx,cy=self.kf.x[:2].reshape(2)
self.trace.append((float(cx),float(cy)))
return (cx,cy)
# ------------------------------------------------------------
# ๐งฎ Direction check
# ------------------------------------------------------------
def analyze_direction(trace, centers):
if len(trace) < 3:
return "NA", 1.0
v = np.array(trace[-1]) - np.array(trace[-3])
if np.linalg.norm(v) < 1e-6:
return "NA", 1.0
v = v / np.linalg.norm(v)
sims = np.dot(centers, v)
max_sim = np.max(sims)
if max_sim < 0:
return "WRONG", float(max_sim)
return "OK", float(max_sim)
# ------------------------------------------------------------
# ๐ฆ Main Processing
# ------------------------------------------------------------
def load_flow_centers(flow_json):
data = json.load(open(flow_json))
centers = np.array(data["flow_centers"])
centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
return centers
def process_video(video_path, flow_json):
centers = load_flow_centers(flow_json)
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 25
w,h = int(cap.get(3)), int(cap.get(4))
out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(out_path, fourcc, fps, (w,h))
tracks, next_id, log = [], 0, []
while True:
ret, frame = cap.read()
if not ret:
break
results = model(frame, verbose=False)[0]
detections=[]
for box in results.boxes:
if int(box.cls) in VEHICLE_CLASSES and box.conf>0.3:
detections.append(box.xyxy[0].cpu().numpy())
# Predict existing
predicted=[t.predict() for t in tracks]
predicted=np.array(predicted) if len(predicted)>0 else np.empty((0,2))
# Assign detections
assigned=set()
if len(predicted)>0 and len(detections)>0:
cost=np.zeros((len(predicted),len(detections)))
for i,p in enumerate(predicted):
for j,d in enumerate(detections):
cx,cy=((d[0]+d[2])/2,(d[1]+d[3])/2)
cost[i,j]=np.linalg.norm(p-np.array([cx,cy]))
r,c=linear_sum_assignment(cost)
for i,j in zip(r,c):
if cost[i,j]<80:
assigned.add(j)
tracks[i].update(detections[j])
# New tracks
for j,d in enumerate(detections):
if j not in assigned:
t=Track(d,next_id); next_id+=1
t.update(d); tracks.append(t)
# Draw results
for trk in tracks:
if len(trk.trace)<3: continue
status, sim = analyze_direction(trk.trace, centers)
x,y=map(int,trk.trace[-1])
color=(0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (200,200,200))
cv2.circle(frame,(x,y),4,color,-1)
cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),
cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
for i in range(1,len(trk.trace)):
cv2.line(frame,(int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
(int(trk.trace[i][0]),int(trk.trace[i][1])),color,1)
log.append({"id":trk.id,"status":status,"cos_sim":round(sim,3)})
out.write(frame)
cap.release(); out.release()
log_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
with open(log_path,"w") as f: json.dump(log,f,indent=2)
return out_path, log_path
# ------------------------------------------------------------
# ๐ฅ๏ธ Gradio UI
# ------------------------------------------------------------
def run_app(video, flow_file):
out_path, log_path = process_video(video, flow_file)
log_data = json.load(open(log_path))
summary = {"vehicles_analyzed": len(log_data)}
return out_path, log_data, summary
description_text = """
### ๐ฆ Wrong-Direction Detection (Stage 3)
Uploads your traffic video and the **flow_stats.json** from Stage 2.
Outputs an annotated video with โ
OK / ๐ซ WRONG labels and a JSON log.
"""
example_vid = "10.mp4" if os.path.exists("10.mp4") else None
example_flow = "flow_stats.json" if os.path.exists("flow_stats.json") else None
demo = gr.Interface(
fn=run_app,
inputs=[gr.Video(label="Upload Traffic Video (.mp4)"),
gr.File(label="Upload flow_stats.json (Stage 2 Output)")],
outputs=[gr.Video(label="Violation Output Video"),
gr.JSON(label="Per-Vehicle Log"),
gr.JSON(label="Summary")],
title="๐ Wrong-Direction Detection โ Stage 3",
description=description_text,
examples=None, # disables example caching
)
if __name__ == "__main__":
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)
|