Files changed (1) hide show
  1. app.py +180 -130
app.py CHANGED
@@ -1,161 +1,211 @@
1
- # ============================================================
2
- # 🚦 Stage 3 – Wrong-Direction Detection (Video Output Version)
3
- # ============================================================
4
-
5
- import gradio as gr
6
- import numpy as np, cv2, json, os, tempfile
7
- from collections import defaultdict
8
- import math
9
 
10
  # ------------------------------------------------------------
11
- # ⚙️ CONFIG
12
  # ------------------------------------------------------------
13
- ANGLE_THRESHOLD = 60 # deg → above = WRONG
14
- SMOOTH_FRAMES = 5 # temporal smoothing
15
- ENTRY_ZONE_RATIO = 0.15 # skip top 15 %
16
- CONF_MIN, CONF_MAX = 0, 100
17
- FPS = 25 # output video fps
18
 
19
  # ------------------------------------------------------------
20
- # 🔧 Helper – universal loader for Gradio inputs
21
  # ------------------------------------------------------------
22
- def load_json_input(file_obj):
23
- if file_obj is None:
24
- raise ValueError("No file provided.")
25
- if isinstance(file_obj, dict) and "name" in file_obj:
26
- path = file_obj["name"]
27
- return json.load(open(path))
28
- elif hasattr(file_obj, "name"):
29
- return json.load(open(file_obj.name))
30
- elif isinstance(file_obj, str):
31
- return json.load(open(file_obj))
32
- else:
33
- raise ValueError("Unsupported file input type.")
34
 
35
  # ------------------------------------------------------------
36
- # 🧩 Load Stage 2 flow model
37
  # ------------------------------------------------------------
38
- def load_flow_model(flow_model_json):
39
- model = load_json_input(flow_model_json)
40
- centers = [np.array(z) for z in model["zone_flow_centers"]]
41
- return centers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # ------------------------------------------------------------
44
- # 🧩 Extract trajectories (Stage 1)
45
  # ------------------------------------------------------------
46
- def extract_trajectories(json_file):
47
- data = load_json_input(json_file)
48
- tracks = {tid: np.array(pts) for tid, pts in data.items() if len(pts) > 2}
49
- return tracks
 
 
 
 
 
 
 
 
 
50
 
51
  # ------------------------------------------------------------
52
- # 🧮 Direction + Angle + Confidence Helpers
53
  # ------------------------------------------------------------
