Spaces:
Running
Running
File size: 6,145 Bytes
23680f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
"""Embedding provider abstraction for HyperView."""
from __future__ import annotations
import hashlib
from importlib import import_module
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any
import numpy as np
from hyperview.core.sample import Sample
@dataclass
class ModelSpec:
"""Structured specification for an embedding model.
Attributes:
provider: Provider identifier (e.g., "embed_anything", "hycoclip")
model_id: Model identifier (HuggingFace model_id, checkpoint path, etc.)
checkpoint: Optional checkpoint path or URL for weight-only models
config_path: Optional config path for models that need it
output_geometry: Geometry of the embedding space ("euclidean", "hyperboloid")
curvature: Hyperbolic curvature (only relevant for hyperbolic geometries)
"""
provider: str
model_id: str
checkpoint: str | None = None
config_path: str | None = None
output_geometry: str = "euclidean"
curvature: float | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-serializable dict."""
d: dict[str, Any] = {
"provider": self.provider,
"model_id": self.model_id,
"geometry": self.output_geometry,
}
if self.checkpoint:
d["checkpoint"] = self.checkpoint
if self.config_path:
d["config_path"] = self.config_path
if self.curvature is not None:
d["curvature"] = self.curvature
return d
@classmethod
def from_dict(cls, d: dict[str, Any]) -> ModelSpec:
"""Create from dict (e.g., loaded from JSON)."""
return cls(
provider=d["provider"],
model_id=d["model_id"],
checkpoint=d.get("checkpoint"),
config_path=d.get("config_path"),
output_geometry=d.get("geometry", "euclidean"),
curvature=d.get("curvature"),
)
def content_hash(self) -> str:
"""Generate a short hash of the spec for collision-resistant keys."""
content = json.dumps(self.to_dict(), sort_keys=True)
return hashlib.sha256(content.encode()).hexdigest()[:12]
class BaseEmbeddingProvider(ABC):
"""Base class for embedding providers."""
@property
@abstractmethod
def provider_id(self) -> str:
"""Unique identifier for this provider."""
...
@abstractmethod
def compute_embeddings(
self,
samples: list[Sample],
model_spec: ModelSpec,
batch_size: int = 32,
show_progress: bool = True,
) -> np.ndarray:
"""Compute embeddings for samples.
Returns:
Array of shape (N, D) where N is len(samples) and D is embedding dim.
"""
...
def get_space_config(self, model_spec: ModelSpec, dim: int) -> dict[str, Any]:
"""Get config dict for SpaceInfo.config_json.
Args:
model_spec: Model specification.
dim: Embedding dimension.
Returns:
Config dict with provider, geometry, model_id, dim, and any extras.
"""
return {
**model_spec.to_dict(),
"dim": dim,
}
_PROVIDER_CLASSES: dict[str, type[BaseEmbeddingProvider]] = {}
_PROVIDER_INSTANCES: dict[str, BaseEmbeddingProvider] = {}
_KNOWN_PROVIDER_MODULES: dict[str, str] = {
"embed_anything": "hyperview.embeddings.providers.embed_anything",
"hycoclip": "hyperview.embeddings.providers.hycoclip",
"hycoclip_onnx": "hyperview.embeddings.providers.hycoclip_onnx",
}
def register_provider(provider_id: str, provider_class: type[BaseEmbeddingProvider]) -> None:
"""Register a new embedding provider class."""
_PROVIDER_CLASSES[provider_id] = provider_class
# Clear cached instance if re-registering
_PROVIDER_INSTANCES.pop(provider_id, None)
def _try_auto_register(provider_id: str, *, silent: bool = True) -> None:
"""Attempt to auto-register a provider by importing its module.
Args:
provider_id: Provider identifier.
silent: If True, swallow ImportError (used when listing providers).
If False, let ImportError propagate (used when explicitly requesting
a provider via get_provider()).
"""
module_name = _KNOWN_PROVIDER_MODULES.get(provider_id)
if not module_name:
return
if silent:
try:
import_module(module_name)
except ImportError:
return
else:
import_module(module_name)
def get_provider(provider_id: str) -> BaseEmbeddingProvider:
"""Get a provider singleton instance by ID.
Providers are cached to preserve model state across calls.
"""
if provider_id not in _PROVIDER_CLASSES:
_try_auto_register(provider_id, silent=False)
if provider_id not in _PROVIDER_CLASSES:
available = ", ".join(sorted(_PROVIDER_CLASSES.keys())) or "(none registered)"
raise ValueError(
f"Unknown embedding provider: '{provider_id}'. "
f"Available: {available}"
)
if provider_id not in _PROVIDER_INSTANCES:
_PROVIDER_INSTANCES[provider_id] = _PROVIDER_CLASSES[provider_id]()
return _PROVIDER_INSTANCES[provider_id]
def list_providers() -> list[str]:
"""List available provider IDs."""
# Trigger auto-registration for known providers
for pid in _KNOWN_PROVIDER_MODULES:
_try_auto_register(pid, silent=True)
return list(_PROVIDER_CLASSES.keys())
def make_provider_aware_space_key(model_spec: ModelSpec) -> str:
"""Generate a collision-resistant space_key from a ModelSpec.
Format: {provider}__{slugified_model_id}__{content_hash}
"""
from hyperview.storage.schema import slugify_model_id
slug = slugify_model_id(model_spec.model_id)
content_hash = model_spec.content_hash()
return f"{model_spec.provider}__{slug}__{content_hash}"
__all__ = [
"BaseEmbeddingProvider",
"ModelSpec",
"get_provider",
"list_providers",
"make_provider_aware_space_key",
"register_provider",
]
|