Spaces:
Running
Running
| """Embedding provider abstraction for HyperView.""" | |
| from __future__ import annotations | |
| import hashlib | |
| from importlib import import_module | |
| import json | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import numpy as np | |
| from hyperview.core.sample import Sample | |
| class ModelSpec: | |
| """Structured specification for an embedding model. | |
| Attributes: | |
| provider: Provider identifier (e.g., "embed_anything", "hycoclip") | |
| model_id: Model identifier (HuggingFace model_id, checkpoint path, etc.) | |
| checkpoint: Optional checkpoint path or URL for weight-only models | |
| config_path: Optional config path for models that need it | |
| output_geometry: Geometry of the embedding space ("euclidean", "hyperboloid") | |
| curvature: Hyperbolic curvature (only relevant for hyperbolic geometries) | |
| """ | |
| provider: str | |
| model_id: str | |
| checkpoint: str | None = None | |
| config_path: str | None = None | |
| output_geometry: str = "euclidean" | |
| curvature: float | None = None | |
| def to_dict(self) -> dict[str, Any]: | |
| """Convert to JSON-serializable dict.""" | |
| d: dict[str, Any] = { | |
| "provider": self.provider, | |
| "model_id": self.model_id, | |
| "geometry": self.output_geometry, | |
| } | |
| if self.checkpoint: | |
| d["checkpoint"] = self.checkpoint | |
| if self.config_path: | |
| d["config_path"] = self.config_path | |
| if self.curvature is not None: | |
| d["curvature"] = self.curvature | |
| return d | |
| def from_dict(cls, d: dict[str, Any]) -> ModelSpec: | |
| """Create from dict (e.g., loaded from JSON).""" | |
| return cls( | |
| provider=d["provider"], | |
| model_id=d["model_id"], | |
| checkpoint=d.get("checkpoint"), | |
| config_path=d.get("config_path"), | |
| output_geometry=d.get("geometry", "euclidean"), | |
| curvature=d.get("curvature"), | |
| ) | |
| def content_hash(self) -> str: | |
| """Generate a short hash of the spec for collision-resistant keys.""" | |
| content = json.dumps(self.to_dict(), sort_keys=True) | |
| return hashlib.sha256(content.encode()).hexdigest()[:12] | |
| class BaseEmbeddingProvider(ABC): | |
| """Base class for embedding providers.""" | |
| def provider_id(self) -> str: | |
| """Unique identifier for this provider.""" | |
| ... | |
| def compute_embeddings( | |
| self, | |
| samples: list[Sample], | |
| model_spec: ModelSpec, | |
| batch_size: int = 32, | |
| show_progress: bool = True, | |
| ) -> np.ndarray: | |
| """Compute embeddings for samples. | |
| Returns: | |
| Array of shape (N, D) where N is len(samples) and D is embedding dim. | |
| """ | |
| ... | |
| def get_space_config(self, model_spec: ModelSpec, dim: int) -> dict[str, Any]: | |
| """Get config dict for SpaceInfo.config_json. | |
| Args: | |
| model_spec: Model specification. | |
| dim: Embedding dimension. | |
| Returns: | |
| Config dict with provider, geometry, model_id, dim, and any extras. | |
| """ | |
| return { | |
| **model_spec.to_dict(), | |
| "dim": dim, | |
| } | |
| _PROVIDER_CLASSES: dict[str, type[BaseEmbeddingProvider]] = {} | |
| _PROVIDER_INSTANCES: dict[str, BaseEmbeddingProvider] = {} | |
| _KNOWN_PROVIDER_MODULES: dict[str, str] = { | |
| "embed_anything": "hyperview.embeddings.providers.embed_anything", | |
| "hycoclip": "hyperview.embeddings.providers.hycoclip", | |
| "hycoclip_onnx": "hyperview.embeddings.providers.hycoclip_onnx", | |
| } | |
| def register_provider(provider_id: str, provider_class: type[BaseEmbeddingProvider]) -> None: | |
| """Register a new embedding provider class.""" | |
| _PROVIDER_CLASSES[provider_id] = provider_class | |
| # Clear cached instance if re-registering | |
| _PROVIDER_INSTANCES.pop(provider_id, None) | |
| def _try_auto_register(provider_id: str, *, silent: bool = True) -> None: | |
| """Attempt to auto-register a provider by importing its module. | |
| Args: | |
| provider_id: Provider identifier. | |
| silent: If True, swallow ImportError (used when listing providers). | |
| If False, let ImportError propagate (used when explicitly requesting | |
| a provider via get_provider()). | |
| """ | |
| module_name = _KNOWN_PROVIDER_MODULES.get(provider_id) | |
| if not module_name: | |
| return | |
| if silent: | |
| try: | |
| import_module(module_name) | |
| except ImportError: | |
| return | |
| else: | |
| import_module(module_name) | |
| def get_provider(provider_id: str) -> BaseEmbeddingProvider: | |
| """Get a provider singleton instance by ID. | |
| Providers are cached to preserve model state across calls. | |
| """ | |
| if provider_id not in _PROVIDER_CLASSES: | |
| _try_auto_register(provider_id, silent=False) | |
| if provider_id not in _PROVIDER_CLASSES: | |
| available = ", ".join(sorted(_PROVIDER_CLASSES.keys())) or "(none registered)" | |
| raise ValueError( | |
| f"Unknown embedding provider: '{provider_id}'. " | |
| f"Available: {available}" | |
| ) | |
| if provider_id not in _PROVIDER_INSTANCES: | |
| _PROVIDER_INSTANCES[provider_id] = _PROVIDER_CLASSES[provider_id]() | |
| return _PROVIDER_INSTANCES[provider_id] | |
| def list_providers() -> list[str]: | |
| """List available provider IDs.""" | |
| # Trigger auto-registration for known providers | |
| for pid in _KNOWN_PROVIDER_MODULES: | |
| _try_auto_register(pid, silent=True) | |
| return list(_PROVIDER_CLASSES.keys()) | |
| def make_provider_aware_space_key(model_spec: ModelSpec) -> str: | |
| """Generate a collision-resistant space_key from a ModelSpec. | |
| Format: {provider}__{slugified_model_id}__{content_hash} | |
| """ | |
| from hyperview.storage.schema import slugify_model_id | |
| slug = slugify_model_id(model_spec.model_id) | |
| content_hash = model_spec.content_hash() | |
| return f"{model_spec.provider}__{slug}__{content_hash}" | |
| __all__ = [ | |
| "BaseEmbeddingProvider", | |
| "ModelSpec", | |
| "get_provider", | |
| "list_providers", | |
| "make_provider_aware_space_key", | |
| "register_provider", | |
| ] | |