54
- def smooth_direction(pts, window=SMOOTH_FRAMES):
55
- if len(pts) < 2: return np.array([0,0])
56
- diffs = np.diff(pts[-window:], axis=0)
57
- v = np.mean(diffs, axis=0)
58
- return v / (np.linalg.norm(v)+1e-6)
59
-
60
- def angle_between(v1, v2):
61
- v1 = v1 / (np.linalg.norm(v1)+1e-6)
62
- v2 = v2 / (np.linalg.norm(v2)+1e-6)
63
- cosang = np.clip(np.dot(v1,v2), -1,1)
64
- return np.degrees(np.arccos(cosang))
65
-
66
- def angle_to_confidence(angle):
67
- if angle<0: return CONF_MIN
68
- if angle>=180: return CONF_MIN
69
- conf = max(CONF_MIN, CONF_MAX - (angle/180)*100)
70
- return round(conf,1)
71
-
72
- def get_zone_idx(y, frame_h, n_zones):
73
- zone_h = frame_h/n_zones
74
- return int(np.clip(y//zone_h, 0, n_zones-1))
75
 
76
  # ------------------------------------------------------------
77
- # 🎥 Main logic → annotated video
78
  # ------------------------------------------------------------
79
- def classify_wrong_direction_video(traj_json, flow_model_json, bg_img=None):
80
- tracks = extract_trajectories(traj_json)
81
- centers_by_zone = load_flow_model(flow_model_json)
82
-
83
- # background size
84
- if bg_img:
85
- if isinstance(bg_img, dict) and "name" in bg_img:
86
- bg_path = bg_img["name"]
87
- elif hasattr(bg_img,"name"):
88
- bg_path = bg_img.name
89
- else:
90
- bg_path = bg_img
91
- bg = cv2.imread(bg_path)
92
- else:
93
- bg = np.ones((600,900,3),dtype=np.uint8)*40
94
- if bg is None: bg = np.ones((600,900,3),dtype=np.uint8)*40
95
- h,w = bg.shape[:2]
96
-
97
- # infer video length from longest track
98
- max_len = max(len(p) for p in tracks.values())
99
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
100
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
101
- writer = cv2.VideoWriter(out_path, fourcc, FPS, (w,h))
102
- font = cv2.FONT_HERSHEY_SIMPLEX
103
-
104
- # render frame-by-frame
105
- for fi in range(max_len):
106
- frame = bg.copy()
107
- for tid, pts in tracks.items():
108
- if fi >= len(pts): continue
109
- cur_pt = pts[fi]
110
- y = cur_pt[1]
111
- zone_idx = get_zone_idx(y, h, len(centers_by_zone))
112
- if y < h*ENTRY_ZONE_RATIO: continue
113
-
114
- # smooth direction using past window
115
- win = pts[max(0,fi-SMOOTH_FRAMES):fi+1]
116
- v = smooth_direction(win)
117
- centers = centers_by_zone[zone_idx]
118
- angles = [angle_between(v,c) for c in centers]
119
- best_angle = min(angles)
120
- conf = angle_to_confidence(best_angle)
121
- label = "OK" if best_angle < ANGLE_THRESHOLD else "WRONG"
122
- color = (0,255,0) if label=="OK" else (0,0,255)
123
-
124
- # draw trajectory so far
125
- for p1,p2 in zip(pts[:fi], pts[1:fi+1]):
126
- cv2.line(frame, tuple(p1.astype(int)), tuple(p2.astype(int)), color, 2)
127
- cv2.circle(frame, tuple(cur_pt.astype(int)), 5, color, -1)
128
- cv2.putText(frame, f"ID:{tid} {label} ({conf}%)",
129
- (int(cur_pt[0])+5, int(cur_pt[1])-5),
130
- font, 0.55, color, 2)
131
-
132
- writer.write(frame)
133
-
134
- writer.release()
135
- return out_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  # ------------------------------------------------------------
138
- # 🖥️ Gradio Interface
139
  # ------------------------------------------------------------
 
 
 
 
 
140
  description_text = """
141
- ### 🚦 Stage 3 – Wrong-Direction Detection (Video Output)
142
- Uses **trajectories (Stage 1)** + **flow model (Stage 2)** to create an annotated MP4:
143
- - Angle-based + temporal smoothing + zone awareness
144
- - Entry-zone gating
145
- - Confidence (%) per vehicle
146
  """
147
 
148
  demo = gr.Interface(
149
- fn=classify_wrong_direction_video,
150
  inputs=[
151
- gr.File(label="Trajectories JSON (Stage 1)"),
152
- gr.File(label="Flow Model JSON (Stage 2)"),
153
- gr.File(label="Optional background frame (.jpg/.png)")
 
 
 
 
 
 
154
  ],
155
- outputs=gr.Video(label="Annotated Video Output"),
156
- title="🚗 Stage 3 — Wrong-Direction Detection (Video Output)",
157
- description=description_text
158
  )
159
 
 
 
 
 
 
160
  if __name__ == "__main__":
161
- demo.launch()
 
1
+ import os, cv2, json, tempfile, zipfile, numpy as np, gradio as gr
2
+ from ultralytics import YOLO
3
+ from filterpy.kalman import KalmanFilter
4
+ from scipy.optimize import linear_sum_assignment
 
 
 
 
5
 
6
  # ------------------------------------------------------------
7
+ # 🔧 Safe-load fix for PyTorch 2.6
8
  # ------------------------------------------------------------
9
+ import torch, ultralytics.nn.tasks as ultralytics_tasks
10
+ torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel])
 
 
 
11
 
