MogensR commited on
Commit
37f2d16
·
1 Parent(s): 96de0be
Files changed (1) hide show
  1. models/matanyone_loader.py +212 -59
models/matanyone_loader.py CHANGED
@@ -44,9 +44,98 @@ def _emit_progress(cb, pct: float, msg: str):
44
  except TypeError:
45
  pass # ignore if cb is incompatible
46
 
47
- class MatAnyError(Exception):
 
48
  pass
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
51
  """Read mask image, convert to float32 [0,1], resize to target (H,W)."""
52
  if not Path(mask_path).exists():
@@ -60,6 +149,7 @@ def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
60
  maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
61
  return maskf
62
 
 
63
  def _to_chw01(img_bgr: np.ndarray) -> np.ndarray:
64
  """BGR [H,W,3] uint8 -> CHW float32 [0,1] RGB."""
65
  # OpenCV gives BGR; convert to RGB
@@ -68,13 +158,16 @@ def _to_chw01(img_bgr: np.ndarray) -> np.ndarray:
68
  chw = np.transpose(rgbf, (2, 0, 1)) # C,H,W
69
  return chw
70
 
 
71
  def _mask_to_1hw(mask_hw01: np.ndarray) -> np.ndarray:
72
  """HW float32 [0,1] -> 1HW float32 [0,1]."""
73
  return np.expand_dims(mask_hw01, axis=0)
74
 
 
75
  def _ensure_dir(p: Path) -> None:
76
  p.mkdir(parents=True, exist_ok=True)
77
 
 
78
  def _open_video_writers(out_dir: Path, fps: float, size: Tuple[int, int]) -> Tuple[cv2.VideoWriter, cv2.VideoWriter]:
79
  """Return (alpha_writer, fg_writer). size=(W,H)."""
80
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
@@ -88,10 +181,12 @@ def _open_video_writers(out_dir: Path, fps: float, size: Tuple[int, int]) -> Tup
88
  raise MatAnyError("Failed to open VideoWriter for alpha/fg outputs.")
89
  return alpha_writer, fg_writer
90
 
 
91
  def _validate_nonempty(file_path: Path) -> None:
92
  if not file_path.exists() or file_path.stat().st_size == 0:
93
  raise MatAnyError(f"Output file missing/empty: {file_path}")
94
 
 
95
  class MatAnyoneSession:
96
  """
97
  Unified, streaming wrapper over MatAnyone variants.
@@ -529,7 +624,7 @@ def process_stream(
529
  gpu_info = f" | GPU: {mem_alloc:.1f}/{mem_cached:.1f}MB"
530
 
531
  status = (f"Processing frame {idx+1}/{N} (ETA: {eta_str}, "
532
- f"{fps:.1f} FPS{gpu_info}")
533
  _emit_progress(progress_cb, min(0.99, current_progress), status)
534
  last_progress_update = time.time()
535
 
@@ -807,68 +902,126 @@ def process_stream(
807
  raise MatAnyError(error_msg) from e
808
 
809
  def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
810
- """Process a chunk of frames with MatAnyone.
 
 
 
 
 
811
 
812
- Args:
813
- frames_bgr: List of frames in BGR format
814
- seed_1hw: Seed mask in 1HW format or None
815
- alpha_writer: VideoWriter for alpha channel output
816
- fg_writer: VideoWriter for foreground output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
817
 
