agent 1.8
Browse files- models/matanyone_loader.py +44 -59
models/matanyone_loader.py
CHANGED
|
@@ -821,71 +821,56 @@ def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
|
|
| 821 |
if not frames_bgr:
|
| 822 |
return
|
| 823 |
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
mask_t = torch.from_numpy(seed_1hw).to(self.device)
|
| 831 |
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 862 |
|
| 863 |
-
except RuntimeError as e:
|
| 864 |
-
if "out of memory" in str(e).lower():
|
| 865 |
-
# Clear CUDA cache and retry once
|
| 866 |
-
torch.cuda.empty_cache()
|
| 867 |
-
log.warning("CUDA out of memory, retrying after cache clear")
|
| 868 |
-
return self._flush_chunk(frames_bgr, seed_1hw, alpha_writer, fg_writer)
|
| 869 |
-
raise
|
| 870 |
-
|
| 871 |
except Exception as e:
|
| 872 |
error_msg = f"Error processing frame chunk: {str(e)}"
|
| 873 |
log.error(error_msg, exc_info=True)
|
| 874 |
raise MatAnyError(error_msg) from e
|
| 875 |
-
cv2.imwrite(path, frame)
|
| 876 |
-
frame_paths.append(path)
|
| 877 |
-
|
| 878 |
-
# Process video from frames
|
| 879 |
-
alphas = self._core.process_video(tmpdir,
|
| 880 |
-
mask_path=seed_1hw_path if seed_1hw is not None else None)
|
| 881 |
-
|
| 882 |
-
# Ensure alphas is a tensor
|
| 883 |
-
if not isinstance(alphas, torch.Tensor):
|
| 884 |
-
alphas = torch.from_numpy(alphas).to(self.device)
|
| 885 |
-
|
| 886 |
-
except Exception as e:
|
| 887 |
-
log.error(f"Error in _flush_chunk: {str(e)}")
|
| 888 |
-
raise
|
| 889 |
|
| 890 |
# Normalize to numpy list of HW float32 [0,1]
|
| 891 |
if isinstance(alphas, torch.Tensor):
|
|
|
|
| 821 |
if not frames_bgr:
|
| 822 |
return
|
| 823 |
|
| 824 |
+
# Prepare inputs
|
| 825 |
+
frames_chw = [_to_chw01(f) for f in frames_bgr] # list of CHW
|
| 826 |
+
frames_t = torch.from_numpy(np.stack(frames_chw)).to(self.device) # T,C,H,W
|
| 827 |
+
mask_t = None
|
| 828 |
+
if seed_1hw is not None:
|
| 829 |
+
mask_t = torch.from_numpy(seed_1hw).to(self.device)
|
|
|
|
| 830 |
|
| 831 |
+
try:
|
| 832 |
+
with torch.no_grad(), self._maybe_amp():
|
| 833 |
+
# Process frames in batch
|
| 834 |
+
if self._api_mode == "process_frame":
|
| 835 |
+
alphas = []
|
| 836 |
+
for i in range(len(frames_t)):
|
| 837 |
+
# Only use mask on first frame if provided
|
| 838 |
+
current_mask = mask_t if (i == 0 and mask_t is not None) else None
|
| 839 |
+
alpha = self._core.process_frame(frames_t[i].unsqueeze(0), current_mask)
|
| 840 |
+
alphas.append(alpha.squeeze(0))
|
| 841 |
+
alphas = torch.stack(alphas)
|
| 842 |
+
elif hasattr(self._core, '_process_tensor_video'):
|
| 843 |
+
# Try direct tensor processing (newer versions)
|
| 844 |
+
alphas = self._core._process_tensor_video(frames_t, mask_t)
|
| 845 |
+
else: # step mode
|
| 846 |
+
alphas = self._core.step(frames_t, mask_t)
|
| 847 |
+
|
| 848 |
+
# Convert to numpy and write frames
|
| 849 |
+
alphas_np = alphas.cpu().numpy()
|
| 850 |
+
for i, alpha in enumerate(alphas_np):
|
| 851 |
+
# Convert alpha to uint8 and write
|
| 852 |
+
alpha_uint8 = (alpha * 255).astype(np.uint8)
|
| 853 |
+
if len(alpha_uint8.shape) == 2: # If single channel, convert to 3 channels
|
| 854 |
+
alpha_uint8 = cv2.cvtColor(alpha_uint8, cv2.COLOR_GRAY2BGR)
|
| 855 |
+
alpha_writer.write(alpha_uint8)
|
| 856 |
+
|
| 857 |
+
# Write foreground (frame * alpha)
|
| 858 |
+
fg = frames_bgr[i] * (alpha[..., None] if alpha.ndim == 2 else alpha[0:1].permute(1, 2, 0))
|
| 859 |
+
fg = fg.astype(np.uint8)
|
| 860 |
+
fg_writer.write(fg)
|
| 861 |
+
|
| 862 |
+
except RuntimeError as e:
|
| 863 |
+
if "out of memory" in str(e).lower():
|
| 864 |
+
# Clear CUDA cache and retry once
|
| 865 |
+
torch.cuda.empty_cache()
|
| 866 |
+
log.warning("CUDA out of memory, retrying after cache clear")
|
| 867 |
+
return self._flush_chunk(frames_bgr, seed_1hw, alpha_writer, fg_writer)
|
| 868 |
+
raise
|
| 869 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 870 |
except Exception as e:
|
| 871 |
error_msg = f"Error processing frame chunk: {str(e)}"
|
| 872 |
log.error(error_msg, exc_info=True)
|
| 873 |
raise MatAnyError(error_msg) from e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 874 |
|
| 875 |
# Normalize to numpy list of HW float32 [0,1]
|
| 876 |
if isinstance(alphas, torch.Tensor):
|