KnowledgeMesh / app /services /embeddings.py
pkheria's picture
psuhing to git
b5e0c74
Raw
History Blame Contribute Delete
3.64 kB
from functools import cached_property, lru_cache
from app.core.config import settings
from app.utils.zerogpu import is_enabled as zerogpu_is_enabled
class LocalEmbeddingClient:
def __init__(self, model: str | None = None, device: str | None = None):
self.model_name = model or settings.NEMOTRON_EMBED_MODEL
self.device = device or _resolve_device()
@cached_property
def model(self):
try:
from sentence_transformers import SentenceTransformer
except ImportError as exc:
raise ImportError(
"sentence-transformers is required for local embeddings. "
"Install dependencies with `pip install -r requirements.txt`."
) from exc
return SentenceTransformer(
self.model_name,
device=self.device,
token=settings.HF_TOKEN or None,
trust_remote_code=True,
)
@cached_property
def native_model(self):
try:
from transformers import AutoModel
except ImportError as exc:
raise ImportError(
"transformers is required for native local embeddings. "
"Install dependencies with `pip install -r requirements.txt`."
) from exc
model = AutoModel.from_pretrained(
self.model_name,
token=settings.HF_TOKEN or None,
trust_remote_code=True,
dtype="auto" if self.device != "cpu" else None,
)
if self.device:
model = model.to(self.device)
return model.eval()
def embed_texts(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []
try:
embeddings = self.model.encode(
texts,
batch_size=8,
normalize_embeddings=True,
show_progress_bar=False,
)
return embeddings.tolist()
except ValueError as exc:
if "Modality 'text' is not supported" not in str(exc):
raise
embeddings = self._embed_with_native_query_encoder(texts)
return embeddings.tolist()
def _embed_with_native_query_encoder(self, texts: list[str]):
try:
import torch
import torch.nn.functional as F
except ImportError as exc:
raise ImportError(
"torch is required for the native Nemotron embedding path. "
"Install dependencies with `pip install -r requirements.txt`."
) from exc
if not hasattr(self.native_model, "forward_queries"):
raise ValueError(
f"{self.model_name} does not support SentenceTransformer text encoding "
"or a native forward_queries API."
)
with torch.no_grad():
output = self.native_model.forward_queries(texts, batch_size=4)
if isinstance(output, (list, tuple)):
output = output[0]
if not torch.is_tensor(output):
output = torch.as_tensor(output)
if output.ndim == 3:
output = output.float().mean(dim=1)
elif output.ndim != 2:
raise ValueError(f"Unexpected embedding shape from {self.model_name}: {tuple(output.shape)}")
return F.normalize(output.float(), p=2, dim=1).cpu()
@lru_cache(maxsize=1)
def get_embedding_client() -> LocalEmbeddingClient:
return LocalEmbeddingClient()
def _resolve_device() -> str:
if zerogpu_is_enabled() and settings.EMBEDDING_DEVICE == "cpu":
return "cuda"
return settings.EMBEDDING_DEVICE