Files changed (1) hide show
  1. app.py +134 -86
app.py CHANGED
@@ -1,24 +1,24 @@
1
- import gradio as gr
2
- import cv2, os, numpy as np, json, tempfile, time
3
  from ultralytics import YOLO
4
  from filterpy.kalman import KalmanFilter
5
  from scipy.optimize import linear_sum_assignment
6
 
7
  # ------------------------------------------------------------
8
- # ⚙️ Load YOLO and flow centers
 
 
 
 
 
 
9
  # ------------------------------------------------------------
10
  MODEL_PATH = "yolov8n.pt"
11
  model = YOLO(MODEL_PATH)
12
- VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorbike, bus, truck
13
 
14
- def load_flow_centers(flow_json):
15
- data = json.load(open(flow_json))
16
- centers = np.array(data["flow_centers"])
17
- centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
18
- return centers
19
 
20
  # ------------------------------------------------------------
21
- # 🧩 Simple Kalman Tracker
22
  # ------------------------------------------------------------
23
  class Track:
24
  def __init__(self, bbox, tid):
@@ -31,133 +31,181 @@ class Track:
31
  self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
32
  self.trace = []
33
 
34
- def centroid(self,b):
35
- x1,y1,x2,y2=b
36
- return [(x1+x2)/2,(y1+y2)/2]
37
- def predict(self): self.kf.predict(); return self.kf.x[:2].reshape(2)
38
- def update(self,b):
39
- z=np.array(self.centroid(b)).reshape(2,1)
 
 
 
 
40
  self.kf.update(z)
41
- cx,cy=self.kf.x[:2].reshape(2)
42
- self.trace.append((float(cx),float(cy)))
43
- return (cx,cy)
 
44
 
45
  # ------------------------------------------------------------
46
- # 🚦 Wrong-Direction Analyzer
47
  # ------------------------------------------------------------
48
  def analyze_direction(trace, centers):
49
- if len(trace)<3: return "NA",1.0
50
- v = np.array(trace[-1]) - np.array(trace[-3]) # motion vector
51
- if np.linalg.norm(v)<1e-6: return "NA",1.0
 
 
52
  v = v / np.linalg.norm(v)
53
  sims = np.dot(centers, v)
54
  max_sim = np.max(sims)
55
- if max_sim < 0: return "WRONG", float(max_sim)
 
56
  return "OK", float(max_sim)
57
 
 
 
 
 
 
 
 
 
 
 
 
58
  # ------------------------------------------------------------
59
- # 🎥 Process Video
60
  # ------------------------------------------------------------
61
- def process_video(video_path, flow_json):
62
  centers = load_flow_centers(flow_json)
63
  cap = cv2.VideoCapture(video_path)
64
  fps = cap.get(cv2.CAP_PROP_FPS) or 25
65
- w,h = int(cap.get(3)), int(cap.get(4))
66
 
67
- tmp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
68
- out = cv2.VideoWriter(tmp_out.name, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w,h))
 
69
 
70
- tracks, next_id = [], 0
71
- log = []
72
 
73
  while True:
74
  ret, frame = cap.read()
75
- if not ret: break
 
 
76
  results = model(frame, verbose=False)[0]
77
- detections=[]
78
  for box in results.boxes:
79
- if int(box.cls) in VEHICLE_CLASSES and box.conf>0.3:
80
  detections.append(box.xyxy[0].cpu().numpy())
81
 
82
- # predict existing tracks
83
- predicted = [trk.predict() for trk in tracks]
84
- predicted = np.array(predicted) if len(predicted)>0 else np.empty((0,2))
85
-
86
- assigned=set()
87
- if len(predicted)>0 and len(detections)>0:
88
- cost=np.zeros((len(predicted),len(detections)))
89
- for i,p in enumerate(predicted):
90
- for j,d in enumerate(detections):
91
- cx,cy=((d[0]+d[2])/2,(d[1]+d[3])/2)
92
- cost[i,j]=np.linalg.norm(p-np.array([cx,cy]))
93
- r,c=linear_sum_assignment(cost)
94
- for i,j in zip(r,c):
95
- if cost[i,j]<80:
 
96
  assigned.add(j)
97
  tracks[i].update(detections[j])
98
 
99
- # new tracks
100
- for j,d in enumerate(detections):
101
  if j not in assigned:
102
- trk=Track(d,next_id); next_id+=1
103
- trk.update(d)
104
- tracks.append(trk)
 
105
 
106
- # draw
107
  for trk in tracks:
108
- if len(trk.trace)<3: continue
 
109
  status, sim = analyze_direction(trk.trace, centers)
110
- x,y=map(int,trk.trace[-1])
111
- color=(0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (255,255,255))
 
 
 
 
 
112
  cv2.circle(frame,(x,y),4,color,-1)
113
- cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
 
114
  for i in range(1,len(trk.trace)):
