"""Clean HyCoCLIP embedding provider (PyTorch) - no external hycoclip package. This is a minimal reimplementation that loads HyCoCLIP weights directly. Only depends on torch, timm, and numpy. Architecture: - ViT backbone (timm) - Linear projection to embedding space - Exponential map to hyperboloid (Lorentz model) Checkpoints: https://huggingface.co/avik-pal/hycoclip Requirements: uv sync --extra ml """ from __future__ import annotations import math from pathlib import Path from typing import Any import numpy as np from hyperview.core.sample import Sample from hyperview.embeddings.providers import ( BaseEmbeddingProvider, ModelSpec, register_provider, ) __all__ = ["HyCoCLIPProvider"] HYCOCLIP_CHECKPOINTS: dict[str, str] = { "hycoclip_vit_s": "hf://avik-pal/hycoclip#hycoclip_vit_s.pth", "hycoclip_vit_b": "hf://avik-pal/hycoclip#hycoclip_vit_b.pth", "meru_vit_s": "hf://avik-pal/hycoclip#meru_vit_s.pth", "meru_vit_b": "hf://avik-pal/hycoclip#meru_vit_b.pth", } def _exp_map_lorentz(x: "torch.Tensor", c: float) -> "torch.Tensor": """Exponential map from tangent space at the hyperboloid vertex. Maps Euclidean tangent vectors at the origin onto the Lorentz (hyperboloid) model of hyperbolic space with curvature -c. Output is ordered as (t, x1, ..., xD) and satisfies: t^2 - ||x||^2 = 1/c This matches HyCoCLIP/MERU exp_map0 numerics by clamping the sinh input for stability and inferring the time component from the hyperboloid constraint. Args: x: Euclidean tangent vectors at the origin, shape (..., D). c: Positive curvature parameter (hyperbolic curvature is -c). Returns: Hyperboloid coordinates, shape (..., D + 1). """ import torch if c <= 0: raise ValueError(f"curvature c must be > 0, got {c}") # Compute in float32 under AMP to avoid float16/bfloat16 overflow. if x.dtype in (torch.float16, torch.bfloat16): x = x.float() sqrt_c = math.sqrt(c) rc_xnorm = sqrt_c * torch.norm(x, dim=-1, keepdim=True) eps = 1e-8 sinh_input = torch.clamp(rc_xnorm, min=eps, max=math.asinh(2**15)) spatial = torch.sinh(sinh_input) * x / torch.clamp(rc_xnorm, min=eps) t = torch.sqrt((1.0 / c) + torch.sum(spatial * spatial, dim=-1, keepdim=True)) return torch.cat([t, spatial], dim=-1) def _create_encoder( embed_dim: int = 512, curvature: float = 0.1, vit_model: str = "vit_small_patch16_224", ) -> "nn.Module": """Create HyCoCLIP image encoder using timm ViT backbone.""" import timm import torch.nn as nn class HyCoCLIPImageEncoder(nn.Module): def __init__(self) -> None: super().__init__() self.backbone = timm.create_model(vit_model, pretrained=False, num_classes=0) backbone_dim = int(getattr(self.backbone, "embed_dim")) self.proj = nn.Linear(backbone_dim, embed_dim, bias=False) self.curvature = curvature self.embed_dim = embed_dim def forward(self, x: "torch.Tensor") -> "torch.Tensor": features = self.backbone(x) spatial = self.proj(features) return _exp_map_lorentz(spatial, self.curvature) return HyCoCLIPImageEncoder() def _load_encoder(checkpoint_path: str, device: str = "cpu") -> Any: """Load HyCoCLIP image encoder from checkpoint.""" import torch ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) state = ckpt["model"] # Extract curvature (stored as log) curvature = torch.exp(state["curv"]).item() # Determine model variant from checkpoint proj_shape = state["visual_proj.weight"].shape embed_dim = proj_shape[0] backbone_dim = proj_shape[1] vit_models = { 384: "vit_small_patch16_224", 768: "vit_base_patch16_224", 1024: "vit_large_patch16_224", } vit_model = vit_models.get(backbone_dim, "vit_small_patch16_224") model = _create_encoder(embed_dim=embed_dim, curvature=curvature, vit_model=vit_model) # Remap checkpoint keys new_state = {} for key, value in state.items(): if key.startswith("visual."): new_state["backbone." + key[7:]] = value elif key == "visual_proj.weight": new_state["proj.weight"] = value model.load_state_dict(new_state, strict=False) return model.to(device).eval() class HyCoCLIPProvider(BaseEmbeddingProvider): """Clean HyCoCLIP provider (PyTorch) - no hycoclip package dependency. Requires: torch, torchvision, timm (install via `uv sync --extra ml`) """ def __init__(self) -> None: self._model: Any = None self._model_spec: ModelSpec | None = None self._device: Any = None self._transform: Any = None @property def provider_id(self) -> str: return "hycoclip" def _get_device(self) -> Any: import torch if self._device is None: self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return self._device def _get_transform(self) -> Any: if self._transform is None: from torchvision import transforms self._transform = transforms.Compose([ transforms.Resize(224, transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], ), ]) return self._transform def _resolve_checkpoint(self, checkpoint: str) -> Path: """Resolve checkpoint path, downloading from HuggingFace if needed.""" # Handle HuggingFace Hub URLs: hf://repo_id#filename if checkpoint.startswith("hf://"): from huggingface_hub import hf_hub_download path = checkpoint[5:] if "#" not in path: raise ValueError(f"HF checkpoint must include filename: {checkpoint}") repo_id, filename = path.split("#", 1) return Path(hf_hub_download(repo_id=repo_id, filename=filename)).resolve() # Local path path = Path(checkpoint).expanduser().resolve() if not path.exists(): raise FileNotFoundError(f"Checkpoint not found: {path}") return path def _load_model(self, model_spec: ModelSpec) -> None: if self._model is not None and self._model_spec == model_spec: return # Auto-resolve checkpoint from model_id if not provided checkpoint = model_spec.checkpoint if not checkpoint: checkpoint = HYCOCLIP_CHECKPOINTS.get(model_spec.model_id) if not checkpoint: available = ", ".join(sorted(HYCOCLIP_CHECKPOINTS.keys())) raise ValueError( f"Unknown HyCoCLIP model_id: '{model_spec.model_id}'. " f"Known models: {available}. " f"Or provide 'checkpoint' path explicitly." ) checkpoint_path = self._resolve_checkpoint(checkpoint) self._model = _load_encoder(str(checkpoint_path), str(self._get_device())) self._model_spec = model_spec def compute_embeddings( self, samples: list["Sample"], model_spec: ModelSpec, batch_size: int = 32, show_progress: bool = True, ) -> np.ndarray: """Compute hyperboloid embeddings for samples.""" import torch self._load_model(model_spec) assert self._model is not None device = self._get_device() transform = self._get_transform() if show_progress: print(f"Computing HyCoCLIP embeddings for {len(samples)} samples...") all_embeddings = [] for i in range(0, len(samples), batch_size): batch_samples = samples[i : i + batch_size] images = [] for sample in batch_samples: img = sample.load_image() if img.mode != "RGB": img = img.convert("RGB") images.append(transform(img)) batch_tensor = torch.stack(images).to(device) with torch.no_grad(), torch.amp.autocast( device_type=device.type, enabled=device.type == "cuda" ): embeddings = self._model(batch_tensor) all_embeddings.append(embeddings.cpu().numpy()) return np.concatenate(all_embeddings, axis=0) def get_space_config(self, model_spec: ModelSpec, dim: int) -> dict[str, Any]: """Return embedding space configuration with curvature.""" self._load_model(model_spec) assert self._model is not None return { "provider": self.provider_id, "model_id": model_spec.model_id, "geometry": "hyperboloid", "checkpoint": model_spec.checkpoint, "dim": dim, "curvature": self._model.curvature, "spatial_dim": self._model.embed_dim, } # Auto-register on import register_provider("hycoclip", HyCoCLIPProvider)