MogensR commited on
Commit
fb9272e
·
1 Parent(s): e7f8adb
Files changed (2) hide show
  1. models/matanyone_loader.py +63 -1
  2. models/sam2_loader.py +61 -1
models/matanyone_loader.py CHANGED
@@ -15,6 +15,7 @@
15
  - Added VRAM logging in process_stream (MATANY_LOG_VRAM=1)
16
  - Enhanced _safe_empty_cache with memory_summary
17
  - Added MatAnyone version logging
 
18
  """
19
 
20
  from __future__ import annotations
@@ -422,4 +423,65 @@ def process_stream(
422
  _emit_progress(progress_cb, 1.0, "MatAnyone: done")
423
  elapsed = time.time() - start
424
  log.info(f"MatAnyone completed: {idx} frames in {elapsed:.1f}s")
425
- return alpha_path, fg_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  - Added VRAM logging in process_stream (MATANY_LOG_VRAM=1)
16
  - Enhanced _safe_empty_cache with memory_summary
17
  - Added MatAnyone version logging
18
+ - Added MatAnyoneModel wrapper class for app_hf.py compatibility
19
  """
20
 
21
  from __future__ import annotations
 
423
  _emit_progress(progress_cb, 1.0, "MatAnyone: done")
424
  elapsed = time.time() - start
425
  log.info(f"MatAnyone completed: {idx} frames in {elapsed:.1f}s")
426
+ return alpha_path, fg_path
427
+
428
+ # ============================================================================
429
+ # MatAnyoneModel Wrapper Class for app_hf.py compatibility
430
+ # ============================================================================
431
+
432
+ class MatAnyoneModel:
433
+ """Wrapper class for MatAnyone to match app_hf.py interface"""
434
+
435
+ def __init__(self, device="cuda"):
436
+ self.device = device
437
+ self.session = None
438
+ self.loaded = False
439
+ log.info(f"Initializing MatAnyoneModel on device: {device}")
440
+
441
+ # Initialize the session
442
+ self._load_model()
443
+
444
+ def _load_model(self):
445
+ """Load the MatAnyone session"""
446
+ try:
447
+ self.session = MatAnyoneSession(device=self.device, precision="auto")
448
+ self.loaded = True
449
+ log.info("MatAnyoneModel loaded successfully")
450
+ except Exception as e:
451
+ log.error(f"Error loading MatAnyoneModel: {e}")
452
+ self.loaded = False
453
+
454
+ def replace_background(self, video_path, masks, background_path):
455
+ """Replace background in video using MatAnyone"""
456
+ if not self.loaded:
457
+ raise MatAnyError("MatAnyoneModel not loaded")
458
+
459
+ try:
460
+ from pathlib import Path
461
+ import tempfile
462
+
463
+ # Convert paths to Path objects
464
+ video_path = Path(video_path)
465
+
466
+ # Handle masks - for now, we'll use the session without a separate mask file
467
+ # since MatAnyone expects SAM2 to provide the initial seed mask
468
+
469
+ # Create output directory
470
+ with tempfile.TemporaryDirectory() as temp_dir:
471
+ output_dir = Path(temp_dir)
472
+
473
+ # Process the video stream
474
+ alpha_path, fg_path = self.session.process_stream(
475
+ video_path=video_path,
476
+ seed_mask_path=None, # We'll rely on SAM2 integration
477
+ out_dir=output_dir,
478
+ progress_cb=None
479
+ )
480
+
481
+ # For now, return the foreground video
482
+ # In a full implementation, you'd composite with the background_path
483
+ return str(fg_path)
484
+
485
+ except Exception as e:
486
+ log.error(f"Error in replace_background: {e}")
487
+ raise MatAnyError(f"Background replacement failed: {e}")
models/sam2_loader.py CHANGED
@@ -12,6 +12,7 @@
12
  - Added SAM2 version logging via importlib.metadata
13
  - Simplified config resolution to match __init__.py
14
  - Fixed missing sys and inspect imports
 
15
  """
16
 
17
  from __future__ import annotations
@@ -279,4 +280,63 @@ def run_sam2_mask(predictor: object,
279
  return m, True
280
  except Exception as e:
281
  logger.warning(f"SAM2 mask generation failed: {e}")
282
- return None, False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  - Added SAM2 version logging via importlib.metadata
13
  - Simplified config resolution to match __init__.py
14
  - Fixed missing sys and inspect imports
15
+ - Added SAM2Model wrapper class for app_hf.py compatibility
16
  """
17
 
18
  from __future__ import annotations
 
280
  return m, True
281
  except Exception as e:
282
  logger.warning(f"SAM2 mask generation failed: {e}")
283
+ return None, False
284
+
285
+ # --------------------------------------------------------------------------------------
286
+ # SAM2Model Wrapper Class for app_hf.py compatibility
287
+ # --------------------------------------------------------------------------------------
288
+ class SAM2Model:
289
+ """Wrapper class for SAM2 model to match app_hf.py interface"""
290
+
291
+ def __init__(self, device="cuda"):
292
+ self.device = device
293
+ self.predictor = None
294
+ self.loaded = False
295
+ logger.info(f"Initializing SAM2Model on device: {device}")
296
+
297
+ # Load the model immediately
298
+ self._load_model()
299
+
300
+ def _load_model(self):
301
+ """Load the SAM2 model"""
302
+ try:
303
+ self.predictor, self.loaded, meta = load_sam2()
304
+ if self.loaded:
305
+ logger.info("SAM2Model loaded successfully")
306
+ else:
307
+ logger.error("Failed to load SAM2Model")
308
+ except Exception as e:
309
+ logger.error(f"Error loading SAM2Model: {e}")
310
+ self.loaded = False
311
+
312
+ def predict(self, video_path):
313
+ """Generate masks for video frames"""
314
+ if not self.loaded:
315
+ logger.error("SAM2Model not loaded")
316
+ return None
317
+
318
+ try:
319
+ import cv2
320
+
321
+ # Read first frame of video to generate initial mask
322
+ cap = cv2.VideoCapture(video_path)
323
+ ret, frame = cap.read()
324
+ cap.release()
325
+
326
+ if not ret:
327
+ logger.error(f"Could not read video: {video_path}")
328
+ return None
329
+
330
+ # Generate mask for the frame
331
+ mask, success = run_sam2_mask(self.predictor, frame, auto=True)
332
+
333
+ if success:
334
+ logger.info("Successfully generated mask from video")
335
+ return mask
336
+ else:
337
+ logger.error("Failed to generate mask from video")
338
+ return None
339
+
340
+ except Exception as e:
341
+ logger.error(f"Error predicting masks: {e}")
342
+ return None