MogensR commited on
Commit
6095d82
·
1 Parent(s): b4ba9ec

Update models/loaders/matanyone_loader.py

Browse files
Files changed (1) hide show
  1. models/loaders/matanyone_loader.py +150 -188
models/loaders/matanyone_loader.py CHANGED
@@ -1,10 +1,9 @@
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
  """
4
- MatAnyone Loader - Official InferenceCore API Implementation
5
- ============================================================
6
- Fixed to use official MatAnyone API to resolve tensor dimension issues.
7
- No manual tensor manipulation - let InferenceCore handle everything internally.
8
  """
9
 
10
  import os
@@ -22,11 +21,117 @@
22
  logger = logging.getLogger(__name__)
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class MatAnyoneLoader:
26
  """
27
- Official MatAnyone loader using InferenceCore API.
28
- This fixes the tensor dimension mismatch by using the official API
29
- which handles all tensor dimensions internally.
30
  """
31
 
32
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
@@ -35,6 +140,7 @@ def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyo
35
  os.makedirs(self.cache_dir, exist_ok=True)
36
 
37
  self.processor = None
 
38
  self.model_id = "PeiqingYang/MatAnyone"
39
  self.load_time = 0.0
40
  self.loaded = False
@@ -50,10 +156,10 @@ def _select_device(self, pref: str) -> str:
50
  return "cpu"
51
  return "cuda" if torch.cuda.is_available() else "cpu"
52
 
53
- def load(self): # <-- CHANGED: No return type hint, returns processor
54
- """Load MatAnyone using official InferenceCore API."""
55
- if self.loaded:
56
- return self.processor # <-- CHANGED: Return processor, not True
57
 
58
  logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
59
  t0 = time.time()
@@ -62,174 +168,32 @@ def load(self): # <-- CHANGED: No return type hint, returns processor
62
  # Import the official API
63
  from matanyone.inference.inference_core import InferenceCore
64
 
65
- # Use official API - this handles ALL tensor dimensions internally
66
- # No manual tensor reshaping needed!
67
  self.processor = InferenceCore(self.model_id)
68
 
 
 
 
69
  self.loaded = True
70
  self.load_time = time.time() - t0
71
- logger.info(f"MatAnyone loaded successfully via InferenceCore API in {self.load_time:.2f}s")
72
- return self.processor # <-- CHANGED: Return processor, not True
73
 
74
  except ImportError as e:
75
  self.load_error = f"MatAnyone not installed: {e}"
76
  logger.error(f"Failed to import MatAnyone. Install with: pip install git+https://github.com/pq-yang/MatAnyone.git@main")
77
- return None # <-- CHANGED: Return None on failure
78
 
79
  except Exception as e:
80
  self.load_error = str(e)
81
  logger.error(f"Failed to load MatAnyone: {e}")
82
  logger.debug(traceback.format_exc())
