Files changed (1) hide show
  1. app.py +127 -223
app.py CHANGED
@@ -1,269 +1,173 @@
1
- import os, cv2, json, tempfile, zipfile, numpy as np, gradio as gr
 
 
 
 
2
  from ultralytics import YOLO
3
  from filterpy.kalman import KalmanFilter
4
  from scipy.optimize import linear_sum_assignment
5
 
6
  # ------------------------------------------------------------
7
- # 🔧 Safe-load fix for PyTorch 2.6
8
  # ------------------------------------------------------------
9
  import torch, ultralytics.nn.tasks as ultralytics_tasks
10
  torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel])
11
 
12
- # ------------------------------------------------------------
13
- # ⚙️ YOLO setup
14
- # ------------------------------------------------------------
15
  MODEL_PATH = "yolov8n.pt"
16
  model = YOLO(MODEL_PATH)
17
  VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
18
 
19
- # ------------------------------------------------------------
20
- # 🧭 Utility: rotate vector by angle
21
- # ------------------------------------------------------------
22
- def rotate_vec(v, theta_deg):
23
- t = np.deg2rad(theta_deg)
24
- R = np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]])
25
- return R @ v
26
-
27
- # ------------------------------------------------------------
28
- # 🧩 Kalman tracker with temporal smoothing + entry gate flag
29
- # ------------------------------------------------------------
30
  class Track:
31
  def __init__(self, bbox, tid):
32
  self.id = tid
33
  self.kf = KalmanFilter(dim_x=4, dim_z=2)
34
  self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
35
  self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
36
- self.kf.P *= 1000.0
37
- self.kf.R *= 10.0
38
- self.kf.x[:2,0] = np.array(self.centroid(bbox), dtype=float)
39
- self.trace = []
40
- self.status_hist = []
41
- self.entry_flag = False
42
- self.active = True
43
- self.missed_frames = 0
44
-
45
- def centroid(self, b):
46
- x1, y1, x2, y2 = b
47
- return [(x1+x2)/2, (y1+y2)/2]
48
-
49
- def predict(self):
50
  self.kf.predict()
51
- self.missed_frames += 1
52
- return self.kf.x[:2].reshape(2)
53
-
54
- def update(self, b):
55
- z = np.array(self.centroid(b)).reshape(2,1)
56
- self.kf.update(z)
57
- cx, cy = self.kf.x[:2].reshape(2)
58
- self.trace.append((float(cx), float(cy)))
59
- self.missed_frames = 0
60
- return (cx, cy)
61
-
62
- # ------------------------------------------------------------
63
- # 🧮 Direction analyzer (angle + temporal aware)
64
- # ------------------------------------------------------------
65
- def analyze_direction(trace, centers, road_angle_deg, hist):
66
- if len(trace) < 3:
67
- return "NA", 1.0
68
-
69
- # motion vector
70
- v = np.array(trace[-1]) - np.array(trace[-3])
71
- if np.linalg.norm(v) < 1e-6:
72
- return "NA", 1.0
73
-
74
- # rotate to road reference
75
- v = rotate_vec(v / np.linalg.norm(v), -road_angle_deg)
76
-
77
- # cosine similarity vs dominant centers
78
- sims = np.dot(centers, v)
79
- max_sim = np.max(sims)
80
-
81
- # temporal averaging
82
- hist.append(max_sim)
83
- if len(hist) > 5:
84
- hist.pop(0)
85
- avg_sim = np.mean(hist)
86
-
87
- if avg_sim < -0.2:
88
- return "WRONG", float(avg_sim)
89
- elif avg_sim > 0.2:
90
- return "OK", float(avg_sim)
91
- else:
92
- return "NA", float(avg_sim)
93
-
94
- # ------------------------------------------------------------
95
- # 🗺️ Load Stage-2 flow stats (centers, angle, zones)
96
- # ------------------------------------------------------------
97
- def load_flow_stats(flow_json):
98
- data = json.load(open(flow_json))
99
- centers = np.array(data["flow_centers"])
100
- centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
101
- road_angle_deg = float(data.get("road_angle_deg", 0.0))
102
- drive_zone = data.get("drive_zone", None)
103
- entry_zones = data.get("entry_zones", [])
104
- return centers, road_angle_deg, drive_zone, entry_zones
105
-
106
- # ------------------------------------------------------------
107
- # 🧾 Zone tests
108
- # ------------------------------------------------------------
109
- def inside_zone(pt, zone):
110
- if zone is None: return True
111
- return cv2.pointPolygonTest(np.array(zone, np.int32), pt, False) >= 0
112
-
113
- def inside_any(pt, zones):
114
- return any(cv2.pointPolygonTest(np.array(z, np.int32), pt, False) >= 0 for z in zones)
115
-
116
- # ------------------------------------------------------------
117
- # 🎥 Process video (angle + temporal + zone + entry-gating)
118
- # ------------------------------------------------------------
119
- def process_video(video_path, flow_json, show_only_wrong=False):
120
- centers, road_angle, drive_zone, entry_zones = load_flow_stats(flow_json)
121
-
122
- cap = cv2.VideoCapture(video_path)
123
- fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
124
- w, h = int(cap.get(3)), int(cap.get(4))
125
-
126
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
127
- out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
128
 