818
- Raises:
819
- MatAnyError: If there's an error processing the frames
820
- """
821
- if not frames_bgr:
822
- return
 
 
 
 
 
 
 
 
823
 
824
- # Prepare inputs
825
- frames_chw = [_to_chw01(f) for f in frames_bgr] # list of CHW
826
- frames_t = torch.from_numpy(np.stack(frames_chw)).to(self.device) # T,C,H,W
827
- mask_t = None
828
- if seed_1hw is not None:
829
- mask_t = torch.from_numpy(seed_1hw).to(self.device)
830
 
831
- try:
832
- with torch.no_grad(), self._maybe_amp():
833
- # Process frames in batch
834
- if self._api_mode == "process_frame":
835
- alphas = []
836
- for i in range(len(frames_t)):
837
- # Only use mask on first frame if provided
838
- current_mask = mask_t if (i == 0 and mask_t is not None) else None
839
- alpha = self._core.process_frame(frames_t[i].unsqueeze(0), current_mask)
840
- alphas.append(alpha.squeeze(0))
841
- alphas = torch.stack(alphas)
842
- elif hasattr(self._core, '_process_tensor_video'):
843
- # Try direct tensor processing (newer versions)
844
- alphas = self._core._process_tensor_video(frames_t, mask_t)
845
- else: # step mode
846
- alphas = self._core.step(frames_t, mask_t)
847
-
848
- # Convert to numpy and write frames
849
- alphas_np = alphas.cpu().numpy()
850
- for i, alpha in enumerate(alphas_np):
851
- # Convert alpha to uint8 and write
852
- alpha_uint8 = (alpha * 255).astype(np.uint8)
853
- if len(alpha_uint8.shape) == 2: # If single channel, convert to 3 channels
854
- alpha_uint8 = cv2.cvtColor(alpha_uint8, cv2.COLOR_GRAY2BGR)
855
- alpha_writer.write(alpha_uint8)
856
-
857
- # Write foreground (frame * alpha)
858
- fg = frames_bgr[i] * (alpha[..., None] if alpha.ndim == 2 else alpha[0:1].permute(1, 2, 0))
859
- fg = fg.astype(np.uint8)
860
- fg_writer.write(fg)
861
 
862
- except RuntimeError as e:
863
- if "out of memory" in str(e).lower():
864
- # Clear CUDA cache and retry once
865
- torch.cuda.empty_cache()
866
- log.warning("CUDA out of memory, retrying after cache clear")
867
- return self._flush_chunk(frames_bgr, seed_1hw, alpha_writer, fg_writer)
868
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869
 
870
  except Exception as e:
871
- error_msg = f"Error processing frame chunk: {str(e)}"
872
- log.error(error_msg, exc_info=True)
873
- raise MatAnyError(error_msg) from e
 
 
 
 
 
 
 
 
874
 
 
 
 
44
  except TypeError:
45
  pass # ignore if cb is incompatible
46
 
47
+ class MatAnyError(RuntimeError):
48
+ """Custom exception for MatAnyone processing errors."""
49
  pass
50
 
51
+
52
+ def _to_device_batch(frames_bgr_np, device, dtype=torch.float16):
53
+ """
54
+ Convert a list/array of BGR uint8 frames [N,H,W,3] to a normalized
55
+ CHW tensor on device using pinned memory + non_blocking copies.
56
+ """
57
+ if isinstance(frames_bgr_np, list):
58
+ frames_bgr_np = np.stack(frames_bgr_np, axis=0) # [N,H,W,3]
59
+ # BGR -> RGB
60
+ frames_rgb = frames_bgr_np[..., ::-1].copy(order="C")
61
+ # to torch
62
+ pin = torch.from_numpy(frames_rgb).pin_memory() # uint8 [N,H,W,3]
63
+ # NCHW and normalize
64
+ t = pin.permute(0, 3, 1, 2).contiguous().to(device, non_blocking=True)
65
+ t = t.to(dtype=dtype) / 255.0
66
+ return t # [N,3,H,W]
67
+
68
+
69
+ def _select_matany_mode(core):
70
+ """
71
+ Pick the best-available MatAnyone API at runtime.
72
+ Priority: process_frame > _process_tensor_video > step
73
+ """
74
+ if hasattr(core, "process_frame"):
75
+ return "process_frame"
76
+ if hasattr(core, "_process_tensor_video"):
77
+ return "_process_tensor_video"
78
+ if hasattr(core, "step"):
79
+ return "step"
80
+ raise MatAnyError("No supported MatAnyone API on core (process_frame/_process_tensor_video/step).")
81
+
82
+
83
+ def _matany_run(core, mode, frames_04chw, seed_1hw=None):
84
+ """
85
+ Dispatch into the selected API. All tensors are on device.
86
+ Returns (alpha_1nhw, fg_n3hw) where alpha is [N,1,H,W], fg [N,3,H,W].
87
+ """
88
+ with torch.no_grad():
89
+ if mode == "process_frame":
90
+ alphas, fgs = [], []
91
+ # process_frame usually wants per-frame tensors in [1,3,H,W]
92
+ for i in range(frames_04chw.shape[0]):
93
+ f = frames_04chw[i:i+1] # [1,3,H,W]
94
+ if seed_1hw is not None and seed_1hw.ndim == 3:
95
+ a, fg = core.process_frame(f, seed_1hw.unsqueeze(0))
96
+ else:
97
+ a, fg = core.process_frame(f)
98
+ alphas.append(a) # [1,1,H,W]
99
+ fgs.append(fg) # [1,3,H,W]
100
+ alpha = torch.cat(alphas, dim=0)
101
+ fg = torch.cat(fgs, dim=0)
102
+ return alpha, fg
103
+
104
+ elif mode == "_process_tensor_video":
105
+ return core._process_tensor_video(frames_04chw.float(), seed_1hw)
106
+
107
+ elif mode == "step":
108
+ alphas, fgs = [], []
109
+ for i in range(frames_04chw.shape[0]):
110
+ f = frames_04chw[i:i+1]
111
+ if i == 0 and seed_1hw is not None:
112
+ a, fg = core.step(f, seed_1hw)
113
+ else:
114
+ a, fg = core.step(f)
115
+ alphas.append(a)
116
+ fgs.append(fg)
117
+ alpha = torch.cat(alphas, dim=0)
118
+ fg = torch.cat(fgs, dim=0)
119
+ return alpha, fg
120
+
121
+ raise MatAnyError(f"Unsupported mode: {mode}")
122
+
123
+
124
+ def _safe_empty_cache():
125
+ if torch.cuda.is_available():
126
+ torch.cuda.synchronize()
127
+ torch.cuda.empty_cache()
128
+
129
+
130
+ def _cuda_snapshot():
131
+ if not torch.cuda.is_available():
132
+ return "CUDA: N/A"
133
+ i = torch.cuda.current_device()
134
+ return (f"device={i}, name={torch.cuda.get_device_name(i)}, "
135
+ f"alloc={torch.cuda.memory_allocated(i)/1e9:.2f}GB, "
136
+ f"reserved={torch.cuda.memory_reserved(i)/1e9:.2f}GB")
137
+
138
+
139
  def _read_mask_hw(mask_path: Path, target_hw: Tuple[int, int]) -> np.ndarray:
140
  """Read mask image, convert to float32 [0,1], resize to target (H,W)."""
141
  if not Path(mask_path).exists():
 
149
  maskf = (mask.astype(np.float32) / 255.0).clip(0.0, 1.0)
150
  return maskf
151
 
152
+
153
  def _to_chw01(img_bgr: np.ndarray) -> np.ndarray:
154
  """BGR [H,W,3] uint8 -> CHW float32 [0,1] RGB."""
155
  # OpenCV gives BGR; convert to RGB
 
158
  chw = np.transpose(rgbf, (2, 0, 1)) # C,H,W
159
  return chw
160
 
161
+
162
  def _mask_to_1hw(mask_hw01: np.ndarray) -> np.ndarray:
163
  """HW float32 [0,1] -> 1HW float32 [0,1]."""
164
  return np.expand_dims(mask_hw01, axis=0)
165
 
166
+
167
  def _ensure_dir(p: Path) -> None:
168
  p.mkdir(parents=True, exist_ok=True)
169
 
170
+
171
  def _open_video_writers(out_dir: Path, fps: float, size: Tuple[int, int]) -> Tuple[cv2.VideoWriter, cv2.VideoWriter]:
172
  """Return (alpha_writer, fg_writer). size=(W,H)."""
173
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
 
181
  raise MatAnyError("Failed to open VideoWriter for alpha/fg outputs.")
182
  return alpha_writer, fg_writer
183
 
184
+
185
  def _validate_nonempty(file_path: Path) -> None:
186
  if not file_path.exists() or file_path.stat().st_size == 0:
187
  raise MatAnyError(f"Output file missing/empty: {file_path}")
188
 
189
+
190
  class MatAnyoneSession:
191
  """
