ayushsaun commited on
Commit
8cb54e4
·
1 Parent(s): 0fc551f

updated inference.py

Browse files
Files changed (1) hide show
  1. inference.py +226 -47
inference.py CHANGED
@@ -1,30 +1,23 @@
1
- import os
 
 
 
 
2
  import cv2
3
  import joblib
 
4
  import numpy as np
5
- from pathlib import Path
6
 
7
 
8
- class ObjectTrackerInference:
9
- def __init__(self, model_dir='models'):
10
- self.model_dir = model_dir
11
-
12
- print("Loading pre-trained models...")
13
- self.position_model = joblib.load(os.path.join(model_dir, 'position_model.joblib'))
14
- self.size_model = joblib.load(os.path.join(model_dir, 'size_model.joblib'))
15
- self.position_scaler = joblib.load(os.path.join(model_dir, 'position_scaler.joblib'))
16
- self.size_scaler = joblib.load(os.path.join(model_dir, 'size_scaler.joblib'))
17
- print("Models loaded successfully!")
18
-
19
- self.sift = cv2.SIFT_create(nfeatures=2000)
20
-
21
- self.orb = cv2.ORB_create(nfeatures=1000)
22
- self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
23
  self.prev_frame = None
24
  self.prev_kp = None
25
  self.prev_desc = None
26
-
27
- def estimate_camera_motion(self, frame):
 
 
28
  if frame is None:
29
  return np.eye(2, 3, dtype=np.float32)
30
 
@@ -61,7 +54,117 @@ class ObjectTrackerInference:
61
  self.prev_desc = desc
62
 
63
  return transform_matrix
64
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def local_binary_pattern(self, image, n_points=8, radius=1):
66
  rows, cols = image.shape
67
  output = np.zeros((rows, cols))
@@ -97,30 +200,63 @@ class ObjectTrackerInference:
97
 
98
  return output
99
 
100
- def extract_features(self, frame, bbox, transform_matrix=None):
101
  if frame is None:
102
- return None
103
 
104
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
105
- x, y, w, h = map(int, bbox)
106
 
107
- x = max(0, min(x, gray.shape[1] - w))
108
- y = max(0, min(y, gray.shape[0] - h))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  w = min(w, gray.shape[1] - x)
110
  h = min(h, gray.shape[0] - y)
 
 
111
 
112
  roi = gray[y:y+h, x:x+w]
113
- if roi.size == 0:
114
- roi = gray
115
-
116
  roi = cv2.resize(roi, (64, 64))
117
 
118
  features = []
119
 
 
120
  hog = cv2.HOGDescriptor((64,64), (16,16), (8,8), (8,8), 9)
121
  hog_features = hog.compute(roi)
122
  features.extend(hog_features.flatten()[:64])
123
 
 
124
  lbp = self.local_binary_pattern(roi, n_points=8, radius=1)
125
  features.extend([
126
  np.mean(lbp),
@@ -128,19 +264,18 @@ class ObjectTrackerInference:
128
  *np.percentile(lbp, [25, 50, 75])
129
  ])
130
 
131
- if transform_matrix is not None:
132
- features.extend([
133
- transform_matrix[0,0],
134
- transform_matrix[1,1],
135
- transform_matrix[0,2],
136
- transform_matrix[1,2]
137
- ])
138
- else:
139
- features.extend([1, 1, 0, 0])
140
-
141
  features.extend([x, y, w, h])
142
 
143
- return np.array(features).reshape(1, -1)
144
 
145
  def predict_bbox(self, features):
146
  features_position = self.position_scaler.transform(features)
@@ -153,13 +288,32 @@ class ObjectTrackerInference:
153
 
154
  return bbox
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def track_video(self, video_path, initial_bbox, output_path='output_tracked.mp4', fps=30):
157
  print(f"Processing video: {video_path}")
158
 
159
  cap = cv2.VideoCapture(video_path)
160
  if not cap.isOpened():
161
  raise ValueError(f"Could not open video: {video_path}")
162
-
163
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
164
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
165
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -169,12 +323,15 @@ class ObjectTrackerInference:
169
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
170
  out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
171
 
172
- self.prev_frame = None
173
- self.prev_kp = None
174
- self.prev_desc = None
 
175
 
176
  current_bbox = initial_bbox
177
  frame_idx = 0
 
 
178
 
