File size: 1,475 Bytes
2aa7bf4 801bba3 2aa7bf4 801bba3 2aa7bf4 801bba3 2aa7bf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
# rag/models/initializer.py
from transformers import AutoTokenizer
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from fastapi import FastAPI
from config import HF_MODEL_REPO_ID, EMBED_MODEL, EMBED_DIR, RERANK_MODEL, RERANK_DIR
# 공식 모델 경로 (원본 토크나이저를 불러오기 위해 지정)
EMBEDDER_ORIGINAL_ID = "Qwen/Qwen3-Embedding-0.6B"
RERANKER_ORIGINAL_ID = "Qwen/Qwen3-Reranker-0.6B"
def initialize_models(app: FastAPI):
# Embedder
embedder_tokenizer = AutoTokenizer.from_pretrained(
EMBEDDER_ORIGINAL_ID,
trust_remote_code=True
)
embedder_model_path = hf_hub_download(
repo_id=HF_MODEL_REPO_ID,
filename=EMBED_MODEL,
subfolder=EMBED_DIR
)
embedder_sess = ort.InferenceSession(embedder_model_path, providers=["CPUExecutionProvider"])
# Reranker
reranker_tokenizer = AutoTokenizer.from_pretrained(
RERANKER_ORIGINAL_ID,
trust_remote_code=True
)
reranker_model_path = hf_hub_download(
repo_id=HF_MODEL_REPO_ID,
filename=RERANK_MODEL,
subfolder=RERANK_DIR
)
reranker_sess = ort.InferenceSession(reranker_model_path, providers=["CPUExecutionProvider"])
# FastAPI app.state에 저장 → 전역 공유
app.state.embedder_tokenizer = embedder_tokenizer
app.state.embedder_sess = embedder_sess
app.state.reranker_tokenizer = reranker_tokenizer
app.state.reranker_sess = reranker_sess
|