pls-rag / modules /embedder.py
m97j's picture
Initial codes commit
4fdc679
raw
history blame
635 Bytes
# rag/modules/embedder.py
import math
from typing import List
from huggingface_hub import InferenceClient
from config import EMBED_MODEL, HF_TOKEN
_client = InferenceClient(model=EMBED_MODEL, token=HF_TOKEN)
def _l2_normalize(vec: List[float]) -> List[float]:
norm = math.sqrt(sum(x * x for x in vec)) or 1.0
return [x / norm for x in vec]
def get_embedding(text: str) -> List[float]:
# feature_extraction은 항상 2차원 배열 반환: [batch_size, embedding_dim]
embedding_2d = _client.feature_extraction(text)
vec = embedding_2d[0] # 첫 번째 행이 입력 문장의 벡터
return _l2_normalize(vec)