ayushsaun commited on
Commit
b4c228c
·
1 Parent(s): 0ac3063

fixed inference bug

Browse files
Files changed (1) hide show
  1. inference.py +122 -69
inference.py CHANGED
@@ -3,6 +3,7 @@ import joblib
3
  import os
4
  import numpy as np
5
 
 
6
  class CameraMotionCompensator:
7
  def __init__(self):
8
  self.prev_frame = None
@@ -41,6 +42,7 @@ class CameraMotionCompensator:
41
  self.prev_desc = desc
42
  return M
43
 
 
44
  class ImprovedSlidingWindowTracker:
45
  def __init__(self, scale_factor=2.0, overlap=0.3):
46
  self.scale_factor = scale_factor
@@ -73,26 +75,35 @@ class ImprovedSlidingWindowTracker:
73
  return windows
74
 
75
  def score_window(self, gray, window, template, template_desc):
76
- x,y,w,h = map(int,window)
77
- roi = gray[y:y+h,x:x+w]
78
- if roi.shape[0]<20 or roi.shape[1]<20:
79
- return 0
80
- roi = cv2.resize(roi,(template.shape[1],template.shape[0]))
81
- _,desc = self.sift.detectAndCompute(roi,None)
82
- if desc is None or template_desc is None:
83
- return 0
84
- matches = self.flann.knnMatch(template_desc,desc,k=2)
85
- good = [m for m,n in matches if m.distance < 0.7*n.distance]
86
- if not good:
 
 
 
 
 
 
87
  return 0
88
- return len(good)*(1-np.mean([m.distance for m in good])/512)
89
 
90
  class ObjectTrackerInference:
91
  def __init__(self, model_dir):
 
92
  self.position_model = joblib.load(os.path.join(model_dir,'position_model.joblib'))
93
  self.size_model = joblib.load(os.path.join(model_dir,'size_model.joblib'))
94
  self.position_scaler = joblib.load(os.path.join(model_dir,'position_scaler.joblib'))
95
  self.size_scaler = joblib.load(os.path.join(model_dir,'size_scaler.joblib'))
 
 
96
  self.window_tracker = ImprovedSlidingWindowTracker()
97
  self.motion = CameraMotionCompensator()
98
  self.template = None
@@ -122,21 +133,24 @@ class ObjectTrackerInference:
122
 
123
  if self.template is None:
124
  x,y,w,h = map(int,prev_bbox)
125
- self.template = gray[y:y+h,x:x+w]
 
 
126
  _,self.template_desc = self.window_tracker.sift.detectAndCompute(self.template,None)
127
 
128
  best_score = -1
129
- best_window = None
130
  for w in windows:
131
  s = self.window_tracker.score_window(gray,w,self.template,self.template_desc)
132
  if s > best_score:
133
  best_score = s
134
  best_window = w
135
 
136
- if best_window is None:
137
- x,y,w,h = map(int,prev_bbox)
138
- else:
139
- x,y,w,h = map(int,best_window)
 
140
 
141
  roi = cv2.resize(gray[y:y+h,x:x+w],(64,64))
142
  hog = cv2.HOGDescriptor((64,64),(16,16),(8,8),(8,8),9).compute(roi).flatten()[:64]
@@ -161,60 +175,99 @@ class ObjectTrackerInference:
161
  inter=(xr-xl)*(yb-yt)
162
  return inter/(w1*h1+w2*h2-inter)
163
 
