Spaces:
Sleeping
Sleeping
fixed inference bug
Browse files- inference.py +122 -69
inference.py
CHANGED
|
@@ -3,6 +3,7 @@ import joblib
|
|
| 3 |
import os
|
| 4 |
import numpy as np
|
| 5 |
|
|
|
|
| 6 |
class CameraMotionCompensator:
|
| 7 |
def __init__(self):
|
| 8 |
self.prev_frame = None
|
|
@@ -41,6 +42,7 @@ class CameraMotionCompensator:
|
|
| 41 |
self.prev_desc = desc
|
| 42 |
return M
|
| 43 |
|
|
|
|
| 44 |
class ImprovedSlidingWindowTracker:
|
| 45 |
def __init__(self, scale_factor=2.0, overlap=0.3):
|
| 46 |
self.scale_factor = scale_factor
|
|
@@ -73,26 +75,35 @@ class ImprovedSlidingWindowTracker:
|
|
| 73 |
return windows
|
| 74 |
|
| 75 |
def score_window(self, gray, window, template, template_desc):
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
return 0
|
| 88 |
-
|
| 89 |
|
| 90 |
class ObjectTrackerInference:
|
| 91 |
def __init__(self, model_dir):
|
|
|
|
| 92 |
self.position_model = joblib.load(os.path.join(model_dir,'position_model.joblib'))
|
| 93 |
self.size_model = joblib.load(os.path.join(model_dir,'size_model.joblib'))
|
| 94 |
self.position_scaler = joblib.load(os.path.join(model_dir,'position_scaler.joblib'))
|
| 95 |
self.size_scaler = joblib.load(os.path.join(model_dir,'size_scaler.joblib'))
|
|
|
|
|
|
|
| 96 |
self.window_tracker = ImprovedSlidingWindowTracker()
|
| 97 |
self.motion = CameraMotionCompensator()
|
| 98 |
self.template = None
|
|
@@ -122,21 +133,24 @@ class ObjectTrackerInference:
|
|
| 122 |
|
| 123 |
if self.template is None:
|
| 124 |
x,y,w,h = map(int,prev_bbox)
|
| 125 |
-
|
|
|
|
|
|
|
| 126 |
_,self.template_desc = self.window_tracker.sift.detectAndCompute(self.template,None)
|
| 127 |
|
| 128 |
best_score = -1
|
| 129 |
-
best_window =
|
| 130 |
for w in windows:
|
| 131 |
s = self.window_tracker.score_window(gray,w,self.template,self.template_desc)
|
| 132 |
if s > best_score:
|
| 133 |
best_score = s
|
| 134 |
best_window = w
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
| 140 |
|
| 141 |
roi = cv2.resize(gray[y:y+h,x:x+w],(64,64))
|
| 142 |
hog = cv2.HOGDescriptor((64,64),(16,16),(8,8),(8,8),9).compute(roi).flatten()[:64]
|
|
@@ -161,60 +175,99 @@ class ObjectTrackerInference:
|
|
| 161 |
inter=(xr-xl)*(yb-yt)
|
| 162 |
return inter/(w1*h1+w2*h2-inter)
|
| 163 |
|
| 164 |
-
def track_video(self, video_path, init_bbox,
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
def main():
|
| 216 |
tracker=ObjectTrackerInference('models')
|
| 217 |
-
tracker.track_video('input_video.mp4',[100,100,50,50],'tracked_output.mp4')
|
|
|
|
|
|
|
| 218 |
|
| 219 |
if __name__=="__main__":
|
| 220 |
main()
|
|
|
|
| 3 |
import os
|
| 4 |
import numpy as np
|
| 5 |
|
| 6 |
+
|
| 7 |
class CameraMotionCompensator:
|
| 8 |
def __init__(self):
|
| 9 |
self.prev_frame = None
|
|
|
|
| 42 |
self.prev_desc = desc
|
| 43 |
return M
|
| 44 |
|
| 45 |
+
|
| 46 |
class ImprovedSlidingWindowTracker:
|
| 47 |
def __init__(self, scale_factor=2.0, overlap=0.3):
|
| 48 |
self.scale_factor = scale_factor
|
|
|
|
| 75 |
return windows
|
| 76 |
|
| 77 |
def score_window(self, gray, window, template, template_desc):
|
| 78 |
+
try:
|
| 79 |
+
x,y,w,h = map(int,window)
|
| 80 |
+
if y+h > gray.shape[0] or x+w > gray.shape[1]:
|
| 81 |
+
return 0
|
| 82 |
+
roi = gray[y:y+h,x:x+w]
|
| 83 |
+
if roi.shape[0]<20 or roi.shape[1]<20:
|
| 84 |
+
return 0
|
| 85 |
+
roi = cv2.resize(roi,(template.shape[1],template.shape[0]))
|
| 86 |
+
_,desc = self.sift.detectAndCompute(roi,None)
|
| 87 |
+
if desc is None or template_desc is None or len(desc)==0:
|
| 88 |
+
return 0
|
| 89 |
+
matches = self.flann.knnMatch(template_desc,desc,k=2)
|
| 90 |
+
good = [m for match_pair in matches if len(match_pair)==2 for m,n in [match_pair] if m.distance < 0.7*n.distance]
|
| 91 |
+
if not good:
|
| 92 |
+
return 0
|
| 93 |
+
return len(good)*(1-np.mean([m.distance for m in good])/512)
|
| 94 |
+
except:
|
| 95 |
return 0
|
| 96 |
+
|
| 97 |
|
| 98 |
class ObjectTrackerInference:
|
| 99 |
def __init__(self, model_dir):
|
| 100 |
+
print(f"Loading models from {model_dir}...")
|
| 101 |
self.position_model = joblib.load(os.path.join(model_dir,'position_model.joblib'))
|
| 102 |
self.size_model = joblib.load(os.path.join(model_dir,'size_model.joblib'))
|
| 103 |
self.position_scaler = joblib.load(os.path.join(model_dir,'position_scaler.joblib'))
|
| 104 |
self.size_scaler = joblib.load(os.path.join(model_dir,'size_scaler.joblib'))
|
| 105 |
+
print("Models loaded successfully!")
|
| 106 |
+
|
| 107 |
self.window_tracker = ImprovedSlidingWindowTracker()
|
| 108 |
self.motion = CameraMotionCompensator()
|
| 109 |
self.template = None
|
|
|
|
| 133 |
|
| 134 |
if self.template is None:
|
| 135 |
x,y,w,h = map(int,prev_bbox)
|
| 136 |
+
x = max(0, min(x, gray.shape[1]-w))
|
| 137 |
+
y = max(0, min(y, gray.shape[0]-h))
|
| 138 |
+
self.template = gray[y:y+h,x:x+w].copy()
|
| 139 |
_,self.template_desc = self.window_tracker.sift.detectAndCompute(self.template,None)
|
| 140 |
|
| 141 |
best_score = -1
|
| 142 |
+
best_window = prev_bbox
|
| 143 |
for w in windows:
|
| 144 |
s = self.window_tracker.score_window(gray,w,self.template,self.template_desc)
|
| 145 |
if s > best_score:
|
| 146 |
best_score = s
|
| 147 |
best_window = w
|
| 148 |
|
| 149 |
+
x,y,w,h = map(int,best_window)
|
| 150 |
+
x = max(0, min(x, gray.shape[1]-10))
|
| 151 |
+
y = max(0, min(y, gray.shape[0]-10))
|
| 152 |
+
w = min(w, gray.shape[1]-x)
|
| 153 |
+
h = min(h, gray.shape[0]-y)
|
| 154 |
|
| 155 |
roi = cv2.resize(gray[y:y+h,x:x+w],(64,64))
|
| 156 |
hog = cv2.HOGDescriptor((64,64),(16,16),(8,8),(8,8),9).compute(roi).flatten()[:64]
|
|
|
|
| 175 |
inter=(xr-xl)*(yb-yt)
|
| 176 |
return inter/(w1*h1+w2*h2-inter)
|
| 177 |
|
| 178 |
+
def track_video(self, video_path, init_bbox, output_path='tracked_output.mp4'):
|
| 179 |
+
print(f"Opening video: {video_path}")
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
cap=cv2.VideoCapture(video_path)
|
| 183 |
+
if not cap.isOpened():
|
| 184 |
+
raise ValueError(f"Cannot open video: {video_path}")
|
| 185 |
+
|
| 186 |
+
w,h=int(cap.get(3)),int(cap.get(4))
|
| 187 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 188 |
+
print(f"Video: {w}x{h}, {total_frames} frames")
|
| 189 |
+
|
| 190 |
+
fourcc = cv2.VideoWriter_fourcc(*'avc1')
|
| 191 |
+
out=cv2.VideoWriter(output_path, fourcc, 30, (w,h))
|
| 192 |
+
|
| 193 |
+
if not out.isOpened():
|
| 194 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 195 |
+
out=cv2.VideoWriter(output_path, fourcc, 30, (w,h))
|
| 196 |
+
|
| 197 |
+
# Reset state
|
| 198 |
+
self.motion.prev_frame = None
|
| 199 |
+
self.template = None
|
| 200 |
+
self.template_desc = None
|
| 201 |
+
self.prev_bbox = None
|
| 202 |
+
self.template_update_counter = 0
|
| 203 |
+
|
| 204 |
+
cur=init_bbox
|
| 205 |
+
frame_idx=0
|
| 206 |
+
|
| 207 |
+
print("Starting tracking...")
|
| 208 |
+
while True:
|
| 209 |
+
ret,frame=cap.read()
|
| 210 |
+
if not ret:
|
| 211 |
+
break
|
| 212 |
+
|
| 213 |
+
M=self.motion.estimate_motion(frame)
|
| 214 |
+
feats,search_bbox,windows=self.extract_features(frame,cur,M)
|
| 215 |
+
|
| 216 |
+
pos=self.position_model.predict(self.position_scaler.transform(feats))
|
| 217 |
+
size=self.size_model.predict(self.size_scaler.transform(feats))
|
| 218 |
+
pred=[int(pos[0,0]),int(pos[0,1]),int(size[0,0]),int(size[0,1])]
|
| 219 |
+
|
| 220 |
+
self.template_update_counter+=1
|
| 221 |
+
if self.template_update_counter>=5 and self.prev_bbox is not None:
|
| 222 |
+
if self.calculate_iou(self.prev_bbox,pred)>0.6:
|
| 223 |
+
g=cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)
|
| 224 |
+
x,y,w1,h1=pred
|
| 225 |
+
self.template=g[y:y+h1,x:x+w1].copy()
|
| 226 |
+
_,self.template_desc=self.window_tracker.sift.detectAndCompute(self.template,None)
|
| 227 |
+
self.template_update_counter=0
|
| 228 |
+
|
| 229 |
+
# Draw yellow search windows
|
| 230 |
+
for wx,wy,ww,wh in windows:
|
| 231 |
+
cv2.rectangle(frame,(wx,wy),(wx+ww,wy+wh),(0,255,255),1)
|
| 232 |
+
|
| 233 |
+
# Draw green motion arrows
|
| 234 |
+
hh,ww=frame.shape[:2]
|
| 235 |
+
for yy in range(0,hh,32):
|
| 236 |
+
for xx in range(0,ww,32):
|
| 237 |
+
sp=np.array([xx,yy,1])
|
| 238 |
+
ep=np.dot(M,sp)
|
| 239 |
+
if abs(ep[0]-xx)>1 or abs(ep[1]-yy)>1:
|
| 240 |
+
cv2.arrowedLine(frame,(xx,yy),(int(ep[0]),int(ep[1])),(0,255,0),1,tipLength=0.2)
|
| 241 |
+
|
| 242 |
+
# Draw tracked bounding box
|
| 243 |
+
x,y,w1,h1=pred
|
| 244 |
+
cv2.rectangle(frame,(x,y),(x+w1,y+h1),(0,255,0),2)
|
| 245 |
+
cv2.putText(frame,f'Frame: {frame_idx}',(10,30),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),2)
|
| 246 |
+
|
| 247 |
+
out.write(frame)
|
| 248 |
+
self.prev_bbox=pred
|
| 249 |
+
cur=pred
|
| 250 |
+
frame_idx+=1
|
| 251 |
+
|
| 252 |
+
if frame_idx % 30 == 0:
|
| 253 |
+
print(f"Processed {frame_idx}/{total_frames} frames")
|
| 254 |
+
|
| 255 |
+
cap.release()
|
| 256 |
+
out.release()
|
| 257 |
+
|
| 258 |
+
print(f"Tracking complete! Saved to: {output_path}")
|
| 259 |
+
return output_path
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"Error during tracking: {str(e)}")
|
| 263 |
+
raise
|
| 264 |
+
|
| 265 |
|
| 266 |
def main():
|
| 267 |
tracker=ObjectTrackerInference('models')
|
| 268 |
+
result = tracker.track_video('input_video.mp4',[100,100,50,50],'tracked_output.mp4')
|
| 269 |
+
print(f"Output: {result}")
|
| 270 |
+
|
| 271 |
|
| 272 |
if __name__=="__main__":
|
| 273 |
main()
|