MogensR commited on
Commit
db85143
·
1 Parent(s): 37f2d16

agent 2.1

Browse files
Files changed (1) hide show
  1. models/matanyone_loader.py +130 -66
models/matanyone_loader.py CHANGED
@@ -49,6 +49,95 @@ class MatAnyError(RuntimeError):
49
  pass
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def _to_device_batch(frames_bgr_np, device, dtype=torch.float16):
53
  """
54
  Convert a list/array of BGR uint8 frames [N,H,W,3] to a normalized
@@ -903,85 +992,62 @@ def process_stream(
903
 
904
  def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
905
  """
906
- Take an in-memory batch of frames (list of np.uint8 BGR), run MatAnyone on GPU,
907
- then write alpha/fg frames via the provided writers. Clears GPU memory on exit.
908
  """
909
- # Initialize variables for cleanup
910
- alpha_n1hw, fg_n3hw, frames_04chw = None, None, None
911
-
912
- try:
913
- device = self.device
914
- use_fp16 = (device.type == "cuda") and getattr(self, 'use_fp16', True)
915
- mode = _select_matany_mode(self._core)
916
 
917
- # Move input frames to device in a batched, pinned way
918
- frames_04chw = _to_device_batch(frames_bgr, device,
919
- dtype=torch.float16 if use_fp16 else torch.float32)
920
 
921
- # Move seed mask to device if provided
922
- seed_tensor = None
923
- if seed_1hw is not None:
924
- seed_tensor = torch.from_numpy(seed_1hw).to(device)
925
 
926
- # Process with CUDA stream if available
927
  if device.type == "cuda":
928
  stream = torch.cuda.Stream()
929
  with torch.cuda.stream(stream):
930
  with torch.autocast(device_type="cuda", enabled=use_fp16):
931
- alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw, seed_tensor)
932
- torch.cuda.synchronize()
933
  else:
934
- alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw, seed_tensor)
 
 
935
 
936
- # Write out results (convert back to CPU uint8)
937
- alpha_cpu = (alpha_n1hw.clamp(0, 1) * 255.0).byte().squeeze(1).contiguous().cpu().numpy() # [N,H,W]
938
-
939
  for i in range(alpha_cpu.shape[0]):
940
- # Write alpha mask
941
- alpha_uint8 = alpha_cpu[i]
942
- if len(alpha_uint8.shape) == 2: # Ensure 3 channels for writer
943
- alpha_uint8 = cv2.cvtColor(alpha_uint8, cv2.COLOR_GRAY2BGR)
944
- alpha_writer.write(alpha_uint8)
945
-
946
- # Write foreground (frame * alpha)
947
- alpha_expanded = alpha_cpu[i] / 255.0
948
- if alpha_expanded.ndim == 2:
949
- alpha_expanded = alpha_expanded[..., None] # [H,W,1]
950
- fg = (frames_bgr[i] * alpha_expanded).astype(np.uint8)
951
- fg_writer.write(fg)
952
-
953
- # Keep seed for temporal methods that need it
954
  if hasattr(self._core, "last_mask"):
955
  self._last_alpha_1hw = self._core.last_mask
956
 
957
  except torch.cuda.OutOfMemoryError as e:
958
- # Downshift strategy: smaller chunk or resolution; propagate with context
959
  snap = _cuda_snapshot()
960
  _safe_empty_cache()
961
- raise MatAnyError(f"CUDA OOM in _flush_chunk (before retry). Snapshot: {snap}") from e
 
962
 
963
  except Exception as e:
964
- # Convert unexpected exceptions to MatAnyError with context
965
  snap = _cuda_snapshot()
966
  raise MatAnyError(f"MatAnyone failure in _flush_chunk: {e} | {snap}") from e
967
 
968
  finally:
969
- # Hard cleanup to avoid lingering allocations between chunks
970
- for var in [alpha_n1hw, fg_n3hw, frames_04chw]:
971
- try:
972
- del var
973
- except Exception:
974
- pass
975
  _safe_empty_cache()
976
-
977
  def process_stream(self, frames_iterable, seed_1hw, alpha_writer, fg_writer, chunk_size=32):
978
  """
