✅ Angle-Aware ✅ Temporal Smoothing

#12
Files changed (1) hide show
  1. app.py +113 -55
app.py CHANGED
@@ -16,9 +16,16 @@ MODEL_PATH = "yolov8n.pt"
16
  model = YOLO(MODEL_PATH)
17
  VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
18
 
 
 
 
 
 
 
 
19
 
20
  # ------------------------------------------------------------
21
- # 🧩 Kalman tracker
22
  # ------------------------------------------------------------
23
  class Track:
24
  def __init__(self, bbox, tid):
@@ -28,8 +35,12 @@ class Track:
28
  self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
29
  self.kf.P *= 1000.0
30
  self.kf.R *= 10.0
31
- self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
32
  self.trace = []
 
 
 
 
33
 
34
  def centroid(self, b):
35
  x1, y1, x2, y2 = b
@@ -37,6 +48,7 @@ class Track:
37
 
38
  def predict(self):
39
  self.kf.predict()
 
40
  return self.kf.x[:2].reshape(2)
41
 
42
  def update(self, b):
@@ -44,67 +56,93 @@ class Track:
44
  self.kf.update(z)
45
  cx, cy = self.kf.x[:2].reshape(2)
46
  self.trace.append((float(cx), float(cy)))
 
47
  return (cx, cy)
48
 
49
-
50
  # ------------------------------------------------------------
51
- # 🧮 Direction analyzer
52
  # ------------------------------------------------------------
53
- def analyze_direction(trace, centers):
54
  if len(trace) < 3:
55
  return "NA", 1.0
 
 
56
  v = np.array(trace[-1]) - np.array(trace[-3])
57
  if np.linalg.norm(v) < 1e-6:
58
  return "NA", 1.0
59
- v = v / np.linalg.norm(v)
 
 
 
 
60
  sims = np.dot(centers, v)
61
  max_sim = np.max(sims)
62
- if max_sim < 0:
63
- return "WRONG", float(max_sim)
64
- return "OK", float(max_sim)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # ------------------------------------------------------------
68
- # 🧭 Load normalized flow centers
69
  # ------------------------------------------------------------
70
- def load_flow_centers(flow_json):
71
  data = json.load(open(flow_json))
72
  centers = np.array(data["flow_centers"])
73
  centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
74
- return centers
 
 
 
 
 
 
 
 
 
 
75
 
 
 
76
 
77
  # ------------------------------------------------------------
78
- # 🎥 Process video
79
  # ------------------------------------------------------------
80
  def process_video(video_path, flow_json, show_only_wrong=False):
81
- centers = load_flow_centers(flow_json)
 
82
  cap = cv2.VideoCapture(video_path)
83
- fps = cap.get(cv2.CAP_PROP_FPS) or 25
84
  w, h = int(cap.get(3)), int(cap.get(4))
85
 
86
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
87
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
88
- out = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
89
 
90
  tracks, next_id, log = [], 0, []
91
 
92
  while True:
93
  ret, frame = cap.read()
94
- if not ret:
95
- break
96
 
97
  results = model(frame, verbose=False)[0]
98
  detections = []
99
  for box in results.boxes:
100
- if int(box.cls) in VEHICLE_CLASSES and box.conf > 0.3:
101
  detections.append(box.xyxy[0].cpu().numpy())
102
 
103
- # Predict existing
104
- predicted = [t.predict() for t in tracks]
105
- predicted = np.array(predicted) if len(predicted) > 0 else np.empty((0,2))
106
 
107
- # Assign detections to tracks
108
  assigned = set()
109
  if len(predicted) > 0 and len(detections) > 0:
110
  cost = np.zeros((len(predicted), len(detections)))
@@ -118,49 +156,73 @@ def process_video(video_path, flow_json, show_only_wrong=False):
118
  assigned.add(j)
119
  tracks[i].update(detections[j])
120
 
121
- # New tracks
122
  for j, d in enumerate(detections):
123
  if j not in assigned:
124
  t = Track(d, next_id)
125
  next_id += 1
126
  t.update(d)
 
 
 
 
127
  tracks.append(t)
128
 
