File size: 3,502 Bytes
e4b3020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f66c7c
 
 
 
 
 
 
 
e4b3020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import Callable
import gc
import torch
import os

LAZY_LOAD_ENABLED = os.getenv("LAZY_LOAD", "false").lower() == "true"


class LazyModel:
    unload_func = None
    init_func: Callable | None = None
    is_loaded = False

    def __init__(self, model_id: str):
        self.model_id = model_id

    def load(self):
        def decorator(init_func):
            if not LAZY_LOAD_ENABLED:
                # Even if eager loading, the model should only be initialized once.
                if not self.is_loaded:
                    init_func()
                    self.is_loaded = True
                self.init_func = init_func
                return init_func

            def wrapper():
                global current_model
                if current_model is not None and current_model != self.model_id:
                    print(
                        f"Unloading currently loaded model '{current_model}' before loading '{self.model_id}'..."
                    )
                    _unload()

                if current_model == self.model_id and self.is_loaded:
                    print(
                        f"Model '{self.model_id}' is already loaded. Skipping initialization."
                    )
                    return

                print(f"Loading model '{self.model_id}'...")
                init_func()
                self.is_loaded = True
                current_model = self
                print(f"Model '{self.model_id}' loaded successfully.")

            # Ensure the init_func also loads lazily
            self.init_func = wrapper
            return wrapper

        return decorator

    def unload(self):
        # Create a decorator to set the unload callback function for this model. This allows the lazy loading mechanism to call the specified function when unloading the model, ensuring proper cleanup of resources.
        def decorator(func):
            def wrapper():
                print(f"Unloading model '{self.model_id}'...")
                func()
                self.is_loaded = False
                print(f"Model '{self.model_id}' unloaded successfully.")

            self.unload_func = wrapper
            return wrapper

        return decorator

    def entry(self):
        def decorator(func):
            def wrapper(*args, **kwargs):
                if not self.init_func:
                    raise RuntimeError(
                        f"Model '{self.model_id}' does not have an initialization function defined."
                    )

                # Ensure the model is loaded before executing the main function
                if self.init_func and not self.is_loaded:
                    print(f"Model '{self.model_id}' is not loaded. Loading now...")
                    self.init_func()

                print(f"Executing main function for model '{self.model_id}'...")
                return func(*args, **kwargs)

            return wrapper

        return decorator


def _unload():
    global current_model
    if current_model and current_model.unload_func:
        current_model.unload_func()
    current_model = None
    # Ensure garbage collection and CUDA cache clearing
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


# Global variaable to keep track of the currently loaded LazyModel instance. This allows the lazy loading mechanism to determine if a model is already loaded and manage unloading of other models when necessary.
current_model: LazyModel | None = None