ayushsaun commited on
Commit
a39d025
·
1 Parent(s): 09c9310

Deploy UAV object tracker

Browse files
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from inference import ObjectTrackerInference
4
+
5
+
6
+ tracker = ObjectTrackerInference(model_dir='models')
7
+
8
+
9
+ def track_object(video, x, y, width, height):
10
+ try:
11
+ if video is None:
12
+ return None
13
+
14
+ initial_bbox = [int(x), int(y), int(width), int(height)]
15
+
16
+ output_path = 'tracked_output.mp4'
17
+ result = tracker.track_video(video, initial_bbox, output_path, fps=30)
18
+
19
+ return result
20
+
21
+ except Exception as e:
22
+ print(f"Error: {str(e)}")
23
+ return None
24
+
25
+ with gr.Blocks(title="UAV Object Tracker") as demo:
26
+
27
+ gr.Markdown("# 🎯 UAV Single Object Tracker")
28
+ gr.Markdown("Upload a video and specify the initial bounding box to track an object.")
29
+
30
+ with gr.Row():
31
+ with gr.Column():
32
+ video_input = gr.Video(label="Upload Video")
33
+
34
+ gr.Markdown("### Initial Bounding Box Coordinates")
35
+ with gr.Row():
36
+ x_input = gr.Number(label="X (top-left)", value=100)
37
+ y_input = gr.Number(label="Y (top-left)", value=100)
38
+ with gr.Row():
39
+ w_input = gr.Number(label="Width", value=50)
40
+ h_input = gr.Number(label="Height", value=50)
41
+
42
+ track_btn = gr.Button("Track Object", variant="primary")
43
+
44
+ with gr.Column():
45
+ video_output = gr.Video(label="Tracked Output")
46
+
47
+ gr.Markdown("### 📖 Instructions")
48
+ gr.Markdown("""
49
+ 1. Upload your video file
50
+ 2. Enter the initial bounding box coordinates (x, y, width, height) for the first frame
51
+ 3. Click 'Track Object' to process
52
+ 4. Download the tracked video from the output
53
+ """)
54
+
55
+ track_btn.click(
56
+ fn=track_object,
57
+ inputs=[video_input, x_input, y_input, w_input, h_input],
58
+ outputs=video_output
59
+ )
60
+
61
+ if __name__ == "__main__":
62
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import joblib
4
+ import numpy as np
5
+ from pathlib import Path
6
+
7
+
8
+ class ObjectTrackerInference:
9
+ def __init__(self, model_dir='models'):
10
+ self.model_dir = model_dir
11
+
12
+ print("Loading pre-trained models...")
13
+ self.position_model = joblib.load(os.path.join(model_dir, 'position_model.joblib'))
14
+ self.size_model = joblib.load(os.path.join(model_dir, 'size_model.joblib'))
15
+ self.position_scaler = joblib.load(os.path.join(model_dir, 'position_scaler.joblib'))
16
+ self.size_scaler = joblib.load(os.path.join(model_dir, 'size_scaler.joblib'))
17
+ print("Models loaded successfully!")
18
+
19
+ self.sift = cv2.SIFT_create(nfeatures=2000)
20
+
21
+ self.orb = cv2.ORB_create(nfeatures=1000)
22
+ self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
23
+ self.prev_frame = None
24
+ self.prev_kp = None
25
+ self.prev_desc = None
26
+
27
+ def estimate_camera_motion(self, frame):
28
+ if frame is None:
29
+ return np.eye(2, 3, dtype=np.float32)
30
+
31
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
32
+ kp, desc = self.orb.detectAndCompute(gray, None)
33
+
34
+ if self.prev_frame is None:
35
+ self.prev_frame = gray
36
+ self.prev_kp = kp
37
+ self.prev_desc = desc
38
+ return np.eye(2, 3, dtype=np.float32)
39
+
40
+ if desc is None or self.prev_desc is None or len(desc) < 4 or len(self.prev_desc) < 4:
41
+ return np.eye(2, 3, dtype=np.float32)
42
+
43
+ matches = self.matcher.match(self.prev_desc, desc)
44
+
45
+ if len(matches) < 4:
46
+ return np.eye(2, 3, dtype=np.float32)
47
+
48
+ matches = sorted(matches, key=lambda x: x.distance)
49
+ good_matches = matches[:min(len(matches), 50)]
50
+
51
+ src_pts = np.float32([self.prev_kp[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
52
+ dst_pts = np.float32([kp[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
53
+
54
+ transform_matrix, _ = cv2.estimateAffinePartial2D(src_pts, dst_pts)
55
+
56
+ if transform_matrix is None:
57
+ transform_matrix = np.eye(2, 3, dtype=np.float32)
58
+
59
+ self.prev_frame = gray
60
+ self.prev_kp = kp
61
+ self.prev_desc = desc
62
+
63
+ return transform_matrix
64
+
65
+ def local_binary_pattern(self, image, n_points=8, radius=1):
66
+ rows, cols = image.shape
67
+ output = np.zeros((rows, cols))
68
+
69
+ for i in range(radius, rows-radius):
70
+ for j in range(radius, cols-radius):
71
+ center = image[i, j]
72
+ pattern = 0
73
+
74
+ for k in range(n_points):
75
+ angle = 2 * np.pi * k / n_points
76
+ x = j + radius * np.cos(angle)
77
+ y = i - radius * np.sin(angle)
78
+ x1, x2 = int(np.floor(x)), int(np.ceil(x))
79
+ y1, y2 = int(np.floor(y)), int(np.ceil(y))
80
+
81
+ f11 = image[y1, x1]
82
+ f12 = image[y1, x2]
83
+ f21 = image[y2, x1]
84
+ f22 = image[y2, x2]
85
+
86
+ x_weight = x - x1
87
+ y_weight = y - y1
88
+
89
+ pixel_value = (f11 * (1-x_weight) * (1-y_weight) +
90
+ f21 * (1-x_weight) * y_weight +
91
+ f12 * x_weight * (1-y_weight) +
92
+ f22 * x_weight * y_weight)
93
+
94
+ pattern |= (pixel_value > center) << k
95
+
96
+ output[i, j] = pattern
97
+
98
+ return output
99
+
100
+ def extract_features(self, frame, bbox, transform_matrix=None):
101
+ if frame is None:
102
+ return None
103
+
104
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
105
+ x, y, w, h = map(int, bbox)
106
+
107
+ x = max(0, min(x, gray.shape[1] - w))
108
+ y = max(0, min(y, gray.shape[0] - h))
109
+ w = min(w, gray.shape[1] - x)
110
+ h = min(h, gray.shape[0] - y)
111
+
112
+ roi = gray[y:y+h, x:x+w]
113
+ if roi.size == 0:
114
+ roi = gray
115
+
116
+ roi = cv2.resize(roi, (64, 64))
117
+
118
+ features = []
119
+
120
+ hog = cv2.HOGDescriptor((64,64), (16,16), (8,8), (8,8), 9)
121
+ hog_features = hog.compute(roi)
122
+ features.extend(hog_features.flatten()[:64])
123
+
124
+ lbp = self.local_binary_pattern(roi, n_points=8, radius=1)
125
+ features.extend([
126
+ np.mean(lbp),
127
+ np.std(lbp),
128
+ *np.percentile(lbp, [25, 50, 75])
129
+ ])
130
+
131
+ if transform_matrix is not None:
132
+ features.extend([
133
+ transform_matrix[0,0],
134
+ transform_matrix[1,1],
135
+ transform_matrix[0,2],
136
+ transform_matrix[1,2]
137
+ ])
138
+ else:
139
+ features.extend([1, 1, 0, 0])
140
+
141
+ features.extend([x, y, w, h])
142
+
143
+ return np.array(features).reshape(1, -1)
144
+
145
+ def predict_bbox(self, features):
146
+ features_position = self.position_scaler.transform(features)
147
+ features_size = self.size_scaler.transform(features)
148
+
149
+ position_pred = self.position_model.predict(features_position)
150
+ size_pred = self.size_model.predict(features_size)
151
+
152
+ bbox = np.hstack([position_pred, size_pred])[0]
153
+
154
+ return bbox
155
+
156
+ def track_video(self, video_path, initial_bbox, output_path='output_tracked.mp4', fps=30):
157
+ print(f"Processing video: {video_path}")
158
+
159
+ cap = cv2.VideoCapture(video_path)
160
+ if not cap.isOpened():
161
+ raise ValueError(f"Could not open video: {video_path}")
162
+
163
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
164
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
165
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
166
+
167
+ print(f"Video: {frame_width}x{frame_height}, {total_frames} frames")
168
+
169
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
170
+ out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
171
+
172
+ self.prev_frame = None
173
+ self.prev_kp = None
174
+ self.prev_desc = None
175
+
176
+ current_bbox = initial_bbox
177
+ frame_idx = 0
178
+
179
+ print("Tracking object...")
180
+
181
+ while True:
182
+ ret, frame = cap.read()
183
+ if not ret:
184
+ break
185
+
186
+ transform_matrix = self.estimate_camera_motion(frame)
187
+
188
+ features = self.extract_features(frame, current_bbox, transform_matrix)
189
+
190
+ if features is not None:
191
+ predicted_bbox = self.predict_bbox(features)
192
+ current_bbox = predicted_bbox
193
+
194
+ x, y, w, h = map(int, current_bbox)
195
+ cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
196
+ cv2.putText(frame, f'Frame: {frame_idx}', (10, 30),
197
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
198
+
199
+ out.write(frame)
200
+ frame_idx += 1
201
+
202
+ if frame_idx % 30 == 0:
203
+ print(f"Processed {frame_idx}/{total_frames} frames")
204
+
205
+ cap.release()
206
+ out.release()
207
+
208
+ print(f"Tracking complete! Video saved to: {output_path}")
209
+ return output_path
210
+
211
+
212
+ def main():
213
+ tracker = ObjectTrackerInference(model_dir='models')
214
+
215
+ video_path = 'input_video.mp4'
216
+ initial_bbox = [100, 100, 50, 50]
217
+ output_path = 'tracked_output.mp4'
218
+
219
+ result = tracker.track_video(video_path, initial_bbox, output_path)
220
+ print(f"Done! Output: {result}")
221
+
222
+
223
+ if __name__ == "__main__":
224
+ main()
models/position_model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c9ef28123b06500bda1e878ba63dc47eee66c607af9d3e714198f6b19ec60f0a
3
+ size 2423
models/position_scaler.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:648f411f27f0abf3f9065b45c5750f5c2b8cecfcff2d51e3813d943d0f97a7b0
3
+ size 2447
models/size_model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e4a605772090e0856d33794d5e62d03a821177cbe617586b6350bcdd8168cc2
3
+ size 98147577
models/size_scaler.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:648f411f27f0abf3f9065b45c5750f5c2b8cecfcff2d51e3813d943d0f97a7b0
3
+ size 2447
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ opencv-python-headless==4.8.1.78
2
+ scikit-learn==1.3.2
3
+ numpy==1.24.3
4
+ joblib==1.3.2
5
+ gradio==4.19.2
6
+ tqdm==4.66.1