Stage 3 (Angle + Temporal + Zone-Aware + Entry Gating)

#9
Files changed (1) hide show
  1. app.py +134 -178
app.py CHANGED
@@ -1,211 +1,167 @@
1
- import os, cv2, json, tempfile, zipfile, numpy as np, gradio as gr
2
- from ultralytics import YOLO
3
- from filterpy.kalman import KalmanFilter
4
- from scipy.optimize import linear_sum_assignment
 
 
 
 
5
 
6
  # ------------------------------------------------------------
7
- # 🔧 Safe-load fix for PyTorch 2.6
8
  # ------------------------------------------------------------
9
- import torch, ultralytics.nn.tasks as ultralytics_tasks
10
- torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel])
 
 
11
 
12
  # ------------------------------------------------------------
13
- # ⚙️ YOLO setup
14
  # ------------------------------------------------------------
15
- MODEL_PATH = "yolov8n.pt"
16
- model = YOLO(MODEL_PATH)
17
- VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
18
-
19
 
20
  # ------------------------------------------------------------
21
- # 🧩 Kalman tracker
22
  # ------------------------------------------------------------
23
- class Track:
24
- def __init__(self, bbox, tid):
25
- self.id = tid
26
- self.kf = KalmanFilter(dim_x=4, dim_z=2)
27
- self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
28
- self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
29
- self.kf.P *= 1000.0
30
- self.kf.R *= 10.0
31
- self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
32
- self.trace = []
33
-
34
- def centroid(self, b):
35
- x1, y1, x2, y2 = b
36
- return [(x1+x2)/2, (y1+y2)/2]
37
-
38
- def predict(self):
39
- self.kf.predict()
40
- return self.kf.x[:2].reshape(2)
41
-
42
- def update(self, b):
43
- z = np.array(self.centroid(b)).reshape(2,1)
44
- self.kf.update(z)
45
- cx, cy = self.kf.x[:2].reshape(2)
46
- self.trace.append((float(cx), float(cy)))
47
- return (cx, cy)
48
-
49
 
50
  # ------------------------------------------------------------
51
- # 🧮 Direction analyzer
52
  # ------------------------------------------------------------
53
- def analyze_direction(trace, centers):
54
- if len(trace) < 3:
55
- return "NA", 1.0
56
- v = np.array(trace[-1]) - np.array(trace[-3])
57
- if np.linalg.norm(v) < 1e-6:
58
- return "NA", 1.0
59
- v = v / np.linalg.norm(v)
60
- sims = np.dot(centers, v)
61
- max_sim = np.max(sims)
62
- if max_sim < 0:
63
- return "WRONG", float(max_sim)
64
- return "OK", float(max_sim)
65
-
66
 
67
  # ------------------------------------------------------------
68
- # 🧭 Load normalized flow centers
69
  # ------------------------------------------------------------
70
- def load_flow_centers(flow_json):
71
- data = json.load(open(flow_json))
72
- centers = np.array(data["flow_centers"])
73
- centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
74
- return centers
75
-
76
 
77
  # ------------------------------------------------------------
