"""Hyperbolic layer with tangent space operations for hyperbolic embeddings.""" from __future__ import annotations import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, Tuple # Numerical stability constants MIN_NORM = 1e-15 BALL_EPS = 1e-5 def project_to_ball(x: torch.Tensor, c: float = 1.0, eps: float = BALL_EPS) -> torch.Tensor: """ Project points to Poincare ball (ensure ||x|| < 1/sqrt(c)). Args: x: Points to project c: Curvature (positive, ball radius = 1/sqrt(c)) eps: Safety margin from boundary Returns: Projected points """ max_norm = (1.0 - eps) / math.sqrt(c) norm = x.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM) cond = norm > max_norm x_proj = x / norm * max_norm return torch.where(cond, x_proj, x) def expmap0(v: torch.Tensor, c: float = 1.0) -> torch.Tensor: """ Exponential map from tangent space at origin to Poincare ball. Maps vectors from Euclidean tangent space to hyperbolic space. Args: v: Tangent vectors at origin [*, dim] c: Curvature Returns: Points on Poincare ball [*, dim] """ sqrt_c = math.sqrt(c) v_norm = v.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM) # exp_0(v) = tanh(sqrt(c) * ||v||) * v / (sqrt(c) * ||v||) return torch.tanh(sqrt_c * v_norm) * v / (sqrt_c * v_norm) def logmap0(y: torch.Tensor, c: float = 1.0) -> torch.Tensor: """ Logarithmic map from Poincare ball to tangent space at origin. Inverse of expmap0. Args: y: Points on Poincare ball [*, dim] c: Curvature Returns: Tangent vectors at origin [*, dim] """ sqrt_c = math.sqrt(c) y_norm = y.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM) # Clamp to valid range for atanh y_norm = y_norm.clamp(max=1.0 - BALL_EPS) # log_0(y) = arctanh(sqrt(c) * ||y||) * y / (sqrt(c) * ||y||) return torch.atanh(sqrt_c * y_norm) * y / (sqrt_c * y_norm) def hyperbolic_distance_tangent( u: torch.Tensor, v: torch.Tensor, c: float = 1.0, ) -> torch.Tensor: """ Approximate hyperbolic distance using tangent space. Valid when ||u||, ||v|| < 0.5 (near origin approximation). This is much faster than full Poincare distance. Args: u, v: Points [*, dim] c: Curvature Returns: Distances [*] """ diff = u - v diff_norm_sq = (diff ** 2).sum(dim=-1) u_norm_sq = (u ** 2).sum(dim=-1) v_norm_sq = (v ** 2).sum(dim=-1) # First-order correction for curvature # d(u,v) ~ ||u-v|| * (1 + c*(||u||^2 + ||v||^2)/12) correction = 1.0 + c * (u_norm_sq + v_norm_sq) / 12.0 return torch.sqrt(diff_norm_sq + MIN_NORM) * correction def poincare_distance( u: torch.Tensor, v: torch.Tensor, c: float = 1.0, ) -> torch.Tensor: """ Full Poincare ball distance (more expensive but exact). d(u,v) = (2/sqrt(c)) * arctanh(sqrt(c) * ||-u + v||) Args: u, v: Points in Poincare ball [*, dim] c: Curvature Returns: Distances [*] """ sqrt_c = math.sqrt(c) # Mobius addition: -u + v # First compute -u + v using the formula diff = v - u u_norm_sq = (u ** 2).sum(dim=-1, keepdim=True) v_norm_sq = (v ** 2).sum(dim=-1, keepdim=True) uv = (u * v).sum(dim=-1, keepdim=True) num = (1 - 2 * c * uv + c * v_norm_sq) * (-u) + (1 + c * u_norm_sq) * v denom = 1 - 2 * c * uv + c * c * u_norm_sq * v_norm_sq mobius_add = num / (denom + MIN_NORM) # Distance mobius_norm = mobius_add.norm(dim=-1).clamp(max=1.0 - BALL_EPS) dist = (2.0 / sqrt_c) * torch.atanh(sqrt_c * mobius_norm) return dist class TangentSpaceProjection(nn.Module): """ Project Euclidean features to hyperbolic tangent space. Uses tangent space at origin for efficiency. """ def __init__( self, input_dim: int, output_dim: int, curvature: float = 1.0, use_bias: bool = True, ): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.curvature = curvature self.linear = nn.Linear(input_dim, output_dim, bias=use_bias) # Initialize for small outputs (stay near origin) nn.init.xavier_uniform_(self.linear.weight, gain=0.1) if use_bias: nn.init.zeros_(self.linear.bias) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Project input to tangent space. Args: x: Input features [*, input_dim] Returns: Dict with 'tangent' (tangent space vectors) and 'ball' (Poincare ball) """ # Project to lower dim tangent = self.linear(x) # Normalize to keep in valid region (||z|| < 0.9) norm = tangent.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM) max_norm = 0.9 / math.sqrt(self.curvature) tangent = tangent * (max_norm * torch.tanh(norm / max_norm) / norm) # Map to Poincare ball ball = expmap0(tangent, self.curvature) return { "tangent": tangent, "ball": ball, } class HyperbolicMLP(nn.Module): """ MLP operating in tangent space with hyperbolic output. """ def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, curvature: float = 1.0, dropout: float = 0.1, ): super().__init__() self.curvature = curvature self.layers = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, output_dim), ) # Initialize for small outputs for m in self.layers: if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight, gain=0.1) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """Forward pass.""" tangent = self.layers(x) # Constrain to valid region norm = tangent.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM) max_norm = 0.9 / math.sqrt(self.curvature) tangent = tangent * (max_norm * torch.tanh(norm / max_norm) / norm) ball = expmap0(tangent, self.curvature) return { "tangent": tangent, "ball": ball, } class HyperbolicDistanceLayer(nn.Module): """ Compute distances to learnable anchor points in hyperbolic space. """ def __init__( self, dim: int, num_anchors: int, curvature: float = 1.0, use_tangent_approx: bool = True, ): super().__init__() self.dim = dim self.num_anchors = num_anchors self.curvature = curvature self.use_tangent_approx = use_tangent_approx # Learnable anchors (initialized small to stay near origin) self.anchors = nn.Parameter(torch.randn(num_anchors, dim) * 0.1) def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Compute distances to all anchors. Args: x: Input points [batch, ..., dim] Returns: Dict with 'distances' [batch, ..., num_anchors] """ # Expand for broadcasting x_expanded = x.unsqueeze(-2) # [batch, ..., 1, dim] anchors_expanded = self.anchors # [num_anchors, dim] if self.use_tangent_approx: distances = hyperbolic_distance_tangent( x_expanded, anchors_expanded, self.curvature ) else: distances = poincare_distance( x_expanded, anchors_expanded, self.curvature ) return {"distances": distances}