Spaces:
Sleeping
Sleeping
File size: 6,265 Bytes
69c12a2 662eb29 69c12a2 662eb29 bf2f314 69c12a2 9b58d8f 69c12a2 9b58d8f 69c12a2 4e65de4 9b58d8f 69c12a2 9b58d8f 69c12a2 4e65de4 9b58d8f 69c12a2 9b58d8f 69c12a2 4e65de4 9b58d8f 69c12a2 9b58d8f 69c12a2 662eb29 9b58d8f 7d823a8 f6e3bea 7d823a8 69c12a2 7b942de bf2f314 7b942de 69c12a2 7b942de 69c12a2 7b942de 69c12a2 ac17ed0 69c12a2 ac17ed0 69c12a2 ac17ed0 69c12a2 ac17ed0 69c12a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | # app/models/initializer.py
import textwrap
from typing import TypedDict, Union
import onnxruntime as ort
import torch
import torch.serialization
from huggingface_hub import hf_hub_download
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizer)
import config
from models.engines.model_extensions import CustomModel
class ModelDict(TypedDict):
llm: Union[CustomModel, torch.nn.Module]
llm_tokenizer: PreTrainedTokenizer
reranker: ort.InferenceSession
reranker_tokenizer: PreTrainedTokenizer
_MODELS: dict[str, ModelDict] = {}
_PREFIX_CACHE = {}
def download_llm() -> tuple[str, str]:
"""
Download the quantized LLM file from Hugging Face Hub
(e.g., model_quantized.pt or model.bin).
Returns the local path to the model and config files.
"""
local_model_path = hf_hub_download(
repo_id=config.HF_MODEL_HUB,
filename=config.HF_LLM_FILENAME,
token=config.HF_TOKEN
)
local_config_path = hf_hub_download(
repo_id=config.HF_MODEL_HUB,
filename=config.HF_CONFIG_FILENAME,
token=config.HF_TOKEN
)
return local_model_path, local_config_path
def download_reranker() -> str:
"""
Download the reranker ONNX file from Hugging Face Hub.
Returns the local path to the reranker file.
"""
return hf_hub_download(
repo_id=config.HF_MODEL_HUB,
filename=config.HF_RERANKER_FILENAME,
token=config.HF_TOKEN
)
def load_llm(local_model_path: str, local_config_path: str) -> CustomModel:
"""
Load the quantized LLM into PyTorch.
If the model file is named 'pytorch_model.bin', from_pretrained will load it automatically.
Otherwise, fall back to manual state_dict loading.
"""
torch.serialization.add_safe_globals([AffineQuantizedTensor])
_config = AutoConfig.from_pretrained(local_config_path)
model = CustomModel(_config)
state_dict = torch.load(local_model_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
return model
def load_reranker(local_model_path: str) -> ort.InferenceSession:
"""
Load reranker model with ONNX Runtime.
"""
return ort.InferenceSession(local_model_path, providers=["CPUExecutionProvider"])
def load_llm_tokenizer() -> PreTrainedTokenizer:
"""
Load tokenizer for LLM
"""
return AutoTokenizer.from_pretrained(
config.HF_LLM_REPO,
token=config.HF_TOKEN
)
def load_reranker_tokenizer() -> PreTrainedTokenizer:
"""
Load tokenizer for reranker
"""
return AutoTokenizer.from_pretrained(
config.HF_RERANKER_REPO,
token=config.HF_TOKEN
)
def load_llm_from_pretrained() -> CustomModel:
"""
Load the official LLM (e.g., 4B model) directly from Hugging Face Hub
using from_pretrained. This bypasses local quantized state_dict loading.
"""
model = AutoModelForCausalLM.from_pretrained(
config.HF_LLM_REPO,
token=config.HF_TOKEN,
dtype=torch.float16,
device_map="cpu"
)
return model
def initialize_models() -> None:
"""
Download and load models on first run, then save to global cache.
"""
global _MODELS
if not _MODELS:
# llm_path, config_path = download_llm()
reranker_path = download_reranker()
# _MODELS["llm"] = load_llm(llm_path, config_path)
_MODELS["llm"] = load_llm_from_pretrained()
_MODELS["llm_tokenizer"] = load_llm_tokenizer()
_MODELS["reranker"] = load_reranker(reranker_path)
_MODELS["reranker_tokenizer"] = load_reranker_tokenizer()
def get_models() -> ModelDict:
"""
Retrieve models and tokenizers from cache.
"""
global _MODELS
if not _MODELS:
initialize_models()
return _MODELS
def initialize_prefixes() -> dict[str, torch.Tensor]:
"""
Initialize prefix cache once and store globally.
Each entry is stored as a torch.Tensor of input_ids.
"""
global _PREFIX_CACHE
if not _PREFIX_CACHE:
models = get_models()
tokenizer = models["llm_tokenizer"]
_PREFIX_CACHE = {
"instruct": tokenizer("/no_think\n", return_tensors="pt")["input_ids"],
"think": tokenizer("/think\n", return_tensors="pt")["input_ids"],
"summarize": tokenizer(textwrap.dedent("""\
Instruction: Summarize the following document in relation to the query
Constraints:
- Keep the summary under 300 words
- Focus only on information relevant to the query
- Maintain the original language of the document
"""), return_tensors="pt")["input_ids"],
"refine": tokenizer(textwrap.dedent("""\
Instruction: Combine and refine these summaries to answer the query
Constraints:
- Provide the final answer in a single coherent paragraph
- Ensure the answer directly addresses the query
- Keep the length under 500 words
- Preserve the language style of the input summaries
"""), return_tensors="pt")["input_ids"],
"query": tokenizer("query: \n", return_tensors="pt")["input_ids"],
"document": tokenizer("document: \n", return_tensors="pt")["input_ids"],
"summaries": tokenizer("summaries:\n", return_tensors="pt")["input_ids"],
"summarize_reminder": tokenizer("Reminder: Keep the summary concise and under 300 words.",
return_tensors="pt")["input_ids"],
"refine_reminder": tokenizer("Reminder: Final answer must be a single coherent paragraph under 500 words.",
return_tensors="pt")["input_ids"],
"newline": tokenizer("\n", return_tensors="pt")["input_ids"],
}
return _PREFIX_CACHE
def get_prefixes() -> dict[str, torch.Tensor]:
"""
Retrieve prefix cache from global storage.
"""
global _PREFIX_CACHE
if not _PREFIX_CACHE:
initialize_prefixes()
return _PREFIX_CACHE
|