File size: 5,233 Bytes
1e732dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696f787
1e732dd
 
 
 
9659593
1e732dd
696f787
1e732dd
 
 
 
9659593
1e732dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
MediGuard AI β€” Embedding Service

Supports Jina AI, Google, HuggingFace, and Ollama embeddings with
automatic fallback chain: Jina β†’ Google β†’ HuggingFace.
"""

from __future__ import annotations

import logging
from functools import lru_cache

from src.exceptions import EmbeddingError, EmbeddingProviderError
from src.settings import get_settings

logger = logging.getLogger(__name__)


class EmbeddingService:
    """Unified embedding interface β€” delegates to the configured provider."""

    def __init__(self, model, provider_name: str, dimension: int):
        self._model = model
        self.provider_name = provider_name
        self.dimension = dimension

    def embed_query(self, text: str) -> list[float]:
        """Embed a single query text."""
        try:
            return self._model.embed_query(text)
        except Exception as exc:
            raise EmbeddingProviderError(f"{self.provider_name} embed_query failed: {exc}") from exc

    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        """Batch-embed a list of texts."""
        try:
            return self._model.embed_documents(texts)
        except Exception as exc:
            raise EmbeddingProviderError(f"{self.provider_name} embed_documents failed: {exc}") from exc


def _make_google_embeddings():
    settings = get_settings()
    api_key = settings.embedding.google_api_key or settings.llm.google_api_key
    if not api_key:
        raise EmbeddingError("GOOGLE_API_KEY not set for Google embeddings")
    from langchain_google_genai import GoogleGenerativeAIEmbeddings

    model = GoogleGenerativeAIEmbeddings(
        model="models/text-embedding-004",
        google_api_key=api_key,
    )
    return EmbeddingService(model, "google", 768)


def _make_huggingface_embeddings():
    settings = get_settings()
    try:
        from langchain_huggingface import HuggingFaceEmbeddings
    except ImportError:
        from langchain_community.embeddings import HuggingFaceEmbeddings

    model = HuggingFaceEmbeddings(model_name=settings.embedding.huggingface_model)
    return EmbeddingService(model, "huggingface", 384)


def _make_ollama_embeddings():
    settings = get_settings()
    try:
        from langchain_ollama import OllamaEmbeddings
    except ImportError:
        from langchain_community.embeddings import OllamaEmbeddings

    model = OllamaEmbeddings(
        model=settings.ollama.embedding_model,
        base_url=settings.ollama.host,
    )
    return EmbeddingService(model, "ollama", 768)


def _make_jina_embeddings():
    settings = get_settings()
    api_key = settings.embedding.jina_api_key
    if not api_key:
        raise EmbeddingError("JINA_API_KEY not set for Jina embeddings")
    # Jina v3 via httpx (lightweight, no extra SDK)
    import httpx

    class _JinaModel:
        """Minimal Jina AI embedding adapter."""

        def __init__(self, api_key: str, model: str):
            self._api_key = api_key
            self._model = model
            self._url = "https://api.jina.ai/v1/embeddings"

        def _call(self, texts: list[str], task: str = "retrieval.passage") -> list[list[float]]:
            headers = {"Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json"}
            payload = {"model": self._model, "input": texts, "task": task}
            resp = httpx.post(self._url, json=payload, headers=headers, timeout=60)
            resp.raise_for_status()
            data = resp.json()["data"]
            return [item["embedding"] for item in sorted(data, key=lambda x: x["index"])]

        def embed_query(self, text: str) -> list[float]:
            return self._call([text], task="retrieval.query")[0]

        def embed_documents(self, texts: list[str]) -> list[list[float]]:
            return self._call(texts, task="retrieval.passage")

    model = _JinaModel(api_key, settings.embedding.jina_model)
    return EmbeddingService(model, "jina", settings.embedding.dimension)


# ── Fallback chain factory ───────────────────────────────────────────────────

_PROVIDERS = {
    "jina": _make_jina_embeddings,
    "google": _make_google_embeddings,
    "huggingface": _make_huggingface_embeddings,
    "ollama": _make_ollama_embeddings,
}

FALLBACK_ORDER = ["jina", "google", "huggingface"]


@lru_cache(maxsize=1)
def make_embedding_service() -> EmbeddingService:
    """Create an embedding service with automatic fallback."""
    settings = get_settings()
    preferred = settings.embedding.provider

    # Try preferred first, then fallbacks
    order = [preferred] + [p for p in FALLBACK_ORDER if p != preferred]
    for provider in order:
        factory = _PROVIDERS.get(provider)
        if factory is None:
            continue
        try:
            svc = factory()
            logger.info("Embedding provider: %s (dim=%d)", svc.provider_name, svc.dimension)
            return svc
        except Exception as exc:
            logger.warning("Embedding provider '%s' failed: %s β€” trying next", provider, exc)

    raise EmbeddingError("All embedding providers failed. Check your API keys and configuration.")