MogensR commited on
Commit
6259fe6
·
1 Parent(s): b502144
Files changed (1) hide show
  1. models/matanyone_loader.py +44 -59
models/matanyone_loader.py CHANGED
@@ -821,71 +821,56 @@ def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
821
  if not frames_bgr:
822
  return
823
 
824
- try:
825
- # Prepare inputs
826
- frames_chw = [_to_chw01(f) for f in frames_bgr] # list of CHW
827
- frames_t = torch.from_numpy(np.stack(frames_chw)).to(self.device) # T,C,H,W
828
- mask_t = None
829
- if seed_1hw is not None:
830
- mask_t = torch.from_numpy(seed_1hw).to(self.device)
831
 
832
- try:
833
- with torch.no_grad(), self._maybe_amp():
834
- # Process frames in batch
835
- if self._api_mode == "process_frame":
836
- alphas = []
837
- for i in range(len(frames_t)):
838
- # Only use mask on first frame if provided
839
- current_mask = mask_t if (i == 0 and mask_t is not None) else None
840
- alpha = self._core.process_frame(frames_t[i].unsqueeze(0), current_mask)
841
- alphas.append(alpha.squeeze(0))
842
- alphas = torch.stack(alphas)
843
- elif hasattr(self._core, '_process_tensor_video'):
844
- # Try direct tensor processing (newer versions)
845
- alphas = self._core._process_tensor_video(frames_t, mask_t)
846
- else: # step mode
847
- alphas = self._core.step(frames_t, mask_t)
848
-
849
- # Convert to numpy and write frames
850
- alphas_np = alphas.cpu().numpy()
851
- for i, alpha in enumerate(alphas_np):
852
- # Convert alpha to uint8 and write
853
- alpha_uint8 = (alpha * 255).astype(np.uint8)
854
- if len(alpha_uint8.shape) == 2: # If single channel, convert to 3 channels
855
- alpha_uint8 = cv2.cvtColor(alpha_uint8, cv2.COLOR_GRAY2BGR)
856
- alpha_writer.write(alpha_uint8)
857
-
858
- # Write foreground (frame * alpha)
859
- fg = frames_bgr[i] * (alpha[..., None] if alpha.ndim == 2 else alpha[0:1].permute(1, 2, 0))
860
- fg = fg.astype(np.uint8)
861
- fg_writer.write(fg)
 
 
 
 
 
 
 
 
862
 
863
- except RuntimeError as e:
864
- if "out of memory" in str(e).lower():
865
- # Clear CUDA cache and retry once
866
- torch.cuda.empty_cache()
867
- log.warning("CUDA out of memory, retrying after cache clear")
868
- return self._flush_chunk(frames_bgr, seed_1hw, alpha_writer, fg_writer)
869
- raise
870
-
871
  except Exception as e:
872
  error_msg = f"Error processing frame chunk: {str(e)}"
873
  log.error(error_msg, exc_info=True)
874
  raise MatAnyError(error_msg) from e
875
- cv2.imwrite(path, frame)
876
- frame_paths.append(path)
877
-
878
- # Process video from frames
879
- alphas = self._core.process_video(tmpdir,
880
- mask_path=seed_1hw_path if seed_1hw is not None else None)
881
-
882
- # Ensure alphas is a tensor
883
- if not isinstance(alphas, torch.Tensor):
884
- alphas = torch.from_numpy(alphas).to(self.device)
885
-
886
- except Exception as e:
887
- log.error(f"Error in _flush_chunk: {str(e)}")
888
- raise
889
 
890
  # Normalize to numpy list of HW float32 [0,1]
891
  if isinstance(alphas, torch.Tensor):
 
821
  if not frames_bgr:
822
  return
823
 
824
+ # Prepare inputs
825
+ frames_chw = [_to_chw01(f) for f in frames_bgr] # list of CHW
826
+ frames_t = torch.from_numpy(np.stack(frames_chw)).to(self.device) # T,C,H,W
827
+ mask_t = None
828
+ if seed_1hw is not None:
829
+ mask_t = torch.from_numpy(seed_1hw).to(self.device)
 
830
 
831
+ try:
832
+ with torch.no_grad(), self._maybe_amp():
833
+ # Process frames in batch
834
+ if self._api_mode == "process_frame":
835
+ alphas = []
836
+ for i in range(len(frames_t)):
837
+ # Only use mask on first frame if provided
838
+ current_mask = mask_t if (i == 0 and mask_t is not None) else None
839
+ alpha = self._core.process_frame(frames_t[i].unsqueeze(0), current_mask)
840
+ alphas.append(alpha.squeeze(0))
841
+ alphas = torch.stack(alphas)
842
+ elif hasattr(self._core, '_process_tensor_video'):
843
+ # Try direct tensor processing (newer versions)
844
+ alphas = self._core._process_tensor_video(frames_t, mask_t)
845
+ else: # step mode
846
+ alphas = self._core.step(frames_t, mask_t)
847
+
848
+ # Convert to numpy and write frames
849
+ alphas_np = alphas.cpu().numpy()
850
+ for i, alpha in enumerate(alphas_np):
851
+ # Convert alpha to uint8 and write
852
+ alpha_uint8 = (alpha * 255).astype(np.uint8)
853
+ if len(alpha_uint8.shape) == 2: # If single channel, convert to 3 channels
854
+ alpha_uint8 = cv2.cvtColor(alpha_uint8, cv2.COLOR_GRAY2BGR)
855
+ alpha_writer.write(alpha_uint8)
856
+
857
+ # Write foreground (frame * alpha)
858
+ fg = frames_bgr[i] * (alpha[..., None] if alpha.ndim == 2 else alpha[0:1].permute(1, 2, 0))
859
+ fg = fg.astype(np.uint8)
860
+ fg_writer.write(fg)
861
+
862
+ except RuntimeError as e:
863
+ if "out of memory" in str(e).lower():
864
+ # Clear CUDA cache and retry once
865
+ torch.cuda.empty_cache()
866
+ log.warning("CUDA out of memory, retrying after cache clear")
867
+ return self._flush_chunk(frames_bgr, seed_1hw, alpha_writer, fg_writer)
868
+ raise
869
 
 
 
 
 
 
 
 
 
870
  except Exception as e:
871
  error_msg = f"Error processing frame chunk: {str(e)}"
872
  log.error(error_msg, exc_info=True)
873
  raise MatAnyError(error_msg) from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
 
875
  # Normalize to numpy list of HW float32 [0,1]
876
  if isinstance(alphas, torch.Tensor):