129
- # --- 🧩 Draw + Log (toggle support) ---
 
 
 
 
 
130
  for trk in tracks:
131
- if len(trk.trace) < 3:
132
  continue
133
- status, sim = analyze_direction(trk.trace, centers)
134
 
135
- # Skip OKs if toggle is enabled
136
- if show_only_wrong and status != "WRONG":
 
 
 
 
 
 
 
137
  continue
138
 
139
- x, y = map(int, trk.trace[-1])
140
- color = (0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (200,200,200))
141
- cv2.circle(frame,(x,y),4,color,-1)
142
- cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),
143
- cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
144
- for i in range(1,len(trk.trace)):
145
- cv2.line(frame,
146
- (int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
147
- (int(trk.trace[i][0]),int(trk.trace[i][1])),
148
- color,1)
149
 
150
- # Log once per unique vehicle
151
- if len(trk.trace) > 5 and not any(entry["id"] == trk.id for entry in log):
152
- log.append({"id": trk.id, "status": status, "cos_sim": round(sim,3)})
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  out.write(frame)
155
 
156
  cap.release()
157
  out.release()
158
 
159
- # Unique summary
160
- unique_ids = {entry["id"] for entry in log}
161
- summary = {"vehicles_analyzed": len(unique_ids)}
 
 
 
 
162
 
163
- # Create ZIP bundle
164
  zip_path = tempfile.NamedTemporaryFile(suffix=".zip", delete=False).name
165
  with zipfile.ZipFile(zip_path, "w") as zf:
166
  zf.write(out_path, arcname="violation_output.mp4")
@@ -169,7 +231,6 @@ def process_video(video_path, flow_json, show_only_wrong=False):
169
 
170
  return out_path, log, summary, zip_path
171
 
172
-
173
  # ------------------------------------------------------------
174
  # 🖥️ Gradio interface
175
  # ------------------------------------------------------------
@@ -177,11 +238,10 @@ def run_app(video, flow_file, show_only_wrong):
177
  vid, log_json, summary, zip_file = process_video(video, flow_file, show_only_wrong)
178
  return vid, log_json, summary, zip_file
179
 
180
-
181
  description_text = """
182
- ### 🚦 Wrong-Direction Detection (Stage 3)
183
  Upload your traffic video and the **flow_stats.json** from Stage 2.
184
- You can toggle whether to display all detections or only WRONG-direction vehicles.
185
  """
186
 
187
  demo = gr.Interface(
@@ -197,15 +257,13 @@ demo = gr.Interface(
197
  gr.JSON(label="Summary"),
198
  gr.File(label="⬇️ Download All Outputs (ZIP)")
199
  ],
200
- title="🚗 Wrong-Direction Detection – Stage 3 (Toggle + ZIP)",
201
  description=description_text,
202
- examples=None,
203
  )
204
 
205
- # Disable analytics / flagging / SSR
206
  demo.flagging_mode = "never"
207
  demo.cache_examples = False
208
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
209
 
210
  if __name__ == "__main__":
211
- demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)
 
16
  model = YOLO(MODEL_PATH)
17
  VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
18
 
19
+ # ------------------------------------------------------------
20
+ # 🧭 Utility: rotate vector by angle
21
+ # ------------------------------------------------------------
22
+ def rotate_vec(v, theta_deg):
23
+ t = np.deg2rad(theta_deg)
24
+ R = np.array([[np.cos(t), -np.sin(t)], [np.sin(t), np.cos(t)]])
25
+ return R @ v
26
 
27
  # ------------------------------------------------------------
28
+ # 🧩 Kalman tracker with temporal smoothing + entry gate flag
29
  # ------------------------------------------------------------
30
  class Track:
31
  def __init__(self, bbox, tid):
 
35
  self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
36
  self.kf.P *= 1000.0
37
  self.kf.R *= 10.0
38
+ self.kf.x[:2,0] = np.array(self.centroid(bbox), dtype=float)
39
  self.trace = []
40
+ self.status_hist = []
41
+ self.entry_flag = False
42
+ self.active = True
43
+ self.missed_frames = 0
44
 
45
  def centroid(self, b):
46
  x1, y1, x2, y2 = b
 
48
 
