ayushsaun commited on
Commit
363db49
·
1 Parent(s): 4fcaf7f

added inference changes

Browse files
Files changed (1) hide show
  1. inference.py +138 -32
inference.py CHANGED
@@ -4,6 +4,109 @@ import joblib
4
  import numpy as np
5
  from pathlib import Path
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class ObjectTrackerInference:
9
  def __init__(self, model_dir='models'):
@@ -23,6 +126,10 @@ class ObjectTrackerInference:
23
  self.prev_frame = None
24
  self.prev_kp = None
25
  self.prev_desc = None
 
 
 
 
26
 
27
  def estimate_camera_motion(self, frame):
28
  if frame is None:
@@ -164,8 +271,6 @@ class ObjectTrackerInference:
164
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
165
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
166
 
167
- print(f"Video: {frame_width}x{frame_height}, {total_frames} frames")
168
-
169
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
170
  out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
171
 
@@ -176,49 +281,50 @@ class ObjectTrackerInference:
176
  current_bbox = initial_bbox
177
  frame_idx = 0
178
 
179
- print("Tracking object...")
180
-
181
  while True:
182
  ret, frame = cap.read()
183
  if not ret:
184
  break
185
 
186
  transform_matrix = self.estimate_camera_motion(frame)
187
-
188
- features = self.extract_features(frame, current_bbox, transform_matrix)
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  if features is not None:
191
- predicted_bbox = self.predict_bbox(features)
192
- current_bbox = predicted_bbox
193
-
 
194
  x, y, w, h = map(int, current_bbox)
