File size: 1,101 Bytes
16fa4e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""Local embeddings via Sentence-Transformers (no inference provider)."""

from __future__ import annotations

from functools import lru_cache

from langchain_core.embeddings import Embeddings
from sentence_transformers import SentenceTransformer

from src.config import settings


@lru_cache(maxsize=1)
def _model() -> SentenceTransformer:
    return SentenceTransformer(settings.embedding_model)


class LocalSentenceTransformerEmbeddings(Embeddings):
    def embed_documents(self, texts: list[str]) -> list[list[float]]:  # type: ignore[override]
        vecs = _model().encode(texts, normalize_embeddings=True)
        tolist = getattr(vecs, "tolist", None)
        return tolist() if callable(tolist) else [list(map(float, v)) for v in vecs]

    def embed_query(self, text: str) -> list[float]:  # type: ignore[override]
        vec = _model().encode([text], normalize_embeddings=True)[0]
        tolist = getattr(vec, "tolist", None)
        return tolist() if callable(tolist) else [float(x) for x in vec]


def get_embeddings() -> Embeddings:
    return LocalSentenceTransformerEmbeddings()