78
- # 🎥 Process video
79
- # ------------------------------------------------------------
80
- def process_video(video_path, flow_json, show_only_wrong=False):
81
- centers = load_flow_centers(flow_json)
82
- cap = cv2.VideoCapture(video_path)
83
- fps = cap.get(cv2.CAP_PROP_FPS) or 25
84
- w, h = int(cap.get(3)), int(cap.get(4))
85
-
86
- out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
87
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
88
- out = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
89
-
90
- tracks, next_id, log = [], 0, []
91
-
92
- while True:
93
- ret, frame = cap.read()
94
- if not ret:
95
- break
96
-
97
- results = model(frame, verbose=False)[0]
98
- detections = []
99
- for box in results.boxes:
100
- if int(box.cls) in VEHICLE_CLASSES and box.conf > 0.3:
101
- detections.append(box.xyxy[0].cpu().numpy())
102
-
103
- # Predict existing
104
- predicted = [t.predict() for t in tracks]
105
- predicted = np.array(predicted) if len(predicted) > 0 else np.empty((0,2))
106
-
107
- # Assign detections to tracks
108
- assigned = set()
109
- if len(predicted) > 0 and len(detections) > 0:
110
- cost = np.zeros((len(predicted), len(detections)))
111
- for i, p in enumerate(predicted):
112
- for j, d in enumerate(detections):
113
- cx, cy = ((d[0]+d[2])/2, (d[1]+d[3])/2)
114
- cost[i,j] = np.linalg.norm(p - np.array([cx,cy]))
115
- r, c = linear_sum_assignment(cost)
116
- for i, j in zip(r, c):
117
- if cost[i,j] < 80:
118
- assigned.add(j)
119
- tracks[i].update(detections[j])
120
-
121
- # New tracks
122
- for j, d in enumerate(detections):
123
- if j not in assigned:
124
- t = Track(d, next_id)
125
- next_id += 1
126
- t.update(d)
127
- tracks.append(t)
128
-
129
- # --- 🧩 Draw + Log (toggle support) ---
130
- for trk in tracks:
131
- if len(trk.trace) < 3:
132
- continue
133
- status, sim = analyze_direction(trk.trace, centers)
134
-
135
- # Skip OKs if toggle is enabled
136
- if show_only_wrong and status != "WRONG":
137
- continue
138
-
139
- x, y = map(int, trk.trace[-1])
140
- color = (0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (200,200,200))
141
- cv2.circle(frame,(x,y),4,color,-1)
142
- cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),
143
- cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
144
- for i in range(1,len(trk.trace)):
145
- cv2.line(frame,
146
- (int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
147
- (int(trk.trace[i][0]),int(trk.trace[i][1])),
148
- color,1)
149
-
150
- # Log once per unique vehicle
151
- if len(trk.trace) > 5 and not any(entry["id"] == trk.id for entry in log):
152
- log.append({"id": trk.id, "status": status, "cos_sim": round(sim,3)})
153
-
154
- out.write(frame)
155
-
156
- cap.release()
157
- out.release()
158
-
159
- # Unique summary
160
- unique_ids = {entry["id"] for entry in log}
161
- summary = {"vehicles_analyzed": len(unique_ids)}
162
-
163
- # Create ZIP bundle
164
- zip_path = tempfile.NamedTemporaryFile(suffix=".zip", delete=False).name
165
- with zipfile.ZipFile(zip_path, "w") as zf:
166
- zf.write(out_path, arcname="violation_output.mp4")
167
- zf.writestr("per_vehicle_log.json", json.dumps(log, indent=2))
168
- zf.writestr("summary.json", json.dumps(summary, indent=2))
169
-
170
- return out_path, log, summary, zip_path
171
-
172
 
173
  # ------------------------------------------------------------
174
- # 🖥️ Gradio interface
175
  # ------------------------------------------------------------
176
- def run_app(video, flow_file, show_only_wrong):
177
- vid, log_json, summary, zip_file = process_video(video, flow_file, show_only_wrong)
178
- return vid, log_json, summary, zip_file
 
 
 
 
 
 
 
 
 
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
 
 
 
181
  description_text = """
182
- ### 🚦 Wrong-Direction Detection (Stage 3)
183
- Upload your traffic video and the **flow_stats.json** from Stage 2.
184
- You can toggle whether to display all detections or only WRONG-direction vehicles.
 
 
185
  """
186
 
187
  demo = gr.Interface(
188
- fn=run_app,
189
  inputs=[
190
- gr.Video(label="Upload Traffic Video (.mp4)"),
191
- gr.File(label="Upload flow_stats.json (Stage 2 Output)"),
192
- gr.Checkbox(label="Show Only Wrong Labels", value=False)
193
  ],
194
  outputs=[
195
- gr.Video(label="Violation Output Video"),
196
- gr.JSON(label="Per-Vehicle Log"),
197
- gr.JSON(label="Summary"),
198
- gr.File(label="⬇️ Download All Outputs (ZIP)")
199
  ],
200
- title="🚗 Wrong-Direction Detection – Stage 3 (Toggle + ZIP)",
201
- description=description_text,
202
- examples=None,
203
  )
204
 
205
- # Disable analytics / flagging / SSR
206
- demo.flagging_mode = "never"
207
- demo.cache_examples = False
208
- os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
209
-
210
  if __name__ == "__main__":
211
- demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)
 
