Spaces:
Sleeping
Sleeping
added inference changes
Browse files- inference.py +138 -32
inference.py
CHANGED
|
@@ -4,6 +4,109 @@ import joblib
|
|
| 4 |
import numpy as np
|
| 5 |
from pathlib import Path
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class ObjectTrackerInference:
|
| 9 |
def __init__(self, model_dir='models'):
|
|
@@ -23,6 +126,10 @@ class ObjectTrackerInference:
|
|
| 23 |
self.prev_frame = None
|
| 24 |
self.prev_kp = None
|
| 25 |
self.prev_desc = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def estimate_camera_motion(self, frame):
|
| 28 |
if frame is None:
|
|
@@ -164,8 +271,6 @@ class ObjectTrackerInference:
|
|
| 164 |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 165 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 166 |
|
| 167 |
-
print(f"Video: {frame_width}x{frame_height}, {total_frames} frames")
|
| 168 |
-
|
| 169 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 170 |
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
|
| 171 |
|
|
@@ -176,49 +281,50 @@ class ObjectTrackerInference:
|
|
| 176 |
current_bbox = initial_bbox
|
| 177 |
frame_idx = 0
|
| 178 |
|
| 179 |
-
print("Tracking object...")
|
| 180 |
-
|
| 181 |
while True:
|
| 182 |
ret, frame = cap.read()
|
| 183 |
if not ret:
|
| 184 |
break
|
| 185 |
|
| 186 |
transform_matrix = self.estimate_camera_motion(frame)
|
| 187 |
-
|
| 188 |
-
features = self.extract_features(frame, current_bbox, transform_matrix)
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
if features is not None:
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
| 194 |
x, y, w, h = map(int, current_bbox)
|
| 195 |
-
cv2.rectangle(frame, (x, y), (x+w, y+h), (0,
|
| 196 |
cv2.putText(frame, f'Frame: {frame_idx}', (10, 30),
|
| 197 |
-
cv2.FONT_HERSHEY_SIMPLEX, 1, (
|
| 198 |
-
|
| 199 |
out.write(frame)
|
| 200 |
frame_idx += 1
|
| 201 |
-
|
| 202 |
-
if frame_idx % 30 == 0:
|
| 203 |
-
print(f"Processed {frame_idx}/{total_frames} frames")
|
| 204 |
|
| 205 |
cap.release()
|
| 206 |
out.release()
|
| 207 |
-
|
| 208 |
-
print(f"Tracking complete! Video saved to: {output_path}")
|
| 209 |
return output_path
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
def main():
|
| 213 |
-
tracker = ObjectTrackerInference(model_dir='models')
|
| 214 |
-
|
| 215 |
-
video_path = 'input_video.mp4'
|
| 216 |
-
initial_bbox = [100, 100, 50, 50]
|
| 217 |
-
output_path = 'tracked_output.mp4'
|
| 218 |
-
|
| 219 |
-
result = tracker.track_video(video_path, initial_bbox, output_path)
|
| 220 |
-
print(f"Done! Output: {result}")
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
if __name__ == "__main__":
|
| 224 |
-
main()
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
from pathlib import Path
|
| 6 |
|
| 7 |
+
class CameraMotionVisualizer:
|
| 8 |
+
@staticmethod
|
| 9 |
+
def draw_motion_grid(frame, transform_matrix, grid_size=32):
|
| 10 |
+
if transform_matrix is None:
|
| 11 |
+
return frame
|
| 12 |
+
|
| 13 |
+
h, w = frame.shape[:2]
|
| 14 |
+
for y in range(0, h, grid_size):
|
| 15 |
+
for x in range(0, w, grid_size):
|
| 16 |
+
start = np.array([x, y, 1])
|
| 17 |
+
end = np.dot(transform_matrix, start)
|
| 18 |
+
if abs(end[0] - x) > 1 or abs(end[1] - y) > 1:
|
| 19 |
+
cv2.arrowedLine(
|
| 20 |
+
frame,
|
| 21 |
+
(int(x), int(y)),
|
| 22 |
+
(int(end[0]), int(end[1])),
|
| 23 |
+
(0, 255, 0),
|
| 24 |
+
1,
|
| 25 |
+
tipLength=0.2
|
| 26 |
+
)
|
| 27 |
+
return frame
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SlidingWindowRefiner:
|
| 31 |
+
def __init__(self):
|
| 32 |
+
self.sift = cv2.SIFT_create(nfeatures=2000)
|
| 33 |
+
|
| 34 |
+
FLANN_INDEX_KDTREE = 1
|
| 35 |
+
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
|
| 36 |
+
search_params = dict(checks=50)
|
| 37 |
+
self.flann = cv2.FlannBasedMatcher(index_params, search_params)
|
| 38 |
+
|
| 39 |
+
self.scale_levels = 3
|
| 40 |
+
self.scale_step = 1.2
|
| 41 |
+
self.scale_factor = 2.0
|
| 42 |
+
self.overlap = 0.3
|
| 43 |
+
|
| 44 |
+
self.template = None
|
| 45 |
+
self.template_kp = None
|
| 46 |
+
self.template_desc = None
|
| 47 |
+
|
| 48 |
+
def initialize_template(self, gray, bbox):
|
| 49 |
+
x, y, w, h = map(int, bbox)
|
| 50 |
+
self.template = gray[y:y+h, x:x+w].copy()
|
| 51 |
+
self.template_kp, self.template_desc = self.sift.detectAndCompute(
|
| 52 |
+
self.template, None
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def generate_windows(self, img_shape, prev_bbox, transform_matrix=None):
|
| 56 |
+
x, y, w, h = map(int, prev_bbox)
|
| 57 |
+
|
| 58 |
+
if transform_matrix is not None:
|
| 59 |
+
center = np.array([[x + w/2, y + h/2, 1]]).T
|
| 60 |
+
transformed = np.dot(transform_matrix, center)
|
| 61 |
+
x = int(transformed[0] - w/2)
|
| 62 |
+
y = int(transformed[1] - h/2)
|
| 63 |
+
|
| 64 |
+
windows = []
|
| 65 |
+
for scale in np.linspace(1/self.scale_step, self.scale_step, self.scale_levels):
|
| 66 |
+
ww = int(w * self.scale_factor * scale)
|
| 67 |
+
hh = int(h * self.scale_factor * scale)
|
| 68 |
+
|
| 69 |
+
step_x = int(ww * (1 - self.overlap))
|
| 70 |
+
step_y = int(hh * (1 - self.overlap))
|
| 71 |
+
|
| 72 |
+
cx, cy = x + w // 2, y + h // 2
|
| 73 |
+
|
| 74 |
+
for dy in range(-step_y, step_y + 1, max(1, step_y // 2)):
|
| 75 |
+
for dx in range(-step_x, step_x + 1, max(1, step_x // 2)):
|
| 76 |
+
wx = max(0, min(cx - ww // 2 + dx, img_shape[1] - ww))
|
| 77 |
+
wy = max(0, min(cy - hh // 2 + dy, img_shape[0] - hh))
|
| 78 |
+
windows.append((wx, wy, ww, hh))
|
| 79 |
+
return windows
|
| 80 |
+
|
| 81 |
+
def score_window(self, gray, window):
|
| 82 |
+
if self.template_desc is None:
|
| 83 |
+
return 0
|
| 84 |
+
|
| 85 |
+
x, y, w, h = map(int, window)
|
| 86 |
+
roi = gray[y:y+h, x:x+w]
|
| 87 |
+
|
| 88 |
+
if roi.shape[0] < 20 or roi.shape[1] < 20:
|
| 89 |
+
return 0
|
| 90 |
+
|
| 91 |
+
roi = cv2.resize(roi, self.template.shape[::-1])
|
| 92 |
+
kp, desc = self.sift.detectAndCompute(roi, None)
|
| 93 |
+
|
| 94 |
+
if desc is None:
|
| 95 |
+
return 0
|
| 96 |
+
|
| 97 |
+
matches = self.flann.knnMatch(self.template_desc, desc, k=2)
|
| 98 |
+
good = [m for m, n in matches if m.distance < 0.7 * n.distance]
|
| 99 |
+
|
| 100 |
+
if not good:
|
| 101 |
+
return 0
|
| 102 |
+
|
| 103 |
+
avg_dist = np.mean([m.distance for m in good])
|
| 104 |
+
return len(good) * (1 - avg_dist / 512)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ================================
|
| 108 |
+
# 🔹 ORIGINAL INFERENCE CLASS
|
| 109 |
+
# ================================
|
| 110 |
|
| 111 |
class ObjectTrackerInference:
|
| 112 |
def __init__(self, model_dir='models'):
|
|
|
|
| 126 |
self.prev_frame = None
|
| 127 |
self.prev_kp = None
|
| 128 |
self.prev_desc = None
|
| 129 |
+
|
| 130 |
+
# 🔹 ADDITIVE
|
| 131 |
+
self.window_refiner = SlidingWindowRefiner()
|
| 132 |
+
self.template_initialized = False
|
| 133 |
|
| 134 |
def estimate_camera_motion(self, frame):
|
| 135 |
if frame is None:
|
|
|
|
| 271 |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 272 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 273 |
|
|
|
|
|
|
|
| 274 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 275 |
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
|
| 276 |
|
|
|
|
| 281 |
current_bbox = initial_bbox
|
| 282 |
frame_idx = 0
|
| 283 |
|
|
|
|
|
|
|
| 284 |
while True:
|
| 285 |
ret, frame = cap.read()
|
| 286 |
if not ret:
|
| 287 |
break
|
| 288 |
|
| 289 |
transform_matrix = self.estimate_camera_motion(frame)
|
| 290 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
|
|
|
| 291 |
|
| 292 |
+
if not self.template_initialized:
|
| 293 |
+
self.window_refiner.initialize_template(gray, current_bbox)
|
| 294 |
+
self.template_initialized = True
|
| 295 |
+
|
| 296 |
+
windows = self.window_refiner.generate_windows(
|
| 297 |
+
frame.shape, current_bbox, transform_matrix
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
best_score = -1
|
| 301 |
+
best_window = None
|
| 302 |
+
|
| 303 |
+
for win in windows:
|
| 304 |
+
score = self.window_refiner.score_window(gray, win)
|
| 305 |
+
xw, yw, ww, hh = map(int, win)
|
| 306 |
+
cv2.rectangle(frame, (xw, yw), (xw+ww, yw+hh), (0, 255, 255), 1)
|
| 307 |
+
if score > best_score:
|
| 308 |
+
best_score = score
|
| 309 |
+
best_window = win
|
| 310 |
+
|
| 311 |
+
if best_window is not None:
|
| 312 |
+
current_bbox = best_window
|
| 313 |
+
|
| 314 |
+
features = self.extract_features(frame, current_bbox, transform_matrix)
|
| 315 |
if features is not None:
|
| 316 |
+
current_bbox = self.predict_bbox(features)
|
| 317 |
+
|
| 318 |
+
frame = CameraMotionVisualizer.draw_motion_grid(frame, transform_matrix)
|
| 319 |
+
|
| 320 |
x, y, w, h = map(int, current_bbox)
|
| 321 |
+
cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 0, 255), 2)
|
| 322 |
cv2.putText(frame, f'Frame: {frame_idx}', (10, 30),
|
| 323 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
| 324 |
+
|
| 325 |
out.write(frame)
|
| 326 |
frame_idx += 1
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
cap.release()
|
| 329 |
out.release()
|
|
|
|
|
|
|
| 330 |
return output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|