164
- def track_video(self, video_path, init_bbox, output):
165
- cap=cv2.VideoCapture(video_path)
166
- w,h=int(cap.get(3)),int(cap.get(4))
167
- out=cv2.VideoWriter(output,cv2.VideoWriter_fourcc(*'mp4v'),30,(w,h))
168
- cur=init_bbox
169
- frame_idx=0
170
-
171
- while True:
172
- ret,frame=cap.read()
173
- if not ret:
174
- break
175
-
176
- M=self.motion.estimate_motion(frame)
177
- feats,search_bbox,windows=self.extract_features(frame,cur,M)
178
-
179
- pos=self.position_model.predict(self.position_scaler.transform(feats))
180
- size=self.size_model.predict(self.size_scaler.transform(feats))
181
- pred=[int(pos[0,0]),int(pos[0,1]),int(size[0,0]),int(size[0,1])]
182
-
183
- self.template_update_counter+=1
184
- if self.template_update_counter>=5 and self.prev_bbox is not None:
185
- if self.calculate_iou(self.prev_bbox,pred)>0.6:
186
- g=cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)
187
- x,y,w1,h1=pred
188
- self.template=g[y:y+h1,x:x+w1]
189
- _,self.template_desc=self.window_tracker.sift.detectAndCompute(self.template,None)
190
- self.template_update_counter=0
191
-
192
- for wx,wy,ww,wh in windows:
193
- cv2.rectangle(frame,(wx,wy),(wx+ww,wy+wh),(0,255,255),1)
194
-
195
- hh,ww=frame.shape[:2]
196
- for yy in range(0,hh,32):
197
- for xx in range(0,ww,32):
198
- sp=np.array([xx,yy,1])
199
- ep=np.dot(M,sp)
200
- if abs(ep[0]-xx)>1 or abs(ep[1]-yy)>1:
201
- cv2.arrowedLine(frame,(xx,yy),(int(ep[0]),int(ep[1])),(0,255,0),1,tipLength=0.2)
202
-
203
- x,y,w1,h1=pred
204
- cv2.rectangle(frame,(x,y),(x+w1,y+h1),(0,255,0),2)
205
- cv2.putText(frame,f'Frame: {frame_idx}',(10,30),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),2)
206
-
207
- out.write(frame)
208
- self.prev_bbox=pred
209
- cur=pred
210
- frame_idx+=1
211
-
212
- cap.release()
213
- out.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  def main():
216
  tracker=ObjectTrackerInference('models')
217
- tracker.track_video('input_video.mp4',[100,100,50,50],'tracked_output.mp4')
 
 
218
 
219
  if __name__=="__main__":
220
  main()
 
3
  import os
4
  import numpy as np
5
 
6
+
7
  class CameraMotionCompensator:
8
  def __init__(self):
9
  self.prev_frame = None
 
42
  self.prev_desc = desc
43
  return M
44
 
45
+
46
  class ImprovedSlidingWindowTracker:
47
  def __init__(self, scale_factor=2.0, overlap=0.3):
48
  self.scale_factor = scale_factor
 
75
  return windows
76
 
77
  def score_window(self, gray, window, template, template_desc):
78
+ try:
79
+ x,y,w,h = map(int,window)
80
+ if y+h > gray.shape[0] or x+w > gray.shape[1]:
81
+ return 0
82
+ roi = gray[y:y+h,x:x+w]
83
+ if roi.shape[0]<20 or roi.shape[1]<20:
84
+ return 0
85
+ roi = cv2.resize(roi,(template.shape[1],template.shape[0]))
86
+ _,desc = self.sift.detectAndCompute(roi,None)
87
+ if desc is None or template_desc is None or len(desc)==0:
88
+ return 0
89
+ matches = self.flann.knnMatch(template_desc,desc,k=2)
90
+ good = [m for match_pair in matches if len(match_pair)==2 for m,n in [match_pair] if m.distance < 0.7*n.distance]
91
+ if not good:
92
+ return 0
93
+ return len(good)*(1-np.mean([m.distance for m in good])/512)
94
+ except:
95
  return 0
96
+
97
 
98
  class ObjectTrackerInference:
99
  def __init__(self, model_dir):
100
+ print(f"Loading models from {model_dir}...")
101
  self.position_model = joblib.load(os.path.join(model_dir,'position_model.joblib'))
102
  self.size_model = joblib.load(os.path.join(model_dir,'size_model.joblib'))
103
  self.position_scaler = joblib.load(os.path.join(model_dir,'position_scaler.joblib'))
104
  self.size_scaler = joblib.load(os.path.join(model_dir,'size_scaler.joblib'))
105
+ print("Models loaded successfully!")
106
+
107
  self.window_tracker = ImprovedSlidingWindowTracker()
108
  self.motion = CameraMotionCompensator()
109
  self.template = None
 
133
 
134
  if self.template is None:
135
  x,y,w,h = map(int,prev_bbox)
136
+ x = max(0, min(x, gray.shape[1]-w))
137
+ y = max(0, min(y, gray.shape[0]-h))
138
+ self.template = gray[y:y+h,x:x+w].copy()
139
  _,self.template_desc = self.window_tracker.sift.detectAndCompute(self.template,None)
140
 
141
  best_score = -1
142
+ best_window = prev_bbox
143
  for w in windows:
144
  s = self.window_tracker.score_window(gray,w,self.template,self.template_desc)
145
  if s > best_score:
146
  best_score = s
147
  best_window = w
148
 
149
+ x,y,w,h = map(int,best_window)
150
+ x = max(0, min(x, gray.shape[1]-10))
151
+ y = max(0, min(y, gray.shape[0]-10))
152
+ w = min(w, gray.shape[1]-x)
153
+ h = min(h, gray.shape[0]-y)
154
 