1
+ # ============================================================
2
+ # 🚦 Stage 3 – Wrong-Direction Detection
3
+ # (Angle + Temporal + Zone-Aware + Entry Gating + Confidence)
4
+ # ============================================================
5
+
6
+ import gradio as gr
7
+ import numpy as np, cv2, json, os, tempfile
8
+ from collections import defaultdict
9
 
10
  # ------------------------------------------------------------
11
+ # ⚙️ CONFIG
12
  # ------------------------------------------------------------
13
+ ANGLE_THRESHOLD = 60 # degrees → above this = WRONG
14
+ SMOOTH_FRAMES = 5 # frames for temporal smoothing
15
+ ENTRY_ZONE_RATIO = 0.15 # top 15% = entry region (skip)
16
+ CONF_MIN, CONF_MAX = 0, 100
17
 
18
  # ------------------------------------------------------------
19
+ # 1️⃣ Load flow model (Stage 2)
20
  # ------------------------------------------------------------
21
+ def load_flow_model(flow_model_json):
22
+ model = json.load(open(flow_model_json))
23
+ centers = [np.array(z) for z in model["zone_flow_centers"]]
24
+ return centers
25
 
26
  # ------------------------------------------------------------
27
+ # 2️⃣ Extract trajectories
28
  # ------------------------------------------------------------
29
+ def extract_trajectories(json_file):
30
+ data = json.load(open(json_file))
31
+ tracks = {tid: np.array(pts) for tid, pts in data.items() if len(pts) > 2}
32
+ return tracks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # ------------------------------------------------------------
35
+ # 3️⃣ Smoothed direction for a trajectory
36
  # ------------------------------------------------------------
37
+ def smooth_direction(pts, window=SMOOTH_FRAMES):
38
+ if len(pts) < 2:
39
+ return np.array([0, 0])
40
+ diffs = np.diff(pts[-window:], axis=0)
41
+ v = np.mean(diffs, axis=0)
42
+ n = np.linalg.norm(v)
43
+ return v / (n + 1e-6)
 
 
 
 
 
 
44
 
45
  # ------------------------------------------------------------
46
+ # 4️⃣ Compute angular difference (deg)
47
  # ------------------------------------------------------------
48
+ def angle_between(v1, v2):
49
+ v1 = v1 / (np.linalg.norm(v1) + 1e-6)
50
+ v2 = v2 / (np.linalg.norm(v2) + 1e-6)
51
+ cosang = np.clip(np.dot(v1, v2), -1, 1)
52
+ return np.degrees(np.arccos(cosang))
 
53
 
54
  # ------------------------------------------------------------
