|
|
|
|
|
import os |
|
|
import gc |
|
|
import json |
|
|
import torch |
|
|
import spacy |
|
|
import threading |
|
|
import subprocess |
|
|
from typing import Any |
|
|
from typing import Dict |
|
|
from typing import Union |
|
|
from pathlib import Path |
|
|
from loguru import logger |
|
|
from typing import Optional |
|
|
from datetime import datetime |
|
|
from transformers import pipeline |
|
|
from collections import OrderedDict |
|
|
from config.settings import settings |
|
|
from transformers import GPT2Tokenizer |
|
|
from transformers import AutoTokenizer |
|
|
from transformers import GPT2LMHeadModel |
|
|
from config.model_config import ModelType |
|
|
from config.model_config import ModelConfig |
|
|
from transformers import AutoModelForCausalLM |
|
|
from transformers import AutoModelForMaskedLM |
|
|
from config.model_config import MODEL_REGISTRY |
|
|
from config.model_config import get_model_config |
|
|
from config.model_config import get_required_models |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import AutoModelForSequenceClassification |
|
|
|
|
|
|
|
|
class ModelCache: |
|
|
""" |
|
|
LRU cache for models with size limit |
|
|
""" |
|
|
def __init__(self, max_size: int = 5): |
|
|
self.max_size = max_size |
|
|
self.cache : OrderedDict = OrderedDict() |
|
|
self.lock = threading.Lock() |
|
|
|
|
|
|
|
|
def get(self, key: str) -> Optional[Any]: |
|
|
""" |
|
|
Get model from cache |
|
|
""" |
|
|
with self.lock: |
|
|
if key in self.cache: |
|
|
|
|
|
self.cache.move_to_end(key) |
|
|
logger.debug(f"Cache hit for model: {key}") |
|
|
|
|
|
return self.cache[key] |
|
|
|
|
|
logger.debug(f"Cache miss for model: {key}") |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def put(self, key: str, model: Any): |
|
|
""" |
|
|
Add model to cache |
|
|
""" |
|
|
with self.lock: |
|
|
if key in self.cache: |
|
|
self.cache.move_to_end(key) |
|
|
|
|
|
else: |
|
|
if (len(self.cache) >= self.max_size): |
|
|
|
|
|
removed_key = next(iter(self.cache)) |
|
|
removed_model = self.cache.pop(removed_key) |
|
|
|
|
|
|
|
|
if hasattr(removed_model, 'to'): |
|
|
removed_model.to('cpu') |
|
|
|
|
|
del removed_model |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info(f"Evicted model from cache: {removed_key}") |
|
|
|
|
|
self.cache[key] = model |
|
|
|
|
|
logger.info(f"Added model to cache: {key}") |
|
|
|
|
|
|
|
|
def clear(self): |
|
|
""" |
|
|
Clear all cached models |
|
|
""" |
|
|
with self.lock: |
|
|
for model in self.cache.values(): |
|
|
if hasattr(model, 'to'): |
|
|
model.to('cpu') |
|
|
del model |
|
|
|
|
|
self.cache.clear() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info("Cleared model cache") |
|
|
|
|
|
|
|
|
def size(self) -> int: |
|
|
""" |
|
|
Get current cache size |
|
|
""" |
|
|
return len(self.cache) |
|
|
|
|
|
|
|
|
|
|
|
class ModelManager: |
|
|
""" |
|
|
Central model management system |
|
|
""" |
|
|
def __init__(self): |
|
|
self.cache = ModelCache(max_size = settings.MAX_CACHED_MODELS) |
|
|
self.device = torch.device(settings.DEVICE if torch.cuda.is_available() else "cpu") |
|
|
self.cache_dir = settings.MODEL_CACHE_DIR |
|
|
|
|
|
self.cache_dir.mkdir(parents = True, |
|
|
exist_ok = True, |
|
|
) |
|
|
|
|
|
|
|
|
self.metadata_file = self.cache_dir / "model_metadata.json" |
|
|
self.metadata = self._load_metadata() |
|
|
|
|
|
logger.info(f"ModelManager initialized with device: {self.device}") |
|
|
logger.info(f"Model cache directory: {self.cache_dir}") |
|
|
|
|
|
|
|
|
def _load_metadata(self) -> Dict: |
|
|
""" |
|
|
Load model metadata from disk |
|
|
""" |
|
|
if self.metadata_file.exists(): |
|
|
try: |
|
|
with open(self.metadata_file, 'r') as f: |
|
|
return json.load(f) |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load metadata: {repr(e)}") |
|
|
|
|
|
return {} |
|
|
|
|
|
|
|
|
def _save_metadata(self): |
|
|
""" |
|
|
Save model metadata to disk |
|
|
""" |
|
|
try: |
|
|
with open(self.metadata_file, 'w') as f: |
|
|
json.dump(obj = self.metadata, |
|
|
fp = f, |
|
|
indent = 4, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to save metadata: {repr(e)}") |
|
|
|
|
|
|
|
|
def _update_metadata(self, model_name: str, model_config: ModelConfig): |
|
|
""" |
|
|
Update metadata for a model |
|
|
""" |
|
|
self.metadata[model_name] = {"model_id" : model_config.model_id, |
|
|
"model_type" : model_config.model_type.value, |
|
|
"downloaded_at" : datetime.now().isoformat(), |
|
|
"size_mb" : model_config.size_mb, |
|
|
"last_used" : datetime.now().isoformat(), |
|
|
} |
|
|
self._save_metadata() |
|
|
|
|
|
|
|
|
def is_model_downloaded(self, model_name: str) -> bool: |
|
|
""" |
|
|
Check if model is already downloaded |
|
|
""" |
|
|
model_config = get_model_config(model_name = model_name) |
|
|
|
|
|
if not model_config: |
|
|
return False |
|
|
|
|
|
|
|
|
model_path = self.cache_dir / model_config.model_id.replace("/", "_") |
|
|
|
|
|
return model_path.exists() and model_name in self.metadata |
|
|
|
|
|
|
|
|
def load_model(self, model_name: str, force_download: bool = False) -> Any: |
|
|
""" |
|
|
Load a model by name |
|
|
|
|
|
Arguments: |
|
|
---------- |
|
|
model_name { str } : Name from MODEL_REGISTRY |
|
|
|
|
|
force_download { bool } : Force re-download even if cached |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
{ Any } : Model instance |
|
|
""" |
|
|
|
|
|
if not force_download: |
|
|
cached = self.cache.get(key = model_name) |
|
|
|
|
|
if cached is not None: |
|
|
return cached |
|
|
|
|
|
|
|
|
model_config = get_model_config(model_name = model_name) |
|
|
|
|
|
if not model_config: |
|
|
raise ValueError(f"Unknown model: {model_name}") |
|
|
|
|
|
logger.info(f"Loading model: {model_name} ({model_config.model_id})") |
|
|
|
|
|
try: |
|
|
|
|
|
if (model_config.model_type == ModelType.SENTENCE_TRANSFORMER): |
|
|
model = self._load_sentence_transformer(config = model_config) |
|
|
|
|
|
elif (model_config.model_type == ModelType.GPT): |
|
|
model = self._load_gpt_model(config = model_config) |
|
|
|
|
|
elif (model_config.model_type == ModelType.CLASSIFIER): |
|
|
model = self._load_classifier(config = model_config) |
|
|
|
|
|
elif (model_config.model_type == ModelType.SEQUENCE_CLASSIFICATION): |
|
|
model = self._load_sequence_classifier(config = model_config) |
|
|
|
|
|
elif (model_config.model_type == ModelType.TRANSFORMER): |
|
|
model = self._load_transformer(config = model_config) |
|
|
|
|
|
elif (model_config.model_type == ModelType.CAUSAL_LM): |
|
|
model = self._load_causal_lm(config = model_config) |
|
|
|
|
|
elif (model_config.model_type == ModelType.MASKED_LM): |
|
|
model = self._load_masked_lm(config = model_config) |
|
|
|
|
|
elif (model_config.model_type == ModelType.RULE_BASED): |
|
|
|
|
|
if model_config.additional_params.get("is_spacy_model", False): |
|
|
model = self._load_spacy_model(config = model_config) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown rule-based model type: {model_name}") |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported model type: {model_config.model_type}") |
|
|
|
|
|
|
|
|
self._update_metadata(model_name = model_name, |
|
|
model_config = model_config, |
|
|
) |
|
|
|
|
|
|
|
|
if model_config.cache_model: |
|
|
self.cache.put(key = model_name, |
|
|
model = model, |
|
|
) |
|
|
|
|
|
logger.success(f"Successfully loaded model: {model_name}") |
|
|
|
|
|
return model |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model {model_name}: {repr(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
def load_tokenizer(self, model_name: str) -> Any: |
|
|
""" |
|
|
Load tokenizer for a model |
|
|
|
|
|
Arguments: |
|
|
---------- |
|
|
model_name { str } : Name from MODEL_REGISTRY |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
{ Any } : Tokenizer instance |
|
|
""" |
|
|
model_config = get_model_config(model_name = model_name) |
|
|
|
|
|
if not model_config: |
|
|
raise ValueError(f"Unknown model: {model_name}") |
|
|
|
|
|
logger.info(f"Loading tokenizer for: {model_name}") |
|
|
|
|
|
try: |
|
|
if (model_config.model_type in [ModelType.GPT, |
|
|
ModelType.CLASSIFIER, |
|
|
ModelType.SEQUENCE_CLASSIFICATION, |
|
|
ModelType.TRANSFORMER, |
|
|
ModelType.CAUSAL_LM, |
|
|
ModelType.MASKED_LM]): |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = model_config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
logger.success(f"Successfully loaded tokenizer for: {model_name}") |
|
|
return tokenizer |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Model type {model_config.model_type} doesn't require a separate tokenizer") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load tokenizer for {model_name}: {repr(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
def _load_sentence_transformer(self, config: ModelConfig) -> SentenceTransformer: |
|
|
""" |
|
|
Load SentenceTransformer model |
|
|
""" |
|
|
model = SentenceTransformer(model_name_or_path = config.model_id, |
|
|
cache_folder = str(self.cache_dir), |
|
|
device = str(self.device), |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def _load_gpt_model(self, config: ModelConfig) -> tuple: |
|
|
""" |
|
|
Load GPT-style model with tokenizer |
|
|
""" |
|
|
model = GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path = config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path = config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
|
|
|
model = model.to(self.device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if (settings.USE_QUANTIZATION and config.quantizable): |
|
|
model = self._quantize_model(model = model) |
|
|
|
|
|
return (model, tokenizer) |
|
|
|
|
|
|
|
|
def _load_causal_lm(self, config: ModelConfig) -> tuple: |
|
|
""" |
|
|
Load causal language model (like GPT-2) for text generation |
|
|
""" |
|
|
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path = config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
|
|
|
model = model.to(self.device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if (settings.USE_QUANTIZATION and config.quantizable): |
|
|
model = self._quantize_model(model = model) |
|
|
|
|
|
return (model, tokenizer) |
|
|
|
|
|
|
|
|
def _load_masked_lm(self, config: ModelConfig) -> tuple: |
|
|
""" |
|
|
Load masked language model (like RoBERTa) for fill-mask tasks |
|
|
""" |
|
|
model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path = config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
|
|
|
model = model.to(self.device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if (settings.USE_QUANTIZATION and config.quantizable): |
|
|
model = self._quantize_model(model = model) |
|
|
|
|
|
return (model, tokenizer) |
|
|
|
|
|
|
|
|
def _load_classifier(self, config: ModelConfig) -> Any: |
|
|
""" |
|
|
Load classification model (for zero-shot, etc.) |
|
|
""" |
|
|
|
|
|
pipe = pipeline("zero-shot-classification", |
|
|
model = config.model_id, |
|
|
device = 0 if self.device.type == "cuda" else -1, |
|
|
model_kwargs = {"cache_dir": str(self.cache_dir)}, |
|
|
) |
|
|
|
|
|
return pipe |
|
|
|
|
|
|
|
|
def _load_sequence_classifier(self, config: ModelConfig) -> Any: |
|
|
""" |
|
|
Load sequence classification model (for domain classification) |
|
|
""" |
|
|
model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path = config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
num_labels = config.additional_params.get('num_labels', 2), |
|
|
) |
|
|
|
|
|
|
|
|
model = model.to(self.device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if (settings.USE_QUANTIZATION and config.quantizable): |
|
|
model = self._quantize_model(model = model) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def _load_transformer(self, config: ModelConfig) -> tuple: |
|
|
""" |
|
|
Load masking transformer model |
|
|
""" |
|
|
model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path = config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path = config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
|
|
|
model = model.to(self.device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if (settings.USE_QUANTIZATION and config.quantizable): |
|
|
model = self._quantize_model(model) |
|
|
|
|
|
return (model, tokenizer) |
|
|
|
|
|
|
|
|
def _quantize_model(self, model): |
|
|
""" |
|
|
Apply INT8 quantization to model |
|
|
""" |
|
|
try: |
|
|
if hasattr(torch.quantization, 'quantize_dynamic'): |
|
|
quantized_model = torch.quantization.quantize_dynamic(model = model, |
|
|
qconfig_spec = {torch.nn.Linear}, |
|
|
dtype = torch.qint8, |
|
|
) |
|
|
logger.info("Applied INT8 quantization to model") |
|
|
|
|
|
return quantized_model |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"Quantization failed: {repr(e)}, using original model") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def load_pipeline(self, model_name: str, task: str) -> pipeline: |
|
|
""" |
|
|
Load a Hugging Face pipeline |
|
|
""" |
|
|
model_config = get_model_config(model_name = model_name) |
|
|
|
|
|
if not model_config: |
|
|
raise ValueError(f"Unknown model: {model_name}") |
|
|
|
|
|
logger.info(f"Loading pipeline: {task} with {model_name}") |
|
|
|
|
|
pipe = pipeline(task = task, |
|
|
model = model_config.model_id, |
|
|
device = 0 if self.device.type == "cuda" else -1, |
|
|
model_kwargs = {"cache_dir": str(self.cache_dir)}, |
|
|
) |
|
|
|
|
|
return pipe |
|
|
|
|
|
|
|
|
def _load_spacy_model(self, config: ModelConfig): |
|
|
""" |
|
|
Load spaCy model |
|
|
""" |
|
|
try: |
|
|
model = spacy.load(config.model_id) |
|
|
logger.info(f"Loaded spaCy model: {config.model_id}") |
|
|
|
|
|
return model |
|
|
|
|
|
except OSError: |
|
|
|
|
|
logger.info(f"Downloading spaCy model: {config.model_id}") |
|
|
|
|
|
subprocess.run(["python", "-m", "spacy", "download", config.model_id], check = True) |
|
|
model = spacy.load(config.model_id) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def download_model(self, model_name: str) -> bool: |
|
|
""" |
|
|
Download model without loading it into memory |
|
|
|
|
|
Arguments: |
|
|
---------- |
|
|
model_name { str } : Name from MODEL_REGISTRY |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
{ bool } : True if successful, False otherwise |
|
|
""" |
|
|
model_config = get_model_config(model_name) |
|
|
|
|
|
if not model_config: |
|
|
logger.error(f"Unknown model: {model_name}") |
|
|
return False |
|
|
|
|
|
if self.is_model_downloaded(model_name): |
|
|
logger.info(f"Model already downloaded: {model_name}") |
|
|
return True |
|
|
|
|
|
logger.info(f"Downloading model: {model_name} ({model_config.model_id})") |
|
|
|
|
|
try: |
|
|
if (model_config.model_type == ModelType.SENTENCE_TRANSFORMER): |
|
|
SentenceTransformer(model_name_or_path = model_config.model_id, |
|
|
cache_folder = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
elif (model_config.model_type == ModelType.GPT): |
|
|
GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path = model_config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path = model_config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
elif (model_config.model_type == ModelType.SEQUENCE_CLASSIFICATION): |
|
|
AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path = model_config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
AutoTokenizer.from_pretrained(pretrained_model_name_or_path = model_config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
elif (model_config.model_type == ModelType.CAUSAL_LM): |
|
|
AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path = model_config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
AutoTokenizer.from_pretrained(pretrained_model_name_or_path = model_config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
elif (model_config.model_type == ModelType.MASKED_LM): |
|
|
AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path = model_config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
AutoTokenizer.from_pretrained(pretrained_model_name_or_path = model_config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
elif (model_config.model_type == ModelType.RULE_BASED): |
|
|
if model_config.additional_params.get("is_spacy_model", False): |
|
|
subprocess.run(["python", "-m", "spacy", "download", model_config.model_id], check = True) |
|
|
|
|
|
else: |
|
|
logger.warning(f"Cannot pre-download rule-based model: {model_name}") |
|
|
|
|
|
return True |
|
|
|
|
|
else: |
|
|
|
|
|
AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path = model_config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
AutoTokenizer.from_pretrained(pretrained_model_name_or_path = model_config.model_id, |
|
|
cache_dir = str(self.cache_dir), |
|
|
) |
|
|
|
|
|
self._update_metadata(model_name, model_config) |
|
|
|
|
|
logger.success(f"Successfully downloaded: {model_name}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to download {model_name}: {repr(e)}") |
|
|
return False |
|
|
|
|
|
|
|
|
def download_all_required(self) -> Dict[str, bool]: |
|
|
""" |
|
|
Download all required models |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
{ dict } : Dict mapping model names to success status |
|
|
""" |
|
|
required_models = get_required_models() |
|
|
results = dict() |
|
|
|
|
|
logger.info(f"Downloading {len(required_models)} required models...") |
|
|
|
|
|
for model_name in required_models: |
|
|
results[model_name] = self.download_model(model_name = model_name) |
|
|
|
|
|
success_count = sum(1 for v in results.values() if v) |
|
|
|
|
|
logger.info(f"Downloaded {success_count}/{len(required_models)} required models") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def get_model_info(self, model_name: str) -> Optional[Dict]: |
|
|
""" |
|
|
Get information about a model |
|
|
""" |
|
|
return self.metadata.get(model_name) |
|
|
|
|
|
|
|
|
def list_downloaded_models(self) -> list: |
|
|
""" |
|
|
List all downloaded models |
|
|
""" |
|
|
return list(self.metadata.keys()) |
|
|
|
|
|
|
|
|
def clear_cache(self): |
|
|
""" |
|
|
Clear model cache |
|
|
""" |
|
|
self.cache.clear() |
|
|
logger.info("Model cache cleared") |
|
|
|
|
|
|
|
|
def unload_model(self, model_name: str): |
|
|
""" |
|
|
Unload a specific model from cache |
|
|
""" |
|
|
with self.cache.lock: |
|
|
if model_name in self.cache.cache: |
|
|
model = self.cache.cache.pop(model_name) |
|
|
if hasattr(model, 'to'): |
|
|
model.to('cpu') |
|
|
|
|
|
del model |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info(f"Unloaded model: {model_name}") |
|
|
|
|
|
|
|
|
def get_memory_usage(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get current memory usage statistics |
|
|
""" |
|
|
stats = {"cached_models" : self.cache.size(), |
|
|
"device" : str(self.device), |
|
|
} |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
stats.update({"gpu_allocated_mb" : torch.cuda.memory_allocated() / 1024**2, |
|
|
"gpu_reserved_mb" : torch.cuda.memory_reserved() / 1024**2, |
|
|
"gpu_max_allocated_mb" : torch.cuda.max_memory_allocated() / 1024**2, |
|
|
}) |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
def optimize_memory(self): |
|
|
""" |
|
|
Optimize memory usage |
|
|
""" |
|
|
logger.info("Optimizing memory...") |
|
|
|
|
|
|
|
|
self.cache.clear() |
|
|
|
|
|
|
|
|
gc.collect() |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info("Memory optimization complete") |
|
|
logger.info(f"Memory usage: {self.get_memory_usage()}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_model_manager_instance : Optional[ModelManager] = None |
|
|
_manager_lock = threading.Lock() |
|
|
|
|
|
|
|
|
def get_model_manager() -> ModelManager: |
|
|
""" |
|
|
Get singleton ModelManager instance |
|
|
""" |
|
|
global _model_manager_instance |
|
|
|
|
|
if _model_manager_instance is None: |
|
|
with _manager_lock: |
|
|
if _model_manager_instance is None: |
|
|
_model_manager_instance = ModelManager() |
|
|
|
|
|
return _model_manager_instance |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["ModelManager", |
|
|
"ModelCache", |
|
|
"get_model_manager", |
|
|
] |