mentorme / services /embedding_service.py
Doanh Van Vu
Add test for recommend mentors endpoint and enhance logging in services
6a14fa9
import torch
from sentence_transformers import SentenceTransformer
import logging
import time
from typing import List, Union
from config.settings import get_settings
logger = logging.getLogger(__name__)
class EmbeddingService:
_instance = None
_model = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(EmbeddingService, cls).__new__(cls)
return cls._instance
def __init__(self):
if EmbeddingService._model is None:
self._load_model()
def _load_model(self):
settings = get_settings()
try:
start_time = time.perf_counter()
logger.info(f"[EMBEDDING] Starting to load embedding model: {settings.EMBEDDING_MODEL_NAME}")
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"[EMBEDDING] Using device: {device}")
EmbeddingService._model = SentenceTransformer(
settings.EMBEDDING_MODEL_NAME,
device=device
)
EmbeddingService._model.max_seq_length = 2048
load_time = time.perf_counter() - start_time
logger.info(f"[EMBEDDING] Embedding model loaded successfully in {load_time:.3f}s")
except Exception as e:
logger.error(f"[EMBEDDING] Failed to load embedding model: {str(e)}", exc_info=True)
raise
def encode(
self,
texts: Union[str, List[str]],
is_query: bool = False,
batch_size: int = 32,
max_length: int = 2048
) -> Union[List[float], List[List[float]]]:
if EmbeddingService._model is None:
raise RuntimeError("Embedding model not loaded")
if isinstance(texts, str):
texts = [texts]
single_text = True
else:
single_text = False
if not texts:
raise ValueError("Texts cannot be empty")
try:
encode_start = time.perf_counter()
embeddings = EmbeddingService._model.encode(
texts,
batch_size=batch_size,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=False
)
encode_time = time.perf_counter() - encode_start
logger.info(f"[EMBEDDING] Encoded {len(texts)} text(s) in {encode_time:.3f}s (is_query={is_query})")
expected_dim = 1024
if single_text:
embedding_list = embeddings[0].tolist()
if len(embedding_list) != expected_dim:
logger.warning(f"[EMBEDDING] Embedding dimension mismatch: expected {expected_dim}, got {len(embedding_list)}")
return embedding_list
result = []
for emb in embeddings:
emb_list = emb.tolist()
if len(emb_list) != expected_dim:
logger.warning(f"[EMBEDDING] Embedding dimension mismatch: expected {expected_dim}, got {len(emb_list)}")
result.append(emb_list)
return result
except Exception as e:
logger.error(f"[EMBEDDING] Error encoding texts: {str(e)}", exc_info=True)
raise
def get_model_info(self) -> dict:
settings = get_settings()
dimension = 1024
if EmbeddingService._model is not None:
try:
test_embedding = EmbeddingService._model.encode(["test"], convert_to_numpy=True)
dimension = len(test_embedding[0])
except Exception as e:
logger.warning(f"Could not determine model dimension: {str(e)}")
return {
"model_name": settings.EMBEDDING_MODEL_NAME,
"dimension": dimension,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"max_seq_length": EmbeddingService._model.max_seq_length if EmbeddingService._model else 2048
}