Corin1998's picture
Update rag/embeddings.py
3840298 verified
from typing import List
import numpy as np
from sentence_transformers import SentenceTransformer
_MODEL = None
def get_embedding_model(name: str = "sentence-transformers/all-MiniLM-L6-v2") -> SentenceTransformer:
"""Lazy-load and cache the SentenceTransformer model."""
global _MODEL
if _MODEL is None:
_MODEL = SentenceTransformer(name)
return _MODEL
def embed_texts(texts: List[str]) -> np.ndarray:
"""
Encode a list of texts into L2-normalized float32 embeddings.
- 空リスト -> (0, dim) の2次元配列を返す
- 1件のみ -> (1, dim) に整形
"""
model = get_embedding_model()
# 0件を安全に処理(FAISSの次元を決めるため dim を取得)
if not texts:
dim = int(getattr(model, "get_sentence_embedding_dimension", lambda: 384)())
return np.zeros((0, dim), dtype="float32")
vecs = model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
# sentence-transformers の戻りが1Dになるケースに備えて2Dへ
arr = np.asarray(vecs, dtype="float32")
if arr.ndim == 1:
arr = arr.reshape(1, -1)
return arr