dylan-marimo-io commited on
Commit
8ab7e68
·
verified ·
1 Parent(s): ef9f423

Upload 7 files

Browse files
mhc/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ mHC (Manifold-Constrained Hyper-Connections) visualization library.
3
+
4
+ This package provides tools for demonstrating the stability properties
5
+ of mHC residual connections compared to unconstrained HC and baseline methods.
6
+
7
+ Modules:
8
+ - sinkhorn: Sinkhorn-Knopp projection onto doubly stochastic matrices
9
+ - metrics: Stability metrics (forward_gain, backward_gain, spectral_norm)
10
+ - simulation: Deep network signal propagation simulation
11
+ - torch_module: PyTorch implementation for use in neural networks
12
+
13
+ Author: Subhadip Mitra <contact@subhadipmitra.com>
14
+ Based on DeepSeek's mHC paper: https://arxiv.org/abs/2512.24880
15
+ """
16
+
17
+ from .sinkhorn import sinkhorn_knopp, is_doubly_stochastic, projection_error
18
+ from .metrics import forward_gain, backward_gain, spectral_norm, compute_all_metrics
19
+ from .simulation import generate_residual_matrix, simulate_depth, run_comparison
20
+
21
+ # PyTorch modules (optional import - requires torch)
22
+ try:
23
+ from .torch_module import SinkhornKnopp, mHCResidual, mHCBlock, create_mhc_mlp
24
+ except ImportError:
25
+ pass # torch not installed
26
+
27
+ __version__ = "0.1.0"
mhc/hyper_connections.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch implementation of Hyper-Connections (HC) and mHC.
3
+
4
+ HC extends residual connections with multiple parallel streams and learned mixing.
5
+ mHC constrains the mixing matrix to be doubly stochastic via Sinkhorn-Knopp.
6
+
7
+ Reference: https://arxiv.org/abs/2512.24880
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+
14
+ def sinkhorn_knopp_torch(M: torch.Tensor, iters: int = 20, eps: float = 1e-8) -> torch.Tensor:
15
+ """
16
+ Differentiable Sinkhorn-Knopp projection to doubly stochastic matrix.
17
+
18
+ Args:
19
+ M: Input matrix of shape (n, n)
20
+ iters: Number of alternating normalization iterations
21
+ eps: Small value for numerical stability
22
+
23
+ Returns:
24
+ Approximately doubly stochastic matrix (rows and cols sum to 1)
25
+ """
26
+ P = torch.exp(M - M.max())
27
+ for _ in range(iters):
28
+ P = P / (P.sum(dim=-1, keepdim=True) + eps)
29
+ P = P / (P.sum(dim=-2, keepdim=True) + eps)
30
+ return P
31
+
32
+
33
+ class HyperConnections(nn.Module):
34
+ """
35
+ Hyper-Connections: Multi-stream residual with learned mixing.
36
+
37
+ Each layer maintains N parallel streams. The mixing matrix H combines
38
+ streams at each layer:
39
+ output = H @ input_streams + layer_contribution
40
+
41
+ Args:
42
+ n_streams: Number of parallel streams (N)
43
+ init_scale: Scale for random initialization of H
44
+
45
+ Shape:
46
+ - Input x: (B, N, D) where B=batch, N=streams, D=features
47
+ - Output: (B, N, D)
48
+ """
49
+
50
+ def __init__(self, n_streams: int = 4, init_scale: float = 0.1):
51
+ super().__init__()
52
+ self.n_streams = n_streams
53
+ # Raw mixing matrix - unconstrained for HC
54
+ self.H_res = nn.Parameter(torch.randn(n_streams, n_streams) * init_scale)
55
+
56
+ def get_mixing_matrix(self) -> torch.Tensor:
57
+ """Return the mixing matrix. Override in subclasses for constraints."""
58
+ return self.H_res
59
+
60
+ def forward(self, x: torch.Tensor, layer_output: torch.Tensor) -> torch.Tensor:
61
+ """
62
+ Apply HC mixing.
63
+
64
+ Args:
65
+ x: Streamed input of shape (B, N, ...)
66
+ layer_output: Output from layer F, shape (B, N, ...)
67
+
68
+ Returns:
69
+ Mixed output: H @ x + layer_output
70
+ """
71
+ H = self.get_mixing_matrix()
72
+ # H @ x using einsum for arbitrary trailing dimensions
73
+ mixed = torch.einsum('ij,bj...->bi...', H, x)
74
+ return mixed + layer_output
75
+
76
+
77
+ class MHC(HyperConnections):
78
+ """
79
+ Manifold-Constrained Hyper-Connections (mHC).
80
+
81
+ Like HC, but the mixing matrix is projected to be doubly stochastic
82
+ via Sinkhorn-Knopp. This ensures:
83
+ - All eigenvalues bounded by 1
84
+ - Stable signal propagation through depth
85
+ - No gradient explosion
86
+
87
+ Args:
88
+ n_streams: Number of parallel streams
89
+ sinkhorn_iters: Number of Sinkhorn iterations for projection
90
+ init_scale: Scale for random initialization
91
+ """
92
+
93
+ def __init__(self, n_streams: int = 4, sinkhorn_iters: int = 20, init_scale: float = 0.1):
94
+ super().__init__(n_streams, init_scale)
95
+ self.sinkhorn_iters = sinkhorn_iters
96
+
97
+ def get_mixing_matrix(self) -> torch.Tensor:
98
+ """Return doubly stochastic mixing matrix via Sinkhorn projection."""
99
+ return sinkhorn_knopp_torch(self.H_res, self.sinkhorn_iters)
100
+
101
+
102
+ class ResidualBlock(nn.Module):
103
+ """
104
+ Residual block with configurable connection type.
105
+
106
+ Supports three modes:
107
+ - 'baseline': Standard residual connection (x + F(x))
108
+ - 'hc': Hyper-Connections with unconstrained mixing
109
+ - 'mhc': Manifold-Constrained HC with doubly stochastic mixing
110
+
111
+ Args:
112
+ channels: Number of channels in conv layers
113
+ method: One of 'baseline', 'hc', 'mhc'
114
+ n_streams: Number of streams for HC/mHC
115
+ sinkhorn_iters: Sinkhorn iterations for mHC
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ channels: int,
121
+ method: str = 'baseline',
122
+ n_streams: int = 4,
123
+ sinkhorn_iters: int = 20
124
+ ):
125
+ super().__init__()
126
+ self.method = method
127
+ self.n_streams = n_streams
128
+
129
+ # Main conv path (standard ResNet-style)
130
+ self.conv = nn.Sequential(
131
+ nn.Conv2d(channels, channels, 3, padding=1, bias=False),
132
+ nn.BatchNorm2d(channels),
133
+ nn.ReLU(inplace=True),
134
+ nn.Conv2d(channels, channels, 3, padding=1, bias=False),
135
+ nn.BatchNorm2d(channels),
136
+ )
137
+
138
+ # HC/mHC mixing
139
+ if method == 'hc':
140
+ self.hc = HyperConnections(n_streams)
141
+ elif method == 'mhc':
142
+ self.hc = MHC(n_streams, sinkhorn_iters)
143
+ else:
144
+ self.hc = None
145
+
146
+ self.relu = nn.ReLU(inplace=True)
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ """
150
+ Forward pass with configurable residual connection.
151
+
152
+ Args:
153
+ x: Input tensor of shape (B, C, H, W)
154
+
155
+ Returns:
156
+ Output tensor of shape (B, C, H, W)
157
+ """
158
+ if self.method == 'baseline':
159
+ # Standard residual: x + F(x)
160
+ return self.relu(x + self.conv(x))
161
+
162
+ # HC/mHC path
163
+ B, C, H, W = x.shape
164
+ N = self.n_streams
165
+
166
+ # Expand input to streams: (B, C, H, W) -> (B, N, C*H*W)
167
+ # Using view instead of expand to avoid memory copy where possible
168
+ x_flat = x.view(B, 1, -1).expand(B, N, -1)
169
+
170
+ # Apply conv to original input
171
+ conv_out = self.conv(x)
172
+ conv_flat = conv_out.view(B, 1, -1).expand(B, N, -1)
173
+
174
+ # Mix via HC/mHC: H @ x_streams + conv_streams
175
+ mixed = self.hc(x_flat, conv_flat)
176
+
177
+ # Collapse streams: mean over N, reshape back
178
+ out = mixed.mean(dim=1).view(B, C, H, W)
179
+ return self.relu(out)
180
+
181
+
182
+ class SimpleCNN(nn.Module):
183
+ """
184
+ Simple CNN with configurable residual connection type.
185
+
186
+ Architecture:
187
+ - Stem: 3x3 conv to channels
188
+ - N residual blocks (configurable connection type)
189
+ - Head: global avg pool + linear classifier
190
+
191
+ Args:
192
+ n_blocks: Number of residual blocks
193
+ channels: Hidden dimension
194
+ method: Residual type ('baseline', 'hc', 'mhc')
195
+ n_streams: Number of streams for HC/mHC
196
+ sinkhorn_iters: Sinkhorn iterations for mHC
197
+ num_classes: Number of output classes
198
+ in_channels: Number of input channels (3 for RGB)
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ n_blocks: int = 8,
204
+ channels: int = 64,
205
+ method: str = 'baseline',
206
+ n_streams: int = 4,
207
+ sinkhorn_iters: int = 20,
208
+ num_classes: int = 10,
209
+ in_channels: int = 3
210
+ ):
211
+ super().__init__()
212
+ self.method = method
213
+
214
+ # Stem
215
+ self.stem = nn.Sequential(
216
+ nn.Conv2d(in_channels, channels, 3, padding=1, bias=False),
217
+ nn.BatchNorm2d(channels),
218
+ nn.ReLU(inplace=True),
219
+ )
220
+
221
+ # Residual blocks
222
+ self.blocks = nn.ModuleList([
223
+ ResidualBlock(channels, method, n_streams, sinkhorn_iters)
224
+ for _ in range(n_blocks)
225
+ ])
226
+
227
+ # Classification head
228
+ self.head = nn.Sequential(
229
+ nn.AdaptiveAvgPool2d(1),
230
+ nn.Flatten(),
231
+ nn.Linear(channels, num_classes),
232
+ )
233
+
234
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
235
+ """
236
+ Forward pass.
237
+
238
+ Args:
239
+ x: Input images of shape (B, C, H, W)
240
+
241
+ Returns:
242
+ Logits of shape (B, num_classes)
243
+ """
244
+ x = self.stem(x)
245
+ for block in self.blocks:
246
+ x = block(x)
247
+ return self.head(x)
248
+
249
+
250
+ def train_with_gradient_tracking(
251
+ model: nn.Module,
252
+ train_loader,
253
+ epochs: int,
254
+ device: torch.device,
255
+ lr: float = 1e-3
256
+ ) -> dict:
257
+ """
258
+ Train model and record gradient magnitudes.
259
+
260
+ Args:
261
+ model: PyTorch model to train
262
+ train_loader: DataLoader for training data
263
+ epochs: Number of training epochs
264
+ device: Device to train on
265
+ lr: Learning rate
266
+
267
+ Returns:
268
+ Dict with training history:
269
+ - 'loss': List of loss values per step
270
+ - 'grad_norms': List of total gradient norms per step
271
+ - 'accuracy': List of batch accuracies per step
272
+ """
273
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
274
+ criterion = nn.CrossEntropyLoss()
275
+
276
+ history = {
277
+ 'loss': [],
278
+ 'grad_norms': [],
279
+ 'accuracy': [],
280
+ }
281
+
282
+ model.train()
283
+ for epoch in range(epochs):
284
+ for data, target in train_loader:
285
+ data, target = data.to(device), target.to(device)
286
+
287
+ optimizer.zero_grad()
288
+ output = model(data)
289
+ loss = criterion(output, target)
290
+ loss.backward()
291
+
292
+ # Record gradient norm
293
+ total_norm = 0.0
294
+ for param in model.parameters():
295
+ if param.grad is not None:
296
+ total_norm += param.grad.norm().item() ** 2
297
+ total_norm = total_norm ** 0.5
298
+
299
+ # Record accuracy
300
+ pred = output.argmax(dim=1)
301
+ acc = (pred == target).float().mean().item()
302
+
303
+ history['loss'].append(loss.item())
304
+ history['grad_norms'].append(total_norm)
305
+ history['accuracy'].append(acc)
306
+
307
+ optimizer.step()
308
+
309
+ return history
mhc/metrics.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Stability metrics for analyzing residual mixing matrices.
3
+
4
+ These metrics quantify how a matrix amplifies signals during forward/backward
5
+ propagation through a neural network layer.
6
+
7
+ Key insight from the mHC paper:
8
+ - Unconstrained matrices (HC) can have unbounded gains, causing signal explosion
9
+ - Doubly stochastic matrices (mHC) have all gains bounded by 1, ensuring stability
10
+
11
+ Metrics:
12
+ - forward_gain: Worst-case signal amplification in forward pass (max row sum)
13
+ - backward_gain: Worst-case gradient amplification in backward pass (max column sum)
14
+ - spectral_norm: Largest singular value (general operator norm)
15
+
16
+ For doubly stochastic matrices, all three equal exactly 1.
17
+
18
+ Author: Subhadip Mitra <contact@subhadipmitra.com>
19
+ Based on DeepSeek's mHC paper: https://arxiv.org/abs/2512.24880
20
+ """
21
+
22
+ import numpy as np
23
+
24
+
25
+ def forward_gain(matrix: np.ndarray) -> float:
26
+ """
27
+ Compute maximum absolute row sum (worst-case signal amplification).
28
+
29
+ This is the infinity norm ||M||_∞, which equals the maximum amplification
30
+ a unit input vector can experience in the forward pass:
31
+ ||Mx||_∞ <= ||M||_∞ * ||x||_∞
32
+
33
+ For a doubly stochastic matrix, all row sums equal 1, so forward_gain = 1.
34
+ For unconstrained matrices, can be arbitrarily large.
35
+
36
+ Args:
37
+ matrix: Input matrix of shape (n, n)
38
+
39
+ Returns:
40
+ Maximum absolute row sum: max_i |sum_j M[i,j]|
41
+
42
+ Example:
43
+ >>> forward_gain(np.eye(4))
44
+ 1.0
45
+ >>> forward_gain(np.ones((4, 4)))
46
+ 4.0
47
+ """
48
+ return float(np.abs(matrix.sum(axis=1)).max())
49
+
50
+
51
+ def backward_gain(matrix: np.ndarray) -> float:
52
+ """
53
+ Compute maximum absolute column sum (worst-case gradient amplification).
54
+
55
+ This is the one norm ||M||_1, which equals the maximum amplification
56
+ a gradient vector can experience in the backward pass:
57
+ ||M^T g||_1 <= ||M||_1 * ||g||_1
58
+
59
+ For a doubly stochastic matrix, all column sums equal 1, so backward_gain = 1.
60
+ For unconstrained matrices, can be arbitrarily large.
61
+
62
+ Args:
63
+ matrix: Input matrix of shape (n, n)
64
+
65
+ Returns:
66
+ Maximum absolute column sum: max_j |sum_i M[i,j]|
67
+
68
+ Example:
69
+ >>> backward_gain(np.eye(4))
70
+ 1.0
71
+ >>> backward_gain(np.ones((4, 4)))
72
+ 4.0
73
+ """
74
+ return float(np.abs(matrix.sum(axis=0)).max())
75
+
76
+
77
+ def spectral_norm(matrix: np.ndarray, iterations: int = 20) -> float:
78
+ """
79
+ Estimate spectral norm (largest singular value) via power iteration.
80
+
81
+ The spectral norm ||M||_2 is the maximum amplification of a unit vector
82
+ in the L2 sense. For any input x with ||x||_2 = 1:
83
+ ||Mx||_2 <= ||M||_2
84
+
85
+ For doubly stochastic matrices, spectral_norm <= 1.
86
+
87
+ Algorithm (power iteration):
88
+ 1. Start with random unit vector v
89
+ 2. Iterate: v = M @ v / ||M @ v||
90
+ 3. Estimate: ||M @ v|| converges to largest singular value
91
+
92
+ Args:
93
+ matrix: Input matrix of shape (n, n)
94
+ iterations: Number of power iterations (20 is usually sufficient)
95
+
96
+ Returns:
97
+ Estimated spectral norm (largest singular value)
98
+
99
+ Example:
100
+ >>> spectral_norm(np.eye(4)) # doctest: +ELLIPSIS
101
+ 1.0...
102
+ >>> spectral_norm(2 * np.eye(4)) # doctest: +ELLIPSIS
103
+ 2.0...
104
+ """
105
+ n = matrix.shape[0]
106
+
107
+ # Initialize with random unit vector
108
+ rng = np.random.default_rng(42) # Fixed seed for reproducibility
109
+ v = rng.standard_normal(n)
110
+ v = v / np.linalg.norm(v)
111
+
112
+ for _ in range(iterations):
113
+ # Power iteration: v = M @ v, then normalize
114
+ w = matrix @ v
115
+ norm = np.linalg.norm(w)
116
+ if norm < 1e-10:
117
+ return 0.0
118
+ v = w / norm
119
+
120
+ # Final estimate: ||M @ v||
121
+ return float(np.linalg.norm(matrix @ v))
122
+
123
+
124
+ def compute_all_metrics(matrix: np.ndarray) -> dict:
125
+ """
126
+ Compute all stability metrics for a matrix.
127
+
128
+ This is the main function for analyzing residual mixing matrices.
129
+ It returns all metrics needed to assess training stability.
130
+
131
+ Args:
132
+ matrix: Input matrix of shape (n, n)
133
+
134
+ Returns:
135
+ Dict containing:
136
+ - spectral_norm: Largest singular value
137
+ - forward_gain: Max absolute row sum
138
+ - backward_gain: Max absolute column sum
139
+ - row_sum_max_dev: Max deviation of row sums from 1
140
+ - col_sum_max_dev: Max deviation of column sums from 1
141
+ - min_entry: Minimum matrix entry
142
+
143
+ Example:
144
+ >>> metrics = compute_all_metrics(np.eye(4))
145
+ >>> metrics['forward_gain']
146
+ 1.0
147
+ >>> metrics['backward_gain']
148
+ 1.0
149
+ """
150
+ row_sums = matrix.sum(axis=1)
151
+ col_sums = matrix.sum(axis=0)
152
+
153
+ return {
154
+ 'spectral_norm': spectral_norm(matrix),
155
+ 'forward_gain': float(np.abs(row_sums).max()),
156
+ 'backward_gain': float(np.abs(col_sums).max()),
157
+ 'row_sum_max_dev': float(np.abs(row_sums - 1).max()),
158
+ 'col_sum_max_dev': float(np.abs(col_sums - 1).max()),
159
+ 'min_entry': float(matrix.min()),
160
+ }
mhc/simulation.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simulation engine for deep network signal propagation.
3
+
4
+ This module simulates how signals propagate through deep residual networks
5
+ with different residual mixing strategies:
6
+
7
+ - baseline: Identity matrices (no mixing, standard residual connections)
8
+ - hc: Random unconstrained matrices (Hyper-Connections)
9
+ - mhc: Sinkhorn-projected doubly stochastic matrices (Manifold-Constrained HC)
10
+
11
+ Key insight from the mHC paper:
12
+ The COMPOSITE mapping (product of all layer matrices H_L @ H_{L-1} @ ... @ H_0)
13
+ is what matters for signal propagation:
14
+ - For HC: composite gains explode exponentially (3000x+ at depth 64)
15
+ - For mHC: composite gains stay bounded (~1.6x at depth 64)
16
+
17
+ This happens because doubly stochastic matrices are closed under multiplication.
18
+
19
+ Author: Subhadip Mitra <contact@subhadipmitra.com>
20
+ Based on DeepSeek's mHC paper: https://arxiv.org/abs/2512.24880
21
+ """
22
+
23
+ import numpy as np
24
+ from typing import Dict, Literal, Optional
25
+
26
+ from .sinkhorn import sinkhorn_knopp
27
+ from .metrics import compute_all_metrics
28
+
29
+
30
+ def generate_residual_matrix(
31
+ n: int,
32
+ method: Literal['baseline', 'hc', 'mhc'],
33
+ sinkhorn_iters: int = 20,
34
+ rng: Optional[np.random.Generator] = None
35
+ ) -> np.ndarray:
36
+ """
37
+ Generate a residual mixing matrix.
38
+
39
+ Args:
40
+ n: Size of square matrix (number of streams)
41
+ method: One of:
42
+ - 'baseline': Identity matrix (no mixing)
43
+ - 'hc': Random matrix with N(0, 1) entries
44
+ - 'mhc': Random matrix projected to doubly stochastic via Sinkhorn
45
+ sinkhorn_iters: Number of Sinkhorn iterations for mHC method
46
+ rng: Random number generator for reproducibility
47
+
48
+ Returns:
49
+ Residual mixing matrix of shape (n, n)
50
+
51
+ Example:
52
+ >>> rng = np.random.default_rng(42)
53
+ >>> M = generate_residual_matrix(4, 'mhc', sinkhorn_iters=20, rng=rng)
54
+ >>> M.shape
55
+ (4, 4)
56
+ """
57
+ if rng is None:
58
+ rng = np.random.default_rng()
59
+
60
+ if method == 'baseline':
61
+ return np.eye(n)
62
+
63
+ # Generate random matrix for HC and mHC
64
+ M = rng.standard_normal((n, n))
65
+
66
+ if method == 'hc':
67
+ return M
68
+
69
+ if method == 'mhc':
70
+ # At k=0, return raw random matrix (same as HC) to show explosive behavior
71
+ # At k>0, apply Sinkhorn projection to show transition to stability
72
+ if sinkhorn_iters == 0:
73
+ return M
74
+ return sinkhorn_knopp(M, iterations=sinkhorn_iters)
75
+
76
+ raise ValueError(f"Unknown method: {method}. Expected 'baseline', 'hc', or 'mhc'.")
77
+
78
+
79
+ def simulate_depth(
80
+ depth: int,
81
+ n: int,
82
+ method: Literal['baseline', 'hc', 'mhc'],
83
+ sinkhorn_iters: int = 20,
84
+ seed: int = 42
85
+ ) -> Dict:
86
+ """
87
+ Simulate signal propagation through a deep residual network.
88
+
89
+ This function generates `depth` residual matrices and computes both
90
+ per-layer metrics and cumulative composite metrics at each depth.
91
+
92
+ The composite mapping at layer l is:
93
+ Composite(l) = H_l @ H_{l-1} @ ... @ H_1 @ H_0
94
+
95
+ This represents the total transformation applied to signals from
96
+ the input to layer l.
97
+
98
+ Args:
99
+ depth: Number of layers to simulate
100
+ n: Matrix size (number of streams in multi-stream residual)
101
+ method: Residual mixing strategy ('baseline', 'hc', or 'mhc')
102
+ sinkhorn_iters: Number of Sinkhorn iterations for mHC
103
+ seed: Random seed for reproducibility
104
+
105
+ Returns:
106
+ Dict containing:
107
+ - 'method': str - the method used
108
+ - 'depth': int - number of layers
109
+ - 'n': int - matrix size
110
+ - 'sinkhorn_iters': int - Sinkhorn iterations used
111
+ - 'seed': int - random seed used
112
+ - 'per_layer': list of dicts with metrics for each layer's matrix
113
+ - 'composite': list of dicts with metrics for composite at each depth
114
+
115
+ Example:
116
+ >>> result = simulate_depth(64, 4, 'mhc', seed=42)
117
+ >>> result['composite'][-1]['forward_gain'] < 5
118
+ True
119
+ """
120
+ rng = np.random.default_rng(seed)
121
+
122
+ per_layer = []
123
+ composite_metrics = []
124
+
125
+ composite = np.eye(n) # Start with identity
126
+
127
+ for layer_idx in range(depth):
128
+ # Generate this layer's residual matrix
129
+ H = generate_residual_matrix(n, method, sinkhorn_iters, rng)
130
+
131
+ # Store per-layer metrics
132
+ per_layer.append({
133
+ 'layer': layer_idx,
134
+ **compute_all_metrics(H)
135
+ })
136
+
137
+ # Update composite: multiply from the left
138
+ # Composite(l) = H_l @ Composite(l-1) = H_l @ H_{l-1} @ ... @ H_0
139
+ composite = H @ composite
140
+
141
+ # Store composite metrics at this depth
142
+ composite_metrics.append({
143
+ 'upto_layer': layer_idx,
144
+ **compute_all_metrics(composite)
145
+ })
146
+
147
+ return {
148
+ 'method': method,
149
+ 'depth': depth,
150
+ 'n': n,
151
+ 'sinkhorn_iters': sinkhorn_iters,
152
+ 'seed': seed,
153
+ 'per_layer': per_layer,
154
+ 'composite': composite_metrics,
155
+ }
156
+
157
+
158
+ def run_comparison(
159
+ depth: int = 64,
160
+ n: int = 4,
161
+ sinkhorn_iters: int = 20,
162
+ seed: int = 42
163
+ ) -> Dict:
164
+ """
165
+ Run simulation for all three methods and return comparison.
166
+
167
+ This is the main entry point for generating comparison data.
168
+ It runs simulate_depth for baseline, HC, and mHC with the same
169
+ parameters, making direct comparison possible.
170
+
171
+ Args:
172
+ depth: Number of layers to simulate
173
+ n: Matrix size (number of streams)
174
+ sinkhorn_iters: Number of Sinkhorn iterations for mHC
175
+ seed: Random seed (same seed used for all methods for fair comparison)
176
+
177
+ Returns:
178
+ Dict with keys 'baseline', 'hc', 'mhc' containing simulation results
179
+
180
+ Example:
181
+ >>> results = run_comparison(depth=64, n=4, seed=42)
182
+ >>> # Baseline should stay at 1
183
+ >>> results['baseline']['composite'][-1]['forward_gain']
184
+ 1.0
185
+ >>> # HC should explode
186
+ >>> results['hc']['composite'][-1]['forward_gain'] > 10
187
+ True
188
+ >>> # mHC should stay bounded
189
+ >>> results['mhc']['composite'][-1]['forward_gain'] < 5
190
+ True
191
+ """
192
+ return {
193
+ 'baseline': simulate_depth(depth, n, 'baseline', sinkhorn_iters, seed),
194
+ 'hc': simulate_depth(depth, n, 'hc', sinkhorn_iters, seed),
195
+ 'mhc': simulate_depth(depth, n, 'mhc', sinkhorn_iters, seed),
196
+ }
197
+
198
+
199
+ if __name__ == "__main__":
200
+ # Quick demo when run directly
201
+ print("Running mHC simulation comparison...")
202
+ print("=" * 50)
203
+
204
+ results = run_comparison(depth=64, n=4, seed=42)
205
+
206
+ for method in ['baseline', 'hc', 'mhc']:
207
+ final_composite = results[method]['composite'][-1]
208
+ print(f"\n{method.upper()}:")
209
+ print(f" Final composite forward_gain: {final_composite['forward_gain']:.4f}")
210
+ print(f" Final composite backward_gain: {final_composite['backward_gain']:.4f}")
211
+ print(f" Final composite spectral_norm: {final_composite['spectral_norm']:.4f}")
mhc/sinkhorn.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sinkhorn-Knopp algorithm for projecting matrices onto doubly stochastic matrices.
3
+
4
+ A doubly stochastic matrix has:
5
+ - All entries >= 0
6
+ - All rows sum to 1
7
+ - All columns sum to 1
8
+
9
+ The Sinkhorn-Knopp algorithm projects any matrix onto this set by:
10
+ 1. Exponentiating the matrix to make all entries positive
11
+ 2. Alternating row and column normalization until convergence
12
+
13
+ Mathematical background:
14
+ The set of doubly stochastic matrices forms the Birkhoff polytope. Sinkhorn-Knopp
15
+ finds the unique doubly stochastic matrix of the form D1 * exp(M) * D2 where
16
+ D1 and D2 are diagonal matrices with positive entries.
17
+
18
+ Key property for mHC: The product of doubly stochastic matrices is also
19
+ doubly stochastic (closure under multiplication), which bounds composite gains.
20
+
21
+ Author: Subhadip Mitra <contact@subhadipmitra.com>
22
+ Based on DeepSeek's mHC paper: https://arxiv.org/abs/2512.24880
23
+ """
24
+
25
+ import numpy as np
26
+
27
+
28
+ def sinkhorn_knopp(matrix: np.ndarray, iterations: int = 20, eps: float = 1e-8) -> np.ndarray:
29
+ """
30
+ Project a matrix onto the set of doubly stochastic matrices.
31
+
32
+ Algorithm:
33
+ 1. P = exp(matrix - max(matrix)) for numerical stability
34
+ 2. For each iteration:
35
+ a. Normalize rows: P[i,:] = P[i,:] / sum(P[i,:])
36
+ b. Normalize columns: P[:,j] = P[:,j] / sum(P[:,j])
37
+ 3. Return P
38
+
39
+ Args:
40
+ matrix: Input matrix of shape (n, n). Can have any real values.
41
+ iterations: Number of normalization iterations. 20 is typically
42
+ sufficient for 1e-3 accuracy.
43
+ eps: Small value for numerical stability to prevent division by zero.
44
+
45
+ Returns:
46
+ Approximately doubly stochastic matrix of shape (n, n) where:
47
+ - All entries are >= 0
48
+ - All row sums are approximately 1
49
+ - All column sums are approximately 1
50
+
51
+ Example:
52
+ >>> M = np.random.randn(4, 4)
53
+ >>> P = sinkhorn_knopp(M, iterations=20)
54
+ >>> np.allclose(P.sum(axis=1), 1, atol=1e-3)
55
+ True
56
+ >>> np.allclose(P.sum(axis=0), 1, atol=1e-3)
57
+ True
58
+ """
59
+ # Subtract max for numerical stability before exponentiation
60
+ # This prevents overflow when matrix has large positive values
61
+ P = np.exp(matrix - matrix.max())
62
+
63
+ for _ in range(iterations):
64
+ # Row normalization: make each row sum to 1
65
+ row_sums = P.sum(axis=1, keepdims=True)
66
+ P = P / np.maximum(row_sums, eps)
67
+
68
+ # Column normalization: make each column sum to 1
69
+ col_sums = P.sum(axis=0, keepdims=True)
70
+ P = P / np.maximum(col_sums, eps)
71
+
72
+ return P
73
+
74
+
75
+ def is_doubly_stochastic(matrix: np.ndarray, tol: float = 1e-3) -> bool:
76
+ """
77
+ Check if a matrix is approximately doubly stochastic.
78
+
79
+ A matrix is doubly stochastic if:
80
+ - All entries are non-negative
81
+ - All row sums equal 1
82
+ - All column sums equal 1
83
+
84
+ Args:
85
+ matrix: Input matrix to check, shape (n, n)
86
+ tol: Tolerance for row/column sum deviation from 1.0
87
+
88
+ Returns:
89
+ True if matrix satisfies all doubly stochastic conditions
90
+ within the given tolerance.
91
+
92
+ Example:
93
+ >>> I = np.eye(4)
94
+ >>> is_doubly_stochastic(I)
95
+ True
96
+ >>> M = np.random.randn(4, 4)
97
+ >>> is_doubly_stochastic(M)
98
+ False
99
+ """
100
+ # Check non-negativity
101
+ if matrix.min() < -tol:
102
+ return False
103
+
104
+ # Check row sums
105
+ row_sums = matrix.sum(axis=1)
106
+ if not np.allclose(row_sums, 1.0, atol=tol):
107
+ return False
108
+
109
+ # Check column sums
110
+ col_sums = matrix.sum(axis=0)
111
+ if not np.allclose(col_sums, 1.0, atol=tol):
112
+ return False
113
+
114
+ return True
115
+
116
+
117
+ def projection_error(matrix: np.ndarray) -> dict:
118
+ """
119
+ Compute how far a matrix is from being doubly stochastic.
120
+
121
+ This is useful for:
122
+ - Verifying Sinkhorn-Knopp convergence
123
+ - Debugging numerical issues
124
+ - Visualizing the projection process
125
+
126
+ Args:
127
+ matrix: Input matrix to analyze, shape (n, n)
128
+
129
+ Returns:
130
+ Dict containing:
131
+ - 'row_sum_max_dev': Maximum absolute deviation of any row sum from 1
132
+ - 'col_sum_max_dev': Maximum absolute deviation of any column sum from 1
133
+ - 'min_entry': Minimum entry in the matrix (should be >= 0 for DS)
134
+
135
+ Example:
136
+ >>> P = sinkhorn_knopp(np.random.randn(4, 4), iterations=20)
137
+ >>> err = projection_error(P)
138
+ >>> err['row_sum_max_dev'] < 1e-3
139
+ True
140
+ """
141
+ row_sums = matrix.sum(axis=1)
142
+ col_sums = matrix.sum(axis=0)
143
+
144
+ return {
145
+ 'row_sum_max_dev': float(np.abs(row_sums - 1.0).max()),
146
+ 'col_sum_max_dev': float(np.abs(col_sums - 1.0).max()),
147
+ 'min_entry': float(matrix.min()),
148
+ }
mhc/torch_module.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch implementation of mHC (Manifold-Constrained Hyper-Connections).
3
+
4
+ This module provides differentiable implementations that can be used
5
+ directly in neural network training:
6
+
7
+ - SinkhornKnopp: Differentiable projection onto doubly stochastic matrices
8
+ - mHCResidual: Complete mHC residual connection module
9
+ - mHCBlock: Wrapper to add mHC residuals to any layer
10
+
11
+ Usage:
12
+ # Wrap any layer with mHC residuals
13
+ layer = nn.Linear(256, 256)
14
+ mhc_layer = mHCBlock(layer, dim=256, n_streams=4)
15
+
16
+ # Forward pass
17
+ x = torch.randn(32, 4, 256) # (batch, n_streams, dim)
18
+ output = mhc_layer(x)
19
+
20
+ Author: Subhadip Mitra <contact@subhadipmitra.com>
21
+ Based on DeepSeek's mHC paper: https://arxiv.org/abs/2512.24880
22
+ """
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ from typing import Optional
27
+
28
+
29
+ class SinkhornKnopp(nn.Module):
30
+ """
31
+ Differentiable Sinkhorn-Knopp projection onto doubly stochastic matrices.
32
+
33
+ Projects any matrix onto the Birkhoff polytope (set of doubly stochastic
34
+ matrices) using alternating row and column normalization.
35
+
36
+ Args:
37
+ iterations: Number of normalization iterations (default: 20)
38
+ eps: Small value for numerical stability (default: 1e-8)
39
+
40
+ Example:
41
+ >>> sinkhorn = SinkhornKnopp(iterations=20)
42
+ >>> M = torch.randn(4, 4)
43
+ >>> P = sinkhorn(M)
44
+ >>> P.sum(dim=1) # Should be close to 1
45
+ tensor([1., 1., 1., 1.])
46
+ """
47
+
48
+ def __init__(self, iterations: int = 20, eps: float = 1e-8):
49
+ super().__init__()
50
+ self.iterations = iterations
51
+ self.eps = eps
52
+
53
+ def forward(self, matrix: torch.Tensor) -> torch.Tensor:
54
+ """
55
+ Project matrix onto doubly stochastic matrices.
56
+
57
+ Args:
58
+ matrix: Input tensor of shape (..., n, n)
59
+
60
+ Returns:
61
+ Approximately doubly stochastic matrix of same shape
62
+ """
63
+ # Subtract max for numerical stability before exp
64
+ P = torch.exp(matrix - matrix.max())
65
+
66
+ for _ in range(self.iterations):
67
+ # Row normalization
68
+ P = P / (P.sum(dim=-1, keepdim=True) + self.eps)
69
+ # Column normalization
70
+ P = P / (P.sum(dim=-2, keepdim=True) + self.eps)
71
+
72
+ return P
73
+
74
+
75
+ class mHCResidual(nn.Module):
76
+ """
77
+ Manifold-Constrained Hyper-Connection residual module.
78
+
79
+ Implements the mHC residual connection with learnable mixing matrices
80
+ that are projected onto doubly stochastic matrices via Sinkhorn-Knopp.
81
+
82
+ The module maintains multiple "streams" of hidden states and mixes them
83
+ using constrained matrices to ensure stable signal propagation.
84
+
85
+ Args:
86
+ dim: Hidden dimension size
87
+ n_streams: Number of parallel streams (default: 4)
88
+ sinkhorn_iters: Number of Sinkhorn iterations (default: 20)
89
+
90
+ Example:
91
+ >>> mhc = mHCResidual(dim=256, n_streams=4)
92
+ >>> x = torch.randn(32, 4, 256) # (batch, n_streams, dim)
93
+ >>> layer_out = torch.randn(32, 256) # Output from layer F
94
+ >>> y = mhc(x, layer_out)
95
+ >>> y.shape
96
+ torch.Size([32, 4, 256])
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ dim: int,
102
+ n_streams: int = 4,
103
+ sinkhorn_iters: int = 20
104
+ ):
105
+ super().__init__()
106
+ self.dim = dim
107
+ self.n_streams = n_streams
108
+
109
+ # Sinkhorn projection
110
+ self.sinkhorn = SinkhornKnopp(iterations=sinkhorn_iters)
111
+
112
+ # Learnable mixing matrices (before projection)
113
+ # H_res: mixing within residual streams
114
+ self.H_res = nn.Parameter(torch.randn(n_streams, n_streams) * 0.01)
115
+
116
+ # H_pre: aggregating streams to layer input (1 x n_streams)
117
+ self.H_pre = nn.Parameter(torch.ones(1, n_streams) / n_streams)
118
+
119
+ # H_post: distributing layer output to streams (n_streams x 1)
120
+ self.H_post = nn.Parameter(torch.ones(n_streams, 1) / n_streams)
121
+
122
+ # Learnable gating scalars (initialized small for stable training)
123
+ self.alpha_res = nn.Parameter(torch.tensor(0.01))
124
+ self.alpha_pre = nn.Parameter(torch.tensor(0.01))
125
+ self.alpha_post = nn.Parameter(torch.tensor(0.01))
126
+
127
+ # Bias terms
128
+ self.bias_res = nn.Parameter(torch.zeros(n_streams, dim))
129
+ self.bias_post = nn.Parameter(torch.zeros(n_streams, dim))
130
+
131
+ def forward(
132
+ self,
133
+ x: torch.Tensor,
134
+ layer_output: torch.Tensor
135
+ ) -> torch.Tensor:
136
+ """
137
+ Apply mHC residual connection.
138
+
139
+ Args:
140
+ x: Input hidden state, shape (batch, n_streams, dim)
141
+ layer_output: Output from layer function F, shape (batch, dim)
142
+
143
+ Returns:
144
+ Updated hidden state, shape (batch, n_streams, dim)
145
+ """
146
+ batch_size = x.shape[0]
147
+
148
+ # Project H_res onto doubly stochastic
149
+ H_res_proj = self.sinkhorn(self.H_res)
150
+
151
+ # Mix residual streams: (batch, n_streams, dim) @ (n_streams, n_streams)^T
152
+ # Equivalent to applying H_res to each position
153
+ x_mixed = torch.einsum('bsd,rs->brd', x, H_res_proj)
154
+
155
+ # Scale by alpha_res and add bias
156
+ x_mixed = self.alpha_res * x_mixed + self.bias_res
157
+
158
+ # Distribute layer output to streams
159
+ # layer_output: (batch, dim) -> (batch, n_streams, dim)
160
+ layer_contrib = layer_output.unsqueeze(1) * self.H_post # (batch, n_streams, dim)
161
+ layer_contrib = self.alpha_post * layer_contrib + self.bias_post
162
+
163
+ # Combine: residual mixing + layer contribution + original input
164
+ output = x + x_mixed + layer_contrib
165
+
166
+ return output
167
+
168
+ def get_aggregated_input(self, x: torch.Tensor) -> torch.Tensor:
169
+ """
170
+ Aggregate multi-stream input for layer function.
171
+
172
+ Args:
173
+ x: Hidden state, shape (batch, n_streams, dim)
174
+
175
+ Returns:
176
+ Aggregated input for layer, shape (batch, dim)
177
+ """
178
+ # Weighted sum across streams
179
+ # H_pre: (1, n_streams), x: (batch, n_streams, dim)
180
+ aggregated = torch.einsum('bsd,os->bd', x, self.H_pre.abs())
181
+ return self.alpha_pre * aggregated
182
+
183
+
184
+ class mHCBlock(nn.Module):
185
+ """
186
+ Wrapper that adds mHC residual connections to any layer.
187
+
188
+ This is the main interface for using mHC in your models. It wraps
189
+ any PyTorch module (e.g., Linear, Attention) with mHC residuals.
190
+
191
+ Args:
192
+ layer: The layer module to wrap (e.g., nn.Linear)
193
+ dim: Hidden dimension
194
+ n_streams: Number of parallel streams (default: 4)
195
+ sinkhorn_iters: Number of Sinkhorn iterations (default: 20)
196
+
197
+ Example:
198
+ >>> # Wrap a linear layer
199
+ >>> layer = nn.Linear(256, 256)
200
+ >>> block = mHCBlock(layer, dim=256, n_streams=4)
201
+ >>>
202
+ >>> # Input has shape (batch, n_streams, dim)
203
+ >>> x = torch.randn(32, 4, 256)
204
+ >>> output = block(x)
205
+ >>> output.shape
206
+ torch.Size([32, 4, 256])
207
+ """
208
+
209
+ def __init__(
210
+ self,
211
+ layer: nn.Module,
212
+ dim: int,
213
+ n_streams: int = 4,
214
+ sinkhorn_iters: int = 20
215
+ ):
216
+ super().__init__()
217
+ self.layer = layer
218
+ self.mhc = mHCResidual(dim, n_streams, sinkhorn_iters)
219
+ self.n_streams = n_streams
220
+
221
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
222
+ """
223
+ Forward pass with mHC residual.
224
+
225
+ Args:
226
+ x: Input tensor, shape (batch, n_streams, dim)
227
+
228
+ Returns:
229
+ Output tensor, shape (batch, n_streams, dim)
230
+ """
231
+ # Aggregate streams for layer input
232
+ layer_input = self.mhc.get_aggregated_input(x)
233
+
234
+ # Apply the wrapped layer
235
+ layer_output = self.layer(layer_input)
236
+
237
+ # Apply mHC residual
238
+ output = self.mhc(x, layer_output)
239
+
240
+ return output
241
+
242
+
243
+ def create_mhc_mlp(
244
+ dim: int,
245
+ n_layers: int,
246
+ n_streams: int = 4,
247
+ sinkhorn_iters: int = 20
248
+ ) -> nn.Sequential:
249
+ """
250
+ Create an MLP with mHC residual connections.
251
+
252
+ Convenience function to create a multi-layer perceptron where
253
+ each layer is wrapped with mHC residuals. All layers maintain
254
+ the same dimension for mHC stream compatibility.
255
+
256
+ Args:
257
+ dim: Hidden dimension (constant throughout)
258
+ n_layers: Number of mHC blocks
259
+ n_streams: Number of mHC streams
260
+ sinkhorn_iters: Sinkhorn iterations
261
+
262
+ Returns:
263
+ nn.Sequential module with mHC blocks
264
+
265
+ Example:
266
+ >>> mlp = create_mhc_mlp(dim=256, n_layers=4)
267
+ >>> x = torch.randn(32, 4, 256) # (batch, n_streams, dim)
268
+ >>> y = mlp(x)
269
+ >>> y.shape
270
+ torch.Size([32, 4, 256])
271
+ """
272
+ layers = []
273
+
274
+ for i in range(n_layers):
275
+ layer = nn.Linear(dim, dim)
276
+ layers.append(mHCBlock(layer, dim, n_streams, sinkhorn_iters))
277
+ if i < n_layers - 1:
278
+ layers.append(nn.GELU())
279
+
280
+ return nn.Sequential(*layers)
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
- marimo
2
  torch
3
- # Or a specific version
4
- # marimo>=0.9.0
5
-
6
- # Add other dependencies as needed
 
1
+ numpy
2
  torch
3
+ matplotlib
4
+ pytest