Spaces:
Sleeping
Sleeping
Revert model offloading (was not requested, slows computation)
Browse files- gradio_app.py +1 -14
gradio_app.py
CHANGED
|
@@ -142,11 +142,7 @@ def generate_masks_sam3(frames, text_prompt):
|
|
| 142 |
else:
|
| 143 |
combined = np.zeros((H,W), dtype=np.uint8)
|
| 144 |
raw_masks[model_out.frame_idx] = combined
|
| 145 |
-
|
| 146 |
-
# Offload SAM3 to CPU after use to free VRAM
|
| 147 |
-
sam3_model.to("cpu")
|
| 148 |
-
torch.cuda.empty_cache()
|
| 149 |
-
return result
|
| 150 |
|
| 151 |
# ── First-frame helpers ──
|
| 152 |
|
|
@@ -204,7 +200,6 @@ def infer(input_video, text_prompt, x1, y1, x2, y2, use_propainter):
|
|
| 204 |
print(f" BBox filter: x1={bx1} y1={by1} x2={bx2} y2={by2}")
|
| 205 |
|
| 206 |
print(f"[2/6] SAM3 detecting '{text_prompt}'...")
|
| 207 |
-
sam3_model.to(device) # move back to GPU for inference
|
| 208 |
dilated_masks = generate_masks_sam3(frames, text_prompt.strip())
|
| 209 |
|
| 210 |
if filter_bbox:
|
|
@@ -251,18 +246,10 @@ def infer(input_video, text_prompt, x1, y1, x2, y2, use_propainter):
|
|
| 251 |
repaired_path = os.path.join(save_path, f"repaired_{ts}.mp4")
|
| 252 |
t0 = time.time()
|
| 253 |
if use_propainter:
|
| 254 |
-
propainter.fix_raft.to(device)
|
| 255 |
-
propainter.fix_flow_complete.to(device)
|
| 256 |
-
propainter.model.to(device)
|
| 257 |
propainter.forward(crop_video_path, crop_mask_path, priori_path,
|
| 258 |
resize_ratio=1.0,
|
| 259 |
video_length=video_length, ref_stride=10,
|
| 260 |
neighbor_length=10, subvideo_length=50, mask_dilation=8)
|
| 261 |
-
# Offload ProPainter to CPU after use
|
| 262 |
-
propainter.fix_raft.to("cpu")
|
| 263 |
-
propainter.fix_flow_complete.to("cpu")
|
| 264 |
-
propainter.model.to("cpu")
|
| 265 |
-
torch.cuda.empty_cache()
|
| 266 |
else:
|
| 267 |
import shutil
|
| 268 |
shutil.copy2(crop_video_path, priori_path)
|
|
|
|
| 142 |
else:
|
| 143 |
combined = np.zeros((H,W), dtype=np.uint8)
|
| 144 |
raw_masks[model_out.frame_idx] = combined
|
| 145 |
+
return [dilate_mask(raw_masks.get(i, np.zeros((H,W),dtype=np.uint8))) for i in range(len(frames))]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
# ── First-frame helpers ──
|
| 148 |
|
|
|
|
| 200 |
print(f" BBox filter: x1={bx1} y1={by1} x2={bx2} y2={by2}")
|
| 201 |
|
| 202 |
print(f"[2/6] SAM3 detecting '{text_prompt}'...")
|
|
|
|
| 203 |
dilated_masks = generate_masks_sam3(frames, text_prompt.strip())
|
| 204 |
|
| 205 |
if filter_bbox:
|
|
|
|
| 246 |
repaired_path = os.path.join(save_path, f"repaired_{ts}.mp4")
|
| 247 |
t0 = time.time()
|
| 248 |
if use_propainter:
|
|
|
|
|
|
|
|
|
|
| 249 |
propainter.forward(crop_video_path, crop_mask_path, priori_path,
|
| 250 |
resize_ratio=1.0,
|
| 251 |
video_length=video_length, ref_stride=10,
|
| 252 |
neighbor_length=10, subvideo_length=50, mask_dilation=8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
else:
|
| 254 |
import shutil
|
| 255 |
shutil.copy2(crop_video_path, priori_path)
|