LimmeDev's picture
Initial MANIFOLD upload - CS2 cheat detection training
454ecdd verified
"""Hyperbolic layer with tangent space operations for hyperbolic embeddings."""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Tuple
# Numerical stability constants
MIN_NORM = 1e-15
BALL_EPS = 1e-5
def project_to_ball(x: torch.Tensor, c: float = 1.0, eps: float = BALL_EPS) -> torch.Tensor:
"""
Project points to Poincare ball (ensure ||x|| < 1/sqrt(c)).
Args:
x: Points to project
c: Curvature (positive, ball radius = 1/sqrt(c))
eps: Safety margin from boundary
Returns:
Projected points
"""
max_norm = (1.0 - eps) / math.sqrt(c)
norm = x.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM)
cond = norm > max_norm
x_proj = x / norm * max_norm
return torch.where(cond, x_proj, x)
def expmap0(v: torch.Tensor, c: float = 1.0) -> torch.Tensor:
"""
Exponential map from tangent space at origin to Poincare ball.
Maps vectors from Euclidean tangent space to hyperbolic space.
Args:
v: Tangent vectors at origin [*, dim]
c: Curvature
Returns:
Points on Poincare ball [*, dim]
"""
sqrt_c = math.sqrt(c)
v_norm = v.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM)
# exp_0(v) = tanh(sqrt(c) * ||v||) * v / (sqrt(c) * ||v||)
return torch.tanh(sqrt_c * v_norm) * v / (sqrt_c * v_norm)
def logmap0(y: torch.Tensor, c: float = 1.0) -> torch.Tensor:
"""
Logarithmic map from Poincare ball to tangent space at origin.
Inverse of expmap0.
Args:
y: Points on Poincare ball [*, dim]
c: Curvature
Returns:
Tangent vectors at origin [*, dim]
"""
sqrt_c = math.sqrt(c)
y_norm = y.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM)
# Clamp to valid range for atanh
y_norm = y_norm.clamp(max=1.0 - BALL_EPS)
# log_0(y) = arctanh(sqrt(c) * ||y||) * y / (sqrt(c) * ||y||)
return torch.atanh(sqrt_c * y_norm) * y / (sqrt_c * y_norm)
def hyperbolic_distance_tangent(
u: torch.Tensor,
v: torch.Tensor,
c: float = 1.0,
) -> torch.Tensor:
"""
Approximate hyperbolic distance using tangent space.
Valid when ||u||, ||v|| < 0.5 (near origin approximation).
This is much faster than full Poincare distance.
Args:
u, v: Points [*, dim]
c: Curvature
Returns:
Distances [*]
"""
diff = u - v
diff_norm_sq = (diff ** 2).sum(dim=-1)
u_norm_sq = (u ** 2).sum(dim=-1)
v_norm_sq = (v ** 2).sum(dim=-1)
# First-order correction for curvature
# d(u,v) ~ ||u-v|| * (1 + c*(||u||^2 + ||v||^2)/12)
correction = 1.0 + c * (u_norm_sq + v_norm_sq) / 12.0
return torch.sqrt(diff_norm_sq + MIN_NORM) * correction
def poincare_distance(
u: torch.Tensor,
v: torch.Tensor,
c: float = 1.0,
) -> torch.Tensor:
"""
Full Poincare ball distance (more expensive but exact).
d(u,v) = (2/sqrt(c)) * arctanh(sqrt(c) * ||-u + v||)
Args:
u, v: Points in Poincare ball [*, dim]
c: Curvature
Returns:
Distances [*]
"""
sqrt_c = math.sqrt(c)
# Mobius addition: -u + v
# First compute -u + v using the formula
diff = v - u
u_norm_sq = (u ** 2).sum(dim=-1, keepdim=True)
v_norm_sq = (v ** 2).sum(dim=-1, keepdim=True)
uv = (u * v).sum(dim=-1, keepdim=True)
num = (1 - 2 * c * uv + c * v_norm_sq) * (-u) + (1 + c * u_norm_sq) * v
denom = 1 - 2 * c * uv + c * c * u_norm_sq * v_norm_sq
mobius_add = num / (denom + MIN_NORM)
# Distance
mobius_norm = mobius_add.norm(dim=-1).clamp(max=1.0 - BALL_EPS)
dist = (2.0 / sqrt_c) * torch.atanh(sqrt_c * mobius_norm)
return dist
class TangentSpaceProjection(nn.Module):
"""
Project Euclidean features to hyperbolic tangent space.
Uses tangent space at origin for efficiency.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
curvature: float = 1.0,
use_bias: bool = True,
):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.curvature = curvature
self.linear = nn.Linear(input_dim, output_dim, bias=use_bias)
# Initialize for small outputs (stay near origin)
nn.init.xavier_uniform_(self.linear.weight, gain=0.1)
if use_bias:
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Project input to tangent space.
Args:
x: Input features [*, input_dim]
Returns:
Dict with 'tangent' (tangent space vectors) and 'ball' (Poincare ball)
"""
# Project to lower dim
tangent = self.linear(x)
# Normalize to keep in valid region (||z|| < 0.9)
norm = tangent.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM)
max_norm = 0.9 / math.sqrt(self.curvature)
tangent = tangent * (max_norm * torch.tanh(norm / max_norm) / norm)
# Map to Poincare ball
ball = expmap0(tangent, self.curvature)
return {
"tangent": tangent,
"ball": ball,
}
class HyperbolicMLP(nn.Module):
"""
MLP operating in tangent space with hyperbolic output.
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
curvature: float = 1.0,
dropout: float = 0.1,
):
super().__init__()
self.curvature = curvature
self.layers = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, output_dim),
)
# Initialize for small outputs
for m in self.layers:
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=0.1)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Forward pass."""
tangent = self.layers(x)
# Constrain to valid region
norm = tangent.norm(dim=-1, keepdim=True).clamp_min(MIN_NORM)
max_norm = 0.9 / math.sqrt(self.curvature)
tangent = tangent * (max_norm * torch.tanh(norm / max_norm) / norm)
ball = expmap0(tangent, self.curvature)
return {
"tangent": tangent,
"ball": ball,
}
class HyperbolicDistanceLayer(nn.Module):
"""
Compute distances to learnable anchor points in hyperbolic space.
"""
def __init__(
self,
dim: int,
num_anchors: int,
curvature: float = 1.0,
use_tangent_approx: bool = True,
):
super().__init__()
self.dim = dim
self.num_anchors = num_anchors
self.curvature = curvature
self.use_tangent_approx = use_tangent_approx
# Learnable anchors (initialized small to stay near origin)
self.anchors = nn.Parameter(torch.randn(num_anchors, dim) * 0.1)
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Compute distances to all anchors.
Args:
x: Input points [batch, ..., dim]
Returns:
Dict with 'distances' [batch, ..., num_anchors]
"""
# Expand for broadcasting
x_expanded = x.unsqueeze(-2) # [batch, ..., 1, dim]
anchors_expanded = self.anchors # [num_anchors, dim]
if self.use_tangent_approx:
distances = hyperbolic_distance_tangent(
x_expanded, anchors_expanded, self.curvature
)
else:
distances = poincare_distance(
x_expanded, anchors_expanded, self.curvature
)
return {"distances": distances}