Spaces:
Running
Running
File size: 9,254 Bytes
23680f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
"""Clean HyCoCLIP embedding provider (PyTorch) - no external hycoclip package.
This is a minimal reimplementation that loads HyCoCLIP weights directly.
Only depends on torch, timm, and numpy.
Architecture:
- ViT backbone (timm)
- Linear projection to embedding space
- Exponential map to hyperboloid (Lorentz model)
Checkpoints: https://huggingface.co/avik-pal/hycoclip
Requirements:
uv sync --extra ml
"""
from __future__ import annotations
import math
from pathlib import Path
from typing import Any
import numpy as np
from hyperview.core.sample import Sample
from hyperview.embeddings.providers import (
BaseEmbeddingProvider,
ModelSpec,
register_provider,
)
__all__ = ["HyCoCLIPProvider"]
HYCOCLIP_CHECKPOINTS: dict[str, str] = {
"hycoclip_vit_s": "hf://avik-pal/hycoclip#hycoclip_vit_s.pth",
"hycoclip_vit_b": "hf://avik-pal/hycoclip#hycoclip_vit_b.pth",
"meru_vit_s": "hf://avik-pal/hycoclip#meru_vit_s.pth",
"meru_vit_b": "hf://avik-pal/hycoclip#meru_vit_b.pth",
}
def _exp_map_lorentz(x: "torch.Tensor", c: float) -> "torch.Tensor":
"""Exponential map from tangent space at the hyperboloid vertex.
Maps Euclidean tangent vectors at the origin onto the Lorentz (hyperboloid)
model of hyperbolic space with curvature -c.
Output is ordered as (t, x1, ..., xD) and satisfies:
t^2 - ||x||^2 = 1/c
This matches HyCoCLIP/MERU exp_map0 numerics by clamping the sinh input for
stability and inferring the time component from the hyperboloid constraint.
Args:
x: Euclidean tangent vectors at the origin, shape (..., D).
c: Positive curvature parameter (hyperbolic curvature is -c).
Returns:
Hyperboloid coordinates, shape (..., D + 1).
"""
import torch
if c <= 0:
raise ValueError(f"curvature c must be > 0, got {c}")
# Compute in float32 under AMP to avoid float16/bfloat16 overflow.
if x.dtype in (torch.float16, torch.bfloat16):
x = x.float()
sqrt_c = math.sqrt(c)
rc_xnorm = sqrt_c * torch.norm(x, dim=-1, keepdim=True)
eps = 1e-8
sinh_input = torch.clamp(rc_xnorm, min=eps, max=math.asinh(2**15))
spatial = torch.sinh(sinh_input) * x / torch.clamp(rc_xnorm, min=eps)
t = torch.sqrt((1.0 / c) + torch.sum(spatial * spatial, dim=-1, keepdim=True))
return torch.cat([t, spatial], dim=-1)
def _create_encoder(
embed_dim: int = 512,
curvature: float = 0.1,
vit_model: str = "vit_small_patch16_224",
) -> "nn.Module":
"""Create HyCoCLIP image encoder using timm ViT backbone."""
import timm
import torch.nn as nn
class HyCoCLIPImageEncoder(nn.Module):
def __init__(self) -> None:
super().__init__()
self.backbone = timm.create_model(vit_model, pretrained=False, num_classes=0)
backbone_dim = int(getattr(self.backbone, "embed_dim"))
self.proj = nn.Linear(backbone_dim, embed_dim, bias=False)
self.curvature = curvature
self.embed_dim = embed_dim
def forward(self, x: "torch.Tensor") -> "torch.Tensor":
features = self.backbone(x)
spatial = self.proj(features)
return _exp_map_lorentz(spatial, self.curvature)
return HyCoCLIPImageEncoder()
def _load_encoder(checkpoint_path: str, device: str = "cpu") -> Any:
"""Load HyCoCLIP image encoder from checkpoint."""
import torch
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
state = ckpt["model"]
# Extract curvature (stored as log)
curvature = torch.exp(state["curv"]).item()
# Determine model variant from checkpoint
proj_shape = state["visual_proj.weight"].shape
embed_dim = proj_shape[0]
backbone_dim = proj_shape[1]
vit_models = {
384: "vit_small_patch16_224",
768: "vit_base_patch16_224",
1024: "vit_large_patch16_224",
}
vit_model = vit_models.get(backbone_dim, "vit_small_patch16_224")
model = _create_encoder(embed_dim=embed_dim, curvature=curvature, vit_model=vit_model)
# Remap checkpoint keys
new_state = {}
for key, value in state.items():
if key.startswith("visual."):
new_state["backbone." + key[7:]] = value
elif key == "visual_proj.weight":
new_state["proj.weight"] = value
model.load_state_dict(new_state, strict=False)
return model.to(device).eval()
class HyCoCLIPProvider(BaseEmbeddingProvider):
"""Clean HyCoCLIP provider (PyTorch) - no hycoclip package dependency.
Requires: torch, torchvision, timm (install via `uv sync --extra ml`)
"""
def __init__(self) -> None:
self._model: Any = None
self._model_spec: ModelSpec | None = None
self._device: Any = None
self._transform: Any = None
@property
def provider_id(self) -> str:
return "hycoclip"
def _get_device(self) -> Any:
import torch
if self._device is None:
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
return self._device
def _get_transform(self) -> Any:
if self._transform is None:
from torchvision import transforms
self._transform = transforms.Compose([
transforms.Resize(224, transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
])
return self._transform
def _resolve_checkpoint(self, checkpoint: str) -> Path:
"""Resolve checkpoint path, downloading from HuggingFace if needed."""
# Handle HuggingFace Hub URLs: hf://repo_id#filename
if checkpoint.startswith("hf://"):
from huggingface_hub import hf_hub_download
path = checkpoint[5:]
if "#" not in path:
raise ValueError(f"HF checkpoint must include filename: {checkpoint}")
repo_id, filename = path.split("#", 1)
return Path(hf_hub_download(repo_id=repo_id, filename=filename)).resolve()
# Local path
path = Path(checkpoint).expanduser().resolve()
if not path.exists():
raise FileNotFoundError(f"Checkpoint not found: {path}")
return path
def _load_model(self, model_spec: ModelSpec) -> None:
if self._model is not None and self._model_spec == model_spec:
return
# Auto-resolve checkpoint from model_id if not provided
checkpoint = model_spec.checkpoint
if not checkpoint:
checkpoint = HYCOCLIP_CHECKPOINTS.get(model_spec.model_id)
if not checkpoint:
available = ", ".join(sorted(HYCOCLIP_CHECKPOINTS.keys()))
raise ValueError(
f"Unknown HyCoCLIP model_id: '{model_spec.model_id}'. "
f"Known models: {available}. "
f"Or provide 'checkpoint' path explicitly."
)
checkpoint_path = self._resolve_checkpoint(checkpoint)
self._model = _load_encoder(str(checkpoint_path), str(self._get_device()))
self._model_spec = model_spec
def compute_embeddings(
self,
samples: list["Sample"],
model_spec: ModelSpec,
batch_size: int = 32,
show_progress: bool = True,
) -> np.ndarray:
"""Compute hyperboloid embeddings for samples."""
import torch
self._load_model(model_spec)
assert self._model is not None
device = self._get_device()
transform = self._get_transform()
if show_progress:
print(f"Computing HyCoCLIP embeddings for {len(samples)} samples...")
all_embeddings = []
for i in range(0, len(samples), batch_size):
batch_samples = samples[i : i + batch_size]
images = []
for sample in batch_samples:
img = sample.load_image()
if img.mode != "RGB":
img = img.convert("RGB")
images.append(transform(img))
batch_tensor = torch.stack(images).to(device)
with torch.no_grad(), torch.amp.autocast(
device_type=device.type, enabled=device.type == "cuda"
):
embeddings = self._model(batch_tensor)
all_embeddings.append(embeddings.cpu().numpy())
return np.concatenate(all_embeddings, axis=0)
def get_space_config(self, model_spec: ModelSpec, dim: int) -> dict[str, Any]:
"""Return embedding space configuration with curvature."""
self._load_model(model_spec)
assert self._model is not None
return {
"provider": self.provider_id,
"model_id": model_spec.model_id,
"geometry": "hyperboloid",
"checkpoint": model_spec.checkpoint,
"dim": dim,
"curvature": self._model.curvature,
"spatial_dim": self._model.embed_dim,
}
# Auto-register on import
register_provider("hycoclip", HyCoCLIPProvider)
|