nishanth-saka commited on
Commit
9453af9
Β·
verified Β·
1 Parent(s): 7f7e0ab

Stage 3 Wrong-Direction Detection (#1)

Browse files

- Stage 3 Wrong-Direction Detection (79f5efc5d2b637e1b0a7ab2c7e46d5798b75e167)

Files changed (1) hide show
  1. app.py +161 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2, os, numpy as np, json, tempfile, time
3
+ from ultralytics import YOLO
4
+ from filterpy.kalman import KalmanFilter
5
+ from scipy.optimize import linear_sum_assignment
6
+
7
+ # ------------------------------------------------------------
8
+ # βš™οΈ Load YOLO and flow centers
9
+ # ------------------------------------------------------------
10
+ MODEL_PATH = "yolov8n.pt"
11
+ model = YOLO(MODEL_PATH)
12
+ VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorbike, bus, truck
13
+
14
+ def load_flow_centers(flow_json):
15
+ data = json.load(open(flow_json))
16
+ centers = np.array(data["flow_centers"])
17
+ centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
18
+ return centers
19
+
20
+ # ------------------------------------------------------------
21
+ # 🧩 Simple Kalman Tracker
22
+ # ------------------------------------------------------------
23
+ class Track:
24
+ def __init__(self, bbox, tid):
25
+ self.id = tid
26
+ self.kf = KalmanFilter(dim_x=4, dim_z=2)
27
+ self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
28
+ self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
29
+ self.kf.P *= 1000.0
30
+ self.kf.R *= 10.0
31
+ self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
32
+ self.trace = []
33
+
34
+ def centroid(self,b):
35
+ x1,y1,x2,y2=b
36
+ return [(x1+x2)/2,(y1+y2)/2]
37
+ def predict(self): self.kf.predict(); return self.kf.x[:2].reshape(2)
38
+ def update(self,b):
39
+ z=np.array(self.centroid(b)).reshape(2,1)
40
+ self.kf.update(z)
41
+ cx,cy=self.kf.x[:2].reshape(2)
42
+ self.trace.append((float(cx),float(cy)))
43
+ return (cx,cy)
44
+
45
+ # ------------------------------------------------------------
46
+ # 🚦 Wrong-Direction Analyzer
47
+ # ------------------------------------------------------------
48
+ def analyze_direction(trace, centers):
49
+ if len(trace)<3: return "NA",1.0
50
+ v = np.array(trace[-1]) - np.array(trace[-3]) # motion vector
51
+ if np.linalg.norm(v)<1e-6: return "NA",1.0
52
+ v = v / np.linalg.norm(v)
53
+ sims = np.dot(centers, v)
54
+ max_sim = np.max(sims)
55
+ if max_sim < 0: return "WRONG", float(max_sim)
56
+ return "OK", float(max_sim)
57
+
58
+ # ------------------------------------------------------------
59
+ # πŸŽ₯ Process Video
60
+ # ------------------------------------------------------------
61
+ def process_video(video_path, flow_json):
62
+ centers = load_flow_centers(flow_json)
63
+ cap = cv2.VideoCapture(video_path)
64
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25
65
+ w,h = int(cap.get(3)), int(cap.get(4))
66
+
67
+ tmp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
68
+ out = cv2.VideoWriter(tmp_out.name, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w,h))
69
+
70
+ tracks, next_id = [], 0
71
+ log = []
72
+
73
+ while True:
74
+ ret, frame = cap.read()
75
+ if not ret: break
76
+ results = model(frame, verbose=False)[0]
77
+ detections=[]
78
+ for box in results.boxes:
79
+ if int(box.cls) in VEHICLE_CLASSES and box.conf>0.3:
80
+ detections.append(box.xyxy[0].cpu().numpy())
81
+
82
+ # predict existing tracks
83
+ predicted = [trk.predict() for trk in tracks]
84
+ predicted = np.array(predicted) if len(predicted)>0 else np.empty((0,2))
85
+
86
+ assigned=set()
87
+ if len(predicted)>0 and len(detections)>0:
88
+ cost=np.zeros((len(predicted),len(detections)))
89
+ for i,p in enumerate(predicted):
90
+ for j,d in enumerate(detections):
91
+ cx,cy=((d[0]+d[2])/2,(d[1]+d[3])/2)
92
+ cost[i,j]=np.linalg.norm(p-np.array([cx,cy]))
93
+ r,c=linear_sum_assignment(cost)
94
+ for i,j in zip(r,c):
95
+ if cost[i,j]<80:
96
+ assigned.add(j)
97
+ tracks[i].update(detections[j])
98
+
99
+ # new tracks
100
+ for j,d in enumerate(detections):
101
+ if j not in assigned:
102
+ trk=Track(d,next_id); next_id+=1
103
+ trk.update(d)
104
+ tracks.append(trk)
105
+
106
+ # draw
107
+ for trk in tracks:
108
+ if len(trk.trace)<3: continue
109
+ status, sim = analyze_direction(trk.trace, centers)
110
+ x,y=map(int,trk.trace[-1])
111
+ color=(0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (255,255,255))
112
+ cv2.circle(frame,(x,y),4,color,-1)
113
+ cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
114
+ for i in range(1,len(trk.trace)):
115
+ cv2.line(frame,(int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
116
+ (int(trk.trace[i][0]),int(trk.trace[i][1])),color,1)
117
+ log.append({"id":trk.id,"status":status,"cos_sim":round(sim,3)})
118
+
119
+ out.write(frame)
120
+
121
+ cap.release(); out.release()
122
+ log_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
123
+ with open(log_path,"w") as f: json.dump(log,f,indent=2)
124
+ return tmp_out.name, log_path
125
+
126
+ # ------------------------------------------------------------
127
+ # πŸ–₯️ Gradio Interface
128
+ # ------------------------------------------------------------
129
+ def run_app(video, flow_file):
130
+ out_path, log_path = process_video(video, flow_file)
131
+ log_data = json.load(open(log_path))
132
+ summary = {"vehicles_analyzed": len(log_data)}
133
+ return out_path, log_data, summary
134
+
135
+ description_text = """
136
+ ### 🚦 Wrong-Direction Detection (Stage 3)
137
+ Uploads your traffic video and the **flow_stats.json** from Stage 2.
138
+ Outputs an annotated video with βœ… OK / 🚫 WRONG labels per vehicle, plus a JSON log.
139
+ """
140
+
141
+ example_vid = "10.mp4" if os.path.exists("10.mp4") else None
142
+ example_flow = "flow_stats.json" if os.path.exists("flow_stats.json") else None
143
+
144
+ demo = gr.Interface(
145
+ fn=run_app,
146
+ inputs=[
147
+ gr.Video(label="Upload Traffic Video (.mp4)"),
148
+ gr.File(label="Upload flow_stats.json (Stage 2 Output)")
149
+ ],
150
+ outputs=[
151
+ gr.Video(label="Violation Output Video"),
152
+ gr.JSON(label="Per-Vehicle Log"),
153
+ gr.JSON(label="Summary")
154
+ ],
155
+ title="πŸš— Wrong-Direction Detection – Stage 3",
156
+ description=description_text,
157
+ examples=[[example_vid, example_flow]] if example_vid and example_flow else None,
158
+ )
159
+
160
+ if __name__ == "__main__":
161
+ demo.launch()