import torch import torch.nn as nn from typing import Optional from einops import rearrange, einsum # type: ignore[import-not-found] def sinkhorn_log( logits: torch.Tensor, num_iters: int = 10, tau: float = 0.05 ) -> torch.Tensor: """ Sinkhorn-Knopp algorithm for doubly stochastic matrix projection. Projects logits onto the Birkhoff Polytope (doubly stochastic matrices). Guarantees spectral norm ||H||_2 ≤ 1 for training stability. Reference: DeepSeek V3, mHC paper (2025) Args: logits: Input matrix logits (n, n) num_iters: Number of Sinkhorn iterations (default: 10) tau: Temperature parameter (default: 0.05) Returns: H: Doubly stochastic matrix (rows sum to 1, cols sum to 1) """ n = logits.size(-1) log_K = logits / tau log_u = torch.zeros(n, dtype=logits.dtype, device=logits.device) log_v = torch.zeros(n, dtype=logits.dtype, device=logits.device) for _ in range(num_iters): log_u = -torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1) log_v = -torch.logsumexp(log_K + log_u.unsqueeze(1), dim=0) H = torch.exp(log_K + log_u.unsqueeze(1) + log_v.unsqueeze(0)) return H class HyperConnections(nn.Module): """ Manifold-constrained Hyper-Connections (mHC) module. Replaces standard residual connections with parameterized layer update: x_{l+1} = H_res · x_l + H_post^T · F(H_pre · x_l, W_l) Key Features: - H_res: Doubly stochastic matrix (Sinkhorn-Knopp projection) - H_pre, H_post: Non-negative mixing matrices (softmax) - width_connection: Stream-based multi-head mixing - depth_connection: Layer-wise residual mixing Reference: DeepSeek V3, tokenbender/mHC-manifold-constrained-hyper-connections """ def __init__( self, dim: int, num_streams: int = 8, sinkhorn_iters: int = 10, sinkhorn_tau: float = 0.05, use_width: bool = True, use_depth: bool = True, ) -> None: """ Initialize HyperConnections module. Args: dim: Hidden dimension of the model num_streams: Number of parallel processing streams (default: 8) sinkhorn_iters: Number of Sinkhorn iterations (default: 10) sinkhorn_tau: Temperature for Sinkhorn projection (default: 0.05) use_width: Enable width connections (stream mixing) use_depth: Enable depth connections (residual mixing) """ super().__init__() self.dim = dim self.num_streams = num_streams self.sinkhorn_iters = sinkhorn_iters self.sinkhorn_tau = sinkhorn_tau self.use_width = use_width self.use_depth = use_depth if dim % num_streams != 0: raise ValueError( f"dim ({dim}) must be divisible by num_streams ({num_streams})" ) self.stream_dim = dim // num_streams if self.use_width: self.H_res_logits_width = nn.Parameter( torch.randn(num_streams, num_streams) * 0.01 ) self.H_pre_logits = nn.Parameter( torch.randn(num_streams, num_streams) * 0.01 ) self.H_post_logits = nn.Parameter( torch.randn(num_streams, num_streams) * 0.01 ) if self.use_depth: self.H_res_logits_depth = nn.Parameter(torch.randn(1, 1) * 0.01) self.alpha = nn.Parameter(torch.ones(1) * 0.5) def width_connection( self, x: torch.Tensor, transformed: torch.Tensor ) -> torch.Tensor: """ Width connection: stream-based mixing across parallel heads. Applies mHC formula with stream decomposition: 1. Split input into streams 2. Apply H_res (doubly stochastic) to residual streams 3. Apply H_pre to transformed input 4. Recombine with H_post^T Args: x: Residual input (batch, seq_len, dim) transformed: Transformed features (batch, seq_len, dim) Returns: output: Mixed features (batch, seq_len, dim) """ batch, seq_len, dim = x.shape x_streams = rearrange( x, "b s (n d) -> b s n d", n=self.num_streams, d=self.stream_dim ) transformed_streams = rearrange( transformed, "b s (n d) -> b s n d", n=self.num_streams, d=self.stream_dim ) H_res = sinkhorn_log( self.H_res_logits_width, num_iters=self.sinkhorn_iters, tau=self.sinkhorn_tau ) residual_mixed = einsum( H_res, x_streams, "n m, b s n d -> b s m d" ) H_pre = torch.softmax(self.H_pre_logits, dim=-1) H_post = torch.softmax(self.H_post_logits, dim=-1) pre_mixed = einsum( H_pre, transformed_streams, "n m, b s n d -> b s m d" ) post_mixed = einsum( H_post, pre_mixed, "m n, b s m d -> b s n d" ) output_streams = residual_mixed + post_mixed output: torch.Tensor = rearrange(output_streams, "b s n d -> b s (n d)") return output def depth_connection( self, x: torch.Tensor, transformed: torch.Tensor ) -> torch.Tensor: """ Depth connection: layer-wise residual mixing. Simplified mHC for cross-layer connections: output = α · x + (1 - α) · transformed where α is learned and bounded to [0, 1]. Args: x: Residual input from previous layer (batch, seq_len, dim) transformed: Transformed features from current layer (batch, seq_len, dim) Returns: output: Mixed features (batch, seq_len, dim) """ alpha_bounded = torch.sigmoid(self.alpha) H_res = sinkhorn_log( self.H_res_logits_depth, num_iters=self.sinkhorn_iters, tau=self.sinkhorn_tau ) output: torch.Tensor = alpha_bounded * x + (1 - alpha_bounded) * transformed return output @torch.compile(mode='max-autotune') def forward( self, x: torch.Tensor, transformed: torch.Tensor, connection_type: str = "width", ) -> torch.Tensor: """ Forward pass through HyperConnections. Args: x: Residual input (batch, seq_len, dim) transformed: Transformed features (batch, seq_len, dim) connection_type: Type of connection ('width' or 'depth') Returns: output: Mixed features (batch, seq_len, dim) """ if connection_type == "width" and self.use_width: return self.width_connection(x, transformed) elif connection_type == "depth" and self.use_depth: return self.depth_connection(x, transformed) elif connection_type == "width": return transformed elif connection_type == "depth": return x + transformed else: raise ValueError( f"Unknown connection_type: {connection_type}. Expected 'width' or 'depth'" )