MogensR commited on
Commit
b502144
·
1 Parent(s): a9f51ee
Files changed (1) hide show
  1. models/matanyone_loader.py +60 -16
models/matanyone_loader.py CHANGED
@@ -807,27 +807,71 @@ def process_stream(
807
  raise MatAnyError(error_msg) from e
808
 
809
  def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
810
- """Process a chunk of frames with MatAnyone."""
 
 
 
 
 
 
 
 
 
 
811
  if not frames_bgr:
812
  return
813
 
814
- # Prepare inputs
815
- frames_chw = [_to_chw01(f) for f in frames_bgr] # list of CHW
816
- frames_t = torch.from_numpy(np.stack(frames_chw)).to(self.device) # T,C,H,W
817
- mask_t = torch.from_numpy(seed_1hw).to(self.device) if seed_1hw is not None else None
 
 
 
818
 
819
- with torch.no_grad(), self._maybe_amp():
820
  try:
821
- # Try direct tensor processing first (newer versions)
822
- if hasattr(self._core, '_process_tensor_video'):
823
- alphas = self._core._process_tensor_video(frames_t, mask_t)
824
- else:
825
- # Fall back to file-based processing if tensor API not available
826
- with tempfile.TemporaryDirectory() as tmpdir:
827
- # Save frames to temp directory
828
- frame_paths = []
829
- for i, frame in enumerate(frames_bgr):
830
- path = os.path.join(tmpdir, f'frame_{i:06d}.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831
  cv2.imwrite(path, frame)
832
  frame_paths.append(path)
833
 
 
807
  raise MatAnyError(error_msg) from e
808
 
809
  def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
810
+ """Process a chunk of frames with MatAnyone.
811
+
812
+ Args:
813
+ frames_bgr: List of frames in BGR format
814
+ seed_1hw: Seed mask in 1HW format or None
815
+ alpha_writer: VideoWriter for alpha channel output
816
+ fg_writer: VideoWriter for foreground output
817
+
818
+ Raises:
819
+ MatAnyError: If there's an error processing the frames
820
+ """
821
  if not frames_bgr:
822
  return
823
 
824
+ try:
825
+ # Prepare inputs
826
+ frames_chw = [_to_chw01(f) for f in frames_bgr] # list of CHW
827
+ frames_t = torch.from_numpy(np.stack(frames_chw)).to(self.device) # T,C,H,W
828
+ mask_t = None
829
+ if seed_1hw is not None:
830
+ mask_t = torch.from_numpy(seed_1hw).to(self.device)
831
 
 
832
  try:
833
+ with torch.no_grad(), self._maybe_amp():
834
+ # Process frames in batch
835
+ if self._api_mode == "process_frame":
836
+ alphas = []
837
+ for i in range(len(frames_t)):
838
+ # Only use mask on first frame if provided
839
+ current_mask = mask_t if (i == 0 and mask_t is not None) else None
840
+ alpha = self._core.process_frame(frames_t[i].unsqueeze(0), current_mask)
841
+ alphas.append(alpha.squeeze(0))
842
+ alphas = torch.stack(alphas)
843
+ elif hasattr(self._core, '_process_tensor_video'):
844
+ # Try direct tensor processing (newer versions)
845
+ alphas = self._core._process_tensor_video(frames_t, mask_t)
846
+ else: # step mode
847
+ alphas = self._core.step(frames_t, mask_t)
848
+
849
+ # Convert to numpy and write frames
850
+ alphas_np = alphas.cpu().numpy()
851
+ for i, alpha in enumerate(alphas_np):
852
+ # Convert alpha to uint8 and write
853
+ alpha_uint8 = (alpha * 255).astype(np.uint8)
854
+ if len(alpha_uint8.shape) == 2: # If single channel, convert to 3 channels
855
+ alpha_uint8 = cv2.cvtColor(alpha_uint8, cv2.COLOR_GRAY2BGR)
856
+ alpha_writer.write(alpha_uint8)
857
+
858
+ # Write foreground (frame * alpha)
859
+ fg = frames_bgr[i] * (alpha[..., None] if alpha.ndim == 2 else alpha[0:1].permute(1, 2, 0))
860
+ fg = fg.astype(np.uint8)
861
+ fg_writer.write(fg)
862
+
863
+ except RuntimeError as e:
864
+ if "out of memory" in str(e).lower():
865
+ # Clear CUDA cache and retry once
866
+ torch.cuda.empty_cache()
867
+ log.warning("CUDA out of memory, retrying after cache clear")
868
+ return self._flush_chunk(frames_bgr, seed_1hw, alpha_writer, fg_writer)
869
+ raise
870
+
871
+ except Exception as e:
872
+ error_msg = f"Error processing frame chunk: {str(e)}"
873
+ log.error(error_msg, exc_info=True)
874
+ raise MatAnyError(error_msg) from e
875
  cv2.imwrite(path, frame)
876
  frame_paths.append(path)
877