129
- tracks, next_id, log = [], 0, []
 
 
 
130
 
131
  while True:
132
  ret, frame = cap.read()
133
- if not ret: break
 
134
 
135
- results = model(frame, verbose=False)[0]
136
- detections = []
137
  for box in results.boxes:
138
- if int(box.cls) in VEHICLE_CLASSES and float(box.conf) > 0.3:
139
- detections.append(box.xyxy[0].cpu().numpy())
140
-
141
- # predict existing
142
- predicted = [t.predict() for t in tracks if t.active]
143
- predicted = np.array(predicted) if predicted else np.empty((0,2))
144
-
145
- # assign detections
146
  assigned = set()
147
- if len(predicted) > 0 and len(detections) > 0:
148
- cost = np.zeros((len(predicted), len(detections)))
149
- for i, p in enumerate(predicted):
150
- for j, d in enumerate(detections):
151
- cx, cy = ((d[0]+d[2])/2, (d[1]+d[3])/2)
152
- cost[i,j] = np.linalg.norm(p - np.array([cx,cy]))
153
- r, c = linear_sum_assignment(cost)
154
- for i, j in zip(r, c):
155
- if cost[i,j] < 80:
156
- assigned.add(j)
157
- tracks[i].update(detections[j])
158
-
159
- # new tracks
160
- for j, d in enumerate(detections):
161
- if j not in assigned:
162
- t = Track(d, next_id)
163
  next_id += 1
164
- t.update(d)
165
- first_pt = tuple(map(int, t.trace[-1]))
166
- # entry gating: mark if starts inside forbidden zone
167
- if inside_any(first_pt, entry_zones):
168
- t.entry_flag = True
169
- tracks.append(t)
170
 
171
- # clean up stale tracks
172
- for t in tracks:
173
- if t.missed_frames > 15:
174
- t.active = False
175
-
176
- # draw + log
177
- for trk in tracks:
178
- if not trk.active or len(trk.trace) < 3:
 
 
 
179
  continue
180
-
181
- x, y = map(int, trk.trace[-1])
182
- if not inside_zone((x, y), drive_zone):
183
- continue # skip outside drive zone
184
-
185
- status, sim = analyze_direction(trk.trace, centers, road_angle, trk.status_hist)
186
- if trk.entry_flag:
187
- status = "WRONG_ENTRY"
188
-
189
- if show_only_wrong and status not in ["WRONG", "WRONG_ENTRY"]:
190
  continue
191
 