115
- cv2.line(frame,(int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
116
- (int(trk.trace[i][0]),int(trk.trace[i][1])),color,1)
117
- log.append({"id":trk.id,"status":status,"cos_sim":round(sim,3)})
 
 
 
 
 
118
 
119
  out.write(frame)
120
 
121
- cap.release(); out.release()
122
- log_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
123
- with open(log_path,"w") as f: json.dump(log,f,indent=2)
124
- return tmp_out.name, log_path
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # ------------------------------------------------------------
127
- # 🖥️ Gradio Interface
128
  # ------------------------------------------------------------
129
- def run_app(video, flow_file):
130
- out_path, log_path = process_video(video, flow_file)
131
- log_data = json.load(open(log_path))
132
- summary = {"vehicles_analyzed": len(log_data)}
133
- return out_path, log_data, summary
134
 
135
  description_text = """
136
  ### 🚦 Wrong-Direction Detection (Stage 3)
137
- Uploads your traffic video and the **flow_stats.json** from Stage 2.
138
- Outputs an annotated video with OK / 🚫 WRONG labels per vehicle, plus a JSON log.
139
  """
140
 
141
- example_vid = "10.mp4" if os.path.exists("10.mp4") else None
142
- example_flow = "flow_stats.json" if os.path.exists("flow_stats.json") else None
143
-
144
  demo = gr.Interface(
145
  fn=run_app,
146
- inputs=[gr.Video(label="Upload Traffic Video (.mp4)"),
147
- gr.File(label="Upload flow_stats.json (Stage 2 Output)")],
148
- outputs=[gr.Video(label="Violation Output Video"),
149
- gr.JSON(label="Per-Vehicle Log"),
150
- gr.JSON(label="Summary")],
151
- title="🚗 Wrong-Direction Detection – Stage 3",
 
 
 
 
 
 
152
  description=description_text,
153
- examples=None, # disable example caching
154
  )
155
 
156
- # 🔧 disable all caching/flagging that causes _csv.Error
157
  demo.flagging_mode = "never"
158
  demo.cache_examples = False
 
159
 
160
  if __name__ == "__main__":
161
- os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
162
  demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)
163
-
 
1
+ import os, cv2, json, tempfile, zipfile, numpy as np, gradio as gr
 
2
  from ultralytics import YOLO
3
  from filterpy.kalman import KalmanFilter
4
  from scipy.optimize import linear_sum_assignment
5
 
6
  # ------------------------------------------------------------
7
+ # 🔧 Safe-load fix for PyTorch 2.6
8
+ # ------------------------------------------------------------
9
+ import torch, ultralytics.nn.tasks as ultralytics_tasks
10
+ torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel])
11
+
12
+ # ------------------------------------------------------------
13
+ # ⚙️ YOLO setup
14
  # ------------------------------------------------------------
15
  MODEL_PATH = "yolov8n.pt"
16
  model = YOLO(MODEL_PATH)
17
+ VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
18
 
 
 
 
 
 
19
 
20
  # ------------------------------------------------------------
21
+ # 🧩 Kalman tracker
22
  # ------------------------------------------------------------
23
  class Track:
24
  def __init__(self, bbox, tid):
 
31
  self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
32
  self.trace = []
33
 
34
+ def centroid(self, b):
35
+ x1, y1, x2, y2 = b
36
+ return [(x1+x2)/2, (y1+y2)/2]
37
+
38
+ def predict(self):
39
+ self.kf.predict()
40
+ return self.kf.x[:2].reshape(2)
41
+
42
+ def update(self, b):
43
+ z = np.array(self.centroid(b)).reshape(2,1)
44
  self.kf.update(z)
45
+ cx, cy = self.kf.x[:2].reshape(2)
46
+ self.trace.append((float(cx), float(cy)))
47
+ return (cx, cy)
48
+
49
 
50
  # ------------------------------------------------------------
51
+ # 🧮 Direction analyzer
52
  # ------------------------------------------------------------
53
  def analyze_direction(trace, centers):
54
+ if len(trace) < 3:
55
+ return "NA", 1.0
56
+ v = np.array(trace[-1]) - np.array(trace[-3])
57
+ if np.linalg.norm(v) < 1e-6:
58
+ return "NA", 1.0
59
  v = v / np.linalg.norm(v)
60
  sims = np.dot(centers, v)
61
  max_sim = np.max(sims)
62
+ if max_sim < 0:
63
+ return "WRONG", float(max_sim)
64
  return "OK", float(max_sim)
65
 
66
+
67
+ # ------------------------------------------------------------
68
+ # 🧭 Load normalized flow centers
69
+ # ------------------------------------------------------------
70
+ def load_flow_centers(flow_json):
71
+ data = json.load(open(flow_json))
72
+ centers = np.array(data["flow_centers"])
73
+ centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
74
+ return centers
75
+
76
+
77
  # ------------------------------------------------------------
78
+ # 🎥 Process video
79
  # ------------------------------------------------------------
80
+ def process_video(video_path, flow_json, show_only_wrong=False):
81
  centers = load_flow_centers(flow_json)
82
  cap = cv2.VideoCapture(video_path)
83
  fps = cap.get(cv2.CAP_PROP_FPS) or 25
84
+ w, h = int(cap.get(3)), int(cap.get(4))
85
 
86
+ out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
87
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
88
+ out = cv2.VideoWriter(out_path, fourcc, fps, (w, h))
89
 
