Files changed (1) hide show
  1. app.py +134 -30
app.py CHANGED
@@ -23,6 +23,61 @@ model = YOLO(MODEL_PATH)
23
  VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # ---------------------------------------------------------
27
  # 🔍 SIMPLE KALMAN TRACKER
28
  # ---------------------------------------------------------
@@ -38,23 +93,28 @@ class Track:
38
  [0,1,0,0]])
39
  self.kf.P *= 1000.0
40
  self.kf.R *= 10.0
41
- self.kf.x[:2] = np.array(self.get_centroid(bbox)).reshape(2,1)
 
 
 
42
  self.trace = []
 
43
 
44
- def get_centroid(self,bbox):
45
- x1,y1,x2,y2 = bbox
46
- return [(x1+x2)/2,(y1+y2)/2]
47
 
48
  def predict(self):
49
  self.kf.predict()
50
  return self.kf.x[:2].reshape(2)
51
 
52
- def update(self,bbox):
 
 
53
  z = np.array(self.get_centroid(bbox)).reshape(2,1)
54
  self.kf.update(z)
55
- cx,cy = self.kf.x[:2].reshape(2)
56
- self.trace.append((float(cx),float(cy)))
57
- return (cx,cy)
58
 
59
 
60
  # ---------------------------------------------------------
@@ -75,7 +135,13 @@ def process_video(video_path):
75
  frame_count = 0
76
 
77
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
78
- pbar = tqdm(total=total_frames if total_frames>0 else 100, desc="Processing")
 
 
 
 
 
 
79
  while True:
80
  ret, frame = cap.read()
81
  if not ret:
@@ -94,22 +160,49 @@ def process_video(video_path):
94
  predicted = [trk.predict() for trk in tracks]
95
  predicted = np.array(predicted) if predicted else np.empty((0,2))
96
 
97
- # --- ASSIGN DETECTIONS ---
98
  assigned = set()
 
 
99
  if len(predicted) > 0 and len(detections) > 0:
100
- cost = np.zeros((len(predicted), len(detections)))
101
- for i, trk in enumerate(predicted):
 
 
 
102
  for j, det in enumerate(detections):
