File size: 1,430 Bytes
2aa7bf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# rag/models/initializer.py
from transformers import AutoTokenizer
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from fastapi import FastAPI
from config import HF_MODEL_REPO_ID, EMBED_MODEL, EMBED_DIR, RERANK_MODEL, RERANK_DIR

def initialize_models(app: FastAPI):
    # Embedder
    embedder_tokenizer = AutoTokenizer.from_pretrained(
        HF_MODEL_REPO_ID,
        subfolder=EMBED_DIR   # ํ† ํฌ๋‚˜์ด์ € ๊ด€๋ จ ํŒŒ์ผ์ด embedder/ ์•ˆ์— ์žˆ์œผ๋ฏ€๋กœ ์ง€์ •
    )
    embedder_model_path = hf_hub_download(
        repo_id=HF_MODEL_REPO_ID,
        filename=EMBED_MODEL,
        subfolder=EMBED_DIR
    )
    embedder_sess = ort.InferenceSession(embedder_model_path, providers=["CPUExecutionProvider"])

    # Reranker
    reranker_tokenizer = AutoTokenizer.from_pretrained(
        HF_MODEL_REPO_ID,
        subfolder=RERANK_DIR   # ํ† ํฌ๋‚˜์ด์ € ๊ด€๋ จ ํŒŒ์ผ์ด reranker/ ์•ˆ์— ์žˆ์œผ๋ฏ€๋กœ ์ง€์ •
    )
    reranker_model_path = hf_hub_download(
        repo_id=HF_MODEL_REPO_ID,
        filename=RERANK_MODEL,
        subfolder=RERANK_DIR
    )
    reranker_sess = ort.InferenceSession(reranker_model_path, providers=["CPUExecutionProvider"])

    # FastAPI app.state์— ์ €์žฅ โ†’ ์ „์—ญ ๊ณต์œ 
    app.state.embedder_tokenizer = embedder_tokenizer
    app.state.embedder_sess = embedder_sess
    app.state.reranker_tokenizer = reranker_tokenizer
    app.state.reranker_sess = reranker_sess