155
  roi = cv2.resize(gray[y:y+h,x:x+w],(64,64))
156
  hog = cv2.HOGDescriptor((64,64),(16,16),(8,8),(8,8),9).compute(roi).flatten()[:64]
 
175
  inter=(xr-xl)*(yb-yt)
176
  return inter/(w1*h1+w2*h2-inter)
177
 
178
+ def track_video(self, video_path, init_bbox, output_path='tracked_output.mp4'):
179
+ print(f"Opening video: {video_path}")
180
+
181
+ try:
182
+ cap=cv2.VideoCapture(video_path)
183
+ if not cap.isOpened():
184
+ raise ValueError(f"Cannot open video: {video_path}")
185
+
186
+ w,h=int(cap.get(3)),int(cap.get(4))
187
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
188
+ print(f"Video: {w}x{h}, {total_frames} frames")
189
+
190
+ fourcc = cv2.VideoWriter_fourcc(*'avc1')
191
+ out=cv2.VideoWriter(output_path, fourcc, 30, (w,h))
192
+
193
+ if not out.isOpened():
194
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
195
+ out=cv2.VideoWriter(output_path, fourcc, 30, (w,h))
196
+
197
+ # Reset state
198
+ self.motion.prev_frame = None
199
+ self.template = None
200
+ self.template_desc = None
201
+ self.prev_bbox = None
202
+ self.template_update_counter = 0
203
+
204
+ cur=init_bbox
205
+ frame_idx=0
206
+
207
+ print("Starting tracking...")
208
+ while True:
209
+ ret,frame=cap.read()
210
+ if not ret:
211
+ break
212
+
213
+ M=self.motion.estimate_motion(frame)
214
+ feats,search_bbox,windows=self.extract_features(frame,cur,M)
215
+
216
+ pos=self.position_model.predict(self.position_scaler.transform(feats))
217
+ size=self.size_model.predict(self.size_scaler.transform(feats))
218
+ pred=[int(pos[0,0]),int(pos[0,1]),int(size[0,0]),int(size[0,1])]
219
+
220
+ self.template_update_counter+=1
221
+ if self.template_update_counter>=5 and self.prev_bbox is not None:
222
+ if self.calculate_iou(self.prev_bbox,pred)>0.6:
223
+ g=cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)
224
+ x,y,w1,h1=pred
225
+ self.template=g[y:y+h1,x:x+w1].copy()
226
+ _,self.template_desc=self.window_tracker.sift.detectAndCompute(self.template,None)
227
+ self.template_update_counter=0
228
+
229
+ # Draw yellow search windows
230
+ for wx,wy,ww,wh in windows:
231
+ cv2.rectangle(frame,(wx,wy),(wx+ww,wy+wh),(0,255,255),1)
232
+
233
+ # Draw green motion arrows
234
+ hh,ww=frame.shape[:2]
235
+ for yy in range(0,hh,32):
236
+ for xx in range(0,ww,32):
237
+ sp=np.array([xx,yy,1])
238
+ ep=np.dot(M,sp)
239
+ if abs(ep[0]-xx)>1 or abs(ep[1]-yy)>1:
240
+ cv2.arrowedLine(frame,(xx,yy),(int(ep[0]),int(ep[1])),(0,255,0),1,tipLength=0.2)
241
+
242
+ # Draw tracked bounding box
243
+ x,y,w1,h1=pred
244
+ cv2.rectangle(frame,(x,y),(x+w1,y+h1),(0,255,0),2)
245
+ cv2.putText(frame,f'Frame: {frame_idx}',(10,30),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),2)
246
+
247
+ out.write(frame)
248
+ self.prev_bbox=pred
249
+ cur=pred
250
+ frame_idx+=1
251
+
252
+ if frame_idx % 30 == 0:
253
+ print(f"Processed {frame_idx}/{total_frames} frames")
254
+
255
+ cap.release()
256
+ out.release()
257
+
258
+ print(f"Tracking complete! Saved to: {output_path}")
259
+ return output_path
260
+
261
+ except Exception as e:
262
+ print(f"Error during tracking: {str(e)}")
263
+ raise
264
+
265
 
266
  def main():
267
  tracker=ObjectTrackerInference('models')
268
+ result = tracker.track_video('input_video.mp4',[100,100,50,50],'tracked_output.mp4')
269
+ print(f"Output: {result}")
270
+
271
 
272
  if __name__=="__main__":
273
  main()