nishanth-saka commited on
Commit
1ba0e3f
·
verified ·
1 Parent(s): d257659
Files changed (1) hide show
  1. app.py +193 -0
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2, os, numpy as np, tempfile, time, json
3
+ from ultralytics import YOLO
4
+ from filterpy.kalman import KalmanFilter
5
+ from scipy.optimize import linear_sum_assignment
6
+ from tqdm import tqdm
7
+
8
+ # ---------------------------------------------------------
9
+ # ⚙️ INIT
10
+ # ---------------------------------------------------------
11
+ MODEL_PATH = "yolov8n.pt"
12
+ model = YOLO(MODEL_PATH)
13
+
14
+ # Vehicle classes from COCO
15
+ VEHICLE_CLASSES = [2, 3, 5, 7] # car, motorcycle, bus, truck
16
+
17
+
18
+ # ---------------------------------------------------------
19
+ # 🔍 SIMPLE KALMAN TRACKER
20
+ # ---------------------------------------------------------
21
+ class Track:
22
+ def __init__(self, bbox, track_id):
23
+ self.id = track_id
24
+ self.kf = KalmanFilter(dim_x=4, dim_z=2)
25
+ self.kf.F = np.array([[1,0,1,0],
26
+ [0,1,0,1],
27
+ [0,0,1,0],
28
+ [0,0,0,1]])
29
+ self.kf.H = np.array([[1,0,0,0],
30
+ [0,1,0,0]])
31
+ self.kf.P *= 1000.0
32
+ self.kf.R *= 10.0
33
+ self.kf.x[:2] = np.array(self.get_centroid(bbox)).reshape(2,1)
34
+ self.trace = []
35
+
36
+ def get_centroid(self,bbox):
37
+ x1,y1,x2,y2 = bbox
38
+ return [(x1+x2)/2,(y1+y2)/2]
39
+
40
+ def predict(self):
41
+ self.kf.predict()
42
+ return self.kf.x[:2].reshape(2)
43
+
44
+ def update(self,bbox):
45
+ z = np.array(self.get_centroid(bbox)).reshape(2,1)
46
+ self.kf.update(z)
47
+ cx,cy = self.kf.x[:2].reshape(2)
48
+ self.trace.append((float(cx),float(cy)))
49
+ return (cx,cy)
50
+
51
+
52
+ # ---------------------------------------------------------
53
+ # 🎥 MAIN PROCESSOR
54
+ # ---------------------------------------------------------
55
+ def process_video(video_path):
56
+ cap = cv2.VideoCapture(video_path)
57
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25
58
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
59
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
60
+
61
+ temp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
62
+ out = cv2.VideoWriter(temp_out.name, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
63
+
64
+ tracks = []
65
+ next_id = 0
66
+ trajectories = {}
67
+ frame_count = 0
68
+
69
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
70
+ pbar = tqdm(total=total_frames if total_frames>0 else 100, desc="Processing")
71
+ while True:
72
+ ret, frame = cap.read()
73
+ if not ret:
74
+ break
75
+ frame_count += 1
76
+
77
+ # --- YOLO DETECTION ---
78
+ results = model(frame, verbose=False)[0]
79
+ detections = []
80
+ for box in results.boxes:
81
+ cls = int(box.cls)
82
+ if cls in VEHICLE_CLASSES and box.conf > 0.3:
83
+ detections.append(box.xyxy[0].cpu().numpy())
84
+
85
+ # --- PREDICT EXISTING TRACKS ---
86
+ predicted = [trk.predict() for trk in tracks]
87
+ predicted = np.array(predicted) if predicted else np.empty((0,2))
88
+
89
+ # --- ASSIGN DETECTIONS ---
90
+ assigned = set()
91
+ if len(predicted) > 0 and len(detections) > 0:
92
+ cost = np.zeros((len(predicted), len(detections)))
93
+ for i, trk in enumerate(predicted):
94
+ for j, det in enumerate(detections):
95
+ cx, cy = ((det[0]+det[2])/2, (det[1]+det[3])/2)
96
+ cost[i, j] = np.linalg.norm(trk - np.array([cx, cy]))
97
+ row_ind, col_ind = linear_sum_assignment(cost)
98
+ for r, c in zip(row_ind, col_ind):
99
+ if cost[r, c] < 80: # distance threshold
100
+ assigned.add(c)
101
+ tracks[r].update(detections[c])
102
+
103
+ # --- NEW TRACKS ---
104
+ for j, det in enumerate(detections):
105
+ if j not in assigned:
106
+ trk = Track(det, next_id)
107
+ next_id += 1
108
+ trk.update(det)
109
+ tracks.append(trk)
110
+
111
+ # --- DRAW OUTPUT ---
112
+ for trk in tracks:
113
+ if len(trk.trace) < 2:
114
+ continue
115
+ x,y = map(int,trk.trace[-1])
116
+ cv2.circle(frame,(x,y),3,(0,255,0),-1)
117
+ cv2.putText(frame,f"ID:{trk.id}",(x-10,y-10),cv2.FONT_HERSHEY_SIMPLEX,0.4,(0,255,0),1)
118
+ for i in range(1,len(trk.trace)):
119
+ cv2.line(frame,(int(trk.trace[i-1][0]),int(trk.trace[i-1][1])),
120
+ (int(trk.trace[i][0]),int(trk.trace[i][1])),
121
+ (0,255,0),1)
122
+ trajectories[trk.id] = trk.trace
123
+
124
+ out.write(frame)
125
+ pbar.update(1)
126
+
127
+ cap.release()
128
+ out.release()
129
+ pbar.close()
130
+
131
+ # Save trajectories JSON
132
+ traj_json = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
133
+ with open(traj_json.name, "w") as f:
134
+ json.dump(trajectories, f)
135
+
136
+ return temp_out.name, traj_json.name
137
+
138
+
139
+ # ---------------------------------------------------------
140
+ # 📤 WRAPPER FOR GRADIO
141
+ # ---------------------------------------------------------
142
+ def run_app(video_file):
143
+ # Copy uploaded video to temp path
144
+ temp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
145
+ if isinstance(video_file, dict) and "name" in video_file:
146
+ src_path = video_file["name"]
147
+ else:
148
+ src_path = video_file
149
+ with open(src_path, "rb") as src, open(temp_path, "wb") as dst:
150
+ dst.write(src.read())
151
+
152
+ start = time.time()
153
+ out_path, json_path = process_video(temp_path)
154
+ end = time.time()
155
+
156
+ summary = {
157
+ "total_time_sec": round(end-start,1),
158
+ "num_tracks": len(json.load(open(json_path))),
159
+ "avg_fps": round(cv2.VideoCapture(temp_path).get(cv2.CAP_PROP_FPS),2)
160
+ }
161
+
162
+ return out_path, json.load(open(json_path)), summary
163
+
164
+
165
+ # ---------------------------------------------------------
166
+ # 🖥️ GRADIO INTERFACE
167
+ # ---------------------------------------------------------
168
+ description_text = """
169
+ ### 🚦 Dominant Flow Tracker (Stage 1)
170
+ Upload or select a sample traffic video below.
171
+ This app detects & tracks vehicles using YOLOv8 + Kalman Filter, and outputs:
172
+ - Annotated tracking video
173
+ - JSON trajectories
174
+ - Summary stats for dominant-flow analysis
175
+ """
176
+
177
+ example_video = "assets/examples/sample1.mp4" if os.path.exists("assets/examples/sample1.mp4") else None
178
+
179
+ demo = gr.Interface(
180
+ fn=run_app,
181
+ inputs=gr.Video(label="Upload or use sample video (.mp4)", type="filepath"),
182
+ outputs=[
183
+ gr.Video(label="Tracked Output"),
184
+ gr.JSON(label="Vehicle Trajectories (Preview)"),
185
+ gr.JSON(label="Summary Stats")
186
+ ],
187
+ title="🚗 Dominant Flow Tracker – Stage 1",
188
+ description=description_text,
189
+ examples=[[example_video]] if example_video else None,
190
+ )
191
+
192
+ if __name__ == "__main__":
193
+ demo.launch()