Alfred Liu commited on
Commit
391e2f4
·
1 Parent(s): 3403979

Fix OOM during evaluation (#39)

Browse files
Files changed (1) hide show
  1. models/sparsebev.py +21 -19
models/sparsebev.py CHANGED
@@ -268,8 +268,9 @@ class SparseBEV(MVXTwoStageDetector):
268
  img_metas[0]['ori_shape'] = [img_shape for _ in range(len(img_filenames))]
269
  img_metas[0]['pad_shape'] = [img_shape for _ in range(len(img_filenames))]
270
 
271
- img_feats_large, img_metas_large = [], []
272
 
 
273
  for i in range(num_frames):
274
  img_indices = list(np.arange(i * 6, (i + 1) * 6))
275
 
@@ -279,41 +280,42 @@ class SparseBEV(MVXTwoStageDetector):
279
  img_metas_curr[0][k] = [img_metas[0][k][i] for i in img_indices]
280
 
281
  if img_filenames[img_indices[0]] in self.memory:
 
282
  img_feats_curr = self.memory[img_filenames[img_indices[0]]]
283
  else:
284
- img_curr_large = img[:, i] # [B, 6, C, H, W]
285
- img_feats_curr = self.extract_feat(img_curr_large, img_metas_curr)
286
  self.memory[img_filenames[img_indices[0]]] = img_feats_curr
287
  self.queue.put(img_filenames[img_indices[0]])
 
 
 
288
 
289
- img_feats_large.append(img_feats_curr)
290
- img_metas_large.append(img_metas_curr)
291
 
292
  # reorganize
293
- feat_levels = len(img_feats_large[0])
294
- img_feats_large_reorganized = []
295
  for j in range(feat_levels):
296
- feat_l = torch.cat([img_feats_large[i][j] for i in range(len(img_feats_large))], dim=0)
297
  feat_l = feat_l.flatten(0, 1)[None, ...]
298
- img_feats_large_reorganized.append(feat_l)
299
 
300
- img_metas_large_reorganized = img_metas_large[0]
301
- for i in range(1, len(img_metas_large)):
302
- for k, v in img_metas_large[i][0].items():
303
  if isinstance(v, list):
304
- img_metas_large_reorganized[0][k].extend(v)
305
 
306
- img_feats = img_feats_large_reorganized
307
- img_metas = img_metas_large_reorganized
308
  img_feats = cast_tensor_type(img_feats, torch.half, torch.float32)
309
 
 
310
  bbox_list = [dict() for _ in range(1)]
311
  bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale)
312
  for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
313
  result_dict['pts_bbox'] = pts_bbox
314
 
315
- while self.queue.qsize() >= 16:
316
- pop_key = self.queue.get()
317
- self.memory.pop(pop_key)
318
-
319
  return bbox_list
 
268
  img_metas[0]['ori_shape'] = [img_shape for _ in range(len(img_filenames))]
269
  img_metas[0]['pad_shape'] = [img_shape for _ in range(len(img_filenames))]
270
 
271
+ img_feats_list, img_metas_list = [], []
272
 
273
+ # extract feature frame by frame
274
  for i in range(num_frames):
275
  img_indices = list(np.arange(i * 6, (i + 1) * 6))
276
 
 
280
  img_metas_curr[0][k] = [img_metas[0][k][i] for i in img_indices]
281
 
282
  if img_filenames[img_indices[0]] in self.memory:
283
+ # found in memory
284
  img_feats_curr = self.memory[img_filenames[img_indices[0]]]
285
  else:
286
+ # extract feature and put into memory
287
+ img_feats_curr = self.extract_feat(img[:, i], img_metas_curr)
288
  self.memory[img_filenames[img_indices[0]]] = img_feats_curr
289
  self.queue.put(img_filenames[img_indices[0]])
290
+ while self.queue.qsize() >= 16: # avoid OOM
291
+ pop_key = self.queue.get()
292
+ self.memory.pop(pop_key)
293
 
294
+ img_feats_list.append(img_feats_curr)
295
+ img_metas_list.append(img_metas_curr)
296
 
297
  # reorganize
298
+ feat_levels = len(img_feats_list[0])
299
+ img_feats_reorganized = []
300
  for j in range(feat_levels):
301
+ feat_l = torch.cat([img_feats_list[i][j] for i in range(len(img_feats_list))], dim=0)
302
  feat_l = feat_l.flatten(0, 1)[None, ...]
303
+ img_feats_reorganized.append(feat_l)
304
 
305
+ img_metas_reorganized = img_metas_list[0]
306
+ for i in range(1, len(img_metas_list)):
307
+ for k, v in img_metas_list[i][0].items():
308
  if isinstance(v, list):
309
+ img_metas_reorganized[0][k].extend(v)
310
 
311
+ img_feats = img_feats_reorganized
312
+ img_metas = img_metas_reorganized
313
  img_feats = cast_tensor_type(img_feats, torch.half, torch.float32)
314
 
315
+ # run detector
316
  bbox_list = [dict() for _ in range(1)]
317
  bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale)
318
  for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
319
  result_dict['pts_bbox'] = pts_bbox
320
 
 
 
 
 
321
  return bbox_list