model-gateway-api / models /lazy_model.py
autumnssuns's picture
🐛 Correctly unset loaded state after unloading
9f66c7c
from typing import Callable
import gc
import torch
import os
LAZY_LOAD_ENABLED = os.getenv("LAZY_LOAD", "false").lower() == "true"
class LazyModel:
unload_func = None
init_func: Callable | None = None
is_loaded = False
def __init__(self, model_id: str):
self.model_id = model_id
def load(self):
def decorator(init_func):
if not LAZY_LOAD_ENABLED:
# Even if eager loading, the model should only be initialized once.
if not self.is_loaded:
init_func()
self.is_loaded = True
self.init_func = init_func
return init_func
def wrapper():
global current_model
if current_model is not None and current_model != self.model_id:
print(
f"Unloading currently loaded model '{current_model}' before loading '{self.model_id}'..."
)
_unload()
if current_model == self.model_id and self.is_loaded:
print(
f"Model '{self.model_id}' is already loaded. Skipping initialization."
)
return
print(f"Loading model '{self.model_id}'...")
init_func()
self.is_loaded = True
current_model = self
print(f"Model '{self.model_id}' loaded successfully.")
# Ensure the init_func also loads lazily
self.init_func = wrapper
return wrapper
return decorator
def unload(self):
# Create a decorator to set the unload callback function for this model. This allows the lazy loading mechanism to call the specified function when unloading the model, ensuring proper cleanup of resources.
def decorator(func):
def wrapper():
print(f"Unloading model '{self.model_id}'...")
func()
self.is_loaded = False
print(f"Model '{self.model_id}' unloaded successfully.")
self.unload_func = wrapper
return wrapper
return decorator
def entry(self):
def decorator(func):
def wrapper(*args, **kwargs):
if not self.init_func:
raise RuntimeError(
f"Model '{self.model_id}' does not have an initialization function defined."
)
# Ensure the model is loaded before executing the main function
if self.init_func and not self.is_loaded:
print(f"Model '{self.model_id}' is not loaded. Loading now...")
self.init_func()
print(f"Executing main function for model '{self.model_id}'...")
return func(*args, **kwargs)
return wrapper
return decorator
def _unload():
global current_model
if current_model and current_model.unload_func:
current_model.unload_func()
current_model = None
# Ensure garbage collection and CUDA cache clearing
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Global variaable to keep track of the currently loaded LazyModel instance. This allows the lazy loading mechanism to determine if a model is already loaded and manage unloading of other models when necessary.
current_model: LazyModel | None = None