agent 2.1
Browse files- models/matanyone_loader.py +130 -66
models/matanyone_loader.py
CHANGED
|
@@ -49,6 +49,95 @@ class MatAnyError(RuntimeError):
|
|
| 49 |
pass
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def _to_device_batch(frames_bgr_np, device, dtype=torch.float16):
|
| 53 |
"""
|
| 54 |
Convert a list/array of BGR uint8 frames [N,H,W,3] to a normalized
|
|
@@ -903,85 +992,62 @@ def process_stream(
|
|
| 903 |
|
| 904 |
def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
|
| 905 |
"""
|
| 906 |
-
|
| 907 |
-
|
| 908 |
"""
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
try:
|
| 913 |
-
device = self.device
|
| 914 |
-
use_fp16 = (device.type == "cuda") and getattr(self, 'use_fp16', True)
|
| 915 |
-
mode = _select_matany_mode(self._core)
|
| 916 |
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
if seed_1hw is not None:
|
| 924 |
-
seed_tensor = torch.from_numpy(seed_1hw).to(device)
|
| 925 |
|
| 926 |
-
# Process with CUDA stream if available
|
| 927 |
if device.type == "cuda":
|
| 928 |
stream = torch.cuda.Stream()
|
| 929 |
with torch.cuda.stream(stream):
|
| 930 |
with torch.autocast(device_type="cuda", enabled=use_fp16):
|
| 931 |
-
alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw,
|
| 932 |
-
|
| 933 |
else:
|
| 934 |
-
alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw,
|
|
|
|
|
|
|
| 935 |
|
| 936 |
-
# Write out results (convert back to CPU uint8)
|
| 937 |
-
alpha_cpu = (alpha_n1hw.clamp(0, 1) * 255.0).byte().squeeze(1).contiguous().cpu().numpy() # [N,H,W]
|
| 938 |
-
|
| 939 |
for i in range(alpha_cpu.shape[0]):
|
| 940 |
-
#
|
| 941 |
-
|
| 942 |
-
|
| 943 |
-
alpha_uint8 = cv2.cvtColor(alpha_uint8, cv2.COLOR_GRAY2BGR)
|
| 944 |
-
alpha_writer.write(alpha_uint8)
|
| 945 |
-
|
| 946 |
-
# Write foreground (frame * alpha)
|
| 947 |
-
alpha_expanded = alpha_cpu[i] / 255.0
|
| 948 |
-
if alpha_expanded.ndim == 2:
|
| 949 |
-
alpha_expanded = alpha_expanded[..., None] # [H,W,1]
|
| 950 |
-
fg = (frames_bgr[i] * alpha_expanded).astype(np.uint8)
|
| 951 |
-
fg_writer.write(fg)
|
| 952 |
-
|
| 953 |
-
# Keep seed for temporal methods that need it
|
| 954 |
if hasattr(self._core, "last_mask"):
|
| 955 |
self._last_alpha_1hw = self._core.last_mask
|
| 956 |
|
| 957 |
except torch.cuda.OutOfMemoryError as e:
|
| 958 |
-
# Downshift strategy: smaller chunk or resolution; propagate with context
|
| 959 |
snap = _cuda_snapshot()
|
| 960 |
_safe_empty_cache()
|
| 961 |
-
raise
|
|
|
|
| 962 |
|
| 963 |
except Exception as e:
|
| 964 |
-
# Convert unexpected exceptions to MatAnyError with context
|
| 965 |
snap = _cuda_snapshot()
|
| 966 |
raise MatAnyError(f"MatAnyone failure in _flush_chunk: {e} | {snap}") from e
|
| 967 |
|
| 968 |
finally:
|
| 969 |
-
#
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
|
| 974 |
-
pass
|
| 975 |
_safe_empty_cache()
|
| 976 |
-
|
| 977 |
def process_stream(self, frames_iterable, seed_1hw, alpha_writer, fg_writer, chunk_size=32):
|
| 978 |
"""
|
| 979 |
-
|
| 980 |
-
|
| 981 |
"""
|
| 982 |
frames_buf = []
|
| 983 |
-
last_error = None
|
| 984 |
-
|
| 985 |
try:
|
| 986 |
for f in frames_iterable:
|
| 987 |
frames_buf.append(f)
|
|
@@ -989,39 +1055,37 @@ def process_stream(self, frames_iterable, seed_1hw, alpha_writer, fg_writer, chu
|
|
| 989 |
try:
|
| 990 |
self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
|
| 991 |
frames_buf.clear()
|
| 992 |
-
except
|
| 993 |
-
#
|
|
|
|
|
|
|
|
|
|
| 994 |
if chunk_size > 4:
|
| 995 |
half = max(4, chunk_size // 2)
|
| 996 |
-
# Split and try smaller batches
|
| 997 |
for i in range(0, len(frames_buf), half):
|
| 998 |
sub = frames_buf[i:i+half]
|
| 999 |
self._flush_chunk(sub, seed_1hw, alpha_writer, fg_writer)
|
| 1000 |
frames_buf.clear()
|
| 1001 |
else:
|
| 1002 |
-
|
| 1003 |
-
break
|
| 1004 |
|
| 1005 |
-
# Flush remainder
|
| 1006 |
if frames_buf:
|
| 1007 |
self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
|
| 1008 |
frames_buf.clear()
|
| 1009 |
|
| 1010 |
except torch.cuda.OutOfMemoryError as e:
|
| 1011 |
-
|
| 1012 |
-
|
|
|
|
|
|
|
| 1013 |
except Exception as e:
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
finally:
|
| 1017 |
frames_buf.clear()
|
| 1018 |
_safe_empty_cache()
|
| 1019 |
-
|
| 1020 |
-
if hasattr(getattr(self, '_core', None), 'reset'):
|
| 1021 |
try:
|
| 1022 |
self._core.reset()
|
| 1023 |
except Exception:
|
| 1024 |
pass
|
| 1025 |
-
|
| 1026 |
-
if last_error:
|
| 1027 |
-
raise last_error
|
|
|
|
| 49 |
pass
|
| 50 |
|
| 51 |
|
| 52 |
+
def _to_device_batch(frames_bgr_np, device, dtype=torch.float16):
|
| 53 |
+
"""
|
| 54 |
+
frames_bgr_np: list or np.ndarray of shape [N,H,W,3], dtype=uint8, BGR
|
| 55 |
+
Returns torch tensor [N,3,H,W] on device, normalized to 0..1
|
| 56 |
+
"""
|
| 57 |
+
if isinstance(frames_bgr_np, list):
|
| 58 |
+
frames_bgr_np = np.stack(frames_bgr_np, axis=0)
|
| 59 |
+
frames_rgb = frames_bgr_np[..., ::-1].copy(order="C") # BGR->RGB
|
| 60 |
+
pin = torch.from_numpy(frames_rgb).pin_memory() # [N,H,W,3]
|
| 61 |
+
t = pin.permute(0, 3, 1, 2).contiguous().to(device, non_blocking=True)
|
| 62 |
+
t = t.to(dtype=dtype) / 255.0
|
| 63 |
+
return t # [N,3,H,W]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _select_matany_mode(core):
|
| 67 |
+
"""Pick best available API."""
|
| 68 |
+
if hasattr(core, "process_frame"):
|
| 69 |
+
return "process_frame"
|
| 70 |
+
if hasattr(core, "_process_tensor_video"):
|
| 71 |
+
return "_process_tensor_video"
|
| 72 |
+
if hasattr(core, "step"):
|
| 73 |
+
return "step"
|
| 74 |
+
raise MatAnyError("MatAnyone core has no supported API (process_frame/_process_tensor_video/step).")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _matany_run(core, mode, frames_04chw, seed_1hw=None, use_fp16=False):
|
| 78 |
+
"""
|
| 79 |
+
Returns (alpha [N,1,H,W], fg [N,3,H,W]) on current device.
|
| 80 |
+
"""
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
if mode == "process_frame":
|
| 83 |
+
alphas, fgs = [], []
|
| 84 |
+
for i in range(frames_04chw.shape[0]):
|
| 85 |
+
f = frames_04chw[i:i+1] # [1,3,H,W]
|
| 86 |
+
if seed_1hw is not None and seed_1hw.ndim == 3:
|
| 87 |
+
a, fg = core.process_frame(f, seed_1hw.unsqueeze(0))
|
| 88 |
+
else:
|
| 89 |
+
a, fg = core.process_frame(f)
|
| 90 |
+
alphas.append(a) # [1,1,H,W]
|
| 91 |
+
fgs.append(fg) # [1,3,H,W]
|
| 92 |
+
alpha = torch.cat(alphas, dim=0)
|
| 93 |
+
fg = torch.cat(fgs, dim=0)
|
| 94 |
+
return alpha, fg
|
| 95 |
+
|
| 96 |
+
elif mode == "_process_tensor_video":
|
| 97 |
+
# Many repos expect float32 for this path
|
| 98 |
+
return core._process_tensor_video(frames_04chw.float(), seed_1hw)
|
| 99 |
+
|
| 100 |
+
elif mode == "step":
|
| 101 |
+
alphas, fgs = [], []
|
| 102 |
+
for i in range(frames_04chw.shape[0]):
|
| 103 |
+
f = frames_04chw[i:i+1]
|
| 104 |
+
if i == 0 and seed_1hw is not None:
|
| 105 |
+
a, fg = core.step(f, seed_1hw)
|
| 106 |
+
else:
|
| 107 |
+
a, fg = core.step(f)
|
| 108 |
+
alphas.append(a)
|
| 109 |
+
fgs.append(fg)
|
| 110 |
+
alpha = torch.cat(alphas, dim=0)
|
| 111 |
+
fg = torch.cat(fgs, dim=0)
|
| 112 |
+
return alpha, fg
|
| 113 |
+
|
| 114 |
+
raise MatAnyError(f"Unsupported MatAnyone mode: {mode}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _cuda_snapshot():
|
| 118 |
+
if not torch.cuda.is_available():
|
| 119 |
+
return "CUDA: N/A"
|
| 120 |
+
i = torch.cuda.current_device()
|
| 121 |
+
return (f"device={i}, name={torch.cuda.get_device_name(i)}, "
|
| 122 |
+
f"alloc={torch.cuda.memory_allocated(i)/1e9:.2f}GB, "
|
| 123 |
+
f"reserved={torch.cuda.memory_reserved(i)/1e9:.2f}GB")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _safe_empty_cache():
|
| 127 |
+
if torch.cuda.is_available():
|
| 128 |
+
try:
|
| 129 |
+
torch.cuda.synchronize()
|
| 130 |
+
except Exception:
|
| 131 |
+
pass
|
| 132 |
+
torch.cuda.empty_cache()
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _to_uint8_cpu(alpha_n1hw, fg_n3hw):
|
| 136 |
+
alpha_cpu = (alpha_n1hw.clamp(0, 1) * 255.0).byte().squeeze(1).contiguous().cpu().numpy() # [N,H,W]
|
| 137 |
+
fg_cpu = (fg_n3hw.clamp(0, 1) * 255.0).byte().permute(0, 2, 3, 1).contiguous().cpu().numpy() # [N,H,W,3] RGB
|
| 138 |
+
return alpha_cpu, fg_cpu
|
| 139 |
+
|
| 140 |
+
|
| 141 |
def _to_device_batch(frames_bgr_np, device, dtype=torch.float16):
|
| 142 |
"""
|
| 143 |
Convert a list/array of BGR uint8 frames [N,H,W,3] to a normalized
|
|
|
|
| 992 |
|
| 993 |
def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
|
| 994 |
"""
|
| 995 |
+
Process an in-memory batch (list of uint8 BGR frames), write results via writers.
|
| 996 |
+
Strong CUDA guards + cleanup.
|
| 997 |
"""
|
| 998 |
+
device = self.device
|
| 999 |
+
use_fp16 = (device.type == "cuda") and getattr(self, "use_fp16", True)
|
| 1000 |
+
mode = _select_matany_mode(self._core)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1001 |
|
| 1002 |
+
frames_04chw = None
|
| 1003 |
+
alpha_n1hw = None
|
| 1004 |
+
fg_n3hw = None
|
| 1005 |
|
| 1006 |
+
try:
|
| 1007 |
+
frames_04chw = _to_device_batch(frames_bgr, device, dtype=torch.float16 if use_fp16 else torch.float32)
|
|
|
|
|
|
|
| 1008 |
|
|
|
|
| 1009 |
if device.type == "cuda":
|
| 1010 |
stream = torch.cuda.Stream()
|
| 1011 |
with torch.cuda.stream(stream):
|
| 1012 |
with torch.autocast(device_type="cuda", enabled=use_fp16):
|
| 1013 |
+
alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw, seed_1hw, use_fp16)
|
| 1014 |
+
stream.synchronize()
|
| 1015 |
else:
|
| 1016 |
+
alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw, seed_1hw, use_fp16)
|
| 1017 |
+
|
| 1018 |
+
alpha_cpu, fg_cpu = _to_uint8_cpu(alpha_n1hw, fg_n3hw)
|
| 1019 |
|
|
|
|
|
|
|
|
|
|
| 1020 |
for i in range(alpha_cpu.shape[0]):
|
| 1021 |
+
alpha_writer.write(alpha_cpu[i]) # [H,W] uint8
|
| 1022 |
+
fg_writer.write(fg_cpu[i][..., ::-1].copy()) # RGB->BGR
|
| 1023 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1024 |
if hasattr(self._core, "last_mask"):
|
| 1025 |
self._last_alpha_1hw = self._core.last_mask
|
| 1026 |
|
| 1027 |
except torch.cuda.OutOfMemoryError as e:
|
|
|
|
| 1028 |
snap = _cuda_snapshot()
|
| 1029 |
_safe_empty_cache()
|
| 1030 |
+
# Re-raise with context for pipeline to catch
|
| 1031 |
+
raise MatAnyError(f"CUDA OOM in _flush_chunk | {snap}") from e
|
| 1032 |
|
| 1033 |
except Exception as e:
|
|
|
|
| 1034 |
snap = _cuda_snapshot()
|
| 1035 |
raise MatAnyError(f"MatAnyone failure in _flush_chunk: {e} | {snap}") from e
|
| 1036 |
|
| 1037 |
finally:
|
| 1038 |
+
# ensure we release heavy tensors
|
| 1039 |
+
try:
|
| 1040 |
+
del alpha_n1hw, fg_n3hw, frames_04chw
|
| 1041 |
+
except Exception:
|
| 1042 |
+
pass
|
|
|
|
| 1043 |
_safe_empty_cache()
|
| 1044 |
+
|
| 1045 |
def process_stream(self, frames_iterable, seed_1hw, alpha_writer, fg_writer, chunk_size=32):
|
| 1046 |
"""
|
| 1047 |
+
Buffer frames from iterable and process in chunks.
|
| 1048 |
+
On OOM, retry once with half chunk size; otherwise bubble up MatAnyError.
|
| 1049 |
"""
|
| 1050 |
frames_buf = []
|
|
|
|
|
|
|
| 1051 |
try:
|
| 1052 |
for f in frames_iterable:
|
| 1053 |
frames_buf.append(f)
|
|
|
|
| 1055 |
try:
|
| 1056 |
self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
|
| 1057 |
frames_buf.clear()
|
| 1058 |
+
except torch.cuda.OutOfMemoryError:
|
| 1059 |
+
# should be wrapped above, but double-guard
|
| 1060 |
+
raise
|
| 1061 |
+
except MatAnyError as inner:
|
| 1062 |
+
# one-time downshift
|
| 1063 |
if chunk_size > 4:
|
| 1064 |
half = max(4, chunk_size // 2)
|
|
|
|
| 1065 |
for i in range(0, len(frames_buf), half):
|
| 1066 |
sub = frames_buf[i:i+half]
|
| 1067 |
self._flush_chunk(sub, seed_1hw, alpha_writer, fg_writer)
|
| 1068 |
frames_buf.clear()
|
| 1069 |
else:
|
| 1070 |
+
raise inner
|
|
|
|
| 1071 |
|
|
|
|
| 1072 |
if frames_buf:
|
| 1073 |
self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
|
| 1074 |
frames_buf.clear()
|
| 1075 |
|
| 1076 |
except torch.cuda.OutOfMemoryError as e:
|
| 1077 |
+
snap = _cuda_snapshot()
|
| 1078 |
+
_safe_empty_cache()
|
| 1079 |
+
raise MatAnyError(f"CUDA OOM in process_stream outer | {snap}") from e
|
| 1080 |
+
|
| 1081 |
except Exception as e:
|
| 1082 |
+
raise MatAnyError(f"Unexpected error in process_stream: {e}") from e
|
| 1083 |
+
|
| 1084 |
finally:
|
| 1085 |
frames_buf.clear()
|
| 1086 |
_safe_empty_cache()
|
| 1087 |
+
if hasattr(self._core, "reset"):
|
|
|
|
| 1088 |
try:
|
| 1089 |
self._core.reset()
|
| 1090 |
except Exception:
|
| 1091 |
pass
|
|
|
|
|
|
|
|
|