49
  def predict(self):
50
  self.kf.predict()
51
+ self.missed_frames += 1
52
  return self.kf.x[:2].reshape(2)
53
 
54
  def update(self, b):
 
56
  self.kf.update(z)
57
  cx, cy = self.kf.x[:2].reshape(2)
58
  self.trace.append((float(cx), float(cy)))
59
+ self.missed_frames = 0
60
  return (cx, cy)
61
 
 
62
  # ------------------------------------------------------------
63
+ # 🧮 Direction analyzer (angle + temporal aware)
64
  # ------------------------------------------------------------
65
+ def analyze_direction(trace, centers, road_angle_deg, hist):
66
  if len(trace) < 3:
67
  return "NA", 1.0
68
+
69
+ # motion vector
70
  v = np.array(trace[-1]) - np.array(trace[-3])
71
  if np.linalg.norm(v) < 1e-6:
72
  return "NA", 1.0
73
+
74
+ # rotate to road reference
75
+ v = rotate_vec(v / np.linalg.norm(v), -road_angle_deg)
76
+
77
+ # cosine similarity vs dominant centers
78
  sims = np.dot(centers, v)
79
  max_sim = np.max(sims)
 
 
 
80
 
81
+ # temporal averaging
82
+ hist.append(max_sim)
83
+ if len(hist) > 5:
84
+ hist.pop(0)
85
+ avg_sim = np.mean(hist)
86
+
87
+ if avg_sim < -0.2:
88
+ return "WRONG", float(avg_sim)
89
+ elif avg_sim > 0.2:
90
+ return "OK", float(avg_sim)
91
+ else:
92
+ return "NA", float(avg_sim)
93
 
94
  # ------------------------------------------------------------
95
+ # 🗺️ Load Stage-2 flow stats (centers, angle, zones)
96
  # ------------------------------------------------------------
97
+ def load_flow_stats(flow_json):
98
  data = json.load(open(flow_json))
99
  centers = np.array(data["flow_centers"])
100
  centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
101
+ road_angle_deg = float(data.get("road_angle_deg", 0.0))
102
+ drive_zone = data.get("drive_zone", None)
103
+ entry_zones = data.get("entry_zones", [])
104
+ return centers, road_angle_deg, drive_zone, entry_zones
105
+
106
+ # ------------------------------------------------------------
107
+ # 🧾 Zone tests
108
+ # ------------------------------------------------------------
109
+ def inside_zone(pt, zone):
110
+ if zone is None: return True
111
+ return cv2.pointPolygonTest(np.array(zone, np.int32), pt, False) >= 0
112
 
113
+ def inside_any(pt, zones):
114
+ return any(cv2.pointPolygonTest(np.array(z, np.int32), pt, False) >= 0 for z in zones)
115
 
116
  # ------------------------------------------------------------
117
+ # 🎥 Process video (angle + temporal + zone + entry-gating)
118
  # ------------------------------------------------------------
119
  def process_video(video_path, flow_json, show_only_wrong=False):
120
+ centers, road_angle, drive_zone, entry_zones = load_flow_stats(flow_json)
121
+
122
  cap = cv2.VideoCapture(video_path)
123
+ fps = int(cap.get(cv2.CAP_PROP_FPS)) or 25
124
  w, h = int(cap.get(3)), int(cap.get(4))
125
 
126
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
127
+ out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
 
128
 
129
  tracks, next_id, log = [], 0, []
130
 
131
  while True:
132
  ret, frame = cap.read()
133
+ if not ret: break
 
134
 
135
  results = model(frame, verbose=False)[0]
136
  detections = []
137
  for box in results.boxes:
138
+ if int(box.cls) in VEHICLE_CLASSES and float(box.conf) > 0.3:
139
  detections.append(box.xyxy[0].cpu().numpy())
140
 
141
+ # predict existing
142
+ predicted = [t.predict() for t in tracks if t.active]
143
+ predicted = np.array(predicted) if predicted else np.empty((0,2))
144
 
145
+ # assign detections
146
  assigned = set()
147
  if len(predicted) > 0 and len(detections) > 0:
148
  cost = np.zeros((len(predicted), len(detections)))
 
