File size: 2,838 Bytes
ebc3bf5
 
 
 
 
 
 
 
6ea3105
ebc3bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ea3105
 
 
 
 
 
 
 
 
ebc3bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4cbe0b
ebc3bf5
 
c4cbe0b
 
ebc3bf5
 
 
 
 
 
 
6ea3105
ebc3bf5
 
 
 
 
6ea3105
ebc3bf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39003c5
ebc3bf5
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import torch
import gc
import config

_llm = None
_tokenizer = None
_tokenizer_only = None
_embedder = None
_current_model_id = None
_current_embedder_id = None


def get_current_model_id() -> str | None:
    return _current_model_id


def get_current_tokenizer_id() -> str | None:
    # Tokenizer is always loaded from the same HF repo as the model.
    return _current_model_id


def get_current_embedder_id() -> str | None:
    return _current_embedder_id


def get_tokenizer_only():
    global _tokenizer_only
    if _tokenizer is not None:
        return _tokenizer
    if _tokenizer_only is None:
        _tokenizer_only = AutoTokenizer.from_pretrained(config.LLM_MODEL)
    return _tokenizer_only


def get_llm():
    global _llm, _tokenizer
    if _llm is None:
        _load_llm(config.LLM_MODEL)
    return _llm, _tokenizer


def switch_llm(model_id: str) -> str:
    global _current_model_id
    if _current_model_id == model_id:
        return f"Already using {model_id}"
    _unload_llm()
    _load_llm(model_id)
    return f"Loaded: {model_id}"


def _load_llm(model_id: str):
    """Load model + its paired tokenizer. Both come from the same model_id."""
    global _llm, _tokenizer, _current_model_id
    _tokenizer = AutoTokenizer.from_pretrained(model_id)
    _llm = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype="auto",  # uses model's native dtype (bfloat16 for Qwen2.5)
        device_map=None,     # load to CPU; @spaces.GPU functions move it on demand
    )
    _llm.eval()
    _current_model_id = model_id


def _unload_llm():
    """Free GPU/CPU memory before loading a different model."""
    global _llm, _tokenizer, _current_model_id, _tokenizer_only
    del _llm
    del _tokenizer
    _llm = None
    _tokenizer = None
    _current_model_id = None
    _tokenizer_only = None
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


def get_embedder():
    global _embedder, _current_embedder_id
    if _embedder is None:
        _load_embedder(config.EMBEDDER_MODEL)
    return _embedder


def switch_embedder(model_id: str) -> str:
    global _current_embedder_id
    if _current_embedder_id == model_id:
        return f"Already using {model_id}"
    _unload_embedder()
    _load_embedder(model_id)
    return f"Loaded: {model_id}"


def _load_embedder(model_id: str):
    global _embedder, _current_embedder_id
    _embedder = SentenceTransformer(model_id, device="cpu")
    _current_embedder_id = model_id


def _unload_embedder():
    global _embedder, _current_embedder_id
    del _embedder
    _embedder = None
    _current_embedder_id = None
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()