| from langchain_core.embeddings import Embeddings |
| from langchain_ollama import OllamaEmbeddings |
| from langchain_openai import OpenAIEmbeddings |
| from pydantic import SecretStr |
|
|
| from api.config import get_settings |
| from rag.hf_hub_inference import HubInferenceEmbeddings |
|
|
|
|
| def create_embedding_function() -> Embeddings: |
| settings = get_settings() |
| provider = settings.llm_provider.lower() |
|
|
| if provider == "openai": |
| if not settings.openai_api_key: |
| raise ValueError("OPENAI_API_KEY is required when LLM_PROVIDER=openai") |
| return OpenAIEmbeddings( |
| model=settings.openai_embedding_model, |
| api_key=SecretStr(settings.openai_api_key), |
| ) |
| if provider == "huggingface": |
| if not settings.huggingface_api_key: |
| raise ValueError( |
| "A Hugging Face token is required when LLM_PROVIDER=huggingface " |
| "(set HUGGINGFACE_API_KEY or HF_TOKEN / HUGGING_FACE_HUB_TOKEN on Spaces)." |
| ) |
| return HubInferenceEmbeddings( |
| model=settings.huggingface_embedding_model, |
| api_token=settings.huggingface_api_key, |
| ) |
| if provider == "ollama": |
| return OllamaEmbeddings( |
| model=settings.ollama_embedding_model, |
| base_url=settings.ollama_base_url, |
| ) |
| raise ValueError(f"Unsupported LLM_PROVIDER: {settings.llm_provider}") |
|
|
|
|