Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- mhc/__init__.py +27 -0
- mhc/hyper_connections.py +309 -0
- mhc/metrics.py +160 -0
- mhc/simulation.py +211 -0
- mhc/sinkhorn.py +148 -0
- mhc/torch_module.py +280 -0
- requirements.txt +3 -5
mhc/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
mHC (Manifold-Constrained Hyper-Connections) visualization library.
|
| 3 |
+
|
| 4 |
+
This package provides tools for demonstrating the stability properties
|
| 5 |
+
of mHC residual connections compared to unconstrained HC and baseline methods.
|
| 6 |
+
|
| 7 |
+
Modules:
|
| 8 |
+
- sinkhorn: Sinkhorn-Knopp projection onto doubly stochastic matrices
|
| 9 |
+
- metrics: Stability metrics (forward_gain, backward_gain, spectral_norm)
|
| 10 |
+
- simulation: Deep network signal propagation simulation
|
| 11 |
+
- torch_module: PyTorch implementation for use in neural networks
|
| 12 |
+
|
| 13 |
+
Author: Subhadip Mitra <contact@subhadipmitra.com>
|
| 14 |
+
Based on DeepSeek's mHC paper: https://arxiv.org/abs/2512.24880
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from .sinkhorn import sinkhorn_knopp, is_doubly_stochastic, projection_error
|
| 18 |
+
from .metrics import forward_gain, backward_gain, spectral_norm, compute_all_metrics
|
| 19 |
+
from .simulation import generate_residual_matrix, simulate_depth, run_comparison
|
| 20 |
+
|
| 21 |
+
# PyTorch modules (optional import - requires torch)
|
| 22 |
+
try:
|
| 23 |
+
from .torch_module import SinkhornKnopp, mHCResidual, mHCBlock, create_mhc_mlp
|
| 24 |
+
except ImportError:
|
| 25 |
+
pass # torch not installed
|
| 26 |
+
|
| 27 |
+
__version__ = "0.1.0"
|
mhc/hyper_connections.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch implementation of Hyper-Connections (HC) and mHC.
|
| 3 |
+
|
| 4 |
+
HC extends residual connections with multiple parallel streams and learned mixing.
|
| 5 |
+
mHC constrains the mixing matrix to be doubly stochastic via Sinkhorn-Knopp.
|
| 6 |
+
|
| 7 |
+
Reference: https://arxiv.org/abs/2512.24880
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def sinkhorn_knopp_torch(M: torch.Tensor, iters: int = 20, eps: float = 1e-8) -> torch.Tensor:
|
| 15 |
+
"""
|
| 16 |
+
Differentiable Sinkhorn-Knopp projection to doubly stochastic matrix.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
M: Input matrix of shape (n, n)
|
| 20 |
+
iters: Number of alternating normalization iterations
|
| 21 |
+
eps: Small value for numerical stability
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Approximately doubly stochastic matrix (rows and cols sum to 1)
|
| 25 |
+
"""
|
| 26 |
+
P = torch.exp(M - M.max())
|
| 27 |
+
for _ in range(iters):
|
| 28 |
+
P = P / (P.sum(dim=-1, keepdim=True) + eps)
|
| 29 |
+
P = P / (P.sum(dim=-2, keepdim=True) + eps)
|
| 30 |
+
return P
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class HyperConnections(nn.Module):
|
| 34 |
+
"""
|
| 35 |
+
Hyper-Connections: Multi-stream residual with learned mixing.
|
| 36 |
+
|
| 37 |
+
Each layer maintains N parallel streams. The mixing matrix H combines
|
| 38 |
+
streams at each layer:
|
| 39 |
+
output = H @ input_streams + layer_contribution
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
n_streams: Number of parallel streams (N)
|
| 43 |
+
init_scale: Scale for random initialization of H
|
| 44 |
+
|
| 45 |
+
Shape:
|
| 46 |
+
- Input x: (B, N, D) where B=batch, N=streams, D=features
|
| 47 |
+
- Output: (B, N, D)
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, n_streams: int = 4, init_scale: float = 0.1):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.n_streams = n_streams
|
| 53 |
+
# Raw mixing matrix - unconstrained for HC
|
| 54 |
+
self.H_res = nn.Parameter(torch.randn(n_streams, n_streams) * init_scale)
|
| 55 |
+
|
| 56 |
+
def get_mixing_matrix(self) -> torch.Tensor:
|
| 57 |
+
"""Return the mixing matrix. Override in subclasses for constraints."""
|
| 58 |
+
return self.H_res
|
| 59 |
+
|
| 60 |
+
def forward(self, x: torch.Tensor, layer_output: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""
|
| 62 |
+
Apply HC mixing.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
x: Streamed input of shape (B, N, ...)
|
| 66 |
+
layer_output: Output from layer F, shape (B, N, ...)
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Mixed output: H @ x + layer_output
|
| 70 |
+
"""
|
| 71 |
+
H = self.get_mixing_matrix()
|
| 72 |
+
# H @ x using einsum for arbitrary trailing dimensions
|
| 73 |
+
mixed = torch.einsum('ij,bj...->bi...', H, x)
|
| 74 |
+
return mixed + layer_output
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class MHC(HyperConnections):
|
| 78 |
+
"""
|
| 79 |
+
Manifold-Constrained Hyper-Connections (mHC).
|
| 80 |
+
|
| 81 |
+
Like HC, but the mixing matrix is projected to be doubly stochastic
|
| 82 |
+
via Sinkhorn-Knopp. This ensures:
|
| 83 |
+
- All eigenvalues bounded by 1
|
| 84 |
+
- Stable signal propagation through depth
|
| 85 |
+
- No gradient explosion
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
n_streams: Number of parallel streams
|
| 89 |
+
sinkhorn_iters: Number of Sinkhorn iterations for projection
|
| 90 |
+
init_scale: Scale for random initialization
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, n_streams: int = 4, sinkhorn_iters: int = 20, init_scale: float = 0.1):
|
| 94 |
+
super().__init__(n_streams, init_scale)
|
| 95 |
+
self.sinkhorn_iters = sinkhorn_iters
|
| 96 |
+
|
| 97 |
+
def get_mixing_matrix(self) -> torch.Tensor:
|
| 98 |
+
"""Return doubly stochastic mixing matrix via Sinkhorn projection."""
|
| 99 |
+
return sinkhorn_knopp_torch(self.H_res, self.sinkhorn_iters)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ResidualBlock(nn.Module):
|
| 103 |
+
"""
|
| 104 |
+
Residual block with configurable connection type.
|
| 105 |
+
|
| 106 |
+
Supports three modes:
|
| 107 |
+
- 'baseline': Standard residual connection (x + F(x))
|
| 108 |
+
- 'hc': Hyper-Connections with unconstrained mixing
|
| 109 |
+
- 'mhc': Manifold-Constrained HC with doubly stochastic mixing
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
channels: Number of channels in conv layers
|
| 113 |
+
method: One of 'baseline', 'hc', 'mhc'
|
| 114 |
+
n_streams: Number of streams for HC/mHC
|
| 115 |
+
sinkhorn_iters: Sinkhorn iterations for mHC
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
channels: int,
|
| 121 |
+
method: str = 'baseline',
|
| 122 |
+
n_streams: int = 4,
|
| 123 |
+
sinkhorn_iters: int = 20
|
| 124 |
+
):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.method = method
|
| 127 |
+
self.n_streams = n_streams
|
| 128 |
+
|
| 129 |
+
# Main conv path (standard ResNet-style)
|
| 130 |
+
self.conv = nn.Sequential(
|
| 131 |
+
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
|
| 132 |
+
nn.BatchNorm2d(channels),
|
| 133 |
+
nn.ReLU(inplace=True),
|
| 134 |
+
nn.Conv2d(channels, channels, 3, padding=1, bias=False),
|
| 135 |
+
nn.BatchNorm2d(channels),
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# HC/mHC mixing
|
| 139 |
+
if method == 'hc':
|
| 140 |
+
self.hc = HyperConnections(n_streams)
|
| 141 |
+
elif method == 'mhc':
|
| 142 |
+
self.hc = MHC(n_streams, sinkhorn_iters)
|
| 143 |
+
else:
|
| 144 |
+
self.hc = None
|
| 145 |
+
|
| 146 |
+
self.relu = nn.ReLU(inplace=True)
|
| 147 |
+
|
| 148 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 149 |
+
"""
|
| 150 |
+
Forward pass with configurable residual connection.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
x: Input tensor of shape (B, C, H, W)
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Output tensor of shape (B, C, H, W)
|
| 157 |
+
"""
|
| 158 |
+
if self.method == 'baseline':
|
| 159 |
+
# Standard residual: x + F(x)
|
| 160 |
+
return self.relu(x + self.conv(x))
|
| 161 |
+
|
| 162 |
+
# HC/mHC path
|
| 163 |
+
B, C, H, W = x.shape
|
| 164 |
+
N = self.n_streams
|
| 165 |
+
|
| 166 |
+
# Expand input to streams: (B, C, H, W) -> (B, N, C*H*W)
|
| 167 |
+
# Using view instead of expand to avoid memory copy where possible
|
| 168 |
+
x_flat = x.view(B, 1, -1).expand(B, N, -1)
|
| 169 |
+
|
| 170 |
+
# Apply conv to original input
|
| 171 |
+
conv_out = self.conv(x)
|
| 172 |
+
conv_flat = conv_out.view(B, 1, -1).expand(B, N, -1)
|
| 173 |
+
|
| 174 |
+
# Mix via HC/mHC: H @ x_streams + conv_streams
|
| 175 |
+
mixed = self.hc(x_flat, conv_flat)
|
| 176 |
+
|
| 177 |
+
# Collapse streams: mean over N, reshape back
|
| 178 |
+
out = mixed.mean(dim=1).view(B, C, H, W)
|
| 179 |
+
return self.relu(out)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class SimpleCNN(nn.Module):
|
| 183 |
+
"""
|
| 184 |
+
Simple CNN with configurable residual connection type.
|
| 185 |
+
|
| 186 |
+
Architecture:
|
| 187 |
+
- Stem: 3x3 conv to channels
|
| 188 |
+
- N residual blocks (configurable connection type)
|
| 189 |
+
- Head: global avg pool + linear classifier
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
n_blocks: Number of residual blocks
|
| 193 |
+
channels: Hidden dimension
|
| 194 |
+
method: Residual type ('baseline', 'hc', 'mhc')
|
| 195 |
+
n_streams: Number of streams for HC/mHC
|
| 196 |
+
sinkhorn_iters: Sinkhorn iterations for mHC
|
| 197 |
+
num_classes: Number of output classes
|
| 198 |
+
in_channels: Number of input channels (3 for RGB)
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
n_blocks: int = 8,
|
| 204 |
+
channels: int = 64,
|
| 205 |
+
method: str = 'baseline',
|
| 206 |
+
n_streams: int = 4,
|
| 207 |
+
sinkhorn_iters: int = 20,
|
| 208 |
+
num_classes: int = 10,
|
| 209 |
+
in_channels: int = 3
|
| 210 |
+
):
|
| 211 |
+
super().__init__()
|
| 212 |
+
self.method = method
|
| 213 |
+
|
| 214 |
+
# Stem
|
| 215 |
+
self.stem = nn.Sequential(
|
| 216 |
+
nn.Conv2d(in_channels, channels, 3, padding=1, bias=False),
|
| 217 |
+
nn.BatchNorm2d(channels),
|
| 218 |
+
nn.ReLU(inplace=True),
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Residual blocks
|
| 222 |
+
self.blocks = nn.ModuleList([
|
| 223 |
+
ResidualBlock(channels, method, n_streams, sinkhorn_iters)
|
| 224 |
+
for _ in range(n_blocks)
|
| 225 |
+
])
|
| 226 |
+
|
| 227 |
+
# Classification head
|
| 228 |
+
self.head = nn.Sequential(
|
| 229 |
+
nn.AdaptiveAvgPool2d(1),
|
| 230 |
+
nn.Flatten(),
|
| 231 |
+
nn.Linear(channels, num_classes),
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 235 |
+
"""
|
| 236 |
+
Forward pass.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
x: Input images of shape (B, C, H, W)
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
Logits of shape (B, num_classes)
|
| 243 |
+
"""
|
| 244 |
+
x = self.stem(x)
|
| 245 |
+
for block in self.blocks:
|
| 246 |
+
x = block(x)
|
| 247 |
+
return self.head(x)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def train_with_gradient_tracking(
|
| 251 |
+
model: nn.Module,
|
| 252 |
+
train_loader,
|
| 253 |
+
epochs: int,
|
| 254 |
+
device: torch.device,
|
| 255 |
+
lr: float = 1e-3
|
| 256 |
+
) -> dict:
|
| 257 |
+
"""
|
| 258 |
+
Train model and record gradient magnitudes.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
model: PyTorch model to train
|
| 262 |
+
train_loader: DataLoader for training data
|
| 263 |
+
epochs: Number of training epochs
|
| 264 |
+
device: Device to train on
|
| 265 |
+
lr: Learning rate
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
Dict with training history:
|
| 269 |
+
- 'loss': List of loss values per step
|
| 270 |
+
- 'grad_norms': List of total gradient norms per step
|
| 271 |
+
- 'accuracy': List of batch accuracies per step
|
| 272 |
+
"""
|
| 273 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
| 274 |
+
criterion = nn.CrossEntropyLoss()
|
| 275 |
+
|
| 276 |
+
history = {
|
| 277 |
+
'loss': [],
|
| 278 |
+
'grad_norms': [],
|
| 279 |
+
'accuracy': [],
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
model.train()
|
| 283 |
+
for epoch in range(epochs):
|
| 284 |
+
for data, target in train_loader:
|
| 285 |
+
data, target = data.to(device), target.to(device)
|
| 286 |
+
|
| 287 |
+
optimizer.zero_grad()
|
| 288 |
+
output = model(data)
|
| 289 |
+
loss = criterion(output, target)
|
| 290 |
+
loss.backward()
|
| 291 |
+
|
| 292 |
+
# Record gradient norm
|
| 293 |
+
total_norm = 0.0
|
| 294 |
+
for param in model.parameters():
|
| 295 |
+
if param.grad is not None:
|
| 296 |
+
total_norm += param.grad.norm().item() ** 2
|
| 297 |
+
total_norm = total_norm ** 0.5
|
| 298 |
+
|
| 299 |
+
# Record accuracy
|
| 300 |
+
pred = output.argmax(dim=1)
|
| 301 |
+
acc = (pred == target).float().mean().item()
|
| 302 |
+
|
| 303 |
+
history['loss'].append(loss.item())
|
| 304 |
+
history['grad_norms'].append(total_norm)
|
| 305 |
+
history['accuracy'].append(acc)
|
| 306 |
+
|
| 307 |
+
optimizer.step()
|
| 308 |
+
|
| 309 |
+
return history
|
mhc/metrics.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Stability metrics for analyzing residual mixing matrices.
|
| 3 |
+
|
| 4 |
+
These metrics quantify how a matrix amplifies signals during forward/backward
|
| 5 |
+
propagation through a neural network layer.
|
| 6 |
+
|
| 7 |
+
Key insight from the mHC paper:
|
| 8 |
+
- Unconstrained matrices (HC) can have unbounded gains, causing signal explosion
|
| 9 |
+
- Doubly stochastic matrices (mHC) have all gains bounded by 1, ensuring stability
|
| 10 |
+
|
| 11 |
+
Metrics:
|
| 12 |
+
- forward_gain: Worst-case signal amplification in forward pass (max row sum)
|
| 13 |
+
- backward_gain: Worst-case gradient amplification in backward pass (max column sum)
|
| 14 |
+
- spectral_norm: Largest singular value (general operator norm)
|
| 15 |
+
|
| 16 |
+
For doubly stochastic matrices, all three equal exactly 1.
|
| 17 |
+
|
| 18 |
+
Author: Subhadip Mitra <contact@subhadipmitra.com>
|
| 19 |
+
Based on DeepSeek's mHC paper: https://arxiv.org/abs/2512.24880
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def forward_gain(matrix: np.ndarray) -> float:
|
| 26 |
+
"""
|
| 27 |
+
Compute maximum absolute row sum (worst-case signal amplification).
|
| 28 |
+
|
| 29 |
+
This is the infinity norm ||M||_∞, which equals the maximum amplification
|
| 30 |
+
a unit input vector can experience in the forward pass:
|
| 31 |
+
||Mx||_∞ <= ||M||_∞ * ||x||_∞
|
| 32 |
+
|
| 33 |
+
For a doubly stochastic matrix, all row sums equal 1, so forward_gain = 1.
|
| 34 |
+
For unconstrained matrices, can be arbitrarily large.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
matrix: Input matrix of shape (n, n)
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Maximum absolute row sum: max_i |sum_j M[i,j]|
|
| 41 |
+
|
| 42 |
+
Example:
|
| 43 |
+
>>> forward_gain(np.eye(4))
|
| 44 |
+
1.0
|
| 45 |
+
>>> forward_gain(np.ones((4, 4)))
|
| 46 |
+
4.0
|
| 47 |
+
"""
|
| 48 |
+
return float(np.abs(matrix.sum(axis=1)).max())
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def backward_gain(matrix: np.ndarray) -> float:
|
| 52 |
+
"""
|
| 53 |
+
Compute maximum absolute column sum (worst-case gradient amplification).
|
| 54 |
+
|
| 55 |
+
This is the one norm ||M||_1, which equals the maximum amplification
|
| 56 |
+
a gradient vector can experience in the backward pass:
|
| 57 |
+
||M^T g||_1 <= ||M||_1 * ||g||_1
|
| 58 |
+
|
| 59 |
+
For a doubly stochastic matrix, all column sums equal 1, so backward_gain = 1.
|
| 60 |
+
For unconstrained matrices, can be arbitrarily large.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
matrix: Input matrix of shape (n, n)
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Maximum absolute column sum: max_j |sum_i M[i,j]|
|
| 67 |
+
|
| 68 |
+
Example:
|
| 69 |
+
>>> backward_gain(np.eye(4))
|
| 70 |
+
1.0
|
| 71 |
+
>>> backward_gain(np.ones((4, 4)))
|
| 72 |
+
4.0
|
| 73 |
+
"""
|
| 74 |
+
return float(np.abs(matrix.sum(axis=0)).max())
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def spectral_norm(matrix: np.ndarray, iterations: int = 20) -> float:
|
| 78 |
+
"""
|
| 79 |
+
Estimate spectral norm (largest singular value) via power iteration.
|
| 80 |
+
|
| 81 |
+
The spectral norm ||M||_2 is the maximum amplification of a unit vector
|
| 82 |
+
in the L2 sense. For any input x with ||x||_2 = 1:
|
| 83 |
+
||Mx||_2 <= ||M||_2
|
| 84 |
+
|
| 85 |
+
For doubly stochastic matrices, spectral_norm <= 1.
|
| 86 |
+
|
| 87 |
+
Algorithm (power iteration):
|
| 88 |
+
1. Start with random unit vector v
|
| 89 |
+
2. Iterate: v = M @ v / ||M @ v||
|
| 90 |
+
3. Estimate: ||M @ v|| converges to largest singular value
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
matrix: Input matrix of shape (n, n)
|
| 94 |
+
iterations: Number of power iterations (20 is usually sufficient)
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Estimated spectral norm (largest singular value)
|
| 98 |
+
|
| 99 |
+
Example:
|
| 100 |
+
>>> spectral_norm(np.eye(4)) # doctest: +ELLIPSIS
|
| 101 |
+
1.0...
|
| 102 |
+
>>> spectral_norm(2 * np.eye(4)) # doctest: +ELLIPSIS
|
| 103 |
+
2.0...
|
| 104 |
+
"""
|
| 105 |
+
n = matrix.shape[0]
|
| 106 |
+
|
| 107 |
+
# Initialize with random unit vector
|
| 108 |
+
rng = np.random.default_rng(42) # Fixed seed for reproducibility
|
| 109 |
+
v = rng.standard_normal(n)
|
| 110 |
+
v = v / np.linalg.norm(v)
|
| 111 |
+
|
| 112 |
+
for _ in range(iterations):
|
| 113 |
+
# Power iteration: v = M @ v, then normalize
|
| 114 |
+
w = matrix @ v
|
| 115 |
+
norm = np.linalg.norm(w)
|
| 116 |
+
if norm < 1e-10:
|
| 117 |
+
return 0.0
|
| 118 |
+
v = w / norm
|
| 119 |
+
|
| 120 |
+
# Final estimate: ||M @ v||
|
| 121 |
+
return float(np.linalg.norm(matrix @ v))
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def compute_all_metrics(matrix: np.ndarray) -> dict:
|
| 125 |
+
"""
|
| 126 |
+
Compute all stability metrics for a matrix.
|
| 127 |
+
|
| 128 |
+
This is the main function for analyzing residual mixing matrices.
|
| 129 |
+
It returns all metrics needed to assess training stability.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
matrix: Input matrix of shape (n, n)
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
Dict containing:
|
| 136 |
+
- spectral_norm: Largest singular value
|
| 137 |
+
- forward_gain: Max absolute row sum
|
| 138 |
+
- backward_gain: Max absolute column sum
|
| 139 |
+
- row_sum_max_dev: Max deviation of row sums from 1
|
| 140 |
+
- col_sum_max_dev: Max deviation of column sums from 1
|
| 141 |
+
- min_entry: Minimum matrix entry
|
| 142 |
+
|
| 143 |
+
Example:
|
| 144 |
+
>>> metrics = compute_all_metrics(np.eye(4))
|
| 145 |
+
>>> metrics['forward_gain']
|
| 146 |
+
1.0
|
| 147 |
+
>>> metrics['backward_gain']
|
| 148 |
+
1.0
|
| 149 |
+
"""
|
| 150 |
+
row_sums = matrix.sum(axis=1)
|
| 151 |
+
col_sums = matrix.sum(axis=0)
|
| 152 |
+
|
| 153 |
+
return {
|
| 154 |
+
'spectral_norm': spectral_norm(matrix),
|
| 155 |
+
'forward_gain': float(np.abs(row_sums).max()),
|
| 156 |
+
'backward_gain': float(np.abs(col_sums).max()),
|
| 157 |
+
'row_sum_max_dev': float(np.abs(row_sums - 1).max()),
|
| 158 |
+
'col_sum_max_dev': float(np.abs(col_sums - 1).max()),
|
| 159 |
+
'min_entry': float(matrix.min()),
|
| 160 |
+
}
|
mhc/simulation.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simulation engine for deep network signal propagation.
|
| 3 |
+
|
| 4 |
+
This module simulates how signals propagate through deep residual networks
|
| 5 |
+
with different residual mixing strategies:
|
| 6 |
+
|
| 7 |
+
- baseline: Identity matrices (no mixing, standard residual connections)
|
| 8 |
+
- hc: Random unconstrained matrices (Hyper-Connections)
|
| 9 |
+
- mhc: Sinkhorn-projected doubly stochastic matrices (Manifold-Constrained HC)
|
| 10 |
+
|
| 11 |
+
Key insight from the mHC paper:
|
| 12 |
+
The COMPOSITE mapping (product of all layer matrices H_L @ H_{L-1} @ ... @ H_0)
|
| 13 |
+
is what matters for signal propagation:
|
| 14 |
+
- For HC: composite gains explode exponentially (3000x+ at depth 64)
|
| 15 |
+
- For mHC: composite gains stay bounded (~1.6x at depth 64)
|
| 16 |
+
|
| 17 |
+
This happens because doubly stochastic matrices are closed under multiplication.
|
| 18 |
+
|
| 19 |
+
Author: Subhadip Mitra <contact@subhadipmitra.com>
|
| 20 |
+
Based on DeepSeek's mHC paper: https://arxiv.org/abs/2512.24880
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
from typing import Dict, Literal, Optional
|
| 25 |
+
|
| 26 |
+
from .sinkhorn import sinkhorn_knopp
|
| 27 |
+
from .metrics import compute_all_metrics
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def generate_residual_matrix(
|
| 31 |
+
n: int,
|
| 32 |
+
method: Literal['baseline', 'hc', 'mhc'],
|
| 33 |
+
sinkhorn_iters: int = 20,
|
| 34 |
+
rng: Optional[np.random.Generator] = None
|
| 35 |
+
) -> np.ndarray:
|
| 36 |
+
"""
|
| 37 |
+
Generate a residual mixing matrix.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
n: Size of square matrix (number of streams)
|
| 41 |
+
method: One of:
|
| 42 |
+
- 'baseline': Identity matrix (no mixing)
|
| 43 |
+
- 'hc': Random matrix with N(0, 1) entries
|
| 44 |
+
- 'mhc': Random matrix projected to doubly stochastic via Sinkhorn
|
| 45 |
+
sinkhorn_iters: Number of Sinkhorn iterations for mHC method
|
| 46 |
+
rng: Random number generator for reproducibility
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Residual mixing matrix of shape (n, n)
|
| 50 |
+
|
| 51 |
+
Example:
|
| 52 |
+
>>> rng = np.random.default_rng(42)
|
| 53 |
+
>>> M = generate_residual_matrix(4, 'mhc', sinkhorn_iters=20, rng=rng)
|
| 54 |
+
>>> M.shape
|
| 55 |
+
(4, 4)
|
| 56 |
+
"""
|
| 57 |
+
if rng is None:
|
| 58 |
+
rng = np.random.default_rng()
|
| 59 |
+
|
| 60 |
+
if method == 'baseline':
|
| 61 |
+
return np.eye(n)
|
| 62 |
+
|
| 63 |
+
# Generate random matrix for HC and mHC
|
| 64 |
+
M = rng.standard_normal((n, n))
|
| 65 |
+
|
| 66 |
+
if method == 'hc':
|
| 67 |
+
return M
|
| 68 |
+
|
| 69 |
+
if method == 'mhc':
|
| 70 |
+
# At k=0, return raw random matrix (same as HC) to show explosive behavior
|
| 71 |
+
# At k>0, apply Sinkhorn projection to show transition to stability
|
| 72 |
+
if sinkhorn_iters == 0:
|
| 73 |
+
return M
|
| 74 |
+
return sinkhorn_knopp(M, iterations=sinkhorn_iters)
|
| 75 |
+
|
| 76 |
+
raise ValueError(f"Unknown method: {method}. Expected 'baseline', 'hc', or 'mhc'.")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def simulate_depth(
|
| 80 |
+
depth: int,
|
| 81 |
+
n: int,
|
| 82 |
+
method: Literal['baseline', 'hc', 'mhc'],
|
| 83 |
+
sinkhorn_iters: int = 20,
|
| 84 |
+
seed: int = 42
|
| 85 |
+
) -> Dict:
|
| 86 |
+
"""
|
| 87 |
+
Simulate signal propagation through a deep residual network.
|
| 88 |
+
|
| 89 |
+
This function generates `depth` residual matrices and computes both
|
| 90 |
+
per-layer metrics and cumulative composite metrics at each depth.
|
| 91 |
+
|
| 92 |
+
The composite mapping at layer l is:
|
| 93 |
+
Composite(l) = H_l @ H_{l-1} @ ... @ H_1 @ H_0
|
| 94 |
+
|
| 95 |
+
This represents the total transformation applied to signals from
|
| 96 |
+
the input to layer l.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
depth: Number of layers to simulate
|
| 100 |
+
n: Matrix size (number of streams in multi-stream residual)
|
| 101 |
+
method: Residual mixing strategy ('baseline', 'hc', or 'mhc')
|
| 102 |
+
sinkhorn_iters: Number of Sinkhorn iterations for mHC
|
| 103 |
+
seed: Random seed for reproducibility
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Dict containing:
|
| 107 |
+
- 'method': str - the method used
|
| 108 |
+
- 'depth': int - number of layers
|
| 109 |
+
- 'n': int - matrix size
|
| 110 |
+
- 'sinkhorn_iters': int - Sinkhorn iterations used
|
| 111 |
+
- 'seed': int - random seed used
|
| 112 |
+
- 'per_layer': list of dicts with metrics for each layer's matrix
|
| 113 |
+
- 'composite': list of dicts with metrics for composite at each depth
|
| 114 |
+
|
| 115 |
+
Example:
|
| 116 |
+
>>> result = simulate_depth(64, 4, 'mhc', seed=42)
|
| 117 |
+
>>> result['composite'][-1]['forward_gain'] < 5
|
| 118 |
+
True
|
| 119 |
+
"""
|
| 120 |
+
rng = np.random.default_rng(seed)
|
| 121 |
+
|
| 122 |
+
per_layer = []
|
| 123 |
+
composite_metrics = []
|
| 124 |
+
|
| 125 |
+
composite = np.eye(n) # Start with identity
|
| 126 |
+
|
| 127 |
+
for layer_idx in range(depth):
|
| 128 |
+
# Generate this layer's residual matrix
|
| 129 |
+
H = generate_residual_matrix(n, method, sinkhorn_iters, rng)
|
| 130 |
+
|
| 131 |
+
# Store per-layer metrics
|
| 132 |
+
per_layer.append({
|
| 133 |
+
'layer': layer_idx,
|
| 134 |
+
**compute_all_metrics(H)
|
| 135 |
+
})
|
| 136 |
+
|
| 137 |
+
# Update composite: multiply from the left
|
| 138 |
+
# Composite(l) = H_l @ Composite(l-1) = H_l @ H_{l-1} @ ... @ H_0
|
| 139 |
+
composite = H @ composite
|
| 140 |
+
|
| 141 |
+
# Store composite metrics at this depth
|
| 142 |
+
composite_metrics.append({
|
| 143 |
+
'upto_layer': layer_idx,
|
| 144 |
+
**compute_all_metrics(composite)
|
| 145 |
+
})
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
'method': method,
|
| 149 |
+
'depth': depth,
|
| 150 |
+
'n': n,
|
| 151 |
+
'sinkhorn_iters': sinkhorn_iters,
|
| 152 |
+
'seed': seed,
|
| 153 |
+
'per_layer': per_layer,
|
| 154 |
+
'composite': composite_metrics,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def run_comparison(
|
| 159 |
+
depth: int = 64,
|
| 160 |
+
n: int = 4,
|
| 161 |
+
sinkhorn_iters: int = 20,
|
| 162 |
+
seed: int = 42
|
| 163 |
+
) -> Dict:
|
| 164 |
+
"""
|
| 165 |
+
Run simulation for all three methods and return comparison.
|
| 166 |
+
|
| 167 |
+
This is the main entry point for generating comparison data.
|
| 168 |
+
It runs simulate_depth for baseline, HC, and mHC with the same
|
| 169 |
+
parameters, making direct comparison possible.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
depth: Number of layers to simulate
|
| 173 |
+
n: Matrix size (number of streams)
|
| 174 |
+
sinkhorn_iters: Number of Sinkhorn iterations for mHC
|
| 175 |
+
seed: Random seed (same seed used for all methods for fair comparison)
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Dict with keys 'baseline', 'hc', 'mhc' containing simulation results
|
| 179 |
+
|
| 180 |
+
Example:
|
| 181 |
+
>>> results = run_comparison(depth=64, n=4, seed=42)
|
| 182 |
+
>>> # Baseline should stay at 1
|
| 183 |
+
>>> results['baseline']['composite'][-1]['forward_gain']
|
| 184 |
+
1.0
|
| 185 |
+
>>> # HC should explode
|
| 186 |
+
>>> results['hc']['composite'][-1]['forward_gain'] > 10
|
| 187 |
+
True
|
| 188 |
+
>>> # mHC should stay bounded
|
| 189 |
+
>>> results['mhc']['composite'][-1]['forward_gain'] < 5
|
| 190 |
+
True
|
| 191 |
+
"""
|
| 192 |
+
return {
|
| 193 |
+
'baseline': simulate_depth(depth, n, 'baseline', sinkhorn_iters, seed),
|
| 194 |
+
'hc': simulate_depth(depth, n, 'hc', sinkhorn_iters, seed),
|
| 195 |
+
'mhc': simulate_depth(depth, n, 'mhc', sinkhorn_iters, seed),
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
# Quick demo when run directly
|
| 201 |
+
print("Running mHC simulation comparison...")
|
| 202 |
+
print("=" * 50)
|
| 203 |
+
|
| 204 |
+
results = run_comparison(depth=64, n=4, seed=42)
|
| 205 |
+
|
| 206 |
+
for method in ['baseline', 'hc', 'mhc']:
|
| 207 |
+
final_composite = results[method]['composite'][-1]
|
| 208 |
+
print(f"\n{method.upper()}:")
|
| 209 |
+
print(f" Final composite forward_gain: {final_composite['forward_gain']:.4f}")
|
| 210 |
+
print(f" Final composite backward_gain: {final_composite['backward_gain']:.4f}")
|
| 211 |
+
print(f" Final composite spectral_norm: {final_composite['spectral_norm']:.4f}")
|
mhc/sinkhorn.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sinkhorn-Knopp algorithm for projecting matrices onto doubly stochastic matrices.
|
| 3 |
+
|
| 4 |
+
A doubly stochastic matrix has:
|
| 5 |
+
- All entries >= 0
|
| 6 |
+
- All rows sum to 1
|
| 7 |
+
- All columns sum to 1
|
| 8 |
+
|
| 9 |
+
The Sinkhorn-Knopp algorithm projects any matrix onto this set by:
|
| 10 |
+
1. Exponentiating the matrix to make all entries positive
|
| 11 |
+
2. Alternating row and column normalization until convergence
|
| 12 |
+
|
| 13 |
+
Mathematical background:
|
| 14 |
+
The set of doubly stochastic matrices forms the Birkhoff polytope. Sinkhorn-Knopp
|
| 15 |
+
finds the unique doubly stochastic matrix of the form D1 * exp(M) * D2 where
|
| 16 |
+
D1 and D2 are diagonal matrices with positive entries.
|
| 17 |
+
|
| 18 |
+
Key property for mHC: The product of doubly stochastic matrices is also
|
| 19 |
+
doubly stochastic (closure under multiplication), which bounds composite gains.
|
| 20 |
+
|
| 21 |
+
Author: Subhadip Mitra <contact@subhadipmitra.com>
|
| 22 |
+
Based on DeepSeek's mHC paper: https://arxiv.org/abs/2512.24880
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def sinkhorn_knopp(matrix: np.ndarray, iterations: int = 20, eps: float = 1e-8) -> np.ndarray:
|
| 29 |
+
"""
|
| 30 |
+
Project a matrix onto the set of doubly stochastic matrices.
|
| 31 |
+
|
| 32 |
+
Algorithm:
|
| 33 |
+
1. P = exp(matrix - max(matrix)) for numerical stability
|
| 34 |
+
2. For each iteration:
|
| 35 |
+
a. Normalize rows: P[i,:] = P[i,:] / sum(P[i,:])
|
| 36 |
+
b. Normalize columns: P[:,j] = P[:,j] / sum(P[:,j])
|
| 37 |
+
3. Return P
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
matrix: Input matrix of shape (n, n). Can have any real values.
|
| 41 |
+
iterations: Number of normalization iterations. 20 is typically
|
| 42 |
+
sufficient for 1e-3 accuracy.
|
| 43 |
+
eps: Small value for numerical stability to prevent division by zero.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Approximately doubly stochastic matrix of shape (n, n) where:
|
| 47 |
+
- All entries are >= 0
|
| 48 |
+
- All row sums are approximately 1
|
| 49 |
+
- All column sums are approximately 1
|
| 50 |
+
|
| 51 |
+
Example:
|
| 52 |
+
>>> M = np.random.randn(4, 4)
|
| 53 |
+
>>> P = sinkhorn_knopp(M, iterations=20)
|
| 54 |
+
>>> np.allclose(P.sum(axis=1), 1, atol=1e-3)
|
| 55 |
+
True
|
| 56 |
+
>>> np.allclose(P.sum(axis=0), 1, atol=1e-3)
|
| 57 |
+
True
|
| 58 |
+
"""
|
| 59 |
+
# Subtract max for numerical stability before exponentiation
|
| 60 |
+
# This prevents overflow when matrix has large positive values
|
| 61 |
+
P = np.exp(matrix - matrix.max())
|
| 62 |
+
|
| 63 |
+
for _ in range(iterations):
|
| 64 |
+
# Row normalization: make each row sum to 1
|
| 65 |
+
row_sums = P.sum(axis=1, keepdims=True)
|
| 66 |
+
P = P / np.maximum(row_sums, eps)
|
| 67 |
+
|
| 68 |
+
# Column normalization: make each column sum to 1
|
| 69 |
+
col_sums = P.sum(axis=0, keepdims=True)
|
| 70 |
+
P = P / np.maximum(col_sums, eps)
|
| 71 |
+
|
| 72 |
+
return P
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def is_doubly_stochastic(matrix: np.ndarray, tol: float = 1e-3) -> bool:
|
| 76 |
+
"""
|
| 77 |
+
Check if a matrix is approximately doubly stochastic.
|
| 78 |
+
|
| 79 |
+
A matrix is doubly stochastic if:
|
| 80 |
+
- All entries are non-negative
|
| 81 |
+
- All row sums equal 1
|
| 82 |
+
- All column sums equal 1
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
matrix: Input matrix to check, shape (n, n)
|
| 86 |
+
tol: Tolerance for row/column sum deviation from 1.0
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
True if matrix satisfies all doubly stochastic conditions
|
| 90 |
+
within the given tolerance.
|
| 91 |
+
|
| 92 |
+
Example:
|
| 93 |
+
>>> I = np.eye(4)
|
| 94 |
+
>>> is_doubly_stochastic(I)
|
| 95 |
+
True
|
| 96 |
+
>>> M = np.random.randn(4, 4)
|
| 97 |
+
>>> is_doubly_stochastic(M)
|
| 98 |
+
False
|
| 99 |
+
"""
|
| 100 |
+
# Check non-negativity
|
| 101 |
+
if matrix.min() < -tol:
|
| 102 |
+
return False
|
| 103 |
+
|
| 104 |
+
# Check row sums
|
| 105 |
+
row_sums = matrix.sum(axis=1)
|
| 106 |
+
if not np.allclose(row_sums, 1.0, atol=tol):
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
# Check column sums
|
| 110 |
+
col_sums = matrix.sum(axis=0)
|
| 111 |
+
if not np.allclose(col_sums, 1.0, atol=tol):
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def projection_error(matrix: np.ndarray) -> dict:
|
| 118 |
+
"""
|
| 119 |
+
Compute how far a matrix is from being doubly stochastic.
|
| 120 |
+
|
| 121 |
+
This is useful for:
|
| 122 |
+
- Verifying Sinkhorn-Knopp convergence
|
| 123 |
+
- Debugging numerical issues
|
| 124 |
+
- Visualizing the projection process
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
matrix: Input matrix to analyze, shape (n, n)
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
Dict containing:
|
| 131 |
+
- 'row_sum_max_dev': Maximum absolute deviation of any row sum from 1
|
| 132 |
+
- 'col_sum_max_dev': Maximum absolute deviation of any column sum from 1
|
| 133 |
+
- 'min_entry': Minimum entry in the matrix (should be >= 0 for DS)
|
| 134 |
+
|
| 135 |
+
Example:
|
| 136 |
+
>>> P = sinkhorn_knopp(np.random.randn(4, 4), iterations=20)
|
| 137 |
+
>>> err = projection_error(P)
|
| 138 |
+
>>> err['row_sum_max_dev'] < 1e-3
|
| 139 |
+
True
|
| 140 |
+
"""
|
| 141 |
+
row_sums = matrix.sum(axis=1)
|
| 142 |
+
col_sums = matrix.sum(axis=0)
|
| 143 |
+
|
| 144 |
+
return {
|
| 145 |
+
'row_sum_max_dev': float(np.abs(row_sums - 1.0).max()),
|
| 146 |
+
'col_sum_max_dev': float(np.abs(col_sums - 1.0).max()),
|
| 147 |
+
'min_entry': float(matrix.min()),
|
| 148 |
+
}
|
mhc/torch_module.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PyTorch implementation of mHC (Manifold-Constrained Hyper-Connections).
|
| 3 |
+
|
| 4 |
+
This module provides differentiable implementations that can be used
|
| 5 |
+
directly in neural network training:
|
| 6 |
+
|
| 7 |
+
- SinkhornKnopp: Differentiable projection onto doubly stochastic matrices
|
| 8 |
+
- mHCResidual: Complete mHC residual connection module
|
| 9 |
+
- mHCBlock: Wrapper to add mHC residuals to any layer
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
# Wrap any layer with mHC residuals
|
| 13 |
+
layer = nn.Linear(256, 256)
|
| 14 |
+
mhc_layer = mHCBlock(layer, dim=256, n_streams=4)
|
| 15 |
+
|
| 16 |
+
# Forward pass
|
| 17 |
+
x = torch.randn(32, 4, 256) # (batch, n_streams, dim)
|
| 18 |
+
output = mhc_layer(x)
|
| 19 |
+
|
| 20 |
+
Author: Subhadip Mitra <contact@subhadipmitra.com>
|
| 21 |
+
Based on DeepSeek's mHC paper: https://arxiv.org/abs/2512.24880
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
from typing import Optional
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class SinkhornKnopp(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
Differentiable Sinkhorn-Knopp projection onto doubly stochastic matrices.
|
| 32 |
+
|
| 33 |
+
Projects any matrix onto the Birkhoff polytope (set of doubly stochastic
|
| 34 |
+
matrices) using alternating row and column normalization.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
iterations: Number of normalization iterations (default: 20)
|
| 38 |
+
eps: Small value for numerical stability (default: 1e-8)
|
| 39 |
+
|
| 40 |
+
Example:
|
| 41 |
+
>>> sinkhorn = SinkhornKnopp(iterations=20)
|
| 42 |
+
>>> M = torch.randn(4, 4)
|
| 43 |
+
>>> P = sinkhorn(M)
|
| 44 |
+
>>> P.sum(dim=1) # Should be close to 1
|
| 45 |
+
tensor([1., 1., 1., 1.])
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, iterations: int = 20, eps: float = 1e-8):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.iterations = iterations
|
| 51 |
+
self.eps = eps
|
| 52 |
+
|
| 53 |
+
def forward(self, matrix: torch.Tensor) -> torch.Tensor:
|
| 54 |
+
"""
|
| 55 |
+
Project matrix onto doubly stochastic matrices.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
matrix: Input tensor of shape (..., n, n)
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Approximately doubly stochastic matrix of same shape
|
| 62 |
+
"""
|
| 63 |
+
# Subtract max for numerical stability before exp
|
| 64 |
+
P = torch.exp(matrix - matrix.max())
|
| 65 |
+
|
| 66 |
+
for _ in range(self.iterations):
|
| 67 |
+
# Row normalization
|
| 68 |
+
P = P / (P.sum(dim=-1, keepdim=True) + self.eps)
|
| 69 |
+
# Column normalization
|
| 70 |
+
P = P / (P.sum(dim=-2, keepdim=True) + self.eps)
|
| 71 |
+
|
| 72 |
+
return P
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class mHCResidual(nn.Module):
|
| 76 |
+
"""
|
| 77 |
+
Manifold-Constrained Hyper-Connection residual module.
|
| 78 |
+
|
| 79 |
+
Implements the mHC residual connection with learnable mixing matrices
|
| 80 |
+
that are projected onto doubly stochastic matrices via Sinkhorn-Knopp.
|
| 81 |
+
|
| 82 |
+
The module maintains multiple "streams" of hidden states and mixes them
|
| 83 |
+
using constrained matrices to ensure stable signal propagation.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
dim: Hidden dimension size
|
| 87 |
+
n_streams: Number of parallel streams (default: 4)
|
| 88 |
+
sinkhorn_iters: Number of Sinkhorn iterations (default: 20)
|
| 89 |
+
|
| 90 |
+
Example:
|
| 91 |
+
>>> mhc = mHCResidual(dim=256, n_streams=4)
|
| 92 |
+
>>> x = torch.randn(32, 4, 256) # (batch, n_streams, dim)
|
| 93 |
+
>>> layer_out = torch.randn(32, 256) # Output from layer F
|
| 94 |
+
>>> y = mhc(x, layer_out)
|
| 95 |
+
>>> y.shape
|
| 96 |
+
torch.Size([32, 4, 256])
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
dim: int,
|
| 102 |
+
n_streams: int = 4,
|
| 103 |
+
sinkhorn_iters: int = 20
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.dim = dim
|
| 107 |
+
self.n_streams = n_streams
|
| 108 |
+
|
| 109 |
+
# Sinkhorn projection
|
| 110 |
+
self.sinkhorn = SinkhornKnopp(iterations=sinkhorn_iters)
|
| 111 |
+
|
| 112 |
+
# Learnable mixing matrices (before projection)
|
| 113 |
+
# H_res: mixing within residual streams
|
| 114 |
+
self.H_res = nn.Parameter(torch.randn(n_streams, n_streams) * 0.01)
|
| 115 |
+
|
| 116 |
+
# H_pre: aggregating streams to layer input (1 x n_streams)
|
| 117 |
+
self.H_pre = nn.Parameter(torch.ones(1, n_streams) / n_streams)
|
| 118 |
+
|
| 119 |
+
# H_post: distributing layer output to streams (n_streams x 1)
|
| 120 |
+
self.H_post = nn.Parameter(torch.ones(n_streams, 1) / n_streams)
|
| 121 |
+
|
| 122 |
+
# Learnable gating scalars (initialized small for stable training)
|
| 123 |
+
self.alpha_res = nn.Parameter(torch.tensor(0.01))
|
| 124 |
+
self.alpha_pre = nn.Parameter(torch.tensor(0.01))
|
| 125 |
+
self.alpha_post = nn.Parameter(torch.tensor(0.01))
|
| 126 |
+
|
| 127 |
+
# Bias terms
|
| 128 |
+
self.bias_res = nn.Parameter(torch.zeros(n_streams, dim))
|
| 129 |
+
self.bias_post = nn.Parameter(torch.zeros(n_streams, dim))
|
| 130 |
+
|
| 131 |
+
def forward(
|
| 132 |
+
self,
|
| 133 |
+
x: torch.Tensor,
|
| 134 |
+
layer_output: torch.Tensor
|
| 135 |
+
) -> torch.Tensor:
|
| 136 |
+
"""
|
| 137 |
+
Apply mHC residual connection.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
x: Input hidden state, shape (batch, n_streams, dim)
|
| 141 |
+
layer_output: Output from layer function F, shape (batch, dim)
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Updated hidden state, shape (batch, n_streams, dim)
|
| 145 |
+
"""
|
| 146 |
+
batch_size = x.shape[0]
|
| 147 |
+
|
| 148 |
+
# Project H_res onto doubly stochastic
|
| 149 |
+
H_res_proj = self.sinkhorn(self.H_res)
|
| 150 |
+
|
| 151 |
+
# Mix residual streams: (batch, n_streams, dim) @ (n_streams, n_streams)^T
|
| 152 |
+
# Equivalent to applying H_res to each position
|
| 153 |
+
x_mixed = torch.einsum('bsd,rs->brd', x, H_res_proj)
|
| 154 |
+
|
| 155 |
+
# Scale by alpha_res and add bias
|
| 156 |
+
x_mixed = self.alpha_res * x_mixed + self.bias_res
|
| 157 |
+
|
| 158 |
+
# Distribute layer output to streams
|
| 159 |
+
# layer_output: (batch, dim) -> (batch, n_streams, dim)
|
| 160 |
+
layer_contrib = layer_output.unsqueeze(1) * self.H_post # (batch, n_streams, dim)
|
| 161 |
+
layer_contrib = self.alpha_post * layer_contrib + self.bias_post
|
| 162 |
+
|
| 163 |
+
# Combine: residual mixing + layer contribution + original input
|
| 164 |
+
output = x + x_mixed + layer_contrib
|
| 165 |
+
|
| 166 |
+
return output
|
| 167 |
+
|
| 168 |
+
def get_aggregated_input(self, x: torch.Tensor) -> torch.Tensor:
|
| 169 |
+
"""
|
| 170 |
+
Aggregate multi-stream input for layer function.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
x: Hidden state, shape (batch, n_streams, dim)
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Aggregated input for layer, shape (batch, dim)
|
| 177 |
+
"""
|
| 178 |
+
# Weighted sum across streams
|
| 179 |
+
# H_pre: (1, n_streams), x: (batch, n_streams, dim)
|
| 180 |
+
aggregated = torch.einsum('bsd,os->bd', x, self.H_pre.abs())
|
| 181 |
+
return self.alpha_pre * aggregated
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class mHCBlock(nn.Module):
|
| 185 |
+
"""
|
| 186 |
+
Wrapper that adds mHC residual connections to any layer.
|
| 187 |
+
|
| 188 |
+
This is the main interface for using mHC in your models. It wraps
|
| 189 |
+
any PyTorch module (e.g., Linear, Attention) with mHC residuals.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
layer: The layer module to wrap (e.g., nn.Linear)
|
| 193 |
+
dim: Hidden dimension
|
| 194 |
+
n_streams: Number of parallel streams (default: 4)
|
| 195 |
+
sinkhorn_iters: Number of Sinkhorn iterations (default: 20)
|
| 196 |
+
|
| 197 |
+
Example:
|
| 198 |
+
>>> # Wrap a linear layer
|
| 199 |
+
>>> layer = nn.Linear(256, 256)
|
| 200 |
+
>>> block = mHCBlock(layer, dim=256, n_streams=4)
|
| 201 |
+
>>>
|
| 202 |
+
>>> # Input has shape (batch, n_streams, dim)
|
| 203 |
+
>>> x = torch.randn(32, 4, 256)
|
| 204 |
+
>>> output = block(x)
|
| 205 |
+
>>> output.shape
|
| 206 |
+
torch.Size([32, 4, 256])
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
layer: nn.Module,
|
| 212 |
+
dim: int,
|
| 213 |
+
n_streams: int = 4,
|
| 214 |
+
sinkhorn_iters: int = 20
|
| 215 |
+
):
|
| 216 |
+
super().__init__()
|
| 217 |
+
self.layer = layer
|
| 218 |
+
self.mhc = mHCResidual(dim, n_streams, sinkhorn_iters)
|
| 219 |
+
self.n_streams = n_streams
|
| 220 |
+
|
| 221 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 222 |
+
"""
|
| 223 |
+
Forward pass with mHC residual.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
x: Input tensor, shape (batch, n_streams, dim)
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Output tensor, shape (batch, n_streams, dim)
|
| 230 |
+
"""
|
| 231 |
+
# Aggregate streams for layer input
|
| 232 |
+
layer_input = self.mhc.get_aggregated_input(x)
|
| 233 |
+
|
| 234 |
+
# Apply the wrapped layer
|
| 235 |
+
layer_output = self.layer(layer_input)
|
| 236 |
+
|
| 237 |
+
# Apply mHC residual
|
| 238 |
+
output = self.mhc(x, layer_output)
|
| 239 |
+
|
| 240 |
+
return output
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def create_mhc_mlp(
|
| 244 |
+
dim: int,
|
| 245 |
+
n_layers: int,
|
| 246 |
+
n_streams: int = 4,
|
| 247 |
+
sinkhorn_iters: int = 20
|
| 248 |
+
) -> nn.Sequential:
|
| 249 |
+
"""
|
| 250 |
+
Create an MLP with mHC residual connections.
|
| 251 |
+
|
| 252 |
+
Convenience function to create a multi-layer perceptron where
|
| 253 |
+
each layer is wrapped with mHC residuals. All layers maintain
|
| 254 |
+
the same dimension for mHC stream compatibility.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
dim: Hidden dimension (constant throughout)
|
| 258 |
+
n_layers: Number of mHC blocks
|
| 259 |
+
n_streams: Number of mHC streams
|
| 260 |
+
sinkhorn_iters: Sinkhorn iterations
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
nn.Sequential module with mHC blocks
|
| 264 |
+
|
| 265 |
+
Example:
|
| 266 |
+
>>> mlp = create_mhc_mlp(dim=256, n_layers=4)
|
| 267 |
+
>>> x = torch.randn(32, 4, 256) # (batch, n_streams, dim)
|
| 268 |
+
>>> y = mlp(x)
|
| 269 |
+
>>> y.shape
|
| 270 |
+
torch.Size([32, 4, 256])
|
| 271 |
+
"""
|
| 272 |
+
layers = []
|
| 273 |
+
|
| 274 |
+
for i in range(n_layers):
|
| 275 |
+
layer = nn.Linear(dim, dim)
|
| 276 |
+
layers.append(mHCBlock(layer, dim, n_streams, sinkhorn_iters))
|
| 277 |
+
if i < n_layers - 1:
|
| 278 |
+
layers.append(nn.GELU())
|
| 279 |
+
|
| 280 |
+
return nn.Sequential(*layers)
|
requirements.txt
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
-
|
| 2 |
torch
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
# Add other dependencies as needed
|
|
|
|
| 1 |
+
numpy
|
| 2 |
torch
|
| 3 |
+
matplotlib
|
| 4 |
+
pytest
|
|
|
|
|
|