979
- Public entry that buffers frames from an iterator and processes them in chunks.
980
- Ensures cleanup and graceful degradation on OOM.
981
  """
982
  frames_buf = []
983
- last_error = None
984
-
985
  try:
986
  for f in frames_iterable:
987
  frames_buf.append(f)
@@ -989,39 +1055,37 @@ def process_stream(self, frames_iterable, seed_1hw, alpha_writer, fg_writer, chu
989
  try:
990
  self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
991
  frames_buf.clear()
992
- except MatAnyError as e:
993
- # Attempt one downshift: halve the chunk and retry once
 
 
 
994
  if chunk_size > 4:
995
  half = max(4, chunk_size // 2)
996
- # Split and try smaller batches
997
  for i in range(0, len(frames_buf), half):
998
  sub = frames_buf[i:i+half]
999
  self._flush_chunk(sub, seed_1hw, alpha_writer, fg_writer)
1000
  frames_buf.clear()
1001
  else:
1002
- last_error = e
1003
- break
1004
 
1005
- # Flush remainder
1006
  if frames_buf:
1007
  self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
1008
  frames_buf.clear()
1009
 
1010
  except torch.cuda.OutOfMemoryError as e:
1011
- last_error = MatAnyError(f"CUDA OOM in process_stream: {_cuda_snapshot()}") from e
1012
-
 
 
1013
  except Exception as e:
1014
- last_error = MatAnyError(f"Unexpected error in process_stream: {e}") from e
1015
-
1016
  finally:
1017
  frames_buf.clear()
1018
  _safe_empty_cache()
1019
- # Optional: if core has a reset method, call it
1020
- if hasattr(getattr(self, '_core', None), 'reset'):
1021
  try:
1022
  self._core.reset()
1023
  except Exception:
1024
  pass
1025
-
1026
- if last_error:
1027
- raise last_error
 
49
  pass
50
 
51
 
52
+ def _to_device_batch(frames_bgr_np, device, dtype=torch.float16):
53
+ """
54
+ frames_bgr_np: list or np.ndarray of shape [N,H,W,3], dtype=uint8, BGR
55
+ Returns torch tensor [N,3,H,W] on device, normalized to 0..1
56
+ """
57
+ if isinstance(frames_bgr_np, list):
58
+ frames_bgr_np = np.stack(frames_bgr_np, axis=0)
59
+ frames_rgb = frames_bgr_np[..., ::-1].copy(order="C") # BGR->RGB
60
+ pin = torch.from_numpy(frames_rgb).pin_memory() # [N,H,W,3]
61
+ t = pin.permute(0, 3, 1, 2).contiguous().to(device, non_blocking=True)
62
+ t = t.to(dtype=dtype) / 255.0
63
+ return t # [N,3,H,W]
64
+
65
+
66
+ def _select_matany_mode(core):
67
+ """Pick best available API."""
68
+ if hasattr(core, "process_frame"):
69
+ return "process_frame"
70
+ if hasattr(core, "_process_tensor_video"):
71
+ return "_process_tensor_video"
72
+ if hasattr(core, "step"):
73
+ return "step"
74
+ raise MatAnyError("MatAnyone core has no supported API (process_frame/_process_tensor_video/step).")
75
+
76
+
77
+ def _matany_run(core, mode, frames_04chw, seed_1hw=None, use_fp16=False):
78
+ """
79
+ Returns (alpha [N,1,H,W], fg [N,3,H,W]) on current device.
80
+ """
81
+ with torch.no_grad():
82
+ if mode == "process_frame":
83
+ alphas, fgs = [], []
84
+ for i in range(frames_04chw.shape[0]):
85
+ f = frames_04chw[i:i+1] # [1,3,H,W]
86
+ if seed_1hw is not None and seed_1hw.ndim == 3:
87
+ a, fg = core.process_frame(f, seed_1hw.unsqueeze(0))
88
+ else:
89
+ a, fg = core.process_frame(f)
90
+ alphas.append(a) # [1,1,H,W]
91
+ fgs.append(fg) # [1,3,H,W]
92
+ alpha = torch.cat(alphas, dim=0)
93
+ fg = torch.cat(fgs, dim=0)
94
+ return alpha, fg
95
+
96
+ elif mode == "_process_tensor_video":
97
+ # Many repos expect float32 for this path
98
+ return core._process_tensor_video(frames_04chw.float(), seed_1hw)
99
+
100
+ elif mode == "step":
101
+ alphas, fgs = [], []
102
+ for i in range(frames_04chw.shape[0]):
103
+ f = frames_04chw[i:i+1]
104
+ if i == 0 and seed_1hw is not None:
105
+ a, fg = core.step(f, seed_1hw)
106
+ else:
107
+ a, fg = core.step(f)
108
+ alphas.append(a)
109
+ fgs.append(fg)
110
+ alpha = torch.cat(alphas, dim=0)
111
+ fg = torch.cat(fgs, dim=0)
112
+ return alpha, fg
113
+
114
+ raise MatAnyError(f"Unsupported MatAnyone mode: {mode}")
115
+
116
+
117
+ def _cuda_snapshot():
118
+ if not torch.cuda.is_available():
119
+ return "CUDA: N/A"
120
+ i = torch.cuda.current_device()
121
+ return (f"device={i}, name={torch.cuda.get_device_name(i)}, "
122
+ f"alloc={torch.cuda.memory_allocated(i)/1e9:.2f}GB, "
123
+ f"reserved={torch.cuda.memory_reserved(i)/1e9:.2f}GB")
124
+
125
+
126
+ def _safe_empty_cache():
127
+ if torch.cuda.is_available():
128
+ try:
129
+ torch.cuda.synchronize()
130
+ except Exception:
131
+ pass
132
+ torch.cuda.empty_cache()
133
+
134
+
135
+ def _to_uint8_cpu(alpha_n1hw, fg_n3hw):
136
+ alpha_cpu = (alpha_n1hw.clamp(0, 1) * 255.0).byte().squeeze(1).contiguous().cpu().numpy() # [N,H,W]
137
+ fg_cpu = (fg_n3hw.clamp(0, 1) * 255.0).byte().permute(0, 2, 3, 1).contiguous().cpu().numpy() # [N,H,W,3] RGB
138
+ return alpha_cpu, fg_cpu
139
+
140
+
141
  def _to_device_batch(frames_bgr_np, device, dtype=torch.float16):
142
  """