83
- return None # <-- CHANGED: Return None on failure
84
-
85
- def process_video(self, video_path: str, mask_path: str, output_dir: Optional[str] = None,
86
- max_size: int = 720, save_frames: bool = False) -> Tuple[Optional[str], Optional[str]]:
87
- """
88
- Process video using official MatAnyone API.
89
-
90
- Args:
91
- video_path: Path to input video
92
- mask_path: Path to first frame mask
93
- output_dir: Output directory (uses temp if None)
94
- max_size: Maximum resolution (-1 for original)
95
- save_frames: Whether to save individual frames
96
-
97
- Returns:
98
- (foreground_path, alpha_path) or (None, None) on error
99
- """
100
- if not self.loaded:
101
- if not self.load():
102
- logger.error(f"MatAnyone not loaded: {self.load_error}")
103
- return None, None
104
-
105
- if output_dir is None:
106
- output_dir = str(self.temp_dir)
107
-
108
- try:
109
- # Use official API - no tensor manipulation needed!
110
- # The API handles all dimension requirements internally
111
- foreground_path, alpha_path = self.processor.process_video(
112
- input_path=str(video_path),
113
- mask_path=str(mask_path),
114
- output_path=str(output_dir),
115
- max_size=max_size,
116
- save_frames=save_frames
117
- )
118
-
119
- logger.info(f"MatAnyone processing complete: fg={foreground_path}, alpha={alpha_path}")
120
- return foreground_path, alpha_path
121
-
122
- except Exception as e:
123
- logger.error(f"MatAnyone processing failed: {e}")
124
- logger.debug(traceback.format_exc())
125
- return None, None
126
-
127
- def process_frames_to_alpha(self, frames: np.ndarray, initial_mask: np.ndarray,
128
- output_dir: Optional[str] = None) -> Optional[np.ndarray]:
129
- """
130
- Process video frames and return alpha masks.
131
- This is a compatibility wrapper for frame-based processing.
132
-
133
- Args:
134
- frames: Video frames as numpy array (T, H, W, C) or list
135
- initial_mask: First frame mask (H, W) with values 0-255
136
- output_dir: Optional output directory
137
-
138
- Returns:
139
- Alpha masks array (T, H, W) or None on error
140
- """
141
- if not self.loaded:
142
- if not self.load():
143
- return None
144
-
145
- if output_dir is None:
146
- output_dir = str(self.temp_dir)
147
-
148
- # Save frames as temporary video
149
- temp_video_path = Path(output_dir) / "temp_input.mp4"
150
- temp_mask_path = Path(output_dir) / "temp_mask.png"
151
-
152
- try:
153
- # Convert frames to video
154
- if isinstance(frames, list):
155
- frames = np.stack(frames)
156
-
157
- # Ensure correct format
158
- if frames.ndim == 5: # (B, C, T, H, W) or similar
159
- # Take first batch, rearrange to (T, H, W, C)
160
- frames = frames[0]
161
- if frames.shape[0] == 3: # Channels first
162
- frames = frames.transpose(1, 2, 3, 0)
163
- elif frames.ndim == 4 and frames.shape[1] == 3: # (T, C, H, W)
164
- frames = frames.transpose(0, 2, 3, 1)
165
-
166
- # Write video
167
- fps = 30
168
- height, width = frames.shape[1:3]
169
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
170
- out = cv2.VideoWriter(str(temp_video_path), fourcc, fps, (width, height))
171
-
172
- for frame in frames:
173
- if frame.dtype in (np.float32, np.float64):
174
- frame = (frame * 255).astype(np.uint8)
175
- if frame.shape[-1] == 3:
176
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
177
- out.write(frame)
178
- out.release()
179
-
180
- # Save mask
181
- if initial_mask.dtype in (np.float32, np.float64):
182
- initial_mask = (initial_mask * 255).astype(np.uint8)
183
- cv2.imwrite(str(temp_mask_path), initial_mask)
184
-
185
- # Process with official API
186
- _, alpha_path = self.process_video(
187
- str(temp_video_path),
188
- str(temp_mask_path),
189
- str(output_dir)
190
- )
191
-
192
- if alpha_path:
193
- # Load alpha video and return as array
194
- return self._load_alpha_video(alpha_path)
195
-
196
- return None
197
-
198
- except Exception as e:
199
- logger.error(f"Frame processing failed: {e}")
200
- return None
201
- finally:
202
- # Cleanup temp files
203
- if temp_video_path.exists():
204
- temp_video_path.unlink()
205
- if temp_mask_path.exists():
206
- temp_mask_path.unlink()
207
-
208
- def _load_alpha_video(self, alpha_video_path: str) -> Optional[np.ndarray]:
209
- """Load alpha video and return as numpy array."""
210
- try:
211
- cap = cv2.VideoCapture(str(alpha_video_path))
212
- frames = []
213
-
214
- while True:
215
- ret, frame = cap.read()
216
- if not ret:
217
- break
218
- # Convert to grayscale if needed
219
- if len(frame.shape) == 3:
220
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
221
- frames.append(frame / 255.0) # Normalize to 0-1
222
-
223
- cap.release()
224
- return np.array(frames) if frames else None
225
-
226
- except Exception as e:
227
- logger.error(f"Failed to load alpha video: {e}")
228
  return None
229
 
230
  def cleanup(self):
231
  """Cleanup temporary files and release resources."""
232
  self.processor = None
 
233
 
234
  # Clean temp directory
235
  if self.temp_dir.exists():
@@ -242,45 +206,43 @@ def cleanup(self):
242
 
243
  def get_info(self) -> Dict[str, Any]:
244
  """Get model information."""
245
- return {
246
  "loaded": self.loaded,
247
  "model_id": self.model_id,
248
  "device": str(self.device),
249
  "load_time": self.load_time,
250
  "error": self.load_error,
251
- "api": "InferenceCore (official)"
252
  }
 
 
 
 
 
 
 
 
253
 
254
  def reset(self):
255
  """Reset the processor for a new video."""
256
- # The official API handles session management internally
257
- # Just log that reset was called
258
- logger.info("MatAnyone session reset requested (handled by InferenceCore)")
259
 
260
- # Compatibility method for existing code that might call this
261
  def __call__(self, image, mask=None, **kwargs):
