| import torch |
| import torch.nn as nn |
| from typing import Optional |
| from einops import rearrange, einsum |
|
|
| def sinkhorn_log( |
| logits: torch.Tensor, num_iters: int = 10, tau: float = 0.05 |
| ) -> torch.Tensor: |
| """ |
| Sinkhorn-Knopp algorithm for doubly stochastic matrix projection. |
| |
| Projects logits onto the Birkhoff Polytope (doubly stochastic matrices). |
| Guarantees spectral norm ||H||_2 ≤ 1 for training stability. |
| |
| Reference: DeepSeek V3, mHC paper (2025) |
| |
| Args: |
| logits: Input matrix logits (n, n) |
| num_iters: Number of Sinkhorn iterations (default: 10) |
| tau: Temperature parameter (default: 0.05) |
| |
| Returns: |
| H: Doubly stochastic matrix (rows sum to 1, cols sum to 1) |
| """ |
| n = logits.size(-1) |
| log_K = logits / tau |
|
|
| log_u = torch.zeros(n, dtype=logits.dtype, device=logits.device) |
| log_v = torch.zeros(n, dtype=logits.dtype, device=logits.device) |
|
|
| for _ in range(num_iters): |
| log_u = -torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1) |
| log_v = -torch.logsumexp(log_K + log_u.unsqueeze(1), dim=0) |
|
|
| H = torch.exp(log_K + log_u.unsqueeze(1) + log_v.unsqueeze(0)) |
|
|
| return H |
|
|
| class HyperConnections(nn.Module): |
| """ |
| Manifold-constrained Hyper-Connections (mHC) module. |
| |
| Replaces standard residual connections with parameterized layer update: |
| x_{l+1} = H_res · x_l + H_post^T · F(H_pre · x_l, W_l) |
| |
| Key Features: |
| - H_res: Doubly stochastic matrix (Sinkhorn-Knopp projection) |
| - H_pre, H_post: Non-negative mixing matrices (softmax) |
| - width_connection: Stream-based multi-head mixing |
| - depth_connection: Layer-wise residual mixing |
| |
| Reference: DeepSeek V3, tokenbender/mHC-manifold-constrained-hyper-connections |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| num_streams: int = 8, |
| sinkhorn_iters: int = 10, |
| sinkhorn_tau: float = 0.05, |
| use_width: bool = True, |
| use_depth: bool = True, |
| ) -> None: |
| """ |
| Initialize HyperConnections module. |
| |
| Args: |
| dim: Hidden dimension of the model |
| num_streams: Number of parallel processing streams (default: 8) |
| sinkhorn_iters: Number of Sinkhorn iterations (default: 10) |
| sinkhorn_tau: Temperature for Sinkhorn projection (default: 0.05) |
| use_width: Enable width connections (stream mixing) |
| use_depth: Enable depth connections (residual mixing) |
| """ |
| super().__init__() |
| self.dim = dim |
| self.num_streams = num_streams |
| self.sinkhorn_iters = sinkhorn_iters |
| self.sinkhorn_tau = sinkhorn_tau |
| self.use_width = use_width |
| self.use_depth = use_depth |
|
|
| if dim % num_streams != 0: |
| raise ValueError( |
| f"dim ({dim}) must be divisible by num_streams ({num_streams})" |
| ) |
|
|
| self.stream_dim = dim // num_streams |
|
|
| if self.use_width: |
| self.H_res_logits_width = nn.Parameter( |
| torch.randn(num_streams, num_streams) * 0.01 |
| ) |
|
|
| self.H_pre_logits = nn.Parameter( |
| torch.randn(num_streams, num_streams) * 0.01 |
| ) |
| self.H_post_logits = nn.Parameter( |
| torch.randn(num_streams, num_streams) * 0.01 |
| ) |
|
|
| if self.use_depth: |
| self.H_res_logits_depth = nn.Parameter(torch.randn(1, 1) * 0.01) |
|
|
| self.alpha = nn.Parameter(torch.ones(1) * 0.5) |
|
|
| def width_connection( |
| self, x: torch.Tensor, transformed: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Width connection: stream-based mixing across parallel heads. |
| |
| Applies mHC formula with stream decomposition: |
| 1. Split input into streams |
| 2. Apply H_res (doubly stochastic) to residual streams |
| 3. Apply H_pre to transformed input |
| 4. Recombine with H_post^T |
| |
| Args: |
| x: Residual input (batch, seq_len, dim) |
| transformed: Transformed features (batch, seq_len, dim) |
| |
| Returns: |
| output: Mixed features (batch, seq_len, dim) |
| """ |
| batch, seq_len, dim = x.shape |
|
|
| x_streams = rearrange( |
| x, "b s (n d) -> b s n d", n=self.num_streams, d=self.stream_dim |
| ) |
| transformed_streams = rearrange( |
| transformed, "b s (n d) -> b s n d", n=self.num_streams, d=self.stream_dim |
| ) |
|
|
| H_res = sinkhorn_log( |
| self.H_res_logits_width, num_iters=self.sinkhorn_iters, tau=self.sinkhorn_tau |
| ) |
|
|
| residual_mixed = einsum( |
| H_res, x_streams, "n m, b s n d -> b s m d" |
| ) |
|
|
| H_pre = torch.softmax(self.H_pre_logits, dim=-1) |
| H_post = torch.softmax(self.H_post_logits, dim=-1) |
|
|
| pre_mixed = einsum( |
| H_pre, transformed_streams, "n m, b s n d -> b s m d" |
| ) |
|
|
| post_mixed = einsum( |
| H_post, pre_mixed, "m n, b s m d -> b s n d" |
| ) |
|
|
| output_streams = residual_mixed + post_mixed |
|
|
| output: torch.Tensor = rearrange(output_streams, "b s n d -> b s (n d)") |
|
|
| return output |
|
|
| def depth_connection( |
| self, x: torch.Tensor, transformed: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Depth connection: layer-wise residual mixing. |
| |
| Simplified mHC for cross-layer connections: |
| output = α · x + (1 - α) · transformed |
| |
| where α is learned and bounded to [0, 1]. |
| |
| Args: |
| x: Residual input from previous layer (batch, seq_len, dim) |
| transformed: Transformed features from current layer (batch, seq_len, dim) |
| |
| Returns: |
| output: Mixed features (batch, seq_len, dim) |
| """ |
| alpha_bounded = torch.sigmoid(self.alpha) |
|
|
| H_res = sinkhorn_log( |
| self.H_res_logits_depth, num_iters=self.sinkhorn_iters, tau=self.sinkhorn_tau |
| ) |
|
|
| output: torch.Tensor = alpha_bounded * x + (1 - alpha_bounded) * transformed |
|
|
| return output |
|
|
| @torch.compile(mode='max-autotune') |
| def forward( |
| self, |
| x: torch.Tensor, |
| transformed: torch.Tensor, |
| connection_type: str = "width", |
| ) -> torch.Tensor: |
| """ |
| Forward pass through HyperConnections. |
| |
| Args: |
| x: Residual input (batch, seq_len, dim) |
| transformed: Transformed features (batch, seq_len, dim) |
| connection_type: Type of connection ('width' or 'depth') |
| |
| Returns: |
| output: Mixed features (batch, seq_len, dim) |
| """ |
| if connection_type == "width" and self.use_width: |
| return self.width_connection(x, transformed) |
| elif connection_type == "depth" and self.use_depth: |
| return self.depth_connection(x, transformed) |
| elif connection_type == "width": |
| return transformed |
| elif connection_type == "depth": |
| return x + transformed |
| else: |
| raise ValueError( |
| f"Unknown connection_type: {connection_type}. Expected 'width' or 'depth'" |
| ) |
|
|