morozovdd's picture
feat: add HyperView app for space
23680f2
"""Embedding provider abstraction for HyperView."""
from __future__ import annotations
import hashlib
from importlib import import_module
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
import numpy as np
from hyperview.core.sample import Sample
@dataclass
class ModelSpec:
"""Structured specification for an embedding model.
Attributes:
provider: Provider identifier (e.g., "embed_anything", "hycoclip")
model_id: Model identifier (HuggingFace model_id, checkpoint path, etc.)
checkpoint: Optional checkpoint path or URL for weight-only models
config_path: Optional config path for models that need it
output_geometry: Geometry of the embedding space ("euclidean", "hyperboloid")
curvature: Hyperbolic curvature (only relevant for hyperbolic geometries)
"""
provider: str
model_id: str
checkpoint: str | None = None
config_path: str | None = None
output_geometry: str = "euclidean"
curvature: float | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-serializable dict."""
d: dict[str, Any] = {
"provider": self.provider,
"model_id": self.model_id,
"geometry": self.output_geometry,
}
if self.checkpoint:
d["checkpoint"] = self.checkpoint
if self.config_path:
d["config_path"] = self.config_path
if self.curvature is not None:
d["curvature"] = self.curvature
return d
@classmethod
def from_dict(cls, d: dict[str, Any]) -> ModelSpec:
"""Create from dict (e.g., loaded from JSON)."""
return cls(
provider=d["provider"],
model_id=d["model_id"],
checkpoint=d.get("checkpoint"),
config_path=d.get("config_path"),
output_geometry=d.get("geometry", "euclidean"),
curvature=d.get("curvature"),
)
def content_hash(self) -> str:
"""Generate a short hash of the spec for collision-resistant keys."""
content = json.dumps(self.to_dict(), sort_keys=True)
return hashlib.sha256(content.encode()).hexdigest()[:12]
class BaseEmbeddingProvider(ABC):
"""Base class for embedding providers."""
@property
@abstractmethod
def provider_id(self) -> str:
"""Unique identifier for this provider."""
...
@abstractmethod
def compute_embeddings(
self,
samples: list[Sample],
model_spec: ModelSpec,
batch_size: int = 32,
show_progress: bool = True,
) -> np.ndarray:
"""Compute embeddings for samples.
Returns:
Array of shape (N, D) where N is len(samples) and D is embedding dim.
"""
...
def get_space_config(self, model_spec: ModelSpec, dim: int) -> dict[str, Any]:
"""Get config dict for SpaceInfo.config_json.
Args:
model_spec: Model specification.
dim: Embedding dimension.
Returns:
Config dict with provider, geometry, model_id, dim, and any extras.
"""
return {
**model_spec.to_dict(),
"dim": dim,
}
_PROVIDER_CLASSES: dict[str, type[BaseEmbeddingProvider]] = {}
_PROVIDER_INSTANCES: dict[str, BaseEmbeddingProvider] = {}
_KNOWN_PROVIDER_MODULES: dict[str, str] = {
"embed_anything": "hyperview.embeddings.providers.embed_anything",
"hycoclip": "hyperview.embeddings.providers.hycoclip",
"hycoclip_onnx": "hyperview.embeddings.providers.hycoclip_onnx",
}
def register_provider(provider_id: str, provider_class: type[BaseEmbeddingProvider]) -> None:
"""Register a new embedding provider class."""
_PROVIDER_CLASSES[provider_id] = provider_class
# Clear cached instance if re-registering
_PROVIDER_INSTANCES.pop(provider_id, None)
def _try_auto_register(provider_id: str, *, silent: bool = True) -> None:
"""Attempt to auto-register a provider by importing its module.
Args:
provider_id: Provider identifier.
silent: If True, swallow ImportError (used when listing providers).
If False, let ImportError propagate (used when explicitly requesting
a provider via get_provider()).
"""
module_name = _KNOWN_PROVIDER_MODULES.get(provider_id)
if not module_name:
return
if silent:
try:
import_module(module_name)
except ImportError:
return
else:
import_module(module_name)
def get_provider(provider_id: str) -> BaseEmbeddingProvider:
"""Get a provider singleton instance by ID.
Providers are cached to preserve model state across calls.
"""
if provider_id not in _PROVIDER_CLASSES:
_try_auto_register(provider_id, silent=False)
if provider_id not in _PROVIDER_CLASSES:
available = ", ".join(sorted(_PROVIDER_CLASSES.keys())) or "(none registered)"
raise ValueError(
f"Unknown embedding provider: '{provider_id}'. "
f"Available: {available}"
)
if provider_id not in _PROVIDER_INSTANCES:
_PROVIDER_INSTANCES[provider_id] = _PROVIDER_CLASSES[provider_id]()
return _PROVIDER_INSTANCES[provider_id]
def list_providers() -> list[str]:
"""List available provider IDs."""
# Trigger auto-registration for known providers
for pid in _KNOWN_PROVIDER_MODULES:
_try_auto_register(pid, silent=True)
return list(_PROVIDER_CLASSES.keys())
def make_provider_aware_space_key(model_spec: ModelSpec) -> str:
"""Generate a collision-resistant space_key from a ModelSpec.
Format: {provider}__{slugified_model_id}__{content_hash}
"""
from hyperview.storage.schema import slugify_model_id
slug = slugify_model_id(model_spec.model_id)
content_hash = model_spec.content_hash()
return f"{model_spec.provider}__{slug}__{content_hash}"
__all__ = [
"BaseEmbeddingProvider",
"ModelSpec",
"get_provider",
"list_providers",
"make_provider_aware_space_key",
"register_provider",
]