55
+ # 5️⃣ Determine zone index for y
56
+ # ------------------------------------------------------------
57
+ def get_zone_idx(y, frame_h, n_zones):
58
+ zone_height = frame_h / n_zones
59
+ return int(np.clip(y // zone_height, 0, n_zones - 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  # ------------------------------------------------------------
62
+ # 6️⃣ Confidence mapping
63
  # ------------------------------------------------------------
64
+ def angle_to_confidence(angle):
65
+ """
66
+ 100% confidence
67
+ ANGLE_THRESHOLD° → 50%
68
+ 180° → 0%
69
+ """
70
+ if angle < 0:
71
+ return CONF_MIN
72
+ if angle >= 180:
73
+ return CONF_MIN
74
+ # linear mapping: smaller angle = higher confidence
75
+ conf = max(CONF_MIN, CONF_MAX - (angle / 180) * 100)
76
+ return round(conf, 1)
77
 
78
+ # ------------------------------------------------------------
79
+ # 7️⃣ Main logic
80
+ # ------------------------------------------------------------
81
+ def classify_wrong_direction(traj_json, flow_model_json, bg_img=None):
82
+ tracks = extract_trajectories(traj_json)
83
+ centers_by_zone = load_flow_model(flow_model_json)
84
+
85
+ if bg_img and os.path.exists(bg_img):
86
+ bg = cv2.imread(bg_img)
87
+ else:
88
+ bg = np.ones((600, 900, 3), dtype=np.uint8) * 40
89
+ h, w = bg.shape[:2]
90
+
91
+ overlay = bg.copy()
92
+ font = cv2.FONT_HERSHEY_SIMPLEX
93
+ results = []
94
+
95
+ for tid, pts in tracks.items():
96
+ if len(pts) < 3:
97
+ continue
98
+ cur_pt = pts[-1]
99
+ y = cur_pt[1]
100
+ zone_idx = get_zone_idx(y, h, len(centers_by_zone))
101
+
102
+ # Skip entry region
103
+ if y < h * ENTRY_ZONE_RATIO:
104
+ continue
105
+
106
+ v = smooth_direction(pts)
107
+ centers = centers_by_zone[zone_idx]
108
+ angles = [angle_between(v, c) for c in centers]
109
+ best_angle = min(angles)
110
+
111
+ # Confidence & label
112
+ conf = angle_to_confidence(best_angle)
113
+ label = "OK" if best_angle < ANGLE_THRESHOLD else "WRONG"
114
+ color = (0, 255, 0) if label == "OK" else (0, 0, 255)
115
+
116
+ # Draw trajectory & label
117
+ for p1, p2 in zip(pts[:-1], pts[1:]):
118
+ cv2.line(overlay, tuple(p1.astype(int)), tuple(p2.astype(int)), color, 2)
119
+ cv2.circle(overlay, tuple(cur_pt.astype(int)), 5, color, -1)
120
+ cv2.putText(
121
+ overlay,
122
+ f"ID:{tid} {label} ({conf}%)",
123
+ (int(cur_pt[0]) + 5, int(cur_pt[1]) - 5),
124
+ font, 0.6, color, 2
125
+ )
126
+
127
+ results.append({
128
+ "id": tid,
129
+ "zone": int(zone_idx),
130
+ "angle": round(best_angle, 1),
131
+ "confidence": conf,
132
+ "label": label
133
+ })
134
+
135
+ combined = cv2.addWeighted(bg, 0.6, overlay, 0.4, 0)
136
+ out_path = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False).name
137
+ cv2.imwrite(out_path, combined)
138
+ return out_path, results
139
 
140
+ # ------------------------------------------------------------
141
+ # 🖥️ Gradio Interface
142
+ # ------------------------------------------------------------
143
  description_text = """
144
+ ### 🚦 Wrong-Direction Detection (Stage 3 — with Confidence)
145
+ - Compares each vehicle’s motion to its zone’s dominant flow.
146
+ - Uses angular difference smaller angle higher confidence.
147
+ - Ignores entry region to avoid false positives.
148
+ - Displays ID, label, and confidence percentage.
149
  """
150
 
151
  demo = gr.Interface(
152
+ fn=classify_wrong_direction,
153
  inputs=[
154
+ gr.File(label="Trajectories JSON (Stage 1)"),
155
+ gr.File(label="Flow Model JSON (Stage 2)"),
156
+ gr.File(label="Optional background frame (.jpg)")
157
  ],
158
  outputs=[
159
+ gr.Image(label="Annotated Output"),
160
+ gr.JSON(label="Per-Vehicle Results")
 
 
161
  ],
162
+ title="🚗 Stage 3 — Wrong-Direction Detection (with Confidence)",
163
+ description=description_text
 
164
  )
165
 
 
 
 
 
 
166
  if __name__ == "__main__":
167
+ demo.launch()