File size: 1,689 Bytes
d44b33d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | """Factory for LangChain embedding backends (OpenAI, Ollama, Hugging Face).
The active provider is ``Settings.llm_provider``. Used by ingest and query paths when
opening or querying Chroma collections.
"""
from langchain_core.embeddings import Embeddings
from langchain_ollama import OllamaEmbeddings
from langchain_openai import OpenAIEmbeddings
from pydantic import SecretStr
from api.config import get_settings
from rag.hf_hub_inference import HubInferenceEmbeddings
def create_embedding_function() -> Embeddings:
"""Return an ``Embeddings`` implementation matching the configured LLM provider."""
settings = get_settings()
provider = settings.llm_provider.lower()
if provider == "openai":
if not settings.openai_api_key:
raise ValueError("OPENAI_API_KEY is required when LLM_PROVIDER=openai")
return OpenAIEmbeddings(
model=settings.openai_embedding_model,
api_key=SecretStr(settings.openai_api_key),
)
if provider == "huggingface":
if not settings.huggingface_api_key:
raise ValueError(
"A Hugging Face token is required when LLM_PROVIDER=huggingface "
"(set HUGGINGFACE_API_KEY or HF_TOKEN / HUGGING_FACE_HUB_TOKEN on Spaces)."
)
return HubInferenceEmbeddings(
model=settings.huggingface_embedding_model,
api_token=settings.huggingface_api_key,
)
if provider == "ollama":
return OllamaEmbeddings(
model=settings.ollama_embedding_model,
base_url=settings.ollama_base_url,
)
raise ValueError(f"Unsupported LLM_PROVIDER: {settings.llm_provider}")
|