Files changed (1) hide show
  1. app.py +58 -45
app.py CHANGED
@@ -1,24 +1,25 @@
1
- import gradio as gr
2
- import cv2, os, numpy as np, json, tempfile, time
3
  from ultralytics import YOLO
4
  from filterpy.kalman import KalmanFilter
5
  from scipy.optimize import linear_sum_assignment
6
 
7
  # ------------------------------------------------------------
8
- # โš™๏ธ Load YOLO and flow centers
 
 
 
 
 
 
 
9
  # ------------------------------------------------------------
10
  MODEL_PATH = "yolov8n.pt"
11
  model = YOLO(MODEL_PATH)
12
  VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorbike, bus, truck
13
 
14
- def load_flow_centers(flow_json):
15
- data = json.load(open(flow_json))
16
- centers = np.array(data["flow_centers"])
17
- centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
18
- return centers
19
 
20
  # ------------------------------------------------------------
21
- # ๐Ÿงฉ Simple Kalman Tracker
22
  # ------------------------------------------------------------
23
  class Track:
24
  def __init__(self, bbox, tid):
@@ -31,7 +32,7 @@ class Track:
31
  self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
32
  self.trace = []
33
 
34
- def centroid(self,b):
35
  x1,y1,x2,y2=b
36
  return [(x1+x2)/2,(y1+y2)/2]
37
  def predict(self): self.kf.predict(); return self.kf.x[:2].reshape(2)
@@ -42,47 +43,60 @@ class Track:
42
  self.trace.append((float(cx),float(cy)))
43
  return (cx,cy)
44
 
 
45
  # ------------------------------------------------------------
46
- # ๐Ÿšฆ Wrong-Direction Analyzer
47
  # ------------------------------------------------------------
48
  def analyze_direction(trace, centers):
49
- if len(trace)<3: return "NA",1.0
50
- v = np.array(trace[-1]) - np.array(trace[-3]) # motion vector
51
- if np.linalg.norm(v)<1e-6: return "NA",1.0
 
 
52
  v = v / np.linalg.norm(v)
53
  sims = np.dot(centers, v)
54
  max_sim = np.max(sims)
55
- if max_sim < 0: return "WRONG", float(max_sim)
 
56
  return "OK", float(max_sim)
57
 
 
58
  # ------------------------------------------------------------
59
- # ๐ŸŽฅ Process Video
60
  # ------------------------------------------------------------
 
 
 
 
 
 
61
  def process_video(video_path, flow_json):
62
  centers = load_flow_centers(flow_json)
63
  cap = cv2.VideoCapture(video_path)
64
  fps = cap.get(cv2.CAP_PROP_FPS) or 25
65
  w,h = int(cap.get(3)), int(cap.get(4))
66
 
67
- tmp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
68
- out = cv2.VideoWriter(tmp_out.name, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w,h))
 
69
 
70
- tracks, next_id = [], 0
71
- log = []
72
 
73
  while True:
74
  ret, frame = cap.read()
75
- if not ret: break
 
76
  results = model(frame, verbose=False)[0]
77
  detections=[]
78
  for box in results.boxes:
79
  if int(box.cls) in VEHICLE_CLASSES and box.conf>0.3:
80
  detections.append(box.xyxy[0].cpu().numpy())
81
 
82
- # predict existing tracks
83
- predicted = [trk.predict() for trk in tracks]
84
- predicted = np.array(predicted) if len(predicted)>0 else np.empty((0,2))
85
 
 
86
  assigned=set()
87
  if len(predicted)>0 and len(detections)>0:
88
  cost=np.zeros((len(predicted),len(detections)))
@@ -96,21 +110,21 @@ def process_video(video_path, flow_json):
96
  assigned.add(j)
97
  tracks[i].update(detections[j])
98
 
99
- # new tracks
100
  for j,d in enumerate(detections):
101
  if j not in assigned:
102
- trk=Track(d,next_id); next_id+=1
103
- trk.update(d)
104
- tracks.append(trk)
105
 
106
- # draw
107
  for trk in tracks:
108
  if len(trk.trace)<3: continue
109
  status, sim = analyze_direction(trk.trace, centers)
110
  x,y=map(int,trk.trace[-1])
111
- color=(0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (255,255,255))
112
  cv2.circle(frame,(x,y),4,color,-1)
113
- cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
 
114
  for i in range(1,len(trk.trace)):
