Spaces:
Sleeping
Sleeping
File size: 5,233 Bytes
1e732dd 696f787 1e732dd 9659593 1e732dd 696f787 1e732dd 9659593 1e732dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """
MediGuard AI β Embedding Service
Supports Jina AI, Google, HuggingFace, and Ollama embeddings with
automatic fallback chain: Jina β Google β HuggingFace.
"""
from __future__ import annotations
import logging
from functools import lru_cache
from src.exceptions import EmbeddingError, EmbeddingProviderError
from src.settings import get_settings
logger = logging.getLogger(__name__)
class EmbeddingService:
"""Unified embedding interface β delegates to the configured provider."""
def __init__(self, model, provider_name: str, dimension: int):
self._model = model
self.provider_name = provider_name
self.dimension = dimension
def embed_query(self, text: str) -> list[float]:
"""Embed a single query text."""
try:
return self._model.embed_query(text)
except Exception as exc:
raise EmbeddingProviderError(f"{self.provider_name} embed_query failed: {exc}") from exc
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Batch-embed a list of texts."""
try:
return self._model.embed_documents(texts)
except Exception as exc:
raise EmbeddingProviderError(f"{self.provider_name} embed_documents failed: {exc}") from exc
def _make_google_embeddings():
settings = get_settings()
api_key = settings.embedding.google_api_key or settings.llm.google_api_key
if not api_key:
raise EmbeddingError("GOOGLE_API_KEY not set for Google embeddings")
from langchain_google_genai import GoogleGenerativeAIEmbeddings
model = GoogleGenerativeAIEmbeddings(
model="models/text-embedding-004",
google_api_key=api_key,
)
return EmbeddingService(model, "google", 768)
def _make_huggingface_embeddings():
settings = get_settings()
try:
from langchain_huggingface import HuggingFaceEmbeddings
except ImportError:
from langchain_community.embeddings import HuggingFaceEmbeddings
model = HuggingFaceEmbeddings(model_name=settings.embedding.huggingface_model)
return EmbeddingService(model, "huggingface", 384)
def _make_ollama_embeddings():
settings = get_settings()
try:
from langchain_ollama import OllamaEmbeddings
except ImportError:
from langchain_community.embeddings import OllamaEmbeddings
model = OllamaEmbeddings(
model=settings.ollama.embedding_model,
base_url=settings.ollama.host,
)
return EmbeddingService(model, "ollama", 768)
def _make_jina_embeddings():
settings = get_settings()
api_key = settings.embedding.jina_api_key
if not api_key:
raise EmbeddingError("JINA_API_KEY not set for Jina embeddings")
# Jina v3 via httpx (lightweight, no extra SDK)
import httpx
class _JinaModel:
"""Minimal Jina AI embedding adapter."""
def __init__(self, api_key: str, model: str):
self._api_key = api_key
self._model = model
self._url = "https://api.jina.ai/v1/embeddings"
def _call(self, texts: list[str], task: str = "retrieval.passage") -> list[list[float]]:
headers = {"Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json"}
payload = {"model": self._model, "input": texts, "task": task}
resp = httpx.post(self._url, json=payload, headers=headers, timeout=60)
resp.raise_for_status()
data = resp.json()["data"]
return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])]
def embed_query(self, text: str) -> list[float]:
return self._call([text], task="retrieval.query")[0]
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._call(texts, task="retrieval.passage")
model = _JinaModel(api_key, settings.embedding.jina_model)
return EmbeddingService(model, "jina", settings.embedding.dimension)
# ββ Fallback chain factory βββββββββββββββββββββββββββββββββββββββββββββββββββ
_PROVIDERS = {
"jina": _make_jina_embeddings,
"google": _make_google_embeddings,
"huggingface": _make_huggingface_embeddings,
"ollama": _make_ollama_embeddings,
}
FALLBACK_ORDER = ["jina", "google", "huggingface"]
@lru_cache(maxsize=1)
def make_embedding_service() -> EmbeddingService:
"""Create an embedding service with automatic fallback."""
settings = get_settings()
preferred = settings.embedding.provider
# Try preferred first, then fallbacks
order = [preferred] + [p for p in FALLBACK_ORDER if p != preferred]
for provider in order:
factory = _PROVIDERS.get(provider)
if factory is None:
continue
try:
svc = factory()
logger.info("Embedding provider: %s (dim=%d)", svc.provider_name, svc.dimension)
return svc
except Exception as exc:
logger.warning("Embedding provider '%s' failed: %s β trying next", provider, exc)
raise EmbeddingError("All embedding providers failed. Check your API keys and configuration.")
|