Spaces:
Running
on
Zero
Running
on
Zero
| """Hyperbolic layer with tangent space operations for hyperbolic embeddings.""" | |
| from __future__ import annotations | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Dict, Tuple | |
| # Numerical stability constants | |
| MIN_NORM = 1e-15 | |
| BALL_EPS = 1e-5 | |
| def project_to_ball(x: torch.Tensor, c: float = 1.0, eps: float = BALL_EPS) -> torch.Tensor: | |
| """ | |
| Project points to Poincare ball (ensure ||x|| < 1/sqrt(c)). | |
| Args: | |
| x: Points to project | |
| c: Curvature (positive, ball radius = 1/sqrt(c)) | |
| eps: Safety margin from boundary | |
| Returns: | |
| Projected points | |
| """ | |
| max_norm = (1.0 - eps) / math.sqrt(c) | |
| norm = x.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM) | |
| cond = norm > max_norm | |
| x_proj = x / norm * max_norm | |
| return torch.where(cond, x_proj, x) | |
| def expmap0(v: torch.Tensor, c: float = 1.0) -> torch.Tensor: | |
| """ | |
| Exponential map from tangent space at origin to Poincare ball. | |
| Maps vectors from Euclidean tangent space to hyperbolic space. | |
| Args: | |
| v: Tangent vectors at origin [*, dim] | |
| c: Curvature | |
| Returns: | |
| Points on Poincare ball [*, dim] | |
| """ | |
| sqrt_c = math.sqrt(c) | |
| v_norm = v.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM) | |
| # exp_0(v) = tanh(sqrt(c) * ||v||) * v / (sqrt(c) * ||v||) | |
| return torch.tanh(sqrt_c * v_norm) * v / (sqrt_c * v_norm) | |
| def logmap0(y: torch.Tensor, c: float = 1.0) -> torch.Tensor: | |
| """ | |
| Logarithmic map from Poincare ball to tangent space at origin. | |
| Inverse of expmap0. | |
| Args: | |
| y: Points on Poincare ball [*, dim] | |
| c: Curvature | |
| Returns: | |
| Tangent vectors at origin [*, dim] | |
| """ | |
| sqrt_c = math.sqrt(c) | |
| y_norm = y.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM) | |
| # Clamp to valid range for atanh | |
| y_norm = y_norm.clamp(max=1.0 - BALL_EPS) | |
| # log_0(y) = arctanh(sqrt(c) * ||y||) * y / (sqrt(c) * ||y||) | |
| return torch.atanh(sqrt_c * y_norm) * y / (sqrt_c * y_norm) | |
| def hyperbolic_distance_tangent( | |
| u: torch.Tensor, | |
| v: torch.Tensor, | |
| c: float = 1.0, | |
| ) -> torch.Tensor: | |
| """ | |
| Approximate hyperbolic distance using tangent space. | |
| Valid when ||u||, ||v|| < 0.5 (near origin approximation). | |
| This is much faster than full Poincare distance. | |
| Args: | |
| u, v: Points [*, dim] | |
| c: Curvature | |
| Returns: | |
| Distances [*] | |
| """ | |
| diff = u - v | |
| diff_norm_sq = (diff ** 2).sum(dim=-1) | |
| u_norm_sq = (u ** 2).sum(dim=-1) | |
| v_norm_sq = (v ** 2).sum(dim=-1) | |
| # First-order correction for curvature | |
| # d(u,v) ~ ||u-v|| * (1 + c*(||u||^2 + ||v||^2)/12) | |
| correction = 1.0 + c * (u_norm_sq + v_norm_sq) / 12.0 | |
| return torch.sqrt(diff_norm_sq + MIN_NORM) * correction | |
| def poincare_distance( | |
| u: torch.Tensor, | |
| v: torch.Tensor, | |
| c: float = 1.0, | |
| ) -> torch.Tensor: | |
| """ | |
| Full Poincare ball distance (more expensive but exact). | |
| d(u,v) = (2/sqrt(c)) * arctanh(sqrt(c) * ||-u + v||) | |
| Args: | |
| u, v: Points in Poincare ball [*, dim] | |
| c: Curvature | |
| Returns: | |
| Distances [*] | |
| """ | |
| sqrt_c = math.sqrt(c) | |
| # Mobius addition: -u + v | |
| # First compute -u + v using the formula | |
| diff = v - u | |
| u_norm_sq = (u ** 2).sum(dim=-1, keepdim=True) | |
| v_norm_sq = (v ** 2).sum(dim=-1, keepdim=True) | |
| uv = (u * v).sum(dim=-1, keepdim=True) | |
| num = (1 - 2 * c * uv + c * v_norm_sq) * (-u) + (1 + c * u_norm_sq) * v | |
| denom = 1 - 2 * c * uv + c * c * u_norm_sq * v_norm_sq | |
| mobius_add = num / (denom + MIN_NORM) | |
| # Distance | |
| mobius_norm = mobius_add.norm(dim=-1).clamp(max=1.0 - BALL_EPS) | |
| dist = (2.0 / sqrt_c) * torch.atanh(sqrt_c * mobius_norm) | |
| return dist | |
| class TangentSpaceProjection(nn.Module): | |
| """ | |
| Project Euclidean features to hyperbolic tangent space. | |
| Uses tangent space at origin for efficiency. | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| output_dim: int, | |
| curvature: float = 1.0, | |
| use_bias: bool = True, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| self.curvature = curvature | |
| self.linear = nn.Linear(input_dim, output_dim, bias=use_bias) | |
| # Initialize for small outputs (stay near origin) | |
| nn.init.xavier_uniform_(self.linear.weight, gain=0.1) | |
| if use_bias: | |
| nn.init.zeros_(self.linear.bias) | |
| def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: | |
| """ | |
| Project input to tangent space. | |
| Args: | |
| x: Input features [*, input_dim] | |
| Returns: | |
| Dict with 'tangent' (tangent space vectors) and 'ball' (Poincare ball) | |
| """ | |
| # Project to lower dim | |
| tangent = self.linear(x) | |
| # Normalize to keep in valid region (||z|| < 0.9) | |
| norm = tangent.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM) | |
| max_norm = 0.9 / math.sqrt(self.curvature) | |
| tangent = tangent * (max_norm * torch.tanh(norm / max_norm) / norm) | |
| # Map to Poincare ball | |
| ball = expmap0(tangent, self.curvature) | |
| return { | |
| "tangent": tangent, | |
| "ball": ball, | |
| } | |
| class HyperbolicMLP(nn.Module): | |
| """ | |
| MLP operating in tangent space with hyperbolic output. | |
| """ | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| hidden_dim: int, | |
| output_dim: int, | |
| curvature: float = 1.0, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.curvature = curvature | |
| self.layers = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, output_dim), | |
| ) | |
| # Initialize for small outputs | |
| for m in self.layers: | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight, gain=0.1) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: | |
| """Forward pass.""" | |
| tangent = self.layers(x) | |
| # Constrain to valid region | |
| norm = tangent.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM) | |
| max_norm = 0.9 / math.sqrt(self.curvature) | |
| tangent = tangent * (max_norm * torch.tanh(norm / max_norm) / norm) | |
| ball = expmap0(tangent, self.curvature) | |
| return { | |
| "tangent": tangent, | |
| "ball": ball, | |
| } | |
| class HyperbolicDistanceLayer(nn.Module): | |
| """ | |
| Compute distances to learnable anchor points in hyperbolic space. | |
| """ | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_anchors: int, | |
| curvature: float = 1.0, | |
| use_tangent_approx: bool = True, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.num_anchors = num_anchors | |
| self.curvature = curvature | |
| self.use_tangent_approx = use_tangent_approx | |
| # Learnable anchors (initialized small to stay near origin) | |
| self.anchors = nn.Parameter(torch.randn(num_anchors, dim) * 0.1) | |
| def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: | |
| """ | |
| Compute distances to all anchors. | |
| Args: | |
| x: Input points [batch, ..., dim] | |
| Returns: | |
| Dict with 'distances' [batch, ..., num_anchors] | |
| """ | |
| # Expand for broadcasting | |
| x_expanded = x.unsqueeze(-2) # [batch, ..., 1, dim] | |
| anchors_expanded = self.anchors # [num_anchors, dim] | |
| if self.use_tangent_approx: | |
| distances = hyperbolic_distance_tangent( | |
| x_expanded, anchors_expanded, self.curvature | |
| ) | |
| else: | |
| distances = poincare_distance( | |
| x_expanded, anchors_expanded, self.curvature | |
| ) | |
| return {"distances": distances} | |