115
  cv2.line(frame,(int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
116
  (int(trk.trace[i][0]),int(trk.trace[i][1])),color,1)
@@ -119,12 +133,14 @@ def process_video(video_path, flow_json):
119
  out.write(frame)
120
 
121
  cap.release(); out.release()
 
122
  log_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
123
  with open(log_path,"w") as f: json.dump(log,f,indent=2)
124
- return tmp_out.name, log_path
 
125
 
126
  # ------------------------------------------------------------
127
- # ๐Ÿ–ฅ๏ธ Gradio Interface
128
  # ------------------------------------------------------------
129
  def run_app(video, flow_file):
130
  out_path, log_path = process_video(video, flow_file)
@@ -135,7 +151,7 @@ def run_app(video, flow_file):
135
  description_text = """
136
  ### ๐Ÿšฆ Wrong-Direction Detection (Stage 3)
137
  Uploads your traffic video and the **flow_stats.json** from Stage 2.
138
- Outputs an annotated video with โœ… OK / ๐Ÿšซ WRONG labels per vehicle, plus a JSON log.
139
  """
140
 
141
  example_vid = "10.mp4" if os.path.exists("10.mp4") else None
@@ -143,19 +159,16 @@ example_flow = "flow_stats.json" if os.path.exists("flow_stats.json") else None
143
 
144
  demo = gr.Interface(
145
  fn=run_app,
146
- inputs=[
147
- gr.Video(label="Upload Traffic Video (.mp4)"),
148
- gr.File(label="Upload flow_stats.json (Stage 2 Output)")
149
- ],
150
- outputs=[
151
- gr.Video(label="Violation Output Video"),
152
- gr.JSON(label="Per-Vehicle Log"),
153
- gr.JSON(label="Summary")
154
- ],
155
  title="๐Ÿš— Wrong-Direction Detection โ€“ Stage 3",
156
  description=description_text,
157
- examples=[[example_vid, example_flow]] if example_vid and example_flow else None,
158
  )
159
 
160
  if __name__ == "__main__":
161
- demo.launch()
 
 
1
+ import os, cv2, json, time, tempfile, numpy as np, gradio as gr
 
2
  from ultralytics import YOLO
3
  from filterpy.kalman import KalmanFilter
4
  from scipy.optimize import linear_sum_assignment
5
 
6
  # ------------------------------------------------------------
7
+ # ๐Ÿ”ง Safe-load fix for PyTorch 2.6
8
+ # ------------------------------------------------------------
9
+ import torch
10
+ import ultralytics.nn.tasks as ultralytics_tasks
11
+ torch.serialization.add_safe_globals([ultralytics_tasks.DetectionModel])
12
+
13
+ # ------------------------------------------------------------
14
+ # โš™๏ธ Model + constants
15
  # ------------------------------------------------------------
16
  MODEL_PATH = "yolov8n.pt"
17
  model = YOLO(MODEL_PATH)
18
  VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorbike, bus, truck
19
 
 
 
 
 
 
20
 
21
  # ------------------------------------------------------------
22
+ # ๐Ÿงฉ Kalman Tracker
23
  # ------------------------------------------------------------
24
  class Track:
25
  def __init__(self, bbox, tid):
 
32
  self.kf.x[:2] = np.array(self.centroid(bbox)).reshape(2,1)
33
  self.trace = []
34
 
35
+ def centroid(self, b):
36
  x1,y1,x2,y2=b
37
  return [(x1+x2)/2,(y1+y2)/2]
38
  def predict(self): self.kf.predict(); return self.kf.x[:2].reshape(2)
 
43
  self.trace.append((float(cx),float(cy)))
44
  return (cx,cy)
45
 
46
+
47
  # ------------------------------------------------------------
48
+ # ๐Ÿงฎ Direction check
49
  # ------------------------------------------------------------
50
  def analyze_direction(trace, centers):
51
+ if len(trace) < 3:
52
+ return "NA", 1.0
53
+ v = np.array(trace[-1]) - np.array(trace[-3])
54
+ if np.linalg.norm(v) < 1e-6:
55
+ return "NA", 1.0
56
  v = v / np.linalg.norm(v)
57
  sims = np.dot(centers, v)
58
  max_sim = np.max(sims)
59
+ if max_sim < 0:
60
+ return "WRONG", float(max_sim)
61
  return "OK", float(max_sim)
62
 
63
+
64
  # ------------------------------------------------------------
65
+ # ๐Ÿšฆ Main Processing
66
  # ------------------------------------------------------------
67
+ def load_flow_centers(flow_json):
68
+ data = json.load(open(flow_json))
69
+ centers = np.array(data["flow_centers"])
70
+ centers = centers / (np.linalg.norm(centers, axis=1, keepdims=True) + 1e-6)
71
+ return centers
72
+
73
  def process_video(video_path, flow_json):
74
  centers = load_flow_centers(flow_json)
75
  cap = cv2.VideoCapture(video_path)
76
  fps = cap.get(cv2.CAP_PROP_FPS) or 25
77
  w,h = int(cap.get(3)), int(cap.get(4))
78
 
79
+ out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
80
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
81
+ out = cv2.VideoWriter(out_path, fourcc, fps, (w,h))
82
 
83
+ tracks, next_id, log = [], 0, []
 
84
 
85
  while True:
86
  ret, frame = cap.read()
87
+ if not ret:
88
+ break
89
  results = model(frame, verbose=False)[0]
90
  detections=[]
91
  for box in results.boxes:
92
  if int(box.cls) in VEHICLE_CLASSES and box.conf>0.3:
93
  detections.append(box.xyxy[0].cpu().numpy())
94
 
95
+ # Predict existing
96
+ predicted=[t.predict() for t in tracks]
97
+ predicted=np.array(predicted) if len(predicted)>0 else np.empty((0,2))
98
 
99
+ # Assign detections
100
  assigned=set()
101
  if len(predicted)>0 and len(detections)>0:
102
  cost=np.zeros((len(predicted),len(detections)))
 
110
  assigned.add(j)
111
  tracks[i].update(detections[j])
112
 
113
+ # New tracks
114
  for j,d in enumerate(detections):
115
  if j not in assigned:
116
+ t=Track(d,next_id); next_id+=1
117
+ t.update(d); tracks.append(t)
 
118
 
119
+ # Draw results
120
  for trk in tracks:
121
  if len(trk.trace)<3: continue
122
  status, sim = analyze_direction(trk.trace, centers)
123
  x,y=map(int,trk.trace[-1])
124
+ color=(0,255,0) if status=="OK" else ((0,0,255) if status=="WRONG" else (200,200,200))
125
  cv2.circle(frame,(x,y),4,color,-1)
126
+ cv2.putText(frame,f"ID:{trk.id} {status}",(x-20,y-10),
127
+ cv2.FONT_HERSHEY_SIMPLEX,0.5,color,1)
128
  for i in range(1,len(trk.trace)):
129
  cv2.line(frame,(int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
130
  (int(trk.trace[i][0]),int(trk.trace[i][1])),color,1)
 
133
  out.write(frame)
134
 
135
  cap.release(); out.release()
136
+
137
  log_path = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
138
  with open(log_path,"w") as f: json.dump(log,f,indent=2)
139
+ return out_path, log_path
140
+
141
 
142
  # ------------------------------------------------------------
143
+ # ๐Ÿ–ฅ๏ธ Gradio UI
144
  # ------------------------------------------------------------
145
  def run_app(video, flow_file):
146
  out_path, log_path = process_video(video, flow_file)
 
151
  description_text = """
152
  ### ๐Ÿšฆ Wrong-Direction Detection (Stage 3)
153
  Uploads your traffic video and the **flow_stats.json** from Stage 2.
154
+ Outputs an annotated video with โœ… OK / ๐Ÿšซ WRONG labels and a JSON log.
155
  """
156
 
157
  example_vid = "10.mp4" if os.path.exists("10.mp4") else None
 
159
 
160
  demo = gr.Interface(
161
  fn=run_app,
162
+ inputs=[gr.Video(label="Upload Traffic Video (.mp4)"),
163
+ gr.File(label="Upload flow_stats.json (Stage 2 Output)")],
164
+ outputs=[gr.Video(label="Violation Output Video"),
165
+ gr.JSON(label="Per-Vehicle Log"),
166
+ gr.JSON(label="Summary")],
 
 
 
 
167
  title="๐Ÿš— Wrong-Direction Detection โ€“ Stage 3",
168
  description=description_text,
169
+ examples=None, # disables example caching
170
  )
171
 
172
  if __name__ == "__main__":
173
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
174
+ demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, show_api=False)