ayushsaun commited on
Commit
0ac3063
·
1 Parent(s): 8cb54e4

updated inference.py

Browse files
Files changed (1) hide show
  1. inference.py +163 -346
inference.py CHANGED
@@ -1,14 +1,8 @@
1
- """
2
- UAV Object Tracker - Inference Script (FIXED)
3
- Properly uses sliding window search and template matching during inference.
4
- """
5
-
6
  import cv2
7
  import joblib
8
  import os
9
  import numpy as np
10
 
11
-
12
  class CameraMotionCompensator:
13
  def __init__(self):
14
  self.prev_frame = None
@@ -16,388 +10,211 @@ class CameraMotionCompensator:
16
  self.prev_desc = None
17
  self.orb = cv2.ORB_create(nfeatures=1000)
18
  self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
19
-
20
  def estimate_motion(self, frame):
21
  if frame is None:
22
  return np.eye(2, 3, dtype=np.float32)
23
-
24
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
25
  kp, desc = self.orb.detectAndCompute(gray, None)
26
-
27
- if self.prev_frame is None:
28
  self.prev_frame = gray
29
  self.prev_kp = kp
30
  self.prev_desc = desc
31
  return np.eye(2, 3, dtype=np.float32)
32
-
33
- if desc is None or self.prev_desc is None or len(desc) < 4 or len(self.prev_desc) < 4:
34
- return np.eye(2, 3, dtype=np.float32)
35
-
36
  matches = self.matcher.match(self.prev_desc, desc)
37
-
38
  if len(matches) < 4:
39
  return np.eye(2, 3, dtype=np.float32)