262
- """
263
- Direct call compatibility wrapper.
264
- For single frame processing or backwards compatibility.
265
- """
266
- if isinstance(image, (list, np.ndarray)) and mask is not None:
267
- # Process as frames
268
- if not isinstance(image, np.ndarray):
269
- image = np.array(image)
270
- if image.ndim == 3: # Single frame
271
- image = image[np.newaxis, ...]
272
-
273
- alphas = self.process_frames_to_alpha(image, mask)
274
- if alphas is not None and len(alphas) > 0:
275
- return alphas[0] if alphas.shape[0] == 1 else alphas
276
 
277
- # Fallback
278
- logger.warning("Direct call to MatAnyoneLoader not fully supported with official API")
279
- return mask if mask is not None else np.zeros(image.shape[:2], dtype=np.float32)
280
-
281
 
282
- # For backwards compatibility - expose session class name even though we don't use it
283
- _MatAnyoneSession = MatAnyoneLoader # Alias for compatibility
284
 
 
 
285
 
286
- __all__ = ["MatAnyoneLoader", "_MatAnyoneSession"]
 
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
  """
4
+ MatAnyone Loader - Wrapper for Official InferenceCore API
5
+ =========================================================
6
+ Creates a callable wrapper around InferenceCore to maintain compatibility.
 
7
  """
8
 
9
  import os
 
21
  logger = logging.getLogger(__name__)
22
 
23
 
24
+ class MatAnyoneCallableWrapper:
25
+ """
26
+ Callable wrapper around InferenceCore to maintain API compatibility.
27
+ Makes the processor work like a callable session.
28
+ """
29
+
30
+ def __init__(self, inference_core):
31
+ self.core = inference_core
32
+ self.initialized = False
33
+
34
+ def __call__(self, image, mask=None, **kwargs):
35
+ """
36
+ Make this wrapper callable like the old session interface.
37
+
38
+ Args:
39
+ image: Input image as numpy array
40
+ mask: Optional mask for first frame
41
+
42
+ Returns:
43
+ Alpha mask as 2D numpy array
44
+ """
45
+ try:
46
+ # For MatAnyone, the first frame needs initialization with a mask
47
+ if not self.initialized:
48
+ if mask is None:
49
+ # Return a default mask if no mask provided for first frame
50
+ logger.warning("First frame called without mask, returning default")
51
+ if isinstance(image, np.ndarray):
52
+ h, w = image.shape[:2]
53
+ else:
54
+ h, w = 512, 512
55
+ return np.ones((h, w), dtype=np.float32) * 0.5
56
+
57
+ # Initialize with first frame and mask
58
+ # The exact API call depends on the InferenceCore implementation
59
+ # This is a placeholder - adjust based on actual API
60
+ if hasattr(self.core, 'step'):
61
+ result = self.core.step(image=image, mask=mask)
62
+ elif hasattr(self.core, 'process_frame'):
63
+ result = self.core.process_frame(image, mask)
64
+ else:
65
+ # Fallback
66
+ logger.warning("InferenceCore API unclear, returning input mask")
67
+ return mask if isinstance(mask, np.ndarray) else np.array(mask)
68
+
69
+ self.initialized = True
70
+ return self._extract_alpha(result)
71
+ else:
72
+ # Subsequent frames - no mask needed
73
+ if hasattr(self.core, 'step'):
74
+ result = self.core.step(image=image)
75
+ elif hasattr(self.core, 'process_frame'):
76
+ result = self.core.process_frame(image)
77
+ else:
78
+ # Fallback - return neutral mask
79
+ if isinstance(image, np.ndarray):
80
+ h, w = image.shape[:2]
81
+ else:
82
+ h, w = 512, 512
83
+ return np.ones((h, w), dtype=np.float32) * 0.5
84
+
85
+ return self._extract_alpha(result)
86
+
87
+ except Exception as e:
88
+ logger.error(f"MatAnyone wrapper call failed: {e}")
89
+ # Return a fallback mask
90
+ if mask is not None:
91
+ return mask if isinstance(mask, np.ndarray) else np.array(mask)
92
+ if isinstance(image, np.ndarray):
93
+ h, w = image.shape[:2]
94
+ else:
95
+ h, w = 512, 512
96
+ return np.ones((h, w), dtype=np.float32) * 0.5
97
+
98
+ def _extract_alpha(self, result):
99
+ """Extract alpha channel from result."""
100
+ if result is None:
101
+ return np.ones((512, 512), dtype=np.float32) * 0.5
102
+
103
+ if isinstance(result, np.ndarray):
104
+ if result.ndim == 2:
105
+ return result.astype(np.float32)
106
+ elif result.ndim == 3:
107
+ # Take first channel or average
108
+ return result[..., 0].astype(np.float32)
109
+ elif result.ndim == 4:
110
+ # Batch dimension - take first
111
+ return result[0, 0].astype(np.float32)
112
+
113
+ # Try to convert to numpy
114
+ try:
115
+ arr = np.array(result)
116
+ if arr.ndim >= 2:
117
+ return arr[..., 0] if arr.ndim > 2 else arr
118
+ except:
119
+ pass
120
+
121
+ return np.ones((512, 512), dtype=np.float32) * 0.5
122
+
123
+ def reset(self):
124
+ """Reset the session state."""
125
+ self.initialized = False
126
+ if hasattr(self.core, 'reset'):
127
+ self.core.reset()
128
+ elif hasattr(self.core, 'clear_memory'):
129
+ self.core.clear_memory()
130
+
131
+
132
  class MatAnyoneLoader:
