File size: 1,689 Bytes
d44b33d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""Factory for LangChain embedding backends (OpenAI, Ollama, Hugging Face).

The active provider is ``Settings.llm_provider``. Used by ingest and query paths when
opening or querying Chroma collections.
"""

from langchain_core.embeddings import Embeddings
from langchain_ollama import OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings
from pydantic import SecretStr

from api.config import get_settings
from rag.hf_hub_inference import HubInferenceEmbeddings


def create_embedding_function() -> Embeddings:
    """Return an ``Embeddings`` implementation matching the configured LLM provider."""
    settings = get_settings()
    provider = settings.llm_provider.lower()

    if provider == "openai":
        if not settings.openai_api_key:
            raise ValueError("OPENAI_API_KEY is required when LLM_PROVIDER=openai")
        return OpenAIEmbeddings(
            model=settings.openai_embedding_model,
            api_key=SecretStr(settings.openai_api_key),
        )
    if provider == "huggingface":
        if not settings.huggingface_api_key:
            raise ValueError(
                "A Hugging Face token is required when LLM_PROVIDER=huggingface "
                "(set HUGGINGFACE_API_KEY or HF_TOKEN / HUGGING_FACE_HUB_TOKEN on Spaces)."
            )
        return HubInferenceEmbeddings(
            model=settings.huggingface_embedding_model,
            api_token=settings.huggingface_api_key,
        )
    if provider == "ollama":
        return OllamaEmbeddings(
            model=settings.ollama_embedding_model,
            base_url=settings.ollama_base_url,
        )
    raise ValueError(f"Unsupported LLM_PROVIDER: {settings.llm_provider}")