| from __future__ import annotations |
|
|
| import math |
| import time |
| from functools import lru_cache |
|
|
| import httpx |
| from huggingface_hub import InferenceClient |
|
|
| from app.config import ( |
| EMBEDDING_API_RETRIES, |
| EMBEDDING_API_RETRY_BACKOFF, |
| EMBEDDING_API_TIMEOUT, |
| EMBEDDING_API_URL, |
| EMBEDDING_DIM, |
| EMBEDDING_MODEL, |
| EMBEDDING_PROVIDER, |
| HF_INFERENCE_PROVIDER, |
| ) |
| from app.runtime_auth import get_hf_api_key |
|
|
|
|
| class EmbeddingModel: |
| def __init__(self) -> None: |
| self.dim = EMBEDDING_DIM |
| self.provider = EMBEDDING_PROVIDER |
|
|
| def encode(self, texts: list[str]) -> list[list[float]]: |
| if self.provider != "hf_api": |
| raise RuntimeError("Local embedding providers are disabled. Use Hugging Face API only.") |
| return self._api_embedding(texts) |
|
|
| def _api_embedding(self, texts: list[str]) -> list[list[float]]: |
| api_key = get_hf_api_key() |
| if not api_key: |
| raise RuntimeError("Enter a Hugging Face token to use API embeddings") |
| if not texts: |
| return [] |
|
|
| client_kwargs = {"api_key": api_key, "timeout": EMBEDDING_API_TIMEOUT} |
| if EMBEDDING_API_URL: |
| client = InferenceClient(model=EMBEDDING_API_URL, **client_kwargs) |
| model = None |
| else: |
| client = InferenceClient(provider=HF_INFERENCE_PROVIDER, **client_kwargs) |
| model = EMBEDDING_MODEL |
|
|
| payload = self._feature_extraction_with_retry(client, texts, model) |
| if hasattr(payload, "tolist"): |
| payload = payload.tolist() |
|
|
| vectors = self._coerce_api_vectors(payload, expected_count=len(texts)) |
| return [self._normalize_vector(vector) for vector in vectors] |
|
|
| def _feature_extraction_with_retry(self, client: InferenceClient, texts: list[str], model: str | None): |
| attempts = max(1, EMBEDDING_API_RETRIES) |
| last_error: Exception | None = None |
| for attempt in range(1, attempts + 1): |
| try: |
| return client.feature_extraction(texts, model=model) |
| except (httpx.TimeoutException, httpx.TransportError) as exc: |
| last_error = exc |
| if attempt == attempts: |
| break |
| time.sleep(EMBEDDING_API_RETRY_BACKOFF * attempt) |
| raise RuntimeError( |
| "Hugging Face embedding request timed out. " |
| "Try lowering EMBEDDING_BATCH_SIZE or increasing EMBEDDING_API_TIMEOUT." |
| ) from last_error |
|
|
| def _coerce_api_vectors(self, payload, expected_count: int) -> list[list[float]]: |
| if not isinstance(payload, list): |
| raise RuntimeError(f"Unexpected embedding API response: {type(payload).__name__}") |
| if expected_count == 1 and self._is_vector(payload): |
| return [self._fit_dimension([float(value) for value in payload])] |
| if len(payload) != expected_count: |
| raise RuntimeError(f"Expected {expected_count} embeddings, received {len(payload)}") |
|
|
| vectors = [] |
| for item in payload: |
| if self._is_vector(item): |
| vectors.append(self._fit_dimension([float(value) for value in item])) |
| elif isinstance(item, list) and item and all(self._is_vector(token_vector) for token_vector in item): |
| vectors.append(self._fit_dimension(self._mean_pool(item))) |
| else: |
| raise RuntimeError("Unexpected embedding vector shape from API") |
| return vectors |
|
|
| def _is_vector(self, value) -> bool: |
| return isinstance(value, list) and all(isinstance(item, int | float) for item in value) |
|
|
| def _mean_pool(self, token_vectors: list[list[float]]) -> list[float]: |
| width = len(token_vectors[0]) |
| pooled = [] |
| for index in range(width): |
| pooled.append(sum(float(vector[index]) for vector in token_vectors) / len(token_vectors)) |
| return pooled |
|
|
| def _fit_dimension(self, vector: list[float]) -> list[float]: |
| if len(vector) == self.dim: |
| return vector |
| if len(vector) > self.dim: |
| return vector[: self.dim] |
| return vector + [0.0] * (self.dim - len(vector)) |
|
|
| def _normalize_vector(self, vector: list[float]) -> list[float]: |
| norm = math.sqrt(sum(value * value for value in vector)) |
| if norm == 0: |
| return vector |
| return [value / norm for value in vector] |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_embedding_model() -> EmbeddingModel: |
| return EmbeddingModel() |
|
|