vietqa-api / src /utils /embeddings.py
quanho114
Deploy VietQA API
ebb8326
"""Embedding models using local HuggingFace Vietnamese models."""
import torch
from langchain_core.embeddings import Embeddings
from langchain_huggingface import HuggingFaceEmbeddings
from src.config import settings
from src.utils.logging import log_pipeline
_embeddings: Embeddings | None = None
def get_device() -> str:
"""Detect optimal device."""
# Force CPU to avoid silent crashes with CUDA/MPS
# Uncomment below if you have working GPU setup
# if torch.cuda.is_available():
# return "cuda"
# if torch.backends.mps.is_available():
# return "mps"
return "cpu"
def get_embeddings() -> Embeddings:
"""Get or create embeddings model singleton (local HuggingFace)."""
global _embeddings
if _embeddings is not None:
return _embeddings
device = get_device()
_embeddings = HuggingFaceEmbeddings(
model_name=settings.embedding_model,
model_kwargs={"device": device},
encode_kwargs={"normalize_embeddings": True},
)
log_pipeline(f"[Embedding] Loaded: {settings.embedding_model} on {device}")
return _embeddings