|
|
|
|
|
from transformers import AutoTokenizer |
|
|
import onnxruntime as ort |
|
|
from huggingface_hub import hf_hub_download |
|
|
from fastapi import FastAPI |
|
|
from config import HF_MODEL_REPO_ID, EMBED_MODEL, EMBED_DIR, RERANK_MODEL, RERANK_DIR |
|
|
|
|
|
def initialize_models(app: FastAPI): |
|
|
|
|
|
embedder_tokenizer = AutoTokenizer.from_pretrained( |
|
|
HF_MODEL_REPO_ID, |
|
|
subfolder=EMBED_DIR |
|
|
) |
|
|
embedder_model_path = hf_hub_download( |
|
|
repo_id=HF_MODEL_REPO_ID, |
|
|
filename=EMBED_MODEL, |
|
|
subfolder=EMBED_DIR |
|
|
) |
|
|
embedder_sess = ort.InferenceSession(embedder_model_path, providers=["CPUExecutionProvider"]) |
|
|
|
|
|
|
|
|
reranker_tokenizer = AutoTokenizer.from_pretrained( |
|
|
HF_MODEL_REPO_ID, |
|
|
subfolder=RERANK_DIR |
|
|
) |
|
|
reranker_model_path = hf_hub_download( |
|
|
repo_id=HF_MODEL_REPO_ID, |
|
|
filename=RERANK_MODEL, |
|
|
subfolder=RERANK_DIR |
|
|
) |
|
|
reranker_sess = ort.InferenceSession(reranker_model_path, providers=["CPUExecutionProvider"]) |
|
|
|
|
|
|
|
|
app.state.embedder_tokenizer = embedder_tokenizer |
|
|
app.state.embedder_sess = embedder_sess |
|
|
app.state.reranker_tokenizer = reranker_tokenizer |
|
|
app.state.reranker_sess = reranker_sess |
|
|
|