Spaces:
Sleeping
Sleeping
| from typing import Callable | |
| import gc | |
| import torch | |
| import os | |
| LAZY_LOAD_ENABLED = os.getenv("LAZY_LOAD", "false").lower() == "true" | |
| class LazyModel: | |
| unload_func = None | |
| init_func: Callable | None = None | |
| is_loaded = False | |
| def __init__(self, model_id: str): | |
| self.model_id = model_id | |
| def load(self): | |
| def decorator(init_func): | |
| if not LAZY_LOAD_ENABLED: | |
| # Even if eager loading, the model should only be initialized once. | |
| if not self.is_loaded: | |
| init_func() | |
| self.is_loaded = True | |
| self.init_func = init_func | |
| return init_func | |
| def wrapper(): | |
| global current_model | |
| if current_model is not None and current_model != self.model_id: | |
| print( | |
| f"Unloading currently loaded model '{current_model}' before loading '{self.model_id}'..." | |
| ) | |
| _unload() | |
| if current_model == self.model_id and self.is_loaded: | |
| print( | |
| f"Model '{self.model_id}' is already loaded. Skipping initialization." | |
| ) | |
| return | |
| print(f"Loading model '{self.model_id}'...") | |
| init_func() | |
| self.is_loaded = True | |
| current_model = self | |
| print(f"Model '{self.model_id}' loaded successfully.") | |
| # Ensure the init_func also loads lazily | |
| self.init_func = wrapper | |
| return wrapper | |
| return decorator | |
| def unload(self): | |
| # Create a decorator to set the unload callback function for this model. This allows the lazy loading mechanism to call the specified function when unloading the model, ensuring proper cleanup of resources. | |
| def decorator(func): | |
| def wrapper(): | |
| print(f"Unloading model '{self.model_id}'...") | |
| func() | |
| self.is_loaded = False | |
| print(f"Model '{self.model_id}' unloaded successfully.") | |
| self.unload_func = wrapper | |
| return wrapper | |
| return decorator | |
| def entry(self): | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| if not self.init_func: | |
| raise RuntimeError( | |
| f"Model '{self.model_id}' does not have an initialization function defined." | |
| ) | |
| # Ensure the model is loaded before executing the main function | |
| if self.init_func and not self.is_loaded: | |
| print(f"Model '{self.model_id}' is not loaded. Loading now...") | |
| self.init_func() | |
| print(f"Executing main function for model '{self.model_id}'...") | |
| return func(*args, **kwargs) | |
| return wrapper | |
| return decorator | |
| def _unload(): | |
| global current_model | |
| if current_model and current_model.unload_func: | |
| current_model.unload_func() | |
| current_model = None | |
| # Ensure garbage collection and CUDA cache clearing | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Global variaable to keep track of the currently loaded LazyModel instance. This allows the lazy loading mechanism to determine if a model is already loaded and manage unloading of other models when necessary. | |
| current_model: LazyModel | None = None | |