192
- color = (0,255,0) if status=="OK" else \
193
- (0,0,255) if status.startswith("WRONG") else (200,200,200)
194
-
195
- cv2.circle(frame, (x,y), 4, color, -1)
196
- cv2.putText(frame, f"ID:{trk.id} {status}", (x-20,y-10),
197
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
198
- for i in range(1, len(trk.trace)):
199
- cv2.line(frame,
200
- (int(trk.trace[i-1][0]), int(trk.trace[i-1][1])),
201
- (int(trk.trace[i][0]), int(trk.trace[i][1])),
202
- color, 1)
203
-
204
- if len(trk.trace) > 5 and not any(e["id"]==trk.id for e in log):
205
- log.append({
206
- "id": trk.id,
207
- "status": status,
208
- "cos_sim": round(sim,3),
209
- "entry_flag": trk.entry_flag
210
- })
211
 
212
  out.write(frame)
213
 
214
  cap.release()
215
  out.release()
216
-
217
- # summary
218
- unique_ids = {e["id"] for e in log}
219
- summary = {
220
- "vehicles_analyzed": len(unique_ids),
221
- "wrong_count": sum(1 for e in log if e["status"].startswith("WRONG")),
222
- "road_angle_deg": road_angle
223
- }
224
-
225
- # zip outputs
226
- zip_path = tempfile.NamedTemporaryFile(suffix=".zip", delete=False).name
227
- with zipfile.ZipFile(zip_path, "w") as zf:
228
- zf.write(out_path, arcname="violation_output.mp4")
229
- zf.writestr("per_vehicle_log.json", json.dumps(log, indent=2))
230
- zf.writestr("summary.json", json.dumps(summary, indent=2))
231
-
232
- return out_path, log, summary, zip_path
233
-
234
- # ------------------------------------------------------------
235
- # 🖥️ Gradio interface
236
- # ------------------------------------------------------------
237
- def run_app(video, flow_file, show_only_wrong):
238
- vid, log_json, summary, zip_file = process_video(video, flow_file, show_only_wrong)
239
- return vid, log_json, summary, zip_file
240
-
241
- description_text = """
242
- ### 🚦 Wrong-Direction Detection (Stage 3 — Angle + Temporal + Zone + Entry-Aware)
243
- Upload your traffic video and the **flow_stats.json** from Stage 2.
244
- Stage 3 will respect the learned road angle, driving zones, and entry gates.
245
  """
246
 
247
  demo = gr.Interface(
248
- fn=run_app,
249
  inputs=[
250
- gr.Video(label="Upload Traffic Video (.mp4)"),
251
- gr.File(label="Upload flow_stats.json (Stage 2 Output)"),
252
- gr.Checkbox(label="Show Only Wrong Labels", value=False)
253
- ],
254
- outputs=[
255
- gr.Video(label="Violation Output Video"),
256
- gr.JSON(label="Per-Vehicle Log"),
257
- gr.JSON(label="Summary"),
258
- gr.File(label="⬇️ Download All Outputs (ZIP)")
259
  ],
260
- title="🚗 Wrong-Direction Detection – Stage 3 (Angle + Temporal + Zone + Entry)",
261
- description=description_text,
 
262
  )
263
 
264
- demo.flagging_mode = "never"
265
- demo.cache_examples = False
266
- os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
267
-
268
  if __name__ == "__main__":
269
- demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)
 
1
+ # ============================================================
2
+ # 🚦 Stage 3 — Wrong Direction Detection (Improved)
3
+ # ============================================================
4
+
5
+ import os, cv2, json, tempfile, numpy as np, gradio as gr
6
  from ultralytics import YOLO
7
  from filterpy.kalman import KalmanFilter
8
  from scipy.optimize import linear_sum_assignment
9
 
10
  # ------------------------------------------------------------
11
+ # 🧠 Safe-load fix for PyTorch 2.6
12
  # ------------------------------------------------------------
13
  import torch, ultralytics.nn.tasks as ultralytics_tasks
14
  torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel])
15
 
 
 
 
16
  MODEL_PATH = "yolov8n.pt"
17
  model = YOLO(MODEL_PATH)
18
  VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
19
 
20
+ # ============================================================
21
+ # 🧩 Kalman-based Tracker
22
+ # ============================================================
 
 
 
 
 
 
 
 
23
  class Track:
24
  def __init__(self, bbox, tid):
25
  self.id = tid
26
  self.kf = KalmanFilter(dim_x=4, dim_z=2)
27
  self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
28
  self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
29
+ self.kf.P *= 10
30
+ self.kf.R *= 1
31
+ self.kf.x[:2] = np.array(bbox[:2]).reshape(2,1)
32
+ self.history = []
33
+ self.frames_seen = 0
34
+ self.status = "OK"
35
+
36
+ def update(self, bbox):
 
 
 
 
 
 
37
  self.kf.predict()
38
+ self.kf.update(np.array(bbox[:2]))
39
+ x, y = self.kf.x[:2].reshape(-1)
40
+ self.history.append([x, y])
41
+ if len(self.history) > 30:
42
+ self.history.pop(0)
43
+ self.frames_seen += 1
44
+ return [x, y]
45
+
46
+ # ============================================================
47
+ # ⚙️ Utilities
48
+ # ============================================================
49
+ def compute_cosine_similarity(v1, v2):
50
+ v1 = v1 / (np.linalg.norm(v1) + 1e-6)
51
+ v2 = v2 / (np.linalg.norm(v2) + 1e-6)
52
+ return np.dot(v1, v2)
53
+
54
+ def smooth_direction(points, window=5):
55
+ """Compute smoothed motion vector using last N points"""
56
+ if len(points) < window + 1:
57
+ return None
58
+ diffs = np.diff(points[-window:], axis=0)
59
+ avg_vec = np.mean(diffs, axis=0)
60
+ if np.linalg.norm(avg_vec) < 1:
61
+ return None
62
+ return avg_vec
63
+
64
+ # ============================================================
65
+ # 🧭 Wrong-Direction Detection Core
66
+ # ============================================================
67
+ def process_video(video_file, stage2_json):
68
+ data = json.load(open(stage2_json))
69
+ lane_flows = np.array(data.get("flow_centers", [[1,0]]))
70
+ drive_zone = np.array(data.get("drive_zone", []))
71
+ entry_zones = [np.array(z) for z in data.get("entry_zones", [])]
72
+
73
+ cap = cv2.VideoCapture(video_file)
74
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
75
+ w, h = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
77
+ out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
78
 
