File size: 8,326 Bytes
eee126e cc4f3fb c9f07e1 eee126e cc4f3fb eee126e c9f07e1 cc4f3fb c9f07e1 eee126e cc4f3fb eee126e 3814cb0 eee126e 3814cb0 cc4f3fb eee126e 00e27de cc4f3fb eee126e 3814cb0 eee126e cc4f3fb eee126e cc4f3fb 3814cb0 eee126e 3814cb0 cc4f3fb eee126e cc4f3fb eee126e 3814cb0 eee126e 57682af c9f07e1 eee126e c9f07e1 57682af eee126e 57682af c9f07e1 eee126e 57682af c9f07e1 57682af eee126e c9f07e1 57682af f72657c c9f07e1 cc4f3fb eee126e c9f07e1 57682af cc4f3fb 30b6279 cc4f3fb f72657c cc4f3fb eee126e cc4f3fb eee126e 3814cb0 eee126e 57682af c9f07e1 3814cb0 f72657c c9f07e1 f72657c c9f07e1 cc4f3fb eee126e c9f07e1 57682af cc4f3fb 30b6279 cc4f3fb f72657c cc4f3fb eee126e cc4f3fb eee126e cc4f3fb eee126e cc4f3fb eee126e c9f07e1 eee126e cc4f3fb eee126e c9f07e1 3814cb0 eee126e cc4f3fb eee126e f72657c eee126e cc4f3fb eee126e 3814cb0 eee126e cc4f3fb eee126e cc4f3fb 3814cb0 f72657c eee126e f72657c eee126e cc4f3fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
#!/usr/bin/env python3
"""
Model Loading and Memory Management
Handles lazy loading of SAM2 and MatAnyone models with caching.
(Enhanced logging, error handling, and memory safety)
"""
import os
import gc
import logging
import streamlit as st
import torch
import psutil
from contextlib import contextmanager
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@contextmanager
def torch_memory_manager():
try:
logger.info("[torch_memory_manager] Enter")
yield
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("[torch_memory_manager] Exit, cleaned up")
def get_memory_usage():
memory_info = {}
if torch.cuda.is_available():
memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9
memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9
memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory -
torch.cuda.memory_allocated()) / 1e9
memory_info['ram_used'] = psutil.virtual_memory().used / 1e9
memory_info['ram_available'] = psutil.virtual_memory().available / 1e9
logger.info(f"[get_memory_usage] {memory_info}")
return memory_info
def clear_model_cache():
"""Manual/debug only: Clear Streamlit resource cache and free memory."""
logger.info("[clear_model_cache] Clearing all model caches...")
if hasattr(st, 'cache_resource'):
st.cache_resource.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.info("[clear_model_cache] Model cache cleared")
@st.cache_resource(show_spinner=False)
def load_sam2_predictor():
"""Load SAM2 image predictor, choosing model size based on available GPU memory."""
try:
logger.info("[load_sam2_predictor] Loading SAM2 image predictor...")
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"[load_sam2_predictor] Using device: {device}")
checkpoint_path = "/home/user/app/checkpoints/sam2.1_hiera_large.pt"
model_cfg = "/home/user/app/configs/sam2.1/sam2.1_hiera_l.yaml"
if not os.path.exists(checkpoint_path) or not os.path.exists(model_cfg):
logger.warning("[load_sam2_predictor] Local checkpoints not found, using Hugging Face.")
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-large",
device=device
)
else:
memory_info = get_memory_usage()
gpu_free = memory_info.get('gpu_free', 0)
if device == "cuda" and gpu_free < 4.0:
logger.warning(f"[load_sam2_predictor] Limited GPU memory ({gpu_free:.1f}GB), using smaller SAM2 model.")
try:
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-tiny",
device=device
)
except Exception as e:
logger.warning(f"[load_sam2_predictor] Tiny model failed, trying small. {e}")
predictor = SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-small",
device=device
)
else:
logger.info("[load_sam2_predictor] Using local large model")
sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
predictor = SAM2ImagePredictor(sam2_model)
if hasattr(predictor, 'model'):
predictor.model.to(device)
predictor.model.eval()
logger.info(f"[load_sam2_predictor] SAM2 model moved to {device} and set to eval mode")
logger.info(f"✅ SAM2 loaded successfully on {device}!")
return predictor
except Exception as e:
logger.error(f"❌ Failed to load SAM2 predictor: {e}", exc_info=True)
import traceback
traceback.print_exc()
return None
def load_sam2():
"""Convenience alias for legacy code: returns only the predictor object."""
predictor = load_sam2_predictor()
return predictor
@st.cache_resource(show_spinner=False)
def load_matanyone_processor():
"""Load MatAnyone processor (inference core) on the best available device."""
try:
logger.info("[load_matanyone_processor] Loading MatAnyone processor...")
from matanyone import InferenceCore
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"[load_matanyone_processor] MatAnyone using device: {device}")
try:
processor = InferenceCore("PeiqingYang/MatAnyone", device=device)
except Exception as e:
logger.warning(f"[load_matanyone_processor] Path warning caught: {e}")
processor = InferenceCore("PeiqingYang/MatAnyone", device=device) # Retry
if hasattr(processor, 'model'):
processor.model.to(device)
processor.model.eval()
logger.info(f"[load_matanyone_processor] MatAnyone model explicitly moved to {device}")
if not hasattr(processor, 'device'):
processor.device = device
logger.info(f"[load_matanyone_processor] Set processor.device to {device}")
logger.info(f"✅ MatAnyone loaded successfully on {device}!")
return processor
except Exception as e:
logger.error(f"❌ Failed to load MatAnyone: {e}", exc_info=True)
import traceback
traceback.print_exc()
return None
def load_matanyone():
"""Convenience alias for legacy code: returns only the processor object."""
processor = load_matanyone_processor()
return processor
def test_models():
"""For admin/diagnosis: attempts to load both models and returns status."""
results = {
'sam2': {'loaded': False, 'error': None},
'matanyone': {'loaded': False, 'error': None}
}
try:
sam2_predictor = load_sam2_predictor()
if sam2_predictor is not None:
results['sam2']['loaded'] = True
else:
results['sam2']['error'] = "Predictor returned None"
except Exception as e:
results['sam2']['error'] = str(e)
logger.error(f"[test_models] SAM2 error: {e}", exc_info=True)
try:
matanyone_processor = load_matanyone_processor()
if matanyone_processor is not None:
results['matanyone']['loaded'] = True
else:
results['matanyone']['error'] = "Processor returned None"
except Exception as e:
results['matanyone']['error'] = str(e)
logger.error(f"[test_models] MatAnyone error: {e}", exc_info=True)
logger.info(f"[test_models] Results: {results}")
return results
def log_memory_usage(stage=""):
memory_info = get_memory_usage()
log_msg = f"Memory usage"
if stage:
log_msg += f" ({stage})"
log_msg += ":"
if 'gpu_allocated' in memory_info:
log_msg += f" GPU {memory_info['gpu_allocated']:.1f}GB allocated, {memory_info['gpu_free']:.1f}GB free"
log_msg += f" | RAM {memory_info['ram_used']:.1f}GB used"
print(log_msg, flush=True)
logger.info(log_msg)
return memory_info
def check_memory_available(required_gb=2.0):
if not torch.cuda.is_available():
return False, 0.0
memory_info = get_memory_usage()
free_gb = memory_info.get('gpu_free', 0)
logger.info(f"[check_memory_available] free_gb={free_gb}, required={required_gb}")
return free_gb >= required_gb, free_gb
def free_memory_aggressive():
"""For emergency/manual use only! Do NOT call after every video or from UI!"""
logger.info("[free_memory_aggressive] Performing aggressive memory cleanup...")
print("Performing aggressive memory cleanup...", flush=True)
clear_model_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
try:
torch.cuda.ipc_collect()
except Exception:
pass
gc.collect()
print("Memory cleanup complete", flush=True)
logger.info("Memory cleanup complete")
log_memory_usage("after cleanup")
|