pragmatic-agent / models /initializer.py
m97j's picture
feat(initializer): store prefix cache as input_ids tensors
ac17ed0
# app/models/initializer.py
import textwrap
from typing import TypedDict, Union
import onnxruntime as ort
import torch
import torch.serialization
from huggingface_hub import hf_hub_download
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizer)
import config
from models.engines.model_extensions import CustomModel
class ModelDict(TypedDict):
llm: Union[CustomModel, torch.nn.Module]
llm_tokenizer: PreTrainedTokenizer
reranker: ort.InferenceSession
reranker_tokenizer: PreTrainedTokenizer
_MODELS: dict[str, ModelDict] = {}
_PREFIX_CACHE = {}
def download_llm() -> tuple[str, str]:
"""
Download the quantized LLM file from Hugging Face Hub
(e.g., model_quantized.pt or model.bin).
Returns the local path to the model and config files.
"""
local_model_path = hf_hub_download(
repo_id=config.HF_MODEL_HUB,
filename=config.HF_LLM_FILENAME,
token=config.HF_TOKEN
)
local_config_path = hf_hub_download(
repo_id=config.HF_MODEL_HUB,
filename=config.HF_CONFIG_FILENAME,
token=config.HF_TOKEN
)
return local_model_path, local_config_path
def download_reranker() -> str:
"""
Download the reranker ONNX file from Hugging Face Hub.
Returns the local path to the reranker file.
"""
return hf_hub_download(
repo_id=config.HF_MODEL_HUB,
filename=config.HF_RERANKER_FILENAME,
token=config.HF_TOKEN
)
def load_llm(local_model_path: str, local_config_path: str) -> CustomModel:
"""
Load the quantized LLM into PyTorch.
If the model file is named 'pytorch_model.bin', from_pretrained will load it automatically.
Otherwise, fall back to manual state_dict loading.
"""
torch.serialization.add_safe_globals([AffineQuantizedTensor])
_config = AutoConfig.from_pretrained(local_config_path)
model = CustomModel(_config)
state_dict = torch.load(local_model_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
return model
def load_reranker(local_model_path: str) -> ort.InferenceSession:
"""
Load reranker model with ONNX Runtime.
"""
return ort.InferenceSession(local_model_path, providers=["CPUExecutionProvider"])
def load_llm_tokenizer() -> PreTrainedTokenizer:
"""
Load tokenizer for LLM
"""
return AutoTokenizer.from_pretrained(
config.HF_LLM_REPO,
token=config.HF_TOKEN
)
def load_reranker_tokenizer() -> PreTrainedTokenizer:
"""
Load tokenizer for reranker
"""
return AutoTokenizer.from_pretrained(
config.HF_RERANKER_REPO,
token=config.HF_TOKEN
)
def load_llm_from_pretrained() -> CustomModel:
"""
Load the official LLM (e.g., 4B model) directly from Hugging Face Hub
using from_pretrained. This bypasses local quantized state_dict loading.
"""
model = AutoModelForCausalLM.from_pretrained(
config.HF_LLM_REPO,
token=config.HF_TOKEN,
dtype=torch.float16,
device_map="cpu"
)
return model
def initialize_models() -> None:
"""
Download and load models on first run, then save to global cache.
"""
global _MODELS
if not _MODELS:
# llm_path, config_path = download_llm()
reranker_path = download_reranker()
# _MODELS["llm"] = load_llm(llm_path, config_path)
_MODELS["llm"] = load_llm_from_pretrained()
_MODELS["llm_tokenizer"] = load_llm_tokenizer()
_MODELS["reranker"] = load_reranker(reranker_path)
_MODELS["reranker_tokenizer"] = load_reranker_tokenizer()
def get_models() -> ModelDict:
"""
Retrieve models and tokenizers from cache.
"""
global _MODELS
if not _MODELS:
initialize_models()
return _MODELS
def initialize_prefixes() -> dict[str, torch.Tensor]:
"""
Initialize prefix cache once and store globally.
Each entry is stored as a torch.Tensor of input_ids.
"""
global _PREFIX_CACHE
if not _PREFIX_CACHE:
models = get_models()
tokenizer = models["llm_tokenizer"]
_PREFIX_CACHE = {
"instruct": tokenizer("/no_think\n", return_tensors="pt")["input_ids"],
"think": tokenizer("/think\n", return_tensors="pt")["input_ids"],
"summarize": tokenizer(textwrap.dedent("""\
Instruction: Summarize the following document in relation to the query
Constraints:
- Keep the summary under 300 words
- Focus only on information relevant to the query
- Maintain the original language of the document
"""), return_tensors="pt")["input_ids"],
"refine": tokenizer(textwrap.dedent("""\
Instruction: Combine and refine these summaries to answer the query
Constraints:
- Provide the final answer in a single coherent paragraph
- Ensure the answer directly addresses the query
- Keep the length under 500 words
- Preserve the language style of the input summaries
"""), return_tensors="pt")["input_ids"],
"query": tokenizer("query: \n", return_tensors="pt")["input_ids"],
"document": tokenizer("document: \n", return_tensors="pt")["input_ids"],
"summaries": tokenizer("summaries:\n", return_tensors="pt")["input_ids"],
"summarize_reminder": tokenizer("Reminder: Keep the summary concise and under 300 words.",
return_tensors="pt")["input_ids"],
"refine_reminder": tokenizer("Reminder: Final answer must be a single coherent paragraph under 500 words.",
return_tensors="pt")["input_ids"],
"newline": tokenizer("\n", return_tensors="pt")["input_ids"],
}
return _PREFIX_CACHE
def get_prefixes() -> dict[str, torch.Tensor]:
"""
Retrieve prefix cache from global storage.
"""
global _PREFIX_CACHE
if not _PREFIX_CACHE:
initialize_prefixes()
return _PREFIX_CACHE