79
+ tracks, next_id = {}, 0
80
+ SIM_THRESH = 0.5 # cosine similarity threshold
81
+ DELAY_FRAMES = 8 # wait N frames before flagging
82
+ MIN_FLOW_SPEED = 1.2 # ignore jitter
83
 
84
  while True:
85
  ret, frame = cap.read()
86
+ if not ret:
87
+ break
88
 
89
+ results = model(frame)[0]
90
+ dets = []
91
  for box in results.boxes:
92
+ cls = int(box.cls[0])
93
+ if cls in VEHICLE_CLASSES:
94
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
95
+ cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
96
+ dets.append([cx, cy])
97
+ dets = np.array(dets)
98
+
99
+ # --- Tracker update ---
100
  assigned = set()
101
+ if len(dets) > 0 and len(tracks) > 0:
102
+ existing = np.array([t.kf.x[:2].reshape(-1) for t in tracks.values()])
103
+ dists = np.linalg.norm(existing[:, None, :] - dets[None, :, :], axis=2)
104
+ row_idx, col_idx = linear_sum_assignment(dists)
105
+ for r, c in zip(row_idx, col_idx):
106
+ if dists[r, c] < 50:
107
+ tid = list(tracks.keys())[r]
108
+ tracks[tid].update(dets[c])
109
+ assigned.add(c)
110
+ for i, d in enumerate(dets):
111
+ if i not in assigned:
112
+ tracks[next_id] = Track(d, next_id)
 
 
 
 
113
  next_id += 1
 
 
 
 
 
 
114
 
115
+ # --- Draw & classify ---
116
+ for tid, trk in list(tracks.items()):
117
+ pos = trk.update(trk.kf.x[:2].reshape(-1))
118
+ pts = np.array(trk.history)
119
+ if len(pts) > 1:
120
+ for i in range(1, len(pts)):
121
+ cv2.line(frame, tuple(np.int32(pts[i-1])), tuple(np.int32(pts[i])), (0, 0, 255), 1)
122
+
123
+ # compute smooth direction
124
+ motion = smooth_direction(pts)
125
+ if motion is None:
126
  continue
127
+ if np.linalg.norm(motion) < MIN_FLOW_SPEED:
 
 
 
 
 
 
 
 
 
128
  continue
129
 
130
+ # cosine similarity to closest lane flow
131
+ sims = [compute_cosine_similarity(motion, f) for f in lane_flows]
132
+ best_sim = max(sims)
133
+
134
+ # only classify after some frames (to reduce false early flag)
135
+ if trk.frames_seen > DELAY_FRAMES:
136
+ if best_sim < SIM_THRESH:
137
+ trk.status = "WRONG"
138
+ color = (0, 0, 255)
139
+ else:
140
+ trk.status = "OK"
141
+ color = (0, 255, 0)
142
+ cv2.putText(frame, f"ID:{tid} {trk.status}", tuple(np.int32(pos)),
143
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
 
 
 
 
 
144
 
145
  out.write(frame)
146
 
147
  cap.release()
148
  out.release()
149
+ return out_path
150
+
151
+ # ============================================================
152
+ # 🎛️ Gradio Interface
153
+ # ============================================================
154
+ description = """
155
+ ### 🚦 Stage 3 — Wrong Direction Detection (Improved)
156
+ - Uses cosine similarity instead of raw angle comparison
157
+ - Lane-wise flow support for curved roads
158
+ - Temporal smoothing & delayed classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  """
160
 
161
  demo = gr.Interface(
162
+ fn=process_video,
163
  inputs=[
164
+ gr.File(label="Input Video"),
165
+ gr.File(label="Stage 2 Flow JSON")
 
 
 
 
 
 
 
166
  ],
167
+ outputs=gr.Video(label="Output (with WRONG/OK labels)"),
168
+ title="🚗 Stage 3 – Improved Wrong-Direction Detection",
169
+ description=description
170
  )
171
 
 
 
 
 
172
  if __name__ == "__main__":
173
+ demo.launch()