VideoBackgroundReplacer2 / models /model_loaders.py
MogensR's picture
Update models/model_loaders.py
c9f07e1 verified
raw
history blame
8.25 kB
#!/usr/bin/env python3
"""
Model Loading and Memory Management
Handles lazy loading of SAM2 and MatAnyone models with caching
(Enhanced logging, error handling, and memory safety)
"""
import os
import gc
import logging
import streamlit as st
import torch
import psutil
import mediapipe as mp
from contextlib import contextmanager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@contextmanager
def torch_memory_manager():
try:
logger.info("[torch_memory_manager] Enter") # [LOG+SAFETY PATCH]
yield
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("[torch_memory_manager] Exit, cleaned up") # [LOG+SAFETY PATCH]
def get_memory_usage():
memory_info = {}
if torch.cuda.is_available():
memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9
memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9
memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory -
torch.cuda.memory_allocated()) / 1e9
memory_info['ram_used'] = psutil.virtual_memory().used / 1e9
memory_info['ram_available'] = psutil.virtual_memory().available / 1e9
logger.info(f"[get_memory_usage] {memory_info}") # [LOG+SAFETY PATCH]
return memory_info
def clear_model_cache():
logger.info("[clear_model_cache] Clearing all model caches...") # [LOG+SAFETY PATCH]
if hasattr(st, 'cache_resource'):
st.cache_resource.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("[clear_model_cache] Model cache cleared") # [LOG+SAFETY PATCH]
@st.cache_resource(show_spinner=False)
def load_sam2_predictor():
try:
logger.info("[load_sam2_predictor] Loading SAM2 image predictor...") # [LOG+SAFETY PATCH]
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"[load_sam2_predictor] Using device: {device}")
checkpoint_path = "/home/user/app/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "/home/user/app/configs/sam2.1/sam2.1_hiera_l.yaml"
if not os.path.exists(checkpoint_path) or not os.path.exists(model_cfg):
logger.warning("[load_sam2_predictor] Local checkpoints not found, using Hugging Face.")
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-large",
device=device
)
else:
memory_info = get_memory_usage()
gpu_free = memory_info.get('gpu_free', 0)
if device == "cuda" and gpu_free < 4.0:
logger.warning(f"[load_sam2_predictor] Limited GPU memory ({gpu_free:.1f}GB), using smaller SAM2 model.")
try:
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-tiny",
device=device
)
except Exception as e:
logger.warning(f"[load_sam2_predictor] Tiny model failed, trying small. {e}")
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-small",
device=device
)
else:
logger.info("[load_sam2_predictor] Using local large model")
sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
predictor = SAM2ImagePredictor(sam2_model)
if hasattr(predictor, 'model'):
predictor.model.to(device)
predictor.model.eval()
logger.info(f"[load_sam2_predictor] SAM2 model moved to {device} and set to eval mode")
logger.info(f"βœ… SAM2 loaded successfully on {device}!")
return predictor, device
except Exception as e:
logger.error(f"❌ Failed to load SAM2 predictor: {e}", exc_info=True)
import traceback
traceback.print_exc()
return None, None
def load_sam2():
predictor, device = load_sam2_predictor()
return predictor
@st.cache_resource(show_spinner=False)
def load_matanyone_processor():
try:
logger.info("[load_matanyone_processor] Loading MatAnyone processor...") # [LOG+SAFETY PATCH]
from matanyone import InferenceCore
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"[load_matanyone_processor] MatAnyone using device: {device}")
processor = InferenceCore("PeiqingYang/MatAnyone", device=device)
if hasattr(processor, 'model'):
processor.model.to(device)
processor.model.eval()
logger.info(f"[load_matanyone_processor] MatAnyone model explicitly moved to {device}")
if not hasattr(processor, 'device'):
processor.device = device
logger.info(f"[load_matanyone_processor] Set processor.device to {device}")
logger.info(f"βœ… MatAnyone loaded successfully on {device}!")
return processor, device
except Exception as e:
logger.error(f"❌ Failed to load MatAnyone: {e}", exc_info=True)
import traceback
traceback.print_exc()
return None, None
def load_matanyone():
processor, device = load_matanyone_processor()
return processor
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(
static_image_mode=False,
model_complexity=1,
enable_segmentation=True,
min_detection_confidence=0.5
)
logger.info("βœ… MediaPipe Pose initialized",) # [LOG+SAFETY PATCH]
def test_models():
results = {
'sam2': {'loaded': False, 'error': None, 'device': None},
'matanyone': {'loaded': False, 'error': None, 'device': None}
}
try:
sam2_predictor, sam2_device = load_sam2_predictor()
if sam2_predictor is not None:
results['sam2']['loaded'] = True
results['sam2']['device'] = sam2_device
else:
results['sam2']['error'] = "Predictor returned None"
except Exception as e:
results['sam2']['error'] = str(e)
logger.error(f"[test_models] SAM2 error: {e}", exc_info=True)
try:
matanyone_processor, matanyone_device = load_matanyone_processor()
if matanyone_processor is not None:
results['matanyone']['loaded'] = True
results['matanyone']['device'] = matanyone_device
else:
results['matanyone']['error'] = "Processor returned None"
except Exception as e:
results['matanyone']['error'] = str(e)
logger.error(f"[test_models] MatAnyone error: {e}", exc_info=True)
logger.info(f"[test_models] Results: {results}") # [LOG+SAFETY PATCH]
return results
def log_memory_usage(stage=""):
memory_info = get_memory_usage()
log_msg = f"Memory usage"
if stage:
log_msg += f" ({stage})"
log_msg += ":"
if 'gpu_allocated' in memory_info:
log_msg += f" GPU {memory_info['gpu_allocated']:.1f}GB allocated, {memory_info['gpu_free']:.1f}GB free"
log_msg += f" | RAM {memory_info['ram_used']:.1f}GB used"
print(log_msg, flush=True)
logger.info(log_msg)
return memory_info
def check_memory_available(required_gb=2.0):
if not torch.cuda.is_available():
return False, 0.0
memory_info = get_memory_usage()
free_gb = memory_info.get('gpu_free', 0)
logger.info(f"[check_memory_available] free_gb={free_gb}, required={required_gb}") # [LOG+SAFETY PATCH]
return free_gb >= required_gb, free_gb
def free_memory_aggressive():
logger.info("[free_memory_aggressive] Performing aggressive memory cleanup...") # [LOG+SAFETY PATCH]
print("Performing aggressive memory cleanup...", flush=True)
clear_model_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
try:
torch.cuda.ipc_collect()
except Exception:
pass
gc.collect()
print("Memory cleanup complete", flush=True)
logger.info("Memory cleanup complete")
log_memory_usage("after cleanup")