jing96963 commited on
Commit
caba23b
·
1 Parent(s): 8ad0a83

Revert model offloading (was not requested, slows computation)

Browse files
Files changed (1) hide show
  1. gradio_app.py +1 -14
gradio_app.py CHANGED
@@ -142,11 +142,7 @@ def generate_masks_sam3(frames, text_prompt):
142
  else:
143
  combined = np.zeros((H,W), dtype=np.uint8)
144
  raw_masks[model_out.frame_idx] = combined
145
- result = [dilate_mask(raw_masks.get(i, np.zeros((H,W),dtype=np.uint8))) for i in range(len(frames))]
146
- # Offload SAM3 to CPU after use to free VRAM
147
- sam3_model.to("cpu")
148
- torch.cuda.empty_cache()
149
- return result
150
 
151
  # ── First-frame helpers ──
152
 
@@ -204,7 +200,6 @@ def infer(input_video, text_prompt, x1, y1, x2, y2, use_propainter):
204
  print(f" BBox filter: x1={bx1} y1={by1} x2={bx2} y2={by2}")
205
 
206
  print(f"[2/6] SAM3 detecting '{text_prompt}'...")
207
- sam3_model.to(device) # move back to GPU for inference
208
  dilated_masks = generate_masks_sam3(frames, text_prompt.strip())
209
 
210
  if filter_bbox:
@@ -251,18 +246,10 @@ def infer(input_video, text_prompt, x1, y1, x2, y2, use_propainter):
251
  repaired_path = os.path.join(save_path, f"repaired_{ts}.mp4")
252
  t0 = time.time()
253
  if use_propainter:
254
- propainter.fix_raft.to(device)
255
- propainter.fix_flow_complete.to(device)
256
- propainter.model.to(device)
257
  propainter.forward(crop_video_path, crop_mask_path, priori_path,
258
  resize_ratio=1.0,
259
  video_length=video_length, ref_stride=10,
260
  neighbor_length=10, subvideo_length=50, mask_dilation=8)
261
- # Offload ProPainter to CPU after use
262
- propainter.fix_raft.to("cpu")
263
- propainter.fix_flow_complete.to("cpu")
264
- propainter.model.to("cpu")
265
- torch.cuda.empty_cache()
266
  else:
267
  import shutil
268
  shutil.copy2(crop_video_path, priori_path)
 
142
  else:
143
  combined = np.zeros((H,W), dtype=np.uint8)
144
  raw_masks[model_out.frame_idx] = combined
145
+ return [dilate_mask(raw_masks.get(i, np.zeros((H,W),dtype=np.uint8))) for i in range(len(frames))]
 
 
 
 
146
 
147
  # ── First-frame helpers ──
148
 
 
200
  print(f" BBox filter: x1={bx1} y1={by1} x2={bx2} y2={by2}")
201
 
202
  print(f"[2/6] SAM3 detecting '{text_prompt}'...")
 
203
  dilated_masks = generate_masks_sam3(frames, text_prompt.strip())
204
 
205
  if filter_bbox:
 
246
  repaired_path = os.path.join(save_path, f"repaired_{ts}.mp4")
247
  t0 = time.time()
248
  if use_propainter:
 
 
 
249
  propainter.forward(crop_video_path, crop_mask_path, priori_path,
250
  resize_ratio=1.0,
251
  video_length=video_length, ref_stride=10,
252
  neighbor_length=10, subvideo_length=50, mask_dilation=8)
 
 
 
 
 
253
  else:
254
  import shutil
255
  shutil.copy2(crop_video_path, priori_path)