40
-
41
- matches = sorted(matches, key=lambda x: x.distance)
42
- good_matches = matches[:min(len(matches), 50)]
43
-
44
- src_pts = np.float32([self.prev_kp[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
45
- dst_pts = np.float32([kp[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
46
-
47
- transform_matrix, _ = cv2.estimateAffinePartial2D(src_pts, dst_pts)
48
-
49
- if transform_matrix is None:
50
- transform_matrix = np.eye(2, 3, dtype=np.float32)
51
-
52
  self.prev_frame = gray
53
  self.prev_kp = kp
54
  self.prev_desc = desc
55
-
56
- return transform_matrix
57
-
58
 
59
  class ImprovedSlidingWindowTracker:
60
  def __init__(self, scale_factor=2.0, overlap=0.3):
61
  self.scale_factor = scale_factor
62
  self.overlap = overlap
63
  self.sift = cv2.SIFT_create(nfeatures=2000)
64
-
65
- FLANN_INDEX_KDTREE = 1
66
- index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
67
- search_params = dict(checks=50)
68
- self.flann = cv2.FlannBasedMatcher(index_params, search_params)
69
-
70
  self.scale_levels = 3
71
  self.scale_step = 1.2
72
-
73
- def generate_multiscale_windows(self, img_shape, prev_bbox, transform_matrix=None):
74
- x, y, w, h = map(int, prev_bbox)
75
-
76
- if transform_matrix is not None:
77
- center = np.array([[x + w/2, y + h/2, 1]], dtype=np.float32).T
78
- transformed_center = np.dot(transform_matrix, center)
79
- x = int(transformed_center[0] - w/2)
80
- y = int(transformed_center[1] - h/2)
81
-
82
- windows = []
83
-
84
- for scale in np.linspace(1/self.scale_step, self.scale_step, self.scale_levels):
85
- window_w = int(w * self.scale_factor * scale)
86
- window_h = int(h * self.scale_factor * scale)
87
-
88
- center_x = x + w // 2
89
- center_y = y + h // 2
90
-
91
- step_x = int(window_w * (1 - self.overlap))
92
- step_y = int(window_h * (1 - self.overlap))
93
-
94
- for dy in range(-step_y, step_y + 1, max(1, step_y // 2)):
95
- for dx in range(-step_x, step_x + 1, max(1, step_x // 2)):
96
- win_x = max(0, min(center_x - window_w // 2 + dx, img_shape[1] - window_w))
97
- win_y = max(0, min(center_y - window_h // 2 + dy, img_shape[0] - window_h))
98
-
99
- # Ensure window is within bounds
100
- if win_x + window_w > img_shape[1]:
101
- window_w = img_shape[1] - win_x
102
- if win_y + window_h > img_shape[0]:
103
- window_h = img_shape[0] - win_y
104
-
105
- if window_w > 10 and window_h > 10:
106
- windows.append((win_x, win_y, window_w, window_h))
107
-
108
  return windows
109
 
110
- def score_window(self, img, window, template, template_desc):
111
- x, y, w, h = map(int, window)
112
-
113
- if x < 0 or y < 0 or x + w > img.shape[1] or y + h > img.shape[0]:
114
  return 0
115
-
116
- roi = img[y:y+h, x:x+w]
117
-
118
- min_size = 20
119
- if roi.shape[0] < min_size or roi.shape[1] < min_size:
120
  return 0
121
-
122
- roi = cv2.resize(roi, (template.shape[1], template.shape[0]))
123
-
124
- kp, desc = self.sift.detectAndCompute(roi, None)
125
-
126
- if desc is None or template_desc is None or len(desc) == 0 or len(template_desc) == 0:
127
  return 0
128
-
129
- try:
130
- matches = self.flann.knnMatch(template_desc, desc, k=2)
131
-
132
- good_matches = []
133
- for match_group in matches:
134
- if len(match_group) == 2:
135
- m, n = match_group
136
- if m.distance < 0.7 * n.distance:
137
- good_matches.append(m)
138
-
139
- if len(good_matches) == 0:
140
- return 0
141
-
142
- avg_distance = np.mean([m.distance for m in good_matches])
143
- score = len(good_matches) * (1 - avg_distance/512)
144
-
145
- return score
146
-
147
- except Exception:
148
- return 0
149
-
150
 
151
  class ObjectTrackerInference:
152
- def __init__(self, model_dir='models'):
153
- self.model_dir = model_dir
154
-
155
- print("Loading pre-trained models...")
156
- self.position_model = joblib.load(os.path.join(model_dir, 'position_model.joblib'))
157
- self.size_model = joblib.load(os.path.join(model_dir, 'size_model.joblib'))
158
- self.position_scaler = joblib.load(os.path.join(model_dir, 'position_scaler.joblib'))
159
- self.size_scaler = joblib.load(os.path.join(model_dir, 'size_scaler.joblib'))
160
- print("Models loaded successfully!")
161
-
162
  self.window_tracker = ImprovedSlidingWindowTracker()
163
- self.motion_compensator = CameraMotionCompensator()
164
-
165
  self.template = None
166
- self.template_descriptors = None
167
-
168
- def local_binary_pattern(self, image, n_points=8, radius=1):
169
- rows, cols = image.shape
170
- output = np.zeros((rows, cols))
171
-
172
- for i in range(radius, rows-radius):
173
- for j in range(radius, cols-radius):
174
- center = image[i, j]
175
- pattern = 0
176
-
177
- for k in range(n_points):
178
- angle = 2 * np.pi * k / n_points
179
- x = j + radius * np.cos(angle)
180
- y = i - radius * np.sin(angle)
181
- x1, x2 = int(np.floor(x)), int(np.ceil(x))
182
- y1, y2 = int(np.floor(y)), int(np.ceil(y))
183
-
184
- f11 = image[y1, x1]
185
- f12 = image[y1, x2]
186
- f21 = image[y2, x1]
187
- f22 = image[y2, x2]
188
-
189
- x_weight = x - x1
190
- y_weight = y - y1
191
-
192
- pixel_value = (f11 * (1-x_weight) * (1-y_weight) +
193
- f21 * (1-x_weight) * y_weight +
194
- f12 * x_weight * (1-y_weight) +
195
- f22 * x_weight * y_weight)
196
-
197
- pattern |= (pixel_value > center) << k
198
-
199
- output[i, j] = pattern
200
-
201
- return output
202
-
203
- def extract_features(self, frame, prev_bbox, transform_matrix):
204
- if frame is None:
205
- return None, prev_bbox
206
-
207
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
208
-
209
- # Use sliding window to find best match
210
- windows = self.window_tracker.generate_multiscale_windows(
211
- frame.shape, prev_bbox, transform_matrix
212
- )
213
-
214
- # Initialize template on first frame
215
  if self.template is None:
216
- x, y, w, h = map(int, prev_bbox)
217
- x = max(0, min(x, gray.shape[1] - w))
218
- y = max(0, min(y, gray.shape[0] - h))
219
- w = min(w, gray.shape[1] - x)
220
- h = min(h, gray.shape[0] - y)
221
-
222
- self.template = gray[y:y+h, x:x+w].copy()
223
- _, self.template_descriptors = self.window_tracker.sift.detectAndCompute(self.template, None)
224
-
225
- # Find best matching window
226
  best_score = -1
227
- best_window = prev_bbox
228
-
229
- for window in windows:
230
- score = self.window_tracker.score_window(
231
- gray, window, self.template, self.template_descriptors
232
- )
233
-
234
- if score > best_score:
235
- best_score = score
236
- best_window = window
237
-
238
- # Use best window for feature extraction
239
- x, y, w, h = map(int, best_window)
240
-
241
- # Ensure bbox is within bounds
242
- x = max(0, min(x, gray.shape[1] - 10))
243
- y = max(0, min(y, gray.shape[0] - 10))
244
- w = min(w, gray.shape[1] - x)
245
- h = min(h, gray.shape[0] - y)
246
- w = max(10, w)
247
- h = max(10, h)
248
-
249
- roi = gray[y:y+h, x:x+w]
250
- roi = cv2.resize(roi, (64, 64))
251
-
252
- features = []
253
-
254
- # HOG features
255
- hog = cv2.HOGDescriptor((64,64), (16,16), (8,8), (8,8), 9)
256
- hog_features = hog.compute(roi)
257
- features.extend(hog_features.flatten()[:64])
258
-
259
- # LBP features
260
- lbp = self.local_binary_pattern(roi, n_points=8, radius=1)
261
- features.extend([
262
- np.mean(lbp),
263
- np.std(lbp),
264
- *np.percentile(lbp, [25, 50, 75])
265
- ])
266
-
267
- # Motion features
268
- features.extend([
269
- transform_matrix[0,0],
270
- transform_matrix[1,1],
271
- transform_matrix[0,2],
272
- transform_matrix[1,2]
273
- ])
274
-
275
- # Position and size
276
- features.extend([x, y, w, h])
277
-
278
- return np.array(features).reshape(1, -1), (x, y, w, h)
279
-
280
- def predict_bbox(self, features):
281
- features_position = self.position_scaler.transform(features)
282
- features_size = self.size_scaler.transform(features)
283
-
284
- position_pred = self.position_model.predict(features_position)
285
- size_pred = self.size_model.predict(features_size)
286
-
287
- bbox = np.hstack([position_pred, size_pred])[0]
288
-
289
- return bbox
290
-
291
- def calculate_iou(self, bbox1, bbox2):
292
- x1, y1, w1, h1 = bbox1
293
- x2, y2, w2, h2 = bbox2
294
-
295
- x_left = max(x1, x2)
296
- y_top = max(y1, y2)
297
- x_right = min(x1 + w1, x2 + w2)
298
- y_bottom = min(y1 + h1, y2 + h2)
299
-
300
- if x_right < x_left or y_bottom < y_top:
301
- return 0.0
302
-
303
- intersection_area = (x_right - x_left) * (y_bottom - y_top)
304
- bbox1_area = w1 * h1
305
- bbox2_area = w2 * h2
306
-
307
- iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
308
- return max(0.0, min(1.0, iou))
309
-
310
- def track_video(self, video_path, initial_bbox, output_path='output_tracked.mp4', fps=30):
311
- print(f"Processing video: {video_path}")
312
-
313
- cap = cv2.VideoCapture(video_path)
314
- if not cap.isOpened():
315
- raise ValueError(f"Could not open video: {video_path}")
316
-
317
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
318
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
319
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
320
-
321
- print(f"Video: {frame_width}x{frame_height}, {total_frames} frames")
322
-
323
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
324
- out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
325
-
326
- # Reset state
327
- self.motion_compensator.prev_frame = None
328
- self.template = None
329
- self.template_descriptors = None
330
-
331
- current_bbox = initial_bbox
332
- frame_idx = 0
333
- template_update_counter = 0
334
- prev_predicted_bbox = None
335
-
336
- print("Tracking object...")
337
-
338
  while True:
339
- ret, frame = cap.read()
340
  if not ret:
341
  break
342
-
343
- transform_matrix = self.motion_compensator.estimate_motion(frame)
344
-
345
- features, search_bbox = self.extract_features(frame, current_bbox, transform_matrix)
346
-
347
- if features is not None:
348
- predicted_bbox = self.predict_bbox(features)
349
-
350
- # Clamp bbox to frame bounds
351
- x, y, w, h = predicted_bbox
352
- x = max(0, min(int(x), frame_width - 10))
353
- y = max(0, min(int(y), frame_height - 10))
354
- w = max(10, min(int(w), frame_width - x))
355
- h = max(10, min(int(h), frame_height - y))
356
- predicted_bbox = [x, y, w, h]
357
-
358
- # Adaptive template update
359
- template_update_counter += 1
360
- if template_update_counter >= 5 and prev_predicted_bbox is not None:
361
- iou = self.calculate_iou(prev_predicted_bbox, predicted_bbox)
362
- if iou > 0.6:
363
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
364
- x, y, w, h = map(int, predicted_bbox)
365
- self.template = gray[y:y+h, x:x+w].copy()
366
- _, self.template_descriptors = self.window_tracker.sift.detectAndCompute(self.template, None)
367
- template_update_counter = 0
368
-
369
- current_bbox = predicted_bbox
370
- prev_predicted_bbox = predicted_bbox
371
-
372
- # Draw bounding box
373
- x, y, w, h = map(int, current_bbox)
374
- cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
375
- cv2.putText(frame, f'Frame: {frame_idx}', (10, 30),
376
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
377
-
378
  out.write(frame)
379
- frame_idx += 1
380
-
381
- if frame_idx % 30 == 0:
382
- print(f"Processed {frame_idx}/{total_frames} frames")
383
-
384
  cap.release()
385
  out.release()
386
-
387
- print(f"Tracking complete! Video saved to: {output_path}")
388
- return output_path
389
-
390
 
391
  def main():
392
- tracker = ObjectTrackerInference(model_dir='models')
393
-
394
- video_path = 'input_video.mp4'
395
- initial_bbox = [100, 100, 50, 50] # [x, y, width, height]
396
- output_path = 'tracked_output.mp4'
397
-
398
- result = tracker.track_video(video_path, initial_bbox, output_path)
399
- print(f"Done! Output: {result}")
400
-
401
 
402
- if __name__ == "__main__":
403
  main()
 
 
 
 
 
 
1
  import cv2
2
  import joblib
3
  import os
4
  import numpy as np
5
 
 
6
  class CameraMotionCompensator:
7
  def __init__(self):
8
  self.prev_frame = None
 
10
  self.prev_desc = None
11
  self.orb = cv2.ORB_create(nfeatures=1000)
12
  self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
13
+
14
  def estimate_motion(self, frame):
15
  if frame is None:
16
  return np.eye(2, 3, dtype=np.float32)
17
+
18
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
19
  kp, desc = self.orb.detectAndCompute(gray, None)
20
+
21
+ if self.prev_frame is None or desc is None or self.prev_desc is None or len(desc) < 4 or len(self.prev_desc) < 4:
22
  self.prev_frame = gray
23
  self.prev_kp = kp
24
  self.prev_desc = desc
25
  return np.eye(2, 3, dtype=np.float32)
26
+
 
 
 
27
  matches = self.matcher.match(self.prev_desc, desc)
 
28
  if len(matches) < 4:
29
  return np.eye(2, 3, dtype=np.float32)
30
+
31
+ matches = sorted(matches, key=lambda x: x.distance)[:50]
32
+ src = np.float32([self.prev_kp[m.queryIdx].pt for m in matches]).reshape(-1,1,2)
33
+ dst = np.float32([kp[m.trainIdx].pt for m in matches]).reshape(-1,1,2)
34
+
35
+ M,_ = cv2.estimateAffinePartial2D(src, dst)
36
+ if M is None:
37
+ M = np.eye(2,3,dtype=np.float32)
38
+
 
 
 
39
  self.prev_frame = gray
40
  self.prev_kp = kp
41
  self.prev_desc = desc
42
+ return M
 
 
43
 
44
  class ImprovedSlidingWindowTracker:
45
  def __init__(self, scale_factor=2.0, overlap=0.3):
46
  self.scale_factor = scale_factor
47
  self.overlap = overlap
48
  self.sift = cv2.SIFT_create(nfeatures=2000)
 
 
 
 
 
 
49
  self.scale_levels = 3
50
  self.scale_step = 1.2
51
+ index_params = dict(algorithm=1, trees=5)
52
+ search_params = dict(checks=50)
53
+ self.flann = cv2.FlannBasedMatcher(index_params, search_params)
54
+
55
+ def generate_multiscale_windows(self, img_shape, prev_bbox, transform_matrix):
56
+ x,y,w,h = map(int, prev_bbox)
57
+ center = np.array([[x+w/2,y+h/2,1]],dtype=np.float32).T
58
+ center = np.dot(transform_matrix, center)
59
+ cx,cy = int(center[0]), int(center[1])
60
+
61
+ windows=[]
62
+ for s in np.linspace(1/self.scale_step, self.scale_step, self.scale_levels):
63
+ ww=int(w*self.scale_factor*s)
64
+ hh=int(h*self.scale_factor*s)
65
+ step_x=max(1,int(ww*(1-self.overlap)//2))
66
+ step_y=max(1,int(hh*(1-self.overlap)//2))
67
+ for dy in range(-step_y,step_y+1,step_y):
68
+ for dx in range(-step_x,step_x+1,step_x):
69
+ wx=max(0,min(cx-ww//2+dx,img_shape[1]-ww))
70
+ wy=max(0,min(cy-hh//2+dy,img_shape[0]-hh))
71
+ if ww>10 and hh>10:
72
+ windows.append((wx,wy,ww,hh))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  return windows
74
 
75
+ def score_window(self, gray, window, template, template_desc):
76
+ x,y,w,h = map(int,window)
77
+ roi = gray[y:y+h,x:x+w]
78
+ if roi.shape[0]<20 or roi.shape[1]<20:
79
  return 0
80
+ roi = cv2.resize(roi,(template.shape[1],template.shape[0]))
81
+ _,desc = self.sift.detectAndCompute(roi,None)
82
+ if desc is None or template_desc is None:
 
 
83
  return 0
84
+ matches = self.flann.knnMatch(template_desc,desc,k=2)
85
+ good = [m for m,n in matches if m.distance < 0.7*n.distance]
86
+ if not good:
 
 
 
87
  return 0
88
+ return len(good)*(1-np.mean([m.distance for m in good])/512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  class ObjectTrackerInference:
91
+ def __init__(self, model_dir):
92
+ self.position_model = joblib.load(os.path.join(model_dir,'position_model.joblib'))
93
+ self.size_model = joblib.load(os.path.join(model_dir,'size_model.joblib'))
94
+ self.position_scaler = joblib.load(os.path.join(model_dir,'position_scaler.joblib'))
95
+ self.size_scaler = joblib.load(os.path.join(model_dir,'size_scaler.joblib'))
 
 
 
 
 
96
  self.window_tracker = ImprovedSlidingWindowTracker()
97
+ self.motion = CameraMotionCompensator()
 
98
  self.template = None
99
+ self.template_desc = None
100
+ self.prev_bbox = None
101
+ self.template_update_counter = 0
102
+
103
+ def local_binary_pattern(self, image):
104
+ r=1;n=8
105
+ out=np.zeros(image.shape)
106
+ for i in range(r,image.shape[0]-r):
107
+ for j in range(r,image.shape[1]-r):
108
+ c=image[i,j];v=0
109
+ for k in range(n):
110
+ a=2*np.pi*k/n
111
+ x=j+r*np.cos(a);y=i-r*np.sin(a)
112
+ x1,x2=int(np.floor(x)),int(np.ceil(x))
113
+ y1,y2=int(np.floor(y)),int(np.ceil(y))
114
+ val=(image[y1,x1]+image[y1,x2]+image[y2,x1]+image[y2,x2])/4
115
+ v|=(val>c)<<k
116
+ out[i,j]=v
117
+ return out
118
+
119
+ def extract_features(self, frame, prev_bbox, M):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
121
+ windows = self.window_tracker.generate_multiscale_windows(frame.shape, prev_bbox, M)
122
+
 
 
 
 
 
123
  if self.template is None:
124
+ x,y,w,h = map(int,prev_bbox)
125
+ self.template = gray[y:y+h,x:x+w]
126
+ _,self.template_desc = self.window_tracker.sift.detectAndCompute(self.template,None)
127
+
 
 
 
 
 
 
128
  best_score = -1
129
+ best_window = None
130
+ for w in windows:
131
+ s = self.window_tracker.score_window(gray,w,self.template,self.template_desc)
132
+ if s > best_score:
133
+ best_score = s
134
+ best_window = w
135
+
136
+ if best_window is None:
137
+ x,y,w,h = map(int,prev_bbox)
138
+ else:
139
+ x,y,w,h = map(int,best_window)
140
+
141
+ roi = cv2.resize(gray[y:y+h,x:x+w],(64,64))
142
+ hog = cv2.HOGDescriptor((64,64),(16,16),(8,8),(8,8),9).compute(roi).flatten()[:64]
143
+ lbp = self.local_binary_pattern(roi)
144
+
145
+ feat = list(hog)+[
146
+ np.mean(lbp),np.std(lbp),
147
+ *np.percentile(lbp,[25,50,75]),
148
+ M[0,0],M[1,1],M[0,2],M[1,2],
149
+ x,y,w,h
150
+ ]
151
+
152
+ return np.array(feat).reshape(1,-1),(x,y,w,h),windows
153
+
154
+ def calculate_iou(self,a,b):
155
+ x1,y1,w1,h1=a
156
+ x2,y2,w2,h2=b
157
+ xl=max(x1,x2);yt=max(y1,y2)
158
+ xr=min(x1+w1,x2+w2);yb=min(y1+h1,y2+h2)
159
+ if xr<xl or yb<yt:
160
+ return 0
161
+ inter=(xr-xl)*(yb-yt)
162
+ return inter/(w1*h1+w2*h2-inter)
163
+
164
+ def track_video(self, video_path, init_bbox, output):
165
+ cap=cv2.VideoCapture(video_path)
166
+ w,h=int(cap.get(3)),int(cap.get(4))
167
+ out=cv2.VideoWriter(output,cv2.VideoWriter_fourcc(*'mp4v'),30,(w,h))
168
+ cur=init_bbox
169
+ frame_idx=0
170
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  while True:
172
+ ret,frame=cap.read()
173
  if not ret:
174
  break
175
+
176
+ M=self.motion.estimate_motion(frame)
177
+ feats,search_bbox,windows=self.extract_features(frame,cur,M)
178
+
179
+ pos=self.position_model.predict(self.position_scaler.transform(feats))
180
+ size=self.size_model.predict(self.size_scaler.transform(feats))
181
+ pred=[int(pos[0,0]),int(pos[0,1]),int(size[0,0]),int(size[0,1])]
182
+
183
+ self.template_update_counter+=1
184
+ if self.template_update_counter>=5 and self.prev_bbox is not None:
185
+ if self.calculate_iou(self.prev_bbox,pred)>0.6:
186
+ g=cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)
187
+ x,y,w1,h1=pred
188
+ self.template=g[y:y+h1,x:x+w1]
189
+ _,self.template_desc=self.window_tracker.sift.detectAndCompute(self.template,None)
190
+ self.template_update_counter=0
191
+
192
+ for wx,wy,ww,wh in windows:
193
+ cv2.rectangle(frame,(wx,wy),(wx+ww,wy+wh),(0,255,255),1)
194
+
195
+ hh,ww=frame.shape[:2]
196
+ for yy in range(0,hh,32):
197
+ for xx in range(0,ww,32):
198
+ sp=np.array([xx,yy,1])
199
+ ep=np.dot(M,sp)
200
+ if abs(ep[0]-xx)>1 or abs(ep[1]-yy)>1:
201
+ cv2.arrowedLine(frame,(xx,yy),(int(ep[0]),int(ep[1])),(0,255,0),1,tipLength=0.2)
202
+
203
+ x,y,w1,h1=pred
204
+ cv2.rectangle(frame,(x,y),(x+w1,y+h1),(0,255,0),2)
205
+ cv2.putText(frame,f'Frame: {frame_idx}',(10,30),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),2)
206
+
 
 
 
 
207
  out.write(frame)
208
+ self.prev_bbox=pred
209
+ cur=pred
210
+ frame_idx+=1
211
+
 
212
  cap.release()
213
  out.release()
 
 
 
 
214
 
215
  def main():
216
+ tracker=ObjectTrackerInference('models')
217
+ tracker.track_video('input_video.mp4',[100,100,50,50],'tracked_output.mp4')
 
 
 
 
 
 
 
218
 
219
+ if __name__=="__main__":
220
  main()