File size: 6,265 Bytes
69c12a2
 
 
 
 
 
662eb29
69c12a2
662eb29
bf2f314
 
69c12a2
 
 
 
 
 
 
 
 
 
 
 
 
 
9b58d8f
69c12a2
 
 
9b58d8f
69c12a2
 
 
4e65de4
9b58d8f
69c12a2
 
9b58d8f
69c12a2
4e65de4
9b58d8f
69c12a2
9b58d8f
69c12a2
 
 
 
 
 
 
 
4e65de4
9b58d8f
69c12a2
 
9b58d8f
69c12a2
 
 
 
 
662eb29
 
9b58d8f
7d823a8
f6e3bea
7d823a8
69c12a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b942de
 
 
 
 
 
 
 
bf2f314
 
7b942de
 
 
69c12a2
 
 
 
 
 
7b942de
69c12a2
 
7b942de
 
69c12a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac17ed0
69c12a2
 
 
 
 
 
ac17ed0
 
69c12a2
 
 
 
 
 
ac17ed0
69c12a2
 
 
 
 
 
 
ac17ed0
 
 
 
 
 
 
 
 
69c12a2
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# app/models/initializer.py
import textwrap
from typing import TypedDict, Union

import onnxruntime as ort
import torch
import torch.serialization
from huggingface_hub import hf_hub_download
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizer)

import config
from models.engines.model_extensions import CustomModel


class ModelDict(TypedDict):
    llm: Union[CustomModel, torch.nn.Module]
    llm_tokenizer: PreTrainedTokenizer
    reranker: ort.InferenceSession
    reranker_tokenizer: PreTrainedTokenizer

_MODELS: dict[str, ModelDict] = {}
_PREFIX_CACHE = {}

def download_llm() -> tuple[str, str]:
    """
    Download the quantized LLM file from Hugging Face Hub 
    (e.g., model_quantized.pt or model.bin).
    Returns the local path to the model and config files.
    """
    local_model_path = hf_hub_download(
        repo_id=config.HF_MODEL_HUB,
        filename=config.HF_LLM_FILENAME,
        token=config.HF_TOKEN
    )

    local_config_path = hf_hub_download(
        repo_id=config.HF_MODEL_HUB,
        filename=config.HF_CONFIG_FILENAME,
        token=config.HF_TOKEN
    )
    return local_model_path, local_config_path

def download_reranker() -> str:
    """
    Download the reranker ONNX file from Hugging Face Hub.
    Returns the local path to the reranker file.
    """
    return hf_hub_download(
        repo_id=config.HF_MODEL_HUB,
        filename=config.HF_RERANKER_FILENAME,
        token=config.HF_TOKEN
    )

def load_llm(local_model_path: str, local_config_path: str) -> CustomModel:
    """
    Load the quantized LLM into PyTorch.
    If the model file is named 'pytorch_model.bin', from_pretrained will load it automatically.
    Otherwise, fall back to manual state_dict loading.
    """
    torch.serialization.add_safe_globals([AffineQuantizedTensor])

    _config = AutoConfig.from_pretrained(local_config_path)
    model = CustomModel(_config)
    state_dict = torch.load(local_model_path, map_location="cpu", weights_only=True)
    model.load_state_dict(state_dict)

    return model

def load_reranker(local_model_path: str) -> ort.InferenceSession:
    """
    Load reranker model with ONNX Runtime.
    """
    return ort.InferenceSession(local_model_path, providers=["CPUExecutionProvider"])

def load_llm_tokenizer() -> PreTrainedTokenizer:
    """
    Load tokenizer for LLM
    """
    return AutoTokenizer.from_pretrained(
        config.HF_LLM_REPO,
        token=config.HF_TOKEN
    )

def load_reranker_tokenizer() -> PreTrainedTokenizer:
    """
    Load tokenizer for reranker
    """
    return AutoTokenizer.from_pretrained(
        config.HF_RERANKER_REPO,
        token=config.HF_TOKEN
    )

def load_llm_from_pretrained() -> CustomModel:
    """
    Load the official LLM (e.g., 4B model) directly from Hugging Face Hub
    using from_pretrained. This bypasses local quantized state_dict loading.
    """
    model = AutoModelForCausalLM.from_pretrained(
        config.HF_LLM_REPO,
        token=config.HF_TOKEN,
        dtype=torch.float16,   
        device_map="cpu"  
    )
    return model

def initialize_models() -> None:
    """
    Download and load models on first run, then save to global cache.
    """
    global _MODELS
    if not _MODELS:
        # llm_path, config_path = download_llm()
        reranker_path = download_reranker()

        # _MODELS["llm"] = load_llm(llm_path, config_path)
        _MODELS["llm"] = load_llm_from_pretrained()
        _MODELS["llm_tokenizer"] = load_llm_tokenizer()

        _MODELS["reranker"] = load_reranker(reranker_path)
        _MODELS["reranker_tokenizer"] = load_reranker_tokenizer()

def get_models() -> ModelDict:
    """
    Retrieve models and tokenizers from cache.
    """
    global _MODELS
    if not _MODELS:
        initialize_models()
    return _MODELS

def initialize_prefixes() -> dict[str, torch.Tensor]:
    """
    Initialize prefix cache once and store globally.
    Each entry is stored as a torch.Tensor of input_ids.
    """
    global _PREFIX_CACHE
    if not _PREFIX_CACHE:
        models = get_models()
        tokenizer = models["llm_tokenizer"]
        _PREFIX_CACHE = {
            "instruct": tokenizer("/no_think\n", return_tensors="pt")["input_ids"],
            "think": tokenizer("/think\n", return_tensors="pt")["input_ids"],
            "summarize": tokenizer(textwrap.dedent("""\
                Instruction: Summarize the following document in relation to the query
                Constraints:
                - Keep the summary under 300 words
                - Focus only on information relevant to the query
                - Maintain the original language of the document
            """), return_tensors="pt")["input_ids"],
            "refine": tokenizer(textwrap.dedent("""\
                Instruction: Combine and refine these summaries to answer the query
                Constraints:
                - Provide the final answer in a single coherent paragraph
                - Ensure the answer directly addresses the query
                - Keep the length under 500 words
                - Preserve the language style of the input summaries
            """), return_tensors="pt")["input_ids"],
            "query": tokenizer("query: \n", return_tensors="pt")["input_ids"],
            "document": tokenizer("document: \n", return_tensors="pt")["input_ids"],
            "summaries": tokenizer("summaries:\n", return_tensors="pt")["input_ids"],
            "summarize_reminder": tokenizer("Reminder: Keep the summary concise and under 300 words.", 
                                            return_tensors="pt")["input_ids"],
            "refine_reminder": tokenizer("Reminder: Final answer must be a single coherent paragraph under 500 words.", 
                                         return_tensors="pt")["input_ids"],
            "newline": tokenizer("\n", return_tensors="pt")["input_ids"],
        }
    return _PREFIX_CACHE

def get_prefixes() -> dict[str, torch.Tensor]:
    """
    Retrieve prefix cache from global storage.
    """
    global _PREFIX_CACHE
    if not _PREFIX_CACHE:
        initialize_prefixes()
    return _PREFIX_CACHE