179
  print("Tracking object...")
180
 
@@ -183,14 +340,36 @@ class ObjectTrackerInference:
183
  if not ret:
184
  break
185
 
186
- transform_matrix = self.estimate_camera_motion(frame)
 
 
187
 
188
- features = self.extract_features(frame, current_bbox, transform_matrix)
189
-
190
  if features is not None:
191
  predicted_bbox = self.predict_bbox(features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  current_bbox = predicted_bbox
 
193
 
 
194
  x, y, w, h = map(int, current_bbox)
195
  cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
196
  cv2.putText(frame, f'Frame: {frame_idx}', (10, 30),
@@ -213,7 +392,7 @@ def main():
213
  tracker = ObjectTrackerInference(model_dir='models')
214
 
215
  video_path = 'input_video.mp4'
216
- initial_bbox = [100, 100, 50, 50]
217
  output_path = 'tracked_output.mp4'
218
 
219
  result = tracker.track_video(video_path, initial_bbox, output_path)
 
1
+ """
2
+ UAV Object Tracker - Inference Script (FIXED)
3
+ Properly uses sliding window search and template matching during inference.
4
+ """
5
+
6
  import cv2
7
  import joblib
8
+ import os
9
  import numpy as np
 
10
 
11
 
12
+ class CameraMotionCompensator:
13
+ def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  self.prev_frame = None
15
  self.prev_kp = None
16
  self.prev_desc = None
17
+ self.orb = cv2.ORB_create(nfeatures=1000)
18
+ self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
19
+
20
+ def estimate_motion(self, frame):
21
  if frame is None:
22
  return np.eye(2, 3, dtype=np.float32)
23
 
 
54
  self.prev_desc = desc
55
 
56
  return transform_matrix
57
+
58
+
59
+ class ImprovedSlidingWindowTracker:
60
+ def __init__(self, scale_factor=2.0, overlap=0.3):
61
+ self.scale_factor = scale_factor
62
+ self.overlap = overlap
63
+ self.sift = cv2.SIFT_create(nfeatures=2000)
64
+
65
+ FLANN_INDEX_KDTREE = 1
66
+ index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
67
+ search_params = dict(checks=50)
68
+ self.flann = cv2.FlannBasedMatcher(index_params, search_params)
69
+
70
+ self.scale_levels = 3
71
+ self.scale_step = 1.2
72
+
73
+ def generate_multiscale_windows(self, img_shape, prev_bbox, transform_matrix=None):
74
+ x, y, w, h = map(int, prev_bbox)
75
+
76
+ if transform_matrix is not None:
77
+ center = np.array([[x + w/2, y + h/2, 1]], dtype=np.float32).T
78
+ transformed_center = np.dot(transform_matrix, center)
79
+ x = int(transformed_center[0] - w/2)
80
+ y = int(transformed_center[1] - h/2)
81
+
82
+ windows = []
83
+
84
+ for scale in np.linspace(1/self.scale_step, self.scale_step, self.scale_levels):
85
+ window_w = int(w * self.scale_factor * scale)
86
+ window_h = int(h * self.scale_factor * scale)
87
+
88
+ center_x = x + w // 2
89
+ center_y = y + h // 2
90
+
91
+ step_x = int(window_w * (1 - self.overlap))
92
+ step_y = int(window_h * (1 - self.overlap))
93
+
94
+ for dy in range(-step_y, step_y + 1, max(1, step_y // 2)):
95
+ for dx in range(-step_x, step_x + 1, max(1, step_x // 2)):
96
+ win_x = max(0, min(center_x - window_w // 2 + dx, img_shape[1] - window_w))
97
+ win_y = max(0, min(center_y - window_h // 2 + dy, img_shape[0] - window_h))
98
+
99
+ # Ensure window is within bounds
100
+ if win_x + window_w > img_shape[1]:
101
+ window_w = img_shape[1] - win_x
102
+ if win_y + window_h > img_shape[0]:
103
+ window_h = img_shape[0] - win_y
104
+
105
+ if window_w > 10 and window_h > 10:
106
+ windows.append((win_x, win_y, window_w, window_h))
107
+
108
+ return windows
109
+
110
+ def score_window(self, img, window, template, template_desc):
111
+ x, y, w, h = map(int, window)
112
+
113
+ if x < 0 or y < 0 or x + w > img.shape[1] or y + h > img.shape[0]:
114
+ return 0
115
+
116
+ roi = img[y:y+h, x:x+w]
117
+
118
+ min_size = 20
119
+ if roi.shape[0] < min_size or roi.shape[1] < min_size:
120
+ return 0
121
+
122
+ roi = cv2.resize(roi, (template.shape[1], template.shape[0]))
123
+
124
+ kp, desc = self.sift.detectAndCompute(roi, None)
125
+
126
+ if desc is None or template_desc is None or len(desc) == 0 or len(template_desc) == 0:
127
+ return 0
128
+
129
+ try:
130
+ matches = self.flann.knnMatch(template_desc, desc, k=2)
131
+
132
+ good_matches = []
133
+ for match_group in matches:
134
+ if len(match_group) == 2:
135
+ m, n = match_group
136
+ if m.distance < 0.7 * n.distance:
137
+ good_matches.append(m)
138
+
139
+ if len(good_matches) == 0:
140
+ return 0
141
+
142
+ avg_distance = np.mean([m.distance for m in good_matches])
143
+ score = len(good_matches) * (1 - avg_distance/512)
144
+
145
+ return score
146
+
147
+ except Exception:
148
+ return 0
149
+
150
+
151
+ class ObjectTrackerInference:
152
+ def __init__(self, model_dir='models'):
153
+ self.model_dir = model_dir
154
+
155
+ print("Loading pre-trained models...")
156
+ self.position_model = joblib.load(os.path.join(model_dir, 'position_model.joblib'))
157
+ self.size_model = joblib.load(os.path.join(model_dir, 'size_model.joblib'))
158
+ self.position_scaler = joblib.load(os.path.join(model_dir, 'position_scaler.joblib'))
159
+ self.size_scaler = joblib.load(os.path.join(model_dir, 'size_scaler.joblib'))
160
+ print("Models loaded successfully!")
161
+
162
+ self.window_tracker = ImprovedSlidingWindowTracker()
163
+ self.motion_compensator = CameraMotionCompensator()
164
+
165
+ self.template = None
166
+ self.template_descriptors = None
167
+
168
  def local_binary_pattern(self, image, n_points=8, radius=1):
169
  rows, cols = image.shape
170
  output = np.zeros((rows, cols))
 
200
 
201
  return output
202
 
203
+ def extract_features(self, frame, prev_bbox, transform_matrix):
204
  if frame is None:
205
+ return None, prev_bbox
206
 
207
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
 
208
 
209
+ # Use sliding window to find best match
210
+ windows = self.window_tracker.generate_multiscale_windows(
211
+ frame.shape, prev_bbox, transform_matrix
212
+ )
213
+
214
+ # Initialize template on first frame
215
+ if self.template is None:
216
+ x, y, w, h = map(int, prev_bbox)
217
+ x = max(0, min(x, gray.shape[1] - w))
218
+ y = max(0, min(y, gray.shape[0] - h))
219
+ w = min(w, gray.shape[1] - x)
220
+ h = min(h, gray.shape[0] - y)
221
+
222
+ self.template = gray[y:y+h, x:x+w].copy()
223
+ _, self.template_descriptors = self.window_tracker.sift.detectAndCompute(self.template, None)
224
+
225
+ # Find best matching window
226
+ best_score = -1
227
+ best_window = prev_bbox
228
+
229
+ for window in windows:
230
+ score = self.window_tracker.score_window(
231
+ gray, window, self.template, self.template_descriptors
232
+ )
233
+
234
+ if score > best_score:
235
+ best_score = score
236
+ best_window = window
237
+
238
+ # Use best window for feature extraction
239
+ x, y, w, h = map(int, best_window)
240
+
241
+ # Ensure bbox is within bounds
242
+ x = max(0, min(x, gray.shape[1] - 10))
243
+ y = max(0, min(y, gray.shape[0] - 10))
244
  w = min(w, gray.shape[1] - x)
245
  h = min(h, gray.shape[0] - y)
246
+ w = max(10, w)
247
+ h = max(10, h)
248
 
249
  roi = gray[y:y+h, x:x+w]
 
 
 
250
  roi = cv2.resize(roi, (64, 64))
251
 
252
  features = []
253
 
254
+ # HOG features
255
  hog = cv2.HOGDescriptor((64,64), (16,16), (8,8), (8,8), 9)
256
  hog_features = hog.compute(roi)
257
  features.extend(hog_features.flatten()[:64])
258
 
259
+ # LBP features
260
  lbp = self.local_binary_pattern(roi, n_points=8, radius=1)
261
  features.extend([
262
  np.mean(lbp),
 
264
  *np.percentile(lbp, [25, 50, 75])
265
  ])
266
 
267
+ # Motion features
268
+ features.extend([
269
+ transform_matrix[0,0],
270
+ transform_matrix[1,1],
271
+ transform_matrix[0,2],
272
+ transform_matrix[1,2]
273
+ ])
274
+
275
+ # Position and size
 
276
  features.extend([x, y, w, h])
277
 
278
+ return np.array(features).reshape(1, -1), (x, y, w, h)
279
 
280
  def predict_bbox(self, features):
281
  features_position = self.position_scaler.transform(features)
 
288
 
289
  return bbox
290
 
291
+ def calculate_iou(self, bbox1, bbox2):
292
+ x1, y1, w1, h1 = bbox1
293
+ x2, y2, w2, h2 = bbox2
294
+
295
+ x_left = max(x1, x2)
296
+ y_top = max(y1, y2)
297
+ x_right = min(x1 + w1, x2 + w2)
298
+ y_bottom = min(y1 + h1, y2 + h2)
299
+
300
+ if x_right < x_left or y_bottom < y_top:
301
+ return 0.0
302
+
303
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
304
+ bbox1_area = w1 * h1
305
+ bbox2_area = w2 * h2
306
+
307
+ iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
308
+ return max(0.0, min(1.0, iou))
309
+
310
  def track_video(self, video_path, initial_bbox, output_path='output_tracked.mp4', fps=30):
311
  print(f"Processing video: {video_path}")
312
 
313
  cap = cv2.VideoCapture(video_path)
314
  if not cap.isOpened():
315
  raise ValueError(f"Could not open video: {video_path}")
316
+
317
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
318
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
319
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
323
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
324
  out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
325
 
326
+ # Reset state
327
+ self.motion_compensator.prev_frame = None
328
+ self.template = None
329
+ self.template_descriptors = None
330
 
331
  current_bbox = initial_bbox
332
  frame_idx = 0
333
+ template_update_counter = 0
334
+ prev_predicted_bbox = None
335
 
336
  print("Tracking object...")
337
 
 
340
  if not ret:
341
  break
342
 
343
+ transform_matrix = self.motion_compensator.estimate_motion(frame)
344
+
345
+ features, search_bbox = self.extract_features(frame, current_bbox, transform_matrix)
346
 
 
 
347
  if features is not None:
348
  predicted_bbox = self.predict_bbox(features)
349
+
350
+ # Clamp bbox to frame bounds
351
+ x, y, w, h = predicted_bbox
352
+ x = max(0, min(int(x), frame_width - 10))
353
+ y = max(0, min(int(y), frame_height - 10))
354
+ w = max(10, min(int(w), frame_width - x))
355
+ h = max(10, min(int(h), frame_height - y))
356
+ predicted_bbox = [x, y, w, h]
357
+
358
+ # Adaptive template update
359
+ template_update_counter += 1
360
+ if template_update_counter >= 5 and prev_predicted_bbox is not None:
361
+ iou = self.calculate_iou(prev_predicted_bbox, predicted_bbox)
362
+ if iou > 0.6:
363
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
364
+ x, y, w, h = map(int, predicted_bbox)
365
+ self.template = gray[y:y+h, x:x+w].copy()
366
+ _, self.template_descriptors = self.window_tracker.sift.detectAndCompute(self.template, None)
367
+ template_update_counter = 0
368
+
369
  current_bbox = predicted_bbox
370
+ prev_predicted_bbox = predicted_bbox
371
 
372
+ # Draw bounding box
373
  x, y, w, h = map(int, current_bbox)
374
  cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
375
  cv2.putText(frame, f'Frame: {frame_idx}', (10, 30),
 
392
  tracker = ObjectTrackerInference(model_dir='models')
393
 
394
  video_path = 'input_video.mp4'
395
+ initial_bbox = [100, 100, 50, 50] # [x, y, width, height]
396
  output_path = 'tracked_output.mp4'
397
 
398
  result = tracker.track_video(video_path, initial_bbox, output_path)