Spaces:
Sleeping
Sleeping
| # app/models/initializer.py | |
| import textwrap | |
| from typing import TypedDict, Union | |
| import onnxruntime as ort | |
| import torch | |
| import torch.serialization | |
| from huggingface_hub import hf_hub_download | |
| from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor | |
| from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, | |
| PreTrainedTokenizer) | |
| import config | |
| from models.engines.model_extensions import CustomModel | |
| class ModelDict(TypedDict): | |
| llm: Union[CustomModel, torch.nn.Module] | |
| llm_tokenizer: PreTrainedTokenizer | |
| reranker: ort.InferenceSession | |
| reranker_tokenizer: PreTrainedTokenizer | |
| _MODELS: dict[str, ModelDict] = {} | |
| _PREFIX_CACHE = {} | |
| def download_llm() -> tuple[str, str]: | |
| """ | |
| Download the quantized LLM file from Hugging Face Hub | |
| (e.g., model_quantized.pt or model.bin). | |
| Returns the local path to the model and config files. | |
| """ | |
| local_model_path = hf_hub_download( | |
| repo_id=config.HF_MODEL_HUB, | |
| filename=config.HF_LLM_FILENAME, | |
| token=config.HF_TOKEN | |
| ) | |
| local_config_path = hf_hub_download( | |
| repo_id=config.HF_MODEL_HUB, | |
| filename=config.HF_CONFIG_FILENAME, | |
| token=config.HF_TOKEN | |
| ) | |
| return local_model_path, local_config_path | |
| def download_reranker() -> str: | |
| """ | |
| Download the reranker ONNX file from Hugging Face Hub. | |
| Returns the local path to the reranker file. | |
| """ | |
| return hf_hub_download( | |
| repo_id=config.HF_MODEL_HUB, | |
| filename=config.HF_RERANKER_FILENAME, | |
| token=config.HF_TOKEN | |
| ) | |
| def load_llm(local_model_path: str, local_config_path: str) -> CustomModel: | |
| """ | |
| Load the quantized LLM into PyTorch. | |
| If the model file is named 'pytorch_model.bin', from_pretrained will load it automatically. | |
| Otherwise, fall back to manual state_dict loading. | |
| """ | |
| torch.serialization.add_safe_globals([AffineQuantizedTensor]) | |
| _config = AutoConfig.from_pretrained(local_config_path) | |
| model = CustomModel(_config) | |
| state_dict = torch.load(local_model_path, map_location="cpu", weights_only=True) | |
| model.load_state_dict(state_dict) | |
| return model | |
| def load_reranker(local_model_path: str) -> ort.InferenceSession: | |
| """ | |
| Load reranker model with ONNX Runtime. | |
| """ | |
| return ort.InferenceSession(local_model_path, providers=["CPUExecutionProvider"]) | |
| def load_llm_tokenizer() -> PreTrainedTokenizer: | |
| """ | |
| Load tokenizer for LLM | |
| """ | |
| return AutoTokenizer.from_pretrained( | |
| config.HF_LLM_REPO, | |
| token=config.HF_TOKEN | |
| ) | |
| def load_reranker_tokenizer() -> PreTrainedTokenizer: | |
| """ | |
| Load tokenizer for reranker | |
| """ | |
| return AutoTokenizer.from_pretrained( | |
| config.HF_RERANKER_REPO, | |
| token=config.HF_TOKEN | |
| ) | |
| def load_llm_from_pretrained() -> CustomModel: | |
| """ | |
| Load the official LLM (e.g., 4B model) directly from Hugging Face Hub | |
| using from_pretrained. This bypasses local quantized state_dict loading. | |
| """ | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.HF_LLM_REPO, | |
| token=config.HF_TOKEN, | |
| dtype=torch.float16, | |
| device_map="cpu" | |
| ) | |
| return model | |
| def initialize_models() -> None: | |
| """ | |
| Download and load models on first run, then save to global cache. | |
| """ | |
| global _MODELS | |
| if not _MODELS: | |
| # llm_path, config_path = download_llm() | |
| reranker_path = download_reranker() | |
| # _MODELS["llm"] = load_llm(llm_path, config_path) | |
| _MODELS["llm"] = load_llm_from_pretrained() | |
| _MODELS["llm_tokenizer"] = load_llm_tokenizer() | |
| _MODELS["reranker"] = load_reranker(reranker_path) | |
| _MODELS["reranker_tokenizer"] = load_reranker_tokenizer() | |
| def get_models() -> ModelDict: | |
| """ | |
| Retrieve models and tokenizers from cache. | |
| """ | |
| global _MODELS | |
| if not _MODELS: | |
| initialize_models() | |
| return _MODELS | |
| def initialize_prefixes() -> dict[str, torch.Tensor]: | |
| """ | |
| Initialize prefix cache once and store globally. | |
| Each entry is stored as a torch.Tensor of input_ids. | |
| """ | |
| global _PREFIX_CACHE | |
| if not _PREFIX_CACHE: | |
| models = get_models() | |
| tokenizer = models["llm_tokenizer"] | |
| _PREFIX_CACHE = { | |
| "instruct": tokenizer("/no_think\n", return_tensors="pt")["input_ids"], | |
| "think": tokenizer("/think\n", return_tensors="pt")["input_ids"], | |
| "summarize": tokenizer(textwrap.dedent("""\ | |
| Instruction: Summarize the following document in relation to the query | |
| Constraints: | |
| - Keep the summary under 300 words | |
| - Focus only on information relevant to the query | |
| - Maintain the original language of the document | |
| """), return_tensors="pt")["input_ids"], | |
| "refine": tokenizer(textwrap.dedent("""\ | |
| Instruction: Combine and refine these summaries to answer the query | |
| Constraints: | |
| - Provide the final answer in a single coherent paragraph | |
| - Ensure the answer directly addresses the query | |
| - Keep the length under 500 words | |
| - Preserve the language style of the input summaries | |
| """), return_tensors="pt")["input_ids"], | |
| "query": tokenizer("query: \n", return_tensors="pt")["input_ids"], | |
| "document": tokenizer("document: \n", return_tensors="pt")["input_ids"], | |
| "summaries": tokenizer("summaries:\n", return_tensors="pt")["input_ids"], | |
| "summarize_reminder": tokenizer("Reminder: Keep the summary concise and under 300 words.", | |
| return_tensors="pt")["input_ids"], | |
| "refine_reminder": tokenizer("Reminder: Final answer must be a single coherent paragraph under 500 words.", | |
| return_tensors="pt")["input_ids"], | |
| "newline": tokenizer("\n", return_tensors="pt")["input_ids"], | |
| } | |
| return _PREFIX_CACHE | |
| def get_prefixes() -> dict[str, torch.Tensor]: | |
| """ | |
| Retrieve prefix cache from global storage. | |
| """ | |
| global _PREFIX_CACHE | |
| if not _PREFIX_CACHE: | |
| initialize_prefixes() | |
| return _PREFIX_CACHE | |