pls-rag / models /initializer.py
m97j's picture
Initial codes commit
2aa7bf4
raw
history blame
1.43 kB
# rag/models/initializer.py
from transformers import AutoTokenizer
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from fastapi import FastAPI
from config import HF_MODEL_REPO_ID, EMBED_MODEL, EMBED_DIR, RERANK_MODEL, RERANK_DIR
def initialize_models(app: FastAPI):
# Embedder
embedder_tokenizer = AutoTokenizer.from_pretrained(
HF_MODEL_REPO_ID,
subfolder=EMBED_DIR # ํ† ํฌ๋‚˜์ด์ € ๊ด€๋ จ ํŒŒ์ผ์ด embedder/ ์•ˆ์— ์žˆ์œผ๋ฏ€๋กœ ์ง€์ •
)
embedder_model_path = hf_hub_download(
repo_id=HF_MODEL_REPO_ID,
filename=EMBED_MODEL,
subfolder=EMBED_DIR
)
embedder_sess = ort.InferenceSession(embedder_model_path, providers=["CPUExecutionProvider"])
# Reranker
reranker_tokenizer = AutoTokenizer.from_pretrained(
HF_MODEL_REPO_ID,
subfolder=RERANK_DIR # ํ† ํฌ๋‚˜์ด์ € ๊ด€๋ จ ํŒŒ์ผ์ด reranker/ ์•ˆ์— ์žˆ์œผ๋ฏ€๋กœ ์ง€์ •
)
reranker_model_path = hf_hub_download(
repo_id=HF_MODEL_REPO_ID,
filename=RERANK_MODEL,
subfolder=RERANK_DIR
)
reranker_sess = ort.InferenceSession(reranker_model_path, providers=["CPUExecutionProvider"])
# FastAPI app.state์— ์ €์žฅ โ†’ ์ „์—ญ ๊ณต์œ 
app.state.embedder_tokenizer = embedder_tokenizer
app.state.embedder_sess = embedder_sess
app.state.reranker_tokenizer = reranker_tokenizer
app.state.reranker_sess = reranker_sess