103
- cx, cy = ((det[0]+det[2])/2, (det[1]+det[3])/2)
104
- cost[i, j] = np.linalg.norm(trk - np.array([cx, cy]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  row_ind, col_ind = linear_sum_assignment(cost)
106
  for r, c in zip(row_ind, col_ind):
107
- if cost[r, c] < 80: # distance threshold
 
108
  assigned.add(c)
109
  tracks[r].update(detections[c])
110
 
111
- # --- NEW TRACKS ---
112
- for j, det in enumerate(detections):
113
  if j not in assigned:
114
  trk = Track(det, next_id)
115
  next_id += 1
@@ -118,15 +211,17 @@ def process_video(video_path):
118
 
119
  # --- DRAW OUTPUT ---
120
  for trk in tracks:
121
- if len(trk.trace) < 2:
122
  continue
123
- x,y = map(int,trk.trace[-1])
124
- cv2.circle(frame,(x,y),3,(0,255,0),-1)
125
- cv2.putText(frame,f"ID:{trk.id}",(x-10,y-10),cv2.FONT_HERSHEY_SIMPLEX,0.4,(0,255,0),1)
126
- for i in range(1,len(trk.trace)):
127
- cv2.line(frame,(int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
128
- (int(trk.trace[i][0]),int(trk.trace[i][1])),
129
- (0,255,0),1)
 
 
130
  trajectories[trk.id] = trk.trace
131
 
132
  out.write(frame)
@@ -161,13 +256,17 @@ def run_app(video_file):
161
  out_path, json_path = process_video(temp_path)
162
  end = time.time()
163
 
 
 
 
 
164
  summary = {
165
- "total_time_sec": round(end-start,1),
166
- "num_tracks": len(json.load(open(json_path))),
167
- "avg_fps": round(cv2.VideoCapture(temp_path).get(cv2.CAP_PROP_FPS),2)
168
  }
169
 
170
- return out_path, json.load(open(json_path)), summary
171
 
172
 
173
  # ---------------------------------------------------------
@@ -180,6 +279,11 @@ This app detects & tracks vehicles using YOLOv8 + Kalman Filter, and outputs:
180
  - Annotated tracking video
181
  - JSON trajectories
182
  - Summary stats for dominant-flow analysis
 
 
 
 
 
183
  """
184
 
185
  example_video = "assets/examples/sample1.mp4" if os.path.exists("assets/examples/sample1.mp4") else None
@@ -198,4 +302,4 @@ demo = gr.Interface(
198
  )
199
 
200
  if __name__ == "__main__":
201
- demo.launch()
 
23
  VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
24
 
25
 
26
+ # ---------------------------------------------------------
27
+ # 🔧 HELPER FUNCTIONS
28
+ # ---------------------------------------------------------
29
+ def bbox_centroid(bbox):
30
+ """xyxy -> (cx, cy)"""
31
+ x1, y1, x2, y2 = bbox
32
+ return ( (x1 + x2) / 2.0, (y1 + y2) / 2.0 )
33
+
34
+ def iou(boxA, boxB):
35
+ """Compute IoU between two xyxy boxes."""
36
+ xA = max(boxA[0], boxB[0])
37
+ yA = max(boxA[1], boxB[1])
38
+ xB = min(boxA[2], boxB[2])
39
+ yB = min(boxA[3], boxB[3])
40
+
41
+ interW = max(0, xB - xA)
42
+ interH = max(0, yB - yA)
43
+ interArea = interW * interH
44
+
45
+ if interArea <= 0:
46
+ return 0.0
47
+
48
+ boxAArea = max(0, (boxA[2] - boxA[0])) * max(0, (boxA[3] - boxA[1]))
49
+ boxBArea = max(0, (boxB[2] - boxB[0])) * max(0, (boxB[3] - boxB[1]))
50
+
51
+ denom = float(boxAArea + boxBArea - interArea)
52
+ if denom <= 0:
53
+ return 0.0
54
+
55
+ return interArea / denom
56
+
57
+
58
+ def direction_penalty(track, det_cx, det_cy, lambda_dir=30.0):
59
+ """
60
+ Penalize assignments that imply a big direction flip.
61
+ 0 = same direction, larger penalty for opposite direction.
62
+ """
63
+ if len(track.trace) < 2:
64
+ return 0.0
65
+
66
+ x_prev, y_prev = track.trace[-2]
67
+ x_last, y_last = track.trace[-1]
68
+ v_prev = np.array([x_last - x_prev, y_last - y_prev], dtype=np.float32)
69
+ v_new = np.array([det_cx - x_last, det_cy - y_last], dtype=np.float32)
70
+
71
+ norm_prev = np.linalg.norm(v_prev)
72
+ norm_new = np.linalg.norm(v_new)
73
+ if norm_prev < 1e-3 or norm_new < 1e-3:
74
+ return 0.0
75
+
76
+ cos_sim = float(np.dot(v_prev, v_new) / (norm_prev * norm_new + 1e-6))
77
+ # cos_sim in [-1, 1]; we want 0 penalty when cos_sim ~ 1
78
+ return (1.0 - cos_sim) * lambda_dir
79
+
80
+
81
  # ---------------------------------------------------------
82
  # 🔍 SIMPLE KALMAN TRACKER
83
  # ---------------------------------------------------------
 
93
  [0,1,0,0]])
94
  self.kf.P *= 1000.0
95
  self.kf.R *= 10.0
96
+
97
+ cx, cy = bbox_centroid(bbox)
98
+ self.kf.x[:2] = np.array([[cx],[cy]])
99
+
100
  self.trace = []
101
+ self.bbox = np.array(bbox, dtype=np.float32) # store last bbox
102
 
103
+ def get_centroid(self, bbox):
104
+ return bbox_centroid(bbox)
 
105
 
106
  def predict(self):
107
  self.kf.predict()
108
  return self.kf.x[:2].reshape(2)
109
 
110
+ def update(self, bbox):
111
+ """Update KF with new bbox measurement and store trace + bbox."""
112
+ self.bbox = np.array(bbox, dtype=np.float32)
113
  z = np.array(self.get_centroid(bbox)).reshape(2,1)
114
  self.kf.update(z)
115
+ cx, cy = self.kf.x[:2].reshape(2)
116
+ self.trace.append((float(cx), float(cy)))
117
+ return (cx, cy)
118
 
119
 
120
  # ---------------------------------------------------------
 
135
  frame_count = 0
136
 
137
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
138
+ pbar = tqdm(total=total_frames if total_frames > 0 else 100, desc="Processing")
139
+
140
+ # Matching hyperparameters
141
+ MAX_DIST = 120.0 # hard gate on centroid distance
142
+ LAMBDA_IOU = 20.0 # weight for IoU term in cost
143
+ MIN_IOU_FOR_BONUS = 0.05 # if IoU below this, essentially no bonus
144
+
145
  while True:
146
  ret, frame = cap.read()
147
  if not ret:
 
160
  predicted = [trk.predict() for trk in tracks]
161
  predicted = np.array(predicted) if predicted else np.empty((0,2))
162
 
 
163
  assigned = set()
164
+
165
+ # --- ASSIGN DETECTIONS TO TRACKS ---
166
  if len(predicted) > 0 and len(detections) > 0:
167
+ detections = np.array(detections, dtype=np.float32)
168
+ cost = np.full((len(predicted), len(detections)), 1e6, dtype=np.float32)
169
+
170
+ for i, pred_centroid in enumerate(predicted):
171
+ trk = tracks[i]
172
  for j, det in enumerate(detections):
173
+ cx, cy = bbox_centroid(det)
174
+ dist = np.linalg.norm(pred_centroid - np.array([cx, cy], dtype=np.float32))
175
+
176
+ # Hard distance gate: don't allow crazy jumps
177
+ if dist > MAX_DIST:
178
+ continue
179
+
180
+ # IoU term – prefer boxes overlapping the previous one
181
+ if trk.bbox is not None:
182
+ iou_val = iou(trk.bbox, det)
183
+ else:
184
+ iou_val = 0.0
185
+
186
+ if iou_val < MIN_IOU_FOR_BONUS:
187
+ iou_val = 0.0
188
+
189
+ dir_pen = direction_penalty(trk, cx, cy, lambda_dir=30.0)
190
+
191
+ # Final cost: lower is better
192
+ # - dist drives proximity
193
+ # - (1 - iou_val) penalizes mismatched shapes/positions
194
+ # - dir_pen penalizes sudden direction flips
195
+ cost[i, j] = dist + (1.0 - iou_val) * LAMBDA_IOU + dir_pen
196
+
197
  row_ind, col_ind = linear_sum_assignment(cost)
198
  for r, c in zip(row_ind, col_ind):
199
+ # Reject matches that are still effectively "too bad"
200
+ if cost[r, c] < 1e5: # anything left at 1e6 was invalid
201
  assigned.add(c)
202
  tracks[r].update(detections[c])
203
 
204
+ # --- NEW TRACKS FOR UNASSIGNED DETECTIONS ---
205
+ for j, det in enumerate(detections if len(predicted) > 0 else detections):
206
  if j not in assigned:
207
  trk = Track(det, next_id)
208
  next_id += 1
 
211
 
212
  # --- DRAW OUTPUT ---
213
  for trk in tracks:
214
+ if len(trk.trace) < 2:
215
  continue
216
+ x, y = map(int, trk.trace[-1])
217
+ cv2.circle(frame, (x, y), 3, (0, 255, 0), -1)
218
+ cv2.putText(frame, f"ID:{trk.id}", (x - 10, y - 10),
219
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
220
+ for i in range(1, len(trk.trace)):
221
+ cv2.line(frame,
222
+ (int(trk.trace[i-1][0]), int(trk.trace[i-1][1])),
223
+ (int(trk.trace[i][0]), int(trk.trace[i][1])),
224
+ (0, 255, 0), 1)
225
  trajectories[trk.id] = trk.trace
226
 
227
  out.write(frame)
 
256
  out_path, json_path = process_video(temp_path)
257
  end = time.time()
258
 
259
+ with open(json_path, "r") as f:
260
+ traj_data = json.load(f)
261
+
262
+ # avg_fps here = original video FPS (processing FPS will differ)
263
  summary = {
264
+ "total_time_sec": round(end - start, 1),
265
+ "num_tracks": len(traj_data),
266
+ "avg_fps": round(cv2.VideoCapture(temp_path).get(cv2.CAP_PROP_FPS) or 25, 2)
267
  }
268
 
269
+ return out_path, traj_data, summary
270
 
271
 
272
  # ---------------------------------------------------------
 
279
  - Annotated tracking video
280
  - JSON trajectories
281
  - Summary stats for dominant-flow analysis
282
+
283
+ 🔧 Tracking is enhanced with:
284
+ - Kalman motion model
285
+ - Distance + IoU + direction-aware matching
286
+ to reduce ID swaps when vehicles overtake or are very close.
287
  """
288
 
289
  example_video = "assets/examples/sample1.mp4" if os.path.exists("assets/examples/sample1.mp4") else None
 
302
  )
303
 
304
  if __name__ == "__main__":
305
+ demo.launch()