156
  assigned.add(j)
157
  tracks[i].update(detections[j])
158
 
159
+ # new tracks
160
  for j, d in enumerate(detections):
161
  if j not in assigned:
162
  t = Track(d, next_id)
163
  next_id += 1
164
  t.update(d)
165
+ first_pt = tuple(map(int, t.trace[-1]))
166
+ # entry gating: mark if starts inside forbidden zone
167
+ if inside_any(first_pt, entry_zones):
168
+ t.entry_flag = True
169
  tracks.append(t)
170
 
171
+ # clean up stale tracks
172
+ for t in tracks:
173
+ if t.missed_frames > 15:
174
+ t.active = False
175
+
176
+ # draw + log
177
  for trk in tracks:
178
+ if not trk.active or len(trk.trace) < 3:
179
  continue
 
180
 
181
+ x, y = map(int, trk.trace[-1])
182
+ if not inside_zone((x, y), drive_zone):
183
+ continue # skip outside drive zone
184
+
185
+ status, sim = analyze_direction(trk.trace, centers, road_angle, trk.status_hist)
186
+ if trk.entry_flag:
187
+ status = "WRONG_ENTRY"
188
+
189
+ if show_only_wrong and status not in ["WRONG", "WRONG_ENTRY"]:
190
  continue
191
 
192
+ color = (0,255,0) if status=="OK" else \
193
+ (0,0,255) if status.startswith("WRONG") else (200,200,200)
 
 
 
 
 
 
 
 
194
 
195
+ cv2.circle(frame, (x,y), 4, color, -1)
196
+ cv2.putText(frame, f"ID:{trk.id} {status}", (x-20,y-10),
197
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
198
+ for i in range(1, len(trk.trace)):
199
+ cv2.line(frame,
200
+ (int(trk.trace[i-1][0]), int(trk.trace[i-1][1])),
201
+ (int(trk.trace[i][0]), int(trk.trace[i][1])),
202
+ color, 1)
203
+
204
+ if len(trk.trace) > 5 and not any(e["id"]==trk.id for e in log):
205
+ log.append({
206
+ "id": trk.id,
207
+ "status": status,
208
+ "cos_sim": round(sim,3),
209
+ "entry_flag": trk.entry_flag
210
+ })
211
 
212
  out.write(frame)
213
 
214
  cap.release()
215
  out.release()
216
 
217
+ # summary
218
+ unique_ids = {e["id"] for e in log}
219
+ summary = {
220
+ "vehicles_analyzed": len(unique_ids),
221
+ "wrong_count": sum(1 for e in log if e["status"].startswith("WRONG")),
222
+ "road_angle_deg": road_angle
223
+ }
224
 
225
+ # zip outputs
226
  zip_path = tempfile.NamedTemporaryFile(suffix=".zip", delete=False).name
227
  with zipfile.ZipFile(zip_path, "w") as zf:
228
  zf.write(out_path, arcname="violation_output.mp4")
 
231
 
232
  return out_path, log, summary, zip_path
233
 
 
234
  # ------------------------------------------------------------
235
  # 🖥️ Gradio interface
236
  # ------------------------------------------------------------
 
238
  vid, log_json, summary, zip_file = process_video(video, flow_file, show_only_wrong)
239
  return vid, log_json, summary, zip_file
240
 
 
241
  description_text = """
242
+ ### 🚦 Wrong-Direction Detection (Stage 3 — Angle + Temporal + Zone + Entry-Aware)
243
  Upload your traffic video and the **flow_stats.json** from Stage 2.
244
+ Stage 3 will respect the learned road angle, driving zones, and entry gates.
245
  """
246
 
247
  demo = gr.Interface(
 
257
  gr.JSON(label="Summary"),
258
  gr.File(label="⬇️ Download All Outputs (ZIP)")
259
  ],
260
+ title="🚗 Wrong-Direction Detection – Stage 3 (Angle + Temporal + Zone + Entry)",
261
  description=description_text,
 
262
  )
263
 
 
264
  demo.flagging_mode = "never"
265
  demo.cache_examples = False
266
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
267
 
268
  if __name__ == "__main__":
269
+ demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)