133
  """
134
+ Official MatAnyone loader using InferenceCore API with callable wrapper.
 
 
135
  """
136
 
137
  def __init__(self, device: str = "cuda", cache_dir: str = "./checkpoints/matanyone_cache"):
 
140
  os.makedirs(self.cache_dir, exist_ok=True)
141
 
142
  self.processor = None
143
+ self.wrapper = None
144
  self.model_id = "PeiqingYang/MatAnyone"
145
  self.load_time = 0.0
146
  self.loaded = False
 
156
  return "cpu"
157
  return "cuda" if torch.cuda.is_available() else "cpu"
158
 
159
+ def load(self):
160
+ """Load MatAnyone using official InferenceCore API and wrap it."""
161
+ if self.loaded and self.wrapper:
162
+ return self.wrapper
163
 
164
  logger.info(f"Loading MatAnyone from HF: {self.model_id} (device={self.device})")
165
  t0 = time.time()
 
168
  # Import the official API
169
  from matanyone.inference.inference_core import InferenceCore
170
 
171
+ # Create the InferenceCore processor
 
172
  self.processor = InferenceCore(self.model_id)
173
 
174
+ # Wrap it to make it callable
175
+ self.wrapper = MatAnyoneCallableWrapper(self.processor)
176
+
177
  self.loaded = True
178
  self.load_time = time.time() - t0
179
+ logger.info(f"MatAnyone loaded and wrapped successfully in {self.load_time:.2f}s")
180
+ return self.wrapper
181
 
182
  except ImportError as e:
183
  self.load_error = f"MatAnyone not installed: {e}"
184
  logger.error(f"Failed to import MatAnyone. Install with: pip install git+https://github.com/pq-yang/MatAnyone.git@main")
185
+ return None
186
 
187
  except Exception as e:
188
  self.load_error = str(e)
189
  logger.error(f"Failed to load MatAnyone: {e}")
190
  logger.debug(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  return None
192
 
193
  def cleanup(self):
194
  """Cleanup temporary files and release resources."""
195
  self.processor = None
196
+ self.wrapper = None
197
 
198
  # Clean temp directory
199
  if self.temp_dir.exists():
 
206
 
207
  def get_info(self) -> Dict[str, Any]:
208
  """Get model information."""
209
+ info = {
210
  "loaded": self.loaded,
211
  "model_id": self.model_id,
212
  "device": str(self.device),
213
  "load_time": self.load_time,
214
  "error": self.load_error,
215
+ "api": "InferenceCore (wrapped)"
216
  }
217
+
218
+ # Add interface info
219
+ if self.processor:
220
+ info["has_step"] = hasattr(self.processor, 'step')
221
+ info["has_process_frame"] = hasattr(self.processor, 'process_frame')
222
+ info["has_process_video"] = hasattr(self.processor, 'process_video')
223
+
224
+ return info
225
 
226
  def reset(self):
227
  """Reset the processor for a new video."""
228
+ if self.wrapper:
229
+ self.wrapper.reset()
230
+ logger.info("MatAnyone session reset")
231
 
232
+ # Compatibility - make the loader itself callable
233
  def __call__(self, image, mask=None, **kwargs):
234
+ """Direct call compatibility."""
235
+ if not self.wrapper:
236
+ if not self.load():
237
+ # Fallback if loading fails
238
+ if mask is not None:
239
+ return mask if isinstance(mask, np.ndarray) else np.array(mask)
240
+ return np.zeros(image.shape[:2], dtype=np.float32)
 
 
 
 
 
 
 
241
 
242
+ return self.wrapper(image, mask, **kwargs)
 
 
 
243
 
 
 
244
 
245
+ # For backwards compatibility
246
+ _MatAnyoneSession = MatAnyoneCallableWrapper
247
 
248
+ __all__ = ["MatAnyoneLoader", "_MatAnyoneSession", "MatAnyoneCallableWrapper"]