192
  Unified, streaming wrapper over MatAnyone variants.
 
624
  gpu_info = f" | GPU: {mem_alloc:.1f}/{mem_cached:.1f}MB"
625
 
626
  status = (f"Processing frame {idx+1}/{N} (ETA: {eta_str}, "
627
+ f"{fps:.1f} FPS{gpu_info}")
628
  _emit_progress(progress_cb, min(0.99, current_progress), status)
629
  last_progress_update = time.time()
630
 
 
902
  raise MatAnyError(error_msg) from e
903
 
904
  def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
905
+ """
906
+ Take an in-memory batch of frames (list of np.uint8 BGR), run MatAnyone on GPU,
907
+ then write alpha/fg frames via the provided writers. Clears GPU memory on exit.
908
+ """
909
+ # Initialize variables for cleanup
910
+ alpha_n1hw, fg_n3hw, frames_04chw = None, None, None
911
 
912
+ try:
913
+ device = self.device
914
+ use_fp16 = (device.type == "cuda") and getattr(self, 'use_fp16', True)
915
+ mode = _select_matany_mode(self._core)
916
+
917
+ # Move input frames to device in a batched, pinned way
918
+ frames_04chw = _to_device_batch(frames_bgr, device,
919
+ dtype=torch.float16 if use_fp16 else torch.float32)
920
+
921
+ # Move seed mask to device if provided
922
+ seed_tensor = None
923
+ if seed_1hw is not None:
924
+ seed_tensor = torch.from_numpy(seed_1hw).to(device)
925
+
926
+ # Process with CUDA stream if available
927
+ if device.type == "cuda":
928
+ stream = torch.cuda.Stream()
929
+ with torch.cuda.stream(stream):
930
+ with torch.autocast(device_type="cuda", enabled=use_fp16):
931
+ alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw, seed_tensor)
932
+ torch.cuda.synchronize()
933
+ else:
934
+ alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw, seed_tensor)
935
+
936
+ # Write out results (convert back to CPU uint8)
937
+ alpha_cpu = (alpha_n1hw.clamp(0, 1) * 255.0).byte().squeeze(1).contiguous().cpu().numpy() # [N,H,W]
938
 
