MogensR commited on
Commit
ec9ba45
·
verified ·
1 Parent(s): a43d32f

Delete pipeline/two_stage_pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline/two_stage_pipeline.py +0 -388
pipeline/two_stage_pipeline.py DELETED
@@ -1,388 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- two_stage_pipeline.py — Ephemeral SAM2 stage + MatAnyone stage
4
- - Stage 1: SAM2 -> lossless mask stream (FFV1 .mkv) + meta.json, then unload SAM2
5
- - Stage 2: read mask stream -> (optional) MatAnyone refine -> composite -> mux audio
6
- """
7
-
8
- import os, sys, gc, json, cv2, time, uuid, torch, shutil, logging, subprocess, threading
9
- import numpy as np
10
- from pathlib import Path
11
- from typing import Optional, Callable, Tuple, Dict, Any
12
- from PIL import Image
13
-
14
- logger = logging.getLogger("backgroundfx_pro.two_stage")
15
- if not logger.handlers:
16
- h = logging.StreamHandler()
17
- h.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s: %(message)s"))
18
- logger.addHandler(h)
19
- logger.setLevel(logging.INFO)
20
-
21
- # ---------------------------
22
- # Env & CUDA helpers
23
- # ---------------------------
24
- def setup_env():
25
- os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF","expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7")
26
- os.environ.setdefault("OMP_NUM_THREADS","1")
27
- os.environ.setdefault("OPENBLAS_NUM_THREADS","1")
28
- os.environ.setdefault("MKL_NUM_THREADS","1")
29
- torch.set_grad_enabled(False)
30
- try:
31
- torch.backends.cudnn.benchmark = True
32
- torch.backends.cuda.matmul.allow_tf32 = True
33
- torch.backends.cudnn.allow_tf32 = True
34
- torch.set_float32_matmul_precision("high")
35
- except Exception:
36
- pass
37
- if torch.cuda.is_available():
38
- try:
39
- torch.cuda.set_per_process_memory_fraction(float(os.getenv("CUDA_MEMORY_FRACTION","0.88")))
40
- except Exception:
41
- pass
42
-
43
- def free_cuda():
44
- if torch.cuda.is_available():
45
- torch.cuda.ipc_collect()
46
- torch.cuda.empty_cache()
47
-
48
- def unload_sam2_modules():
49
- """Aggressively unload SAM2 python modules to reduce RSS."""
50
- try:
51
- import importlib
52
- mods = [m for m in list(sys.modules) if m.startswith("sam2")]
53
- for m in mods:
54
- sys.modules.pop(m, None)
55
- importlib.invalidate_caches()
56
- gc.collect()
57
- free_cuda()
58
- logger.info("SAM2 modules unloaded.")
59
- except Exception as e:
60
- logger.warning(f"Unloading SAM2 modules: {e}")
61
-
62
- # ---------------------------
63
- # Video probing
64
- # ---------------------------
65
- def probe_video(path:str) -> Tuple[int,int,float,int]:
66
- cap = cv2.VideoCapture(path)
67
- if not cap.isOpened():
68
- raise RuntimeError(f"Cannot open video: {path}")
69
- fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
70
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
71
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
72
- n = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
73
- cap.release()
74
- return w,h,float(fps),n
75
-
76
- # ---------------------------
77
- # FFmpeg mask writers/readers
78
- # ---------------------------
79
- class MaskFFV1Writer:
80
- """Write uint8 binary/gray masks to FFV1 lossless .mkv via pipe."""
81
- def __init__(self, path:str, w:int, h:int, fps:float):
82
- self.path = path
83
- self.w, self.h, self.fps = w,h,fps
84
- self.proc = None
85
-
86
- def __enter__(self):
87
- cmd = [
88
- "ffmpeg","-y","-hide_banner","-loglevel","error",
89
- "-f","rawvideo","-pix_fmt","gray","-s",f"{self.w}x{self.h}","-r",f"{self.fps}",
90
- "-i","-",
91
- "-c:v","ffv1","-level","3","-g","1", self.path
92
- ]
93
- self.proc = subprocess.Popen(cmd, stdin=subprocess.PIPE)
94
- return self
95
-
96
- def write(self, mask_u8: np.ndarray):
97
- # Expect HxW uint8 (0/255). Ensure contiguous.
98
- if mask_u8.dtype != np.uint8:
99
- mask_u8 = mask_u8.astype(np.uint8)
100
- self.proc.stdin.write(mask_u8.tobytes())
101
-
102
- def __exit__(self, exc_type, exc, tb):
103
- if self.proc:
104
- try:
105
- self.proc.stdin.flush()
106
- self.proc.stdin.close()
107
- self.proc.wait(timeout=120)
108
- except Exception:
109
- self.proc.kill()
110
-
111
- class MaskFFV1Reader:
112
- """Read uint8 masks from FFV1 .mkv via pipe."""
113
- def __init__(self, path:str, w:int, h:int):
114
- self.path = path
115
- self.w,self.h = w,h
116
- self.proc = None
117
- self.frame_bytes = w*h
118
-
119
- def __enter__(self):
120
- cmd = [
121
- "ffmpeg","-hide_banner","-loglevel","error","-i", self.path,
122
- "-f","rawvideo","-pix_fmt","gray","-"
123
- ]
124
- self.proc = subprocess.Popen(cmd, stdout=subprocess.PIPE)
125
- return self
126
-
127
- def read(self) -> Optional[np.ndarray]:
128
- buf = self.proc.stdout.read(self.frame_bytes)
129
- if not buf or len(buf) < self.frame_bytes:
130
- return None
131
- return np.frombuffer(buf, dtype=np.uint8).reshape(self.h, self.w)
132
-
133
- def __exit__(self, exc_type, exc, tb):
134
- if self.proc:
135
- try:
136
- self.proc.stdout.close()
137
- self.proc.wait(timeout=30)
138
- except Exception:
139
- self.proc.kill()
140
-
141
- # Fallback: PNG sequence (disk heavy but simple & robust)
142
- class MaskPNGWriter:
143
- def __init__(self, dirpath: Path):
144
- self.dir = dirpath; self.dir.mkdir(parents=True, exist_ok=True); self.idx=0
145
- def write(self, mask_u8: np.ndarray):
146
- cv2.imwrite(str(self.dir / f"{self.idx:06d}.png"), mask_u8)
147
- self.idx+=1
148
-
149
- class MaskPNGReader:
150
- def __init__(self, dirpath: Path):
151
- self.dir=dirpath; self.idx=0
152
- def read(self) -> Optional[np.ndarray]:
153
- p = self.dir / f"{self.idx:06d}.png"
154
- if not p.exists(): return None
155
- img = cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)
156
- self.idx+=1
157
- return img
158
-
159
- # ---------------------------
160
- # Stage 1 — SAM2 → mask dump
161
- # ---------------------------
162
- def stage1_dump_masks(video_path:str, out_dir:Path, obj_point:Tuple[int,int]=None) -> Dict[str,Any]:
163
- """
164
- Run only SAM2, save masks as FFV1 (preferred) or PNG sequence + meta.json.
165
- Returns meta dict.
166
- """
167
- setup_env()
168
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169
- w,h,fps,n = probe_video(video_path)
170
- out_dir.mkdir(parents=True, exist_ok=True)
171
- meta = {"video":video_path, "width":w,"height":h,"fps":fps,"frames":n, "storage":None}
172
- logger.info(f"[Stage1] {w}x{h}@{fps:.2f} | frames={n}")
173
-
174
- # Load SAM2 (your wrapper)
175
- from models.sam2_loader import SAM2Predictor
176
- predictor = SAM2Predictor(device=device)
177
- state = predictor.init_state(video_path=video_path)
178
-
179
- # Prompt: center positive if not provided
180
- if obj_point is None:
181
- obj_point = (w//2, h//2)
182
- pts = np.array([[obj_point[0], obj_point[1]]], dtype=np.float32)
183
- labels = np.array([1], dtype=np.int32)
184
- ann_obj_id = 1
185
- with torch.inference_mode():
186
- predictor.add_new_points(state, 0, ann_obj_id, pts, labels)
187
-
188
- # Preferred: FFV1 mask stream
189
- mask_mkv = out_dir / "mask.mkv"
190
- use_png = False
191
- try:
192
- with MaskFFV1Writer(str(mask_mkv), w, h, fps) as writer, \
193
- torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16 if device.type=="cuda" else None):
194
- for _, out_ids, out_logits in predictor.propagate_in_video(state):
195
- # pick ann_obj_id
196
- i = None
197
- if isinstance(out_ids, torch.Tensor):
198
- nz = (out_ids == ann_obj_id).nonzero(as_tuple=False)
199
- if nz.numel() > 0: i = nz[0].item()
200
- else:
201
- ids = list(out_ids); i = ids.index(ann_obj_id) if ann_obj_id in ids else None
202
- if i is None:
203
- # write empty
204
- writer.write(np.zeros((h,w), np.uint8))
205
- continue
206
- mask = (out_logits[i] > 0).detach()
207
- mask_u8 = (mask.float().mul_(255).to("cpu", non_blocking=True).numpy()).astype(np.uint8)
208
- writer.write(mask_u8)
209
- meta["storage"] = "ffv1"
210
- meta["mask_path"] = str(mask_mkv)
211
- logger.info("[Stage1] Masks saved as FFV1 .mkv")
212
- except Exception as e:
213
- logger.warning(f"FFV1 writer failed ({e}), falling back to PNG sequence.")
214
- png_dir = out_dir / "masks_png"
215
- wr = MaskPNGWriter(png_dir)
216
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16 if device.type=="cuda" else None):
217
- for _, out_ids, out_logits in predictor.propagate_in_video(state):
218
- i = None
219
- if isinstance(out_ids, torch.Tensor):
220
- nz = (out_ids == ann_obj_id).nonzero(as_tuple=False)
221
- if nz.numel() > 0: i = nz[0].item()
222
- else:
223
- ids = list(out_ids); i = ids.index(ann_obj_id) if ann_obj_id in ids else None
224
- if i is None:
225
- wr.write(np.zeros((h,w), np.uint8)); continue
226
- mask = (out_logits[i] > 0).detach()
227
- wr.write((mask.float().mul_(255).to("cpu").numpy()).astype(np.uint8))
228
- meta["storage"] = "png"
229
- meta["mask_path"] = str(png_dir)
230
-
231
- # Persist meta
232
- with open(out_dir / "meta.json","w") as f:
233
- json.dump(meta, f)
234
- # Unload SAM2 completely
235
- del predictor, state
236
- free_cuda(); unload_sam2_modules()
237
- return meta
238
-
239
- # ---------------------------
240
- # Stage 2 — refine + compose
241
- # ---------------------------
242
- def stage2_refine_and_compose(video_path:str, mask_dir:Path, background_image:Image.Image,
243
- out_path:str, use_matany:bool=True) -> str:
244
- w,h,fps,n = probe_video(video_path)
245
- bg = background_image.resize((w,h), Image.LANCZOS)
246
- bg_np = np.array(bg).astype(np.float32)
247
-
248
- # Read meta
249
- with open(mask_dir / "meta.json","r") as f:
250
- meta = json.load(f)
251
- storage = meta["storage"]; mask_path = meta["mask_path"]
252
-
253
- # Optional MatAnyone
254
- session = None
255
- if use_matany:
256
- try:
257
- from models.matanyone_loader import MatAnyoneSession as _M
258
- except Exception:
259
- try:
260
- from models.matanyone_loader import MatAnyoneLoader as _M
261
- except Exception:
262
- _M = None
263
- if _M:
264
- session = _M(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
265
- if hasattr(session,"model") and session.model is not None:
266
- session.model.eval()
267
-
268
- # Open video + writer
269
- cap = cv2.VideoCapture(video_path)
270
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
271
- tmp_out = str(Path(out_path).with_suffix(".noaudio.mp4"))
272
- writer = cv2.VideoWriter(tmp_out, fourcc, fps, (w,h))
273
-
274
- # Open mask reader
275
- if storage == "ffv1":
276
- mreader = MaskFFV1Reader(mask_path, w, h)
277
- mreader.__enter__()
278
- read_mask = lambda : mreader.read()
279
- else:
280
- mreader = MaskPNGReader(Path(mask_path))
281
- read_mask = lambda : mreader.read()
282
-
283
- i = 0
284
- try:
285
- while True:
286
- ok, frame_bgr = cap.read()
287
- if not ok: break
288
- mask_u8 = read_mask()
289
- if mask_u8 is None:
290
- # out of masks; write original
291
- writer.write(frame_bgr); i+=1; continue
292
-
293
- # Optional refine
294
- if session is not None:
295
- try:
296
- frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
297
- # Provide a float mask 0..1 to session; adapt if your API differs
298
- mask_f = (mask_u8.astype(np.float32) / 255.0)
299
- if hasattr(session,"refine_mask"):
300
- mask_refined = session.refine_mask(frame_rgb, mask_f)
301
- elif hasattr(session,"process_frame"):
302
- mask_refined = session.process_frame(frame_rgb, mask_f)
303
- else:
304
- mask_refined = mask_f
305
- if isinstance(mask_refined, torch.Tensor):
306
- mask_u8 = (mask_refined.detach().clamp(0,1).mul(255).to("cpu").numpy()).astype(np.uint8)
307
- elif isinstance(mask_refined, np.ndarray):
308
- mask_u8 = (np.clip(mask_refined,0,1)*255).astype(np.uint8)
309
- except Exception as e:
310
- logger.debug(f"MatAnyone refine failed @frame {i}: {e}")
311
-
312
- # Composite
313
- m = (mask_u8.astype(np.float32)/255.0)[...,None] # HxWx1
314
- fr = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB).astype(np.float32)
315
- comp = fr*m + bg_np*(1.0-m)
316
- comp_bgr = cv2.cvtColor(comp.astype(np.uint8), cv2.COLOR_RGB2BGR)
317
- writer.write(comp_bgr)
318
-
319
- if i % 50 == 0:
320
- logger.info(f"[Stage2] frame {i}/{n}")
321
- i += 1
322
- finally:
323
- cap.release(); writer.release()
324
- if isinstance(mreader, MaskFFV1Reader):
325
- mreader.__exit__(None,None,None)
326
-
327
- # Mux audio
328
- final_out = str(Path(out_path))
329
- cmd = [
330
- "ffmpeg","-y","-hide_banner","-loglevel","error",
331
- "-i", tmp_out, "-i", video_path,
332
- "-map","0:v:0","-map","1:a:0","-c:v","copy","-c:a","aac","-shortest", final_out
333
- ]
334
- try:
335
- r = subprocess.run(cmd, capture_output=True, text=True, timeout=180)
336
- if r.returncode != 0:
337
- logger.warning(f"Audio mux failed: {r.stderr.strip()}")
338
- shutil.move(tmp_out, final_out)
339
- else:
340
- os.remove(tmp_out)
341
- except Exception:
342
- shutil.move(tmp_out, final_out)
343
- return final_out
344
-
345
- # ---------------------------
346
- # Orchestrator
347
- # ---------------------------
348
- def process_two_stage(
349
- video_path:str,
350
- background_image: Image.Image,
351
- workdir: Optional[Path]=None,
352
- progress: Optional[Callable[[str,float],None]] = None,
353
- use_matany: bool = True,
354
- ) -> str:
355
- setup_env()
356
- if workdir is None:
357
- workdir = Path.cwd()/ "tmp" / f"job_{uuid.uuid4().hex[:8]}"
358
- workdir.mkdir(parents=True, exist_ok=True)
359
-
360
- # Stage 1
361
- if progress: progress("Stage 1: SAM2 mask pass", 0.05)
362
- mask_dir = workdir / "sam2_masks"
363
- meta = stage1_dump_masks(video_path, mask_dir)
364
- if progress: progress("Stage 1 complete", 0.45)
365
-
366
- # Stage 2
367
- if progress: progress("Stage 2: refine + compose", 0.50)
368
- out_path = workdir / f"final_{int(time.time())}.mp4"
369
- final_video = stage2_refine_and_compose(video_path, mask_dir, background_image, str(out_path), use_matany=use_matany)
370
- if progress: progress("Done", 1.0)
371
- logger.info(f"Output: {final_video}")
372
- return final_video
373
-
374
- # ---------------------------
375
- # CLI
376
- # ---------------------------
377
- if __name__ == "__main__":
378
- import argparse
379
- parser = argparse.ArgumentParser(description="Two-stage BackgroundFX Pro")
380
- parser.add_argument("--video", required=True)
381
- parser.add_argument("--background", required=True)
382
- parser.add_argument("--outdir", default=None)
383
- parser.add_argument("--no-matany", action="store_true")
384
- args = parser.parse_args()
385
-
386
- bg = Image.open(args.background).convert("RGB")
387
- out = process_two_stage(args.video, bg, Path(args.outdir) if args.outdir else None, use_matany=not args.no_matany)
388
- print(out)