12
  # ------------------------------------------------------------
13
+ # ⚙️ YOLO setup
14
  # ------------------------------------------------------------
15
+ MODEL_PATH = "yolov8n.pt"
16
+ model = YOLO(MODEL_PATH)
17
+ VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
18
+
 
 
 
 
 
 
 
 
19
 
20
  # ------------------------------------------------------------
21
+ # 🧩 Kalman tracker
22
  # ------------------------------------------------------------
23
+ class Track:
24
+ def __init__(self, bbox, tid):
25
+ self.id = tid
26
+ self.kf = KalmanFilter(dim_x=4, dim_z=2)
27
+ self.kf.F = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]])
28
+ self.kf.H = np.array([[1,0,0,0],[0,1,0,0]])
29
+ self.kf.P *= 1000.0
30
+ self.kf.R *= 10.0
31
+ self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
32
+ self.trace = []
33
+
34
+ def centroid(self, b):
35
+ x1, y1, x2, y2 = b
36
+ return [(x1+x2)/2, (y1+y2)/2]
37
+
38
+ def predict(self):
39
+ self.kf.predict()
40
+ return self.kf.x[:2].reshape(2)
41
+
42
+ def update(self, b):
43
+ z = np.array(self.centroid(b)).reshape(2,1)
44
+ self.kf.update(z)
45
+ cx, cy = self.kf.x[:2].reshape(2)
46
+ self.trace.append((float(cx), float(cy)))
47
+ return (cx, cy)
48
+
49
 
50
  # ------------------------------------------------------------
51
+ # 🧮 Direction analyzer
52
  # ------------------------------------------------------------
53
+ def analyze_direction(trace, centers):
54
+ if len(trace) < 3:
55
+ return "NA", 1.0
56
+ v = np.array(trace[-1]) - np.array(trace[-3])
57
+ if np.linalg.norm(v) < 1e-6:
58
+ return "NA", 1.0
59
+ v = v / np.linalg.norm(v)
60
+ sims = np.dot(centers, v)
61
+ max_sim = np.max(sims)
62
+ if max_sim < 0:
63
+ return "WRONG", float(max_sim)
64
+ return "OK", float(max_sim)
65
+
66
 
67
  # ------------------------------------------------------------
68
+ # 🧭 Load normalized flow centers
69
  # ------------------------------------------------------------
70
+ def load_flow_centers(flow_json):
71
+ data = json.load(open(flow_json))
72
+ centers = np.array(data["flow_centers"])
73
+ centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
74
+ return centers
75
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # ------------------------------------------------------------
78
+ # 🎥 Process video
79
  # ------------------------------------------------------------
