pls-rag / models /embedder.py
m97j's picture
Initial codes commit
2aa7bf4
raw
history blame
884 Bytes
# rag/models/embedder.py
from typing import List
import numpy as np
import onnxruntime as ort
from fastapi import Request
def _l2_normalize(vec: np.ndarray) -> List[float]:
norm = np.linalg.norm(vec) or 1.0
return (vec / norm).tolist()
def get_embedding(request: Request, text: str) -> List[float]:
"""
request.app.state.embedder_sess : ONNX Runtime InferenceSession
request.app.state.embedder_tokenizer : ํ† ํฌ๋‚˜์ด์ €
"""
tokenizer = request.app.state.embedder_tokenizer
sess: ort.InferenceSession = request.app.state.embedder_sess
inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256)
ort_inputs = {k: v for k, v in inputs.items()}
ort_outs = sess.run(None, ort_inputs)
# ์ผ๋ฐ˜์ ์œผ๋กœ ์ฒซ ๋ฒˆ์งธ ์ถœ๋ ฅ์ด [batch, dim] ์ž„๋ฒ ๋”ฉ
vec = ort_outs[0][0]
return _l2_normalize(vec)