90
+ tracks, next_id, log = [], 0, []
 
91
 
92
  while True:
93
  ret, frame = cap.read()
94
+ if not ret:
95
+ break
96
+
97
  results = model(frame, verbose=False)[0]
98
+ detections = []
99
  for box in results.boxes:
100
+ if int(box.cls) in VEHICLE_CLASSES and box.conf > 0.3:
101
  detections.append(box.xyxy[0].cpu().numpy())
102
 
103
+ # Predict existing
104
+ predicted = [t.predict() for t in tracks]
105
+ predicted = np.array(predicted) if len(predicted) > 0 else np.empty((0,2))
106
+
107
+ # Assign detections to tracks
108
+ assigned = set()
109
+ if len(predicted) > 0 and len(detections) > 0:
110
+ cost = np.zeros((len(predicted), len(detections)))
111
+ for i, p in enumerate(predicted):
112
+ for j, d in enumerate(detections):
113
+ cx, cy = ((d[0]+d[2])/2, (d[1]+d[3])/2)
114
+ cost[i,j] = np.linalg.norm(p - np.array([cx,cy]))
115
+ r, c = linear_sum_assignment(cost)
116
+ for i, j in zip(r, c):
117
+ if cost[i,j] < 80:
118
  assigned.add(j)
119
  tracks[i].update(detections[j])
120
 
121
+ # New tracks
122
+ for j, d in enumerate(detections):
123
  if j not in assigned:
124
+ t = Track(d, next_id)
125
+ next_id += 1
126
+ t.update(d)
127
+ tracks.append(t)
128
 
129
+ # --- 🧩 Draw + Log (toggle support) ---
130
  for trk in tracks:
131
+ if len(trk.trace) < 3:
132
+ continue
133
  status, sim = analyze_direction(trk.trace, centers)
134
+
135
+ # Skip OKs if toggle is enabled
136
+ if show_only_wrong and status != "WRONG":
137
+ continue
138
+
139
+ x, y = map(int, trk.trace[-1])
140
+ color = (0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (200,200,200))
141
  cv2.circle(frame,(x,y),4,color,-1)
142
+ cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),
143
+ cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
144
  for i in range(1,len(trk.trace)):
145
+ cv2.line(frame,
146
+ (int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
147
+ (int(trk.trace[i][0]),int(trk.trace[i][1])),
148
+ color,1)
149
+
150
+ # Log once per unique vehicle
151
+ if len(trk.trace) > 5 and not any(entry["id"] == trk.id for entry in log):
152
+ log.append({"id": trk.id, "status": status, "cos_sim": round(sim,3)})
153
 
154
  out.write(frame)
155
 
156
+ cap.release()
157
+ out.release()
158
+
159
+ # Unique summary
160
+ unique_ids = {entry["id"] for entry in log}
161
+ summary = {"vehicles_analyzed": len(unique_ids)}
162
+
163
+ # Create ZIP bundle
164
+ zip_path = tempfile.NamedTemporaryFile(suffix=".zip", delete=False).name
165
+ with zipfile.ZipFile(zip_path, "w") as zf:
166
+ zf.write(out_path, arcname="violation_output.mp4")
167
+ zf.writestr("per_vehicle_log.json", json.dumps(log, indent=2))
168
+ zf.writestr("summary.json", json.dumps(summary, indent=2))
169
+
170
+ return out_path, log, summary, zip_path
171
+
172
 
173
  # ------------------------------------------------------------
174
+ # 🖥️ Gradio interface
175
  # ------------------------------------------------------------
176
+ def run_app(video, flow_file, show_only_wrong):
177
+ vid, log_json, summary, zip_file = process_video(video, flow_file, show_only_wrong)
178
+ return vid, log_json, summary, zip_file
179
+
 
180
 
181
  description_text = """
182
  ### 🚦 Wrong-Direction Detection (Stage 3)
183
+ Upload your traffic video and the **flow_stats.json** from Stage 2.
184
+ You can toggle whether to display all detections or only WRONG-direction vehicles.
185
  """
186
 
 
 
 
187
  demo = gr.Interface(
188
  fn=run_app,
189
+ inputs=[
190
+ gr.Video(label="Upload Traffic Video (.mp4)"),
191
+ gr.File(label="Upload flow_stats.json (Stage 2 Output)"),
192
+ gr.Checkbox(label="Show Only Wrong Labels", value=False)
193
+ ],
194
+ outputs=[
195
+ gr.Video(label="Violation Output Video"),
196
+ gr.JSON(label="Per-Vehicle Log"),
197
+ gr.JSON(label="Summary"),
198
+ gr.File(label="⬇️ Download All Outputs (ZIP)")
199
+ ],
200
+ title="🚗 Wrong-Direction Detection – Stage 3 (Toggle + ZIP)",
201
  description=description_text,
202
+ examples=None,
203
  )
204
 
205
+ # Disable analytics / flagging / SSR
206
  demo.flagging_mode = "never"
207
  demo.cache_examples = False
208
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
209
 
210
  if __name__ == "__main__":
 
211
  demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)