80
+ def process_video(video_path, flow_json, show_only_wrong=False):
81
+ centers = load_flow_centers(flow_json)
82
+ cap = cv2.VideoCapture(video_path)
83
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25
84
+ w, h = int(cap.get(3)), int(cap.get(4))
85
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
87
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
88
+ out = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
89
+
90
+ tracks, next_id, log = [], 0, []
91
+
92
+ while True:
93
+ ret, frame = cap.read()
94
+ if not ret:
95
+ break
96
+
97
+ results = model(frame, verbose=False)[0]
98
+ detections = []
99
+ for box in results.boxes:
100
+ if int(box.cls) in VEHICLE_CLASSES and box.conf > 0.3:
101
+ detections.append(box.xyxy[0].cpu().numpy())
102
+
103
+ # Predict existing
104
+ predicted = [t.predict() for t in tracks]
105
+ predicted = np.array(predicted) if len(predicted) > 0 else np.empty((0,2))
106
+
107
+ # Assign detections to tracks
108
+ assigned = set()
109
+ if len(predicted) > 0 and len(detections) > 0:
110
+ cost = np.zeros((len(predicted), len(detections)))
111
+ for i, p in enumerate(predicted):
112
+ for j, d in enumerate(detections):
113
+ cx, cy = ((d[0]+d[2])/2, (d[1]+d[3])/2)
114
+ cost[i,j] = np.linalg.norm(p - np.array([cx,cy]))
115
+ r, c = linear_sum_assignment(cost)
116
+ for i, j in zip(r, c):
117
+ if cost[i,j] < 80:
118
+ assigned.add(j)
119
+ tracks[i].update(detections[j])
120
+
121
+ # New tracks
122
+ for j, d in enumerate(detections):
123
+ if j not in assigned:
124
+ t = Track(d, next_id)
125
+ next_id += 1
126
+ t.update(d)
127
+ tracks.append(t)
128
+
129
+ # --- 🧩 Draw + Log (toggle support) ---
130
+ for trk in tracks:
131
+ if len(trk.trace) < 3:
132
+ continue
133
+ status, sim = analyze_direction(trk.trace, centers)
134
+
135
+ # Skip OKs if toggle is enabled
136
+ if show_only_wrong and status != "WRONG":
137
+ continue
138
+
139
+ x, y = map(int, trk.trace[-1])
140
+ color = (0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (200,200,200))
141
+ cv2.circle(frame,(x,y),4,color,-1)
142
+ cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),
143
+ cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
144
+ for i in range(1,len(trk.trace)):
145
+ cv2.line(frame,
146
+ (int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
147
+ (int(trk.trace[i][0]),int(trk.trace[i][1])),
148
+ color,1)
149
+
150
+ # Log once per unique vehicle
151
+ if len(trk.trace) > 5 and not any(entry["id"] == trk.id for entry in log):
152
+ log.append({"id": trk.id, "status": status, "cos_sim": round(sim,3)})
153
+
154
+ out.write(frame)
155
+
156
+ cap.release()
157
+ out.release()
158
+
159
+ # Unique summary
160
+ unique_ids = {entry["id"] for entry in log}
161
+ summary = {"vehicles_analyzed": len(unique_ids)}
162
+
163
+ # Create ZIP bundle
164
+ zip_path = tempfile.NamedTemporaryFile(suffix=".zip", delete=False).name
165
+ with zipfile.ZipFile(zip_path, "w") as zf:
166
+ zf.write(out_path, arcname="violation_output.mp4")
167
+ zf.writestr("per_vehicle_log.json", json.dumps(log, indent=2))
168
+ zf.writestr("summary.json", json.dumps(summary, indent=2))
169
+
170
+ return out_path, log, summary, zip_path
171
+
172
 
173
  # ------------------------------------------------------------
174
+ # 🖥️ Gradio interface
175
  # ------------------------------------------------------------
176
+ def run_app(video, flow_file, show_only_wrong):
177
+ vid, log_json, summary, zip_file = process_video(video, flow_file, show_only_wrong)
178
+ return vid, log_json, summary, zip_file
179
+
180
+
181
  description_text = """
182
+ ### 🚦 Wrong-Direction Detection (Stage 3)
183
+ Upload your traffic video and the **flow_stats.json** from Stage 2.
184
+ You can toggle whether to display all detections or only WRONG-direction vehicles.
 
 
185
  """
186
 
187
  demo = gr.Interface(
188
+ fn=run_app,
189
  inputs=[
190
+ gr.Video(label="Upload Traffic Video (.mp4)"),
191
+ gr.File(label="Upload flow_stats.json (Stage 2 Output)"),
192
+ gr.Checkbox(label="Show Only Wrong Labels", value=False)
193
+ ],
194
+ outputs=[
195
+ gr.Video(label="Violation Output Video"),
196
+ gr.JSON(label="Per-Vehicle Log"),
197
+ gr.JSON(label="Summary"),
198
+ gr.File(label="⬇️ Download All Outputs (ZIP)")
199
  ],
200
+ title="🚗 Wrong-Direction Detection – Stage 3 (Toggle + ZIP)",
201
+ description=description_text,
202
+ examples=None,
203
  )
204
 
205
+ # Disable analytics / flagging / SSR
206
+ demo.flagging_mode = "never"
207
+ demo.cache_examples = False
208
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
209
+
210
  if __name__ == "__main__":
211
+ demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)