Spaces:
Running on Zero
Running on Zero
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from sentence_transformers import SentenceTransformer | |
| import torch | |
| import gc | |
| import config | |
| _llm = None | |
| _tokenizer = None | |
| _tokenizer_only = None | |
| _embedder = None | |
| _current_model_id = None | |
| _current_embedder_id = None | |
| def get_current_model_id() -> str | None: | |
| return _current_model_id | |
| def get_current_tokenizer_id() -> str | None: | |
| # Tokenizer is always loaded from the same HF repo as the model. | |
| return _current_model_id | |
| def get_current_embedder_id() -> str | None: | |
| return _current_embedder_id | |
| def get_tokenizer_only(): | |
| global _tokenizer_only | |
| if _tokenizer is not None: | |
| return _tokenizer | |
| if _tokenizer_only is None: | |
| _tokenizer_only = AutoTokenizer.from_pretrained(config.LLM_MODEL) | |
| return _tokenizer_only | |
| def get_llm(): | |
| global _llm, _tokenizer | |
| if _llm is None: | |
| _load_llm(config.LLM_MODEL) | |
| return _llm, _tokenizer | |
| def switch_llm(model_id: str) -> str: | |
| global _current_model_id | |
| if _current_model_id == model_id: | |
| return f"Already using {model_id}" | |
| _unload_llm() | |
| _load_llm(model_id) | |
| return f"Loaded: {model_id}" | |
| def _load_llm(model_id: str): | |
| """Load model + its paired tokenizer. Both come from the same model_id.""" | |
| global _llm, _tokenizer, _current_model_id | |
| _tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| _llm = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype="auto", # uses model's native dtype (bfloat16 for Qwen2.5) | |
| device_map=None, # load to CPU; @spaces.GPU functions move it on demand | |
| ) | |
| _llm.eval() | |
| _current_model_id = model_id | |
| def _unload_llm(): | |
| """Free GPU/CPU memory before loading a different model.""" | |
| global _llm, _tokenizer, _current_model_id, _tokenizer_only | |
| del _llm | |
| del _tokenizer | |
| _llm = None | |
| _tokenizer = None | |
| _current_model_id = None | |
| _tokenizer_only = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def get_embedder(): | |
| global _embedder, _current_embedder_id | |
| if _embedder is None: | |
| _load_embedder(config.EMBEDDER_MODEL) | |
| return _embedder | |
| def switch_embedder(model_id: str) -> str: | |
| global _current_embedder_id | |
| if _current_embedder_id == model_id: | |
| return f"Already using {model_id}" | |
| _unload_embedder() | |
| _load_embedder(model_id) | |
| return f"Loaded: {model_id}" | |
| def _load_embedder(model_id: str): | |
| global _embedder, _current_embedder_id | |
| _embedder = SentenceTransformer(model_id, device="cpu") | |
| _current_embedder_id = model_id | |
| def _unload_embedder(): | |
| global _embedder, _current_embedder_id | |
| del _embedder | |
| _embedder = None | |
| _current_embedder_id = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |