sem-v6-training / src /sem_v6 /modules /hyper_connections.py
icarus112's picture
Upload folder using huggingface_hub
518db7a verified
import torch
import torch.nn as nn
from typing import Optional
from einops import rearrange, einsum # type: ignore[import-not-found]
def sinkhorn_log(
logits: torch.Tensor, num_iters: int = 10, tau: float = 0.05
) -> torch.Tensor:
"""
Sinkhorn-Knopp algorithm for doubly stochastic matrix projection.
Projects logits onto the Birkhoff Polytope (doubly stochastic matrices).
Guarantees spectral norm ||H||_2 ≤ 1 for training stability.
Reference: DeepSeek V3, mHC paper (2025)
Args:
logits: Input matrix logits (n, n)
num_iters: Number of Sinkhorn iterations (default: 10)
tau: Temperature parameter (default: 0.05)
Returns:
H: Doubly stochastic matrix (rows sum to 1, cols sum to 1)
"""
n = logits.size(-1)
log_K = logits / tau
log_u = torch.zeros(n, dtype=logits.dtype, device=logits.device)
log_v = torch.zeros(n, dtype=logits.dtype, device=logits.device)
for _ in range(num_iters):
log_u = -torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1)
log_v = -torch.logsumexp(log_K + log_u.unsqueeze(1), dim=0)
H = torch.exp(log_K + log_u.unsqueeze(1) + log_v.unsqueeze(0))
return H
class HyperConnections(nn.Module):
"""
Manifold-constrained Hyper-Connections (mHC) module.
Replaces standard residual connections with parameterized layer update:
x_{l+1} = H_res · x_l + H_post^T · F(H_pre · x_l, W_l)
Key Features:
- H_res: Doubly stochastic matrix (Sinkhorn-Knopp projection)
- H_pre, H_post: Non-negative mixing matrices (softmax)
- width_connection: Stream-based multi-head mixing
- depth_connection: Layer-wise residual mixing
Reference: DeepSeek V3, tokenbender/mHC-manifold-constrained-hyper-connections
"""
def __init__(
self,
dim: int,
num_streams: int = 8,
sinkhorn_iters: int = 10,
sinkhorn_tau: float = 0.05,
use_width: bool = True,
use_depth: bool = True,
) -> None:
"""
Initialize HyperConnections module.
Args:
dim: Hidden dimension of the model
num_streams: Number of parallel processing streams (default: 8)
sinkhorn_iters: Number of Sinkhorn iterations (default: 10)
sinkhorn_tau: Temperature for Sinkhorn projection (default: 0.05)
use_width: Enable width connections (stream mixing)
use_depth: Enable depth connections (residual mixing)
"""
super().__init__()
self.dim = dim
self.num_streams = num_streams
self.sinkhorn_iters = sinkhorn_iters
self.sinkhorn_tau = sinkhorn_tau
self.use_width = use_width
self.use_depth = use_depth
if dim % num_streams != 0:
raise ValueError(
f"dim ({dim}) must be divisible by num_streams ({num_streams})"
)
self.stream_dim = dim // num_streams
if self.use_width:
self.H_res_logits_width = nn.Parameter(
torch.randn(num_streams, num_streams) * 0.01
)
self.H_pre_logits = nn.Parameter(
torch.randn(num_streams, num_streams) * 0.01
)
self.H_post_logits = nn.Parameter(
torch.randn(num_streams, num_streams) * 0.01
)
if self.use_depth:
self.H_res_logits_depth = nn.Parameter(torch.randn(1, 1) * 0.01)
self.alpha = nn.Parameter(torch.ones(1) * 0.5)
def width_connection(
self, x: torch.Tensor, transformed: torch.Tensor
) -> torch.Tensor:
"""
Width connection: stream-based mixing across parallel heads.
Applies mHC formula with stream decomposition:
1. Split input into streams
2. Apply H_res (doubly stochastic) to residual streams
3. Apply H_pre to transformed input
4. Recombine with H_post^T
Args:
x: Residual input (batch, seq_len, dim)
transformed: Transformed features (batch, seq_len, dim)
Returns:
output: Mixed features (batch, seq_len, dim)
"""
batch, seq_len, dim = x.shape
x_streams = rearrange(
x, "b s (n d) -> b s n d", n=self.num_streams, d=self.stream_dim
)
transformed_streams = rearrange(
transformed, "b s (n d) -> b s n d", n=self.num_streams, d=self.stream_dim
)
H_res = sinkhorn_log(
self.H_res_logits_width, num_iters=self.sinkhorn_iters, tau=self.sinkhorn_tau
)
residual_mixed = einsum(
H_res, x_streams, "n m, b s n d -> b s m d"
)
H_pre = torch.softmax(self.H_pre_logits, dim=-1)
H_post = torch.softmax(self.H_post_logits, dim=-1)
pre_mixed = einsum(
H_pre, transformed_streams, "n m, b s n d -> b s m d"
)
post_mixed = einsum(
H_post, pre_mixed, "m n, b s m d -> b s n d"
)
output_streams = residual_mixed + post_mixed
output: torch.Tensor = rearrange(output_streams, "b s n d -> b s (n d)")
return output
def depth_connection(
self, x: torch.Tensor, transformed: torch.Tensor
) -> torch.Tensor:
"""
Depth connection: layer-wise residual mixing.
Simplified mHC for cross-layer connections:
output = α · x + (1 - α) · transformed
where α is learned and bounded to [0, 1].
Args:
x: Residual input from previous layer (batch, seq_len, dim)
transformed: Transformed features from current layer (batch, seq_len, dim)
Returns:
output: Mixed features (batch, seq_len, dim)
"""
alpha_bounded = torch.sigmoid(self.alpha)
H_res = sinkhorn_log(
self.H_res_logits_depth, num_iters=self.sinkhorn_iters, tau=self.sinkhorn_tau
)
output: torch.Tensor = alpha_bounded * x + (1 - alpha_bounded) * transformed
return output
@torch.compile(mode='max-autotune')
def forward(
self,
x: torch.Tensor,
transformed: torch.Tensor,
connection_type: str = "width",
) -> torch.Tensor:
"""
Forward pass through HyperConnections.
Args:
x: Residual input (batch, seq_len, dim)
transformed: Transformed features (batch, seq_len, dim)
connection_type: Type of connection ('width' or 'depth')
Returns:
output: Mixed features (batch, seq_len, dim)
"""
if connection_type == "width" and self.use_width:
return self.width_connection(x, transformed)
elif connection_type == "depth" and self.use_depth:
return self.depth_connection(x, transformed)
elif connection_type == "width":
return transformed
elif connection_type == "depth":
return x + transformed
else:
raise ValueError(
f"Unknown connection_type: {connection_type}. Expected 'width' or 'depth'"
)