939
+ for i in range(alpha_cpu.shape[0]):
940
+ # Write alpha mask
941
+ alpha_uint8 = alpha_cpu[i]
942
+ if len(alpha_uint8.shape) == 2: # Ensure 3 channels for writer
943
+ alpha_uint8 = cv2.cvtColor(alpha_uint8, cv2.COLOR_GRAY2BGR)
944
+ alpha_writer.write(alpha_uint8)
945
+
946
+ # Write foreground (frame * alpha)
947
+ alpha_expanded = alpha_cpu[i] / 255.0
948
+ if alpha_expanded.ndim == 2:
949
+ alpha_expanded = alpha_expanded[..., None] # [H,W,1]
950
+ fg = (frames_bgr[i] * alpha_expanded).astype(np.uint8)
951
+ fg_writer.write(fg)
952
 
953
+ # Keep seed for temporal methods that need it
954
+ if hasattr(self._core, "last_mask"):
955
+ self._last_alpha_1hw = self._core.last_mask
 
 
 
956
 
957
+ except torch.cuda.OutOfMemoryError as e:
958
+ # Downshift strategy: smaller chunk or resolution; propagate with context
959
+ snap = _cuda_snapshot()
960
+ _safe_empty_cache()
961
+ raise MatAnyError(f"CUDA OOM in _flush_chunk (before retry). Snapshot: {snap}") from e
962
+
963
+ except Exception as e:
964
+ # Convert unexpected exceptions to MatAnyError with context
965
+ snap = _cuda_snapshot()
966
+ raise MatAnyError(f"MatAnyone failure in _flush_chunk: {e} | {snap}") from e
967
+
968
+ finally:
969
+ # Hard cleanup to avoid lingering allocations between chunks
970
+ for var in [alpha_n1hw, fg_n3hw, frames_04chw]:
971
+ try:
972
+ del var
973
+ except Exception:
974
+ pass
975
+ _safe_empty_cache()
 
 
 
 
 
 
 
 
 
 
 
976
 
977
+ def process_stream(self, frames_iterable, seed_1hw, alpha_writer, fg_writer, chunk_size=32):
978
+ """
979
+ Public entry that buffers frames from an iterator and processes them in chunks.
980
+ Ensures cleanup and graceful degradation on OOM.
981
+ """
982
+ frames_buf = []
983
+ last_error = None
984
+
985
+ try:
986
+ for f in frames_iterable:
987
+ frames_buf.append(f)
988
+ if len(frames_buf) >= chunk_size:
989
+ try:
990
+ self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
991
+ frames_buf.clear()
992
+ except MatAnyError as e:
993
+ # Attempt one downshift: halve the chunk and retry once
994
+ if chunk_size > 4:
995
+ half = max(4, chunk_size // 2)
996
+ # Split and try smaller batches
997
+ for i in range(0, len(frames_buf), half):
998
+ sub = frames_buf[i:i+half]
999
+ self._flush_chunk(sub, seed_1hw, alpha_writer, fg_writer)
1000
+ frames_buf.clear()
1001
+ else:
1002
+ last_error = e
1003
+ break
1004
+
1005
+ # Flush remainder
1006
+ if frames_buf:
1007
+ self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
1008
+ frames_buf.clear()
1009
+
1010
+ except torch.cuda.OutOfMemoryError as e:
1011
+ last_error = MatAnyError(f"CUDA OOM in process_stream: {_cuda_snapshot()}") from e
1012
 
1013
  except Exception as e:
1014
+ last_error = MatAnyError(f"Unexpected error in process_stream: {e}") from e
1015
+
1016
+ finally:
1017
+ frames_buf.clear()
1018
+ _safe_empty_cache()
1019
+ # Optional: if core has a reset method, call it
1020
+ if hasattr(getattr(self, '_core', None), 'reset'):
1021
+ try:
1022
+ self._core.reset()
1023
+ except Exception:
1024
+ pass
1025
 
1026
+ if last_error:
1027
+ raise last_error