143
  Convert a list/array of BGR uint8 frames [N,H,W,3] to a normalized
 
992
 
993
  def _flush_chunk(self, frames_bgr, seed_1hw, alpha_writer, fg_writer):
994
  """
995
+ Process an in-memory batch (list of uint8 BGR frames), write results via writers.
996
+ Strong CUDA guards + cleanup.
997
  """
998
+ device = self.device
999
+ use_fp16 = (device.type == "cuda") and getattr(self, "use_fp16", True)
1000
+ mode = _select_matany_mode(self._core)
 
 
 
 
1001
 
1002
+ frames_04chw = None
1003
+ alpha_n1hw = None
1004
+ fg_n3hw = None
1005
 
1006
+ try:
1007
+ frames_04chw = _to_device_batch(frames_bgr, device, dtype=torch.float16 if use_fp16 else torch.float32)
 
 
1008
 
 
1009
  if device.type == "cuda":
1010
  stream = torch.cuda.Stream()
1011
  with torch.cuda.stream(stream):
1012
  with torch.autocast(device_type="cuda", enabled=use_fp16):
1013
+ alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw, seed_1hw, use_fp16)
1014
+ stream.synchronize()
1015
  else:
1016
+ alpha_n1hw, fg_n3hw = _matany_run(self._core, mode, frames_04chw, seed_1hw, use_fp16)
1017
+
1018
+ alpha_cpu, fg_cpu = _to_uint8_cpu(alpha_n1hw, fg_n3hw)
1019
 
 
 
 
1020
  for i in range(alpha_cpu.shape[0]):
1021
+ alpha_writer.write(alpha_cpu[i]) # [H,W] uint8
1022
+ fg_writer.write(fg_cpu[i][..., ::-1].copy()) # RGB->BGR
1023
+
 
 
 
 
 
 
 
 
 
 
 
1024
  if hasattr(self._core, "last_mask"):
1025
  self._last_alpha_1hw = self._core.last_mask
1026
 
1027
  except torch.cuda.OutOfMemoryError as e:
 
1028
  snap = _cuda_snapshot()
1029
  _safe_empty_cache()
1030
+ # Re-raise with context for pipeline to catch
1031
+ raise MatAnyError(f"CUDA OOM in _flush_chunk | {snap}") from e
1032
 
1033
  except Exception as e:
 
1034
  snap = _cuda_snapshot()
1035
  raise MatAnyError(f"MatAnyone failure in _flush_chunk: {e} | {snap}") from e
1036
 
1037
  finally:
1038
+ # ensure we release heavy tensors
1039
+ try:
1040
+ del alpha_n1hw, fg_n3hw, frames_04chw
1041
+ except Exception:
1042
+ pass
 
1043
  _safe_empty_cache()
1044
+
1045
  def process_stream(self, frames_iterable, seed_1hw, alpha_writer, fg_writer, chunk_size=32):
1046
  """
1047
+ Buffer frames from iterable and process in chunks.
1048
+ On OOM, retry once with half chunk size; otherwise bubble up MatAnyError.
1049
  """
1050
  frames_buf = []
 
 
1051
  try:
1052
  for f in frames_iterable:
1053
  frames_buf.append(f)
 
1055
  try:
1056
  self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
1057
  frames_buf.clear()
1058
+ except torch.cuda.OutOfMemoryError:
1059
+ # should be wrapped above, but double-guard
1060
+ raise
1061
+ except MatAnyError as inner:
1062
+ # one-time downshift
1063
  if chunk_size > 4:
1064
  half = max(4, chunk_size // 2)
 
1065
  for i in range(0, len(frames_buf), half):
1066
  sub = frames_buf[i:i+half]
1067
  self._flush_chunk(sub, seed_1hw, alpha_writer, fg_writer)
1068
  frames_buf.clear()
1069
  else:
1070
+ raise inner
 
1071
 
 
1072
  if frames_buf:
1073
  self._flush_chunk(frames_buf, seed_1hw, alpha_writer, fg_writer)
1074
  frames_buf.clear()
1075
 
1076
  except torch.cuda.OutOfMemoryError as e:
1077
+ snap = _cuda_snapshot()
1078
+ _safe_empty_cache()
1079
+ raise MatAnyError(f"CUDA OOM in process_stream outer | {snap}") from e
1080
+
1081
  except Exception as e:
1082
+ raise MatAnyError(f"Unexpected error in process_stream: {e}") from e
1083
+
1084
  finally:
1085
  frames_buf.clear()
1086
  _safe_empty_cache()
1087
+ if hasattr(self._core, "reset"):
 
1088
  try:
1089
  self._core.reset()
1090
  except Exception:
1091
  pass