195
- cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 2)
196
  cv2.putText(frame, f'Frame: {frame_idx}', (10, 30),
197
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
198
-
199
  out.write(frame)
200
  frame_idx += 1
201
-
202
- if frame_idx % 30 == 0:
203
- print(f"Processed {frame_idx}/{total_frames} frames")
204
 
205
  cap.release()
206
  out.release()
207
-
208
- print(f"Tracking complete! Video saved to: {output_path}")
209
  return output_path
210
-
211
-
212
- def main():
213
- tracker = ObjectTrackerInference(model_dir='models')
214
-
215
- video_path = 'input_video.mp4'
216
- initial_bbox = [100, 100, 50, 50]
217
- output_path = 'tracked_output.mp4'
218
-
219
- result = tracker.track_video(video_path, initial_bbox, output_path)
220
- print(f"Done! Output: {result}")
221
-
222
-
223
- if __name__ == "__main__":
224
- main()
 
4
  import numpy as np
5
  from pathlib import Path
6
 
7
+ class CameraMotionVisualizer:
8
+ @staticmethod
9
+ def draw_motion_grid(frame, transform_matrix, grid_size=32):
10
+ if transform_matrix is None:
11
+ return frame
12
+
13
+ h, w = frame.shape[:2]
14
+ for y in range(0, h, grid_size):
15
+ for x in range(0, w, grid_size):
16
+ start = np.array([x, y, 1])
17
+ end = np.dot(transform_matrix, start)
18
+ if abs(end[0] - x) > 1 or abs(end[1] - y) > 1:
19
+ cv2.arrowedLine(
20
+ frame,
21
+ (int(x), int(y)),
22
+ (int(end[0]), int(end[1])),
23
+ (0, 255, 0),
24
+ 1,
25
+ tipLength=0.2
26
+ )
27
+ return frame
28
+
29
+
30
+ class SlidingWindowRefiner:
31
+ def __init__(self):
32
+ self.sift = cv2.SIFT_create(nfeatures=2000)
33
+
34
+ FLANN_INDEX_KDTREE = 1
35
+ index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
36
+ search_params = dict(checks=50)
37
+ self.flann = cv2.FlannBasedMatcher(index_params, search_params)
38
+
39
+ self.scale_levels = 3
40
+ self.scale_step = 1.2
41
+ self.scale_factor = 2.0
42
+ self.overlap = 0.3
43
+
44
+ self.template = None
45
+ self.template_kp = None
46
+ self.template_desc = None
47
+
48
+ def initialize_template(self, gray, bbox):
49
+ x, y, w, h = map(int, bbox)
50
+ self.template = gray[y:y+h, x:x+w].copy()
51
+ self.template_kp, self.template_desc = self.sift.detectAndCompute(
52
+ self.template, None
53
+ )
54
+
55
+ def generate_windows(self, img_shape, prev_bbox, transform_matrix=None):
56
+ x, y, w, h = map(int, prev_bbox)
57
+
58
+ if transform_matrix is not None:
59
+ center = np.array([[x + w/2, y + h/2, 1]]).T
60
+ transformed = np.dot(transform_matrix, center)
61
+ x = int(transformed[0] - w/2)
62
+ y = int(transformed[1] - h/2)
63
+
64
+ windows = []
65
+ for scale in np.linspace(1/self.scale_step, self.scale_step, self.scale_levels):
66
+ ww = int(w * self.scale_factor * scale)
67
+ hh = int(h * self.scale_factor * scale)
68
+
69
+ step_x = int(ww * (1 - self.overlap))
70
+ step_y = int(hh * (1 - self.overlap))
71
+
72
+ cx, cy = x + w // 2, y + h // 2
73
+
74
+ for dy in range(-step_y, step_y + 1, max(1, step_y // 2)):
75
+ for dx in range(-step_x, step_x + 1, max(1, step_x // 2)):
76
+ wx = max(0, min(cx - ww // 2 + dx, img_shape[1] - ww))
77
+ wy = max(0, min(cy - hh // 2 + dy, img_shape[0] - hh))
78
+ windows.append((wx, wy, ww, hh))
79
+ return windows
80
+
81
+ def score_window(self, gray, window):
82
+ if self.template_desc is None:
83
+ return 0
84
+
85
+ x, y, w, h = map(int, window)
86
+ roi = gray[y:y+h, x:x+w]
87
+
88
+ if roi.shape[0] < 20 or roi.shape[1] < 20:
89
+ return 0
90
+
91
+ roi = cv2.resize(roi, self.template.shape[::-1])
92
+ kp, desc = self.sift.detectAndCompute(roi, None)
93
+
94
+ if desc is None:
95
+ return 0
96
+
97
+ matches = self.flann.knnMatch(self.template_desc, desc, k=2)
98
+ good = [m for m, n in matches if m.distance < 0.7 * n.distance]
99
+
100
+ if not good:
101
+ return 0
102
+
103
+ avg_dist = np.mean([m.distance for m in good])
104
+ return len(good) * (1 - avg_dist / 512)
105
+
106
+
107
+ # ================================
108
+ # 🔹 ORIGINAL INFERENCE CLASS
109
+ # ================================
110
 
111
  class ObjectTrackerInference:
112
  def __init__(self, model_dir='models'):
 
126
  self.prev_frame = None
127
  self.prev_kp = None
128
  self.prev_desc = None
129
+
130
+ # 🔹 ADDITIVE
131
+ self.window_refiner = SlidingWindowRefiner()
132
+ self.template_initialized = False
133
 
134
  def estimate_camera_motion(self, frame):
135
  if frame is None:
 
271
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
272
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
273
 
 
 
274
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
275
  out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
276
 
 
281
  current_bbox = initial_bbox
282
  frame_idx = 0
283
 
 
 
284
  while True:
285
  ret, frame = cap.read()
286
  if not ret:
287
  break
288
 
289
  transform_matrix = self.estimate_camera_motion(frame)
290
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
 
291
 
292
+ if not self.template_initialized:
293
+ self.window_refiner.initialize_template(gray, current_bbox)
294
+ self.template_initialized = True
295
+
296
+ windows = self.window_refiner.generate_windows(
297
+ frame.shape, current_bbox, transform_matrix
298
+ )
299
+
300
+ best_score = -1
301
+ best_window = None
302
+
303
+ for win in windows:
304
+ score = self.window_refiner.score_window(gray, win)
305
+ xw, yw, ww, hh = map(int, win)
306
+ cv2.rectangle(frame, (xw, yw), (xw+ww, yw+hh), (0, 255, 255), 1)
307
+ if score > best_score:
308
+ best_score = score
309
+ best_window = win
310
+
311
+ if best_window is not None:
312
+ current_bbox = best_window
313
+
314
+ features = self.extract_features(frame, current_bbox, transform_matrix)
315
  if features is not None:
316
+ current_bbox = self.predict_bbox(features)
317
+
318
+ frame = CameraMotionVisualizer.draw_motion_grid(frame, transform_matrix)
319
+
320
  x, y, w, h = map(int, current_bbox)
321
+ cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 0, 255), 2)
322
  cv2.putText(frame, f'Frame: {frame_idx}', (10, 30),
323
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
324
+
325
  out.write(frame)
326
  frame_idx += 1
 
 
 
327
 
328
  cap.release()
329
  out.release()
 
 
330
  return output_path