Spaces:
Sleeping
Sleeping
| import marimo | |
| __generated_with = "0.18.4" | |
| app = marimo.App(width="medium", app_title="Mhc Exploration") | |
| with app.setup(hide_code=True): | |
| import sys | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| import numpy as np | |
| import altair as alt | |
| import pandas as pd | |
| import marimo as mo | |
| from mhc import run_comparison, sinkhorn_knopp, compute_all_metrics | |
| def _(): | |
| mo.md(r""" | |
| # Exploring Manifold-Constrained Hyper-Connections (mHC) | |
| [](https://molab.marimo.io/github/bassrehab/mhc-visualizer/blob/main/notebook/mhc_exploration.py) | |
| This interactive notebook demonstrates the key insight from DeepSeek's mHC paper: | |
| **unconstrained residual mixing matrices cause signal explosion in deep networks, | |
| while doubly stochastic constraints keep signals bounded.** | |
| Use the controls below to explore how Sinkhorn iterations affect stability! | |
| **Paper:** https://arxiv.org/abs/2512.24880 | |
| **Viz Author (under MIT):** Subhadip Mitra | |
| **Implementation/ Viz Repository:** https://github.com/bassrehab/mhc-visualizer | |
| """) | |
| return | |
| def _(): | |
| mo.md(r""" | |
| ## The Sinkhorn-Knopp Algorithm | |
| The Sinkhorn-Knopp algorithm projects any positive matrix onto the set of **doubly stochastic matrices** - matrices where all rows and columns sum to 1. | |
| ### Why Doubly Stochastic? | |
| Doubly stochastic matrices have a crucial property: **they are closed under multiplication**. This means: | |
| - If A and B are doubly stochastic, then A @ B is also doubly stochastic | |
| - The spectral norm is bounded: ||A|| <= 1 | |
| - Signal propagation stays bounded even through many layers | |
| ### The Algorithm | |
| Starting from any positive matrix, we alternate between: | |
| 1. Normalizing rows to sum to 1 | |
| 2. Normalizing columns to sum to 1 | |
| This converges to a doubly stochastic matrix! | |
| """) | |
| return | |
| def _(): | |
| mo.md(r""" | |
| ## Interactive Controls | |
| Adjust the parameters below to see how they affect signal propagation stability. | |
| """) | |
| return | |
| def _(active_config): | |
| # Sliders initialized from active config (re-created when preset clicked) | |
| sinkhorn_slider = mo.ui.slider( | |
| start=0, | |
| stop=30, | |
| value=active_config["k"], | |
| step=1, | |
| label="Sinkhorn Iterations (k)", | |
| show_value=True, | |
| ) | |
| depth_slider = mo.ui.slider( | |
| start=10, | |
| stop=200, | |
| value=active_config["depth"], | |
| step=10, | |
| label="Network Depth", | |
| show_value=True, | |
| ) | |
| n_dropdown = mo.ui.dropdown( | |
| options={"2": 2, "4": 4, "8": 8}, | |
| value=active_config["n"], | |
| label="Number of Streams (n)", | |
| ) | |
| seed_input = mo.ui.number( | |
| value=active_config["seed"], start=0, stop=10000, label="Random Seed" | |
| ) | |
| controls = mo.hstack( | |
| [sinkhorn_slider, depth_slider, n_dropdown, seed_input], justify="start", gap=2 | |
| ) | |
| controls | |
| return depth_slider, n_dropdown, seed_input, sinkhorn_slider | |
| def _(): | |
| # Preset buttons - clicking triggers re-run with new defaults | |
| preset_default = mo.ui.run_button(label="Default") | |
| preset_explosion = mo.ui.run_button(label="HC Explosion (k=0)") | |
| preset_minimal = mo.ui.run_button(label="Minimal Projection (k=5)") | |
| preset_deep = mo.ui.run_button(label="Deep Network (200)") | |
| randomize_btn = mo.ui.run_button(label="Randomize Seed") | |
| presets = mo.hstack( | |
| [preset_default, preset_explosion, preset_minimal, preset_deep, randomize_btn], | |
| justify="start", | |
| gap=1, | |
| ) | |
| presets | |
| return preset_deep, preset_explosion, preset_minimal, randomize_btn | |
| def _(depth_slider, n_dropdown, seed_input, sinkhorn_slider): | |
| # Run the simulation with current parameters | |
| results = run_comparison( | |
| depth=depth_slider.value, | |
| n=int(n_dropdown.value), | |
| sinkhorn_iters=sinkhorn_slider.value, | |
| seed=seed_input.value, | |
| ) | |
| return (results,) | |
| def _(): | |
| mo.md(r""" | |
| ## Signal Propagation: The Explosion Problem | |
| The real issue isn't single-layer behavior - it's what happens when we **compose many layers**. | |
| In a deep network, the effective transformation is: | |
| $$H_{composite} = H_L \cdot H_{L-1} \cdot ... \cdot H_1$$ | |
| Watch the chart below: **HC (red) explodes exponentially, while mHC (blue) stays bounded!** | |
| """) | |
| return | |
| def _(base_chart): | |
| mo.ui.altair_chart(base_chart) | |
| return | |
| def _(): | |
| mo.md(r""" | |
| ## Stability Metrics | |
| The table below shows metrics at the selected layer: | |
| - **Forward Gain**: Maximum row sum (worst-case signal amplification) | |
| - **Backward Gain**: Maximum column sum (gradient flow) | |
| - **Spectral Norm**: Largest singular value (operator norm) | |
| For doubly stochastic matrices (mHC), all these should be close to 1! | |
| """) | |
| return | |
| def _(layer_selector, results): | |
| layer_idx = layer_selector.value - 1 | |
| baseline_m = results["baseline"]["composite"][layer_idx] | |
| hc_m = results["hc"]["composite"][layer_idx] | |
| mhc_m = results["mhc"]["composite"][layer_idx] | |
| def _format_gain(g): | |
| if g > 1000: | |
| return f"{g:.2e}" | |
| return f"{g:.4f}" | |
| def _status_badge(g): | |
| if g < 2: | |
| return mo.md(f"**{_format_gain(g)}** :green_circle:") | |
| elif g < 10: | |
| return mo.md(f"**{_format_gain(g)}** :yellow_circle:") | |
| else: | |
| return mo.md(f"**{_format_gain(g)}** :red_circle:") | |
| metrics_md = mo.md(f""" | |
| ### Metrics at Layer {layer_selector.value} | |
| | Metric | Baseline | HC (Unconstrained) | mHC (Sinkhorn) | | |
| |--------|----------|-------------------|----------------| | |
| | Forward Gain | {_format_gain(baseline_m["forward_gain"])} | {_format_gain(hc_m["forward_gain"])} | {_format_gain(mhc_m["forward_gain"])} | | |
| | Backward Gain | {_format_gain(baseline_m["backward_gain"])} | {_format_gain(hc_m["backward_gain"])} | {_format_gain(mhc_m["backward_gain"])} | | |
| | Spectral Norm | {_format_gain(baseline_m["spectral_norm"])} | {_format_gain(hc_m["spectral_norm"])} | {_format_gain(mhc_m["spectral_norm"])} | | |
| | Row Sum Dev | {baseline_m["row_sum_max_dev"]:.2e} | {hc_m["row_sum_max_dev"]:.2e} | {mhc_m["row_sum_max_dev"]:.2e} | | |
| | Col Sum Dev | {baseline_m["col_sum_max_dev"]:.2e} | {hc_m["col_sum_max_dev"]:.2e} | {mhc_m["col_sum_max_dev"]:.2e} | | |
| """) | |
| metrics_md | |
| return | |
| def _(): | |
| mo.md(r""" | |
| ## Matrix Visualization | |
| Compare a sample residual mixing matrix before and after Sinkhorn projection. | |
| - **HC (left)**: Random matrix with arbitrary row/column sums | |
| - **mHC (right)**: Sinkhorn-projected doubly stochastic matrix (all rows and columns sum to 1) | |
| """) | |
| return | |
| def _(heatmaps): | |
| heatmaps | |
| return | |
| def _(hc_sample, mhc_sample): | |
| # Show row and column sums | |
| hc_row_sums = hc_sample.sum(axis=1) | |
| hc_col_sums = hc_sample.sum(axis=0) | |
| mhc_row_sums = mhc_sample.sum(axis=1) | |
| mhc_col_sums = mhc_sample.sum(axis=0) | |
| sums_md = mo.md(f""" | |
| **Row/Column Sums:** | |
| | | HC | mHC | | |
| |---|---|---| | |
| | Row Sums | {np.array2string(hc_row_sums, precision=2)} | {np.array2string(mhc_row_sums, precision=3)} | | |
| | Col Sums | {np.array2string(hc_col_sums, precision=2)} | {np.array2string(mhc_col_sums, precision=3)} | | |
| | Max Row Dev from 1 | {np.abs(hc_row_sums - 1).max():.4f} | {np.abs(mhc_row_sums - 1).max():.2e} | | |
| | Max Col Dev from 1 | {np.abs(hc_col_sums - 1).max():.4f} | {np.abs(mhc_col_sums - 1).max():.2e} | | |
| Notice how mHC row/column sums are all ~1.0 (doubly stochastic)! | |
| """) | |
| sums_md | |
| return | |
| def _(): | |
| mo.md(r""" | |
| ## The Manifold Dial: Varying Sinkhorn Iterations | |
| Watch how the matrix transforms as we increase iterations: | |
| - **k=0**: Raw random matrix (same as HC) - unconstrained | |
| - **k=1-5**: Partial projection, rapid stabilization | |
| - **k=10-20**: Fully doubly stochastic | |
| """) | |
| return | |
| def _(seed_input): | |
| # Show progression of Sinkhorn iterations | |
| iters_to_show = [0, 1, 2, 5, 10, 20] | |
| rng_dial = np.random.default_rng(seed_input.value) | |
| dial_base = rng_dial.standard_normal((4, 4)) | |
| dial_data = [] | |
| for k in iters_to_show: | |
| if k == 0: | |
| mat = dial_base # Raw random matrix (same as HC) | |
| else: | |
| mat = sinkhorn_knopp(dial_base, iterations=k) | |
| for _i in range(4): | |
| for _j in range(4): | |
| dial_data.append( | |
| { | |
| "row": str(_i), | |
| "col": str(_j), | |
| "value": float(mat[_i, _j]), | |
| "k": f"k={k}", | |
| } | |
| ) | |
| dial_df = pd.DataFrame(dial_data) | |
| dial_chart = ( | |
| alt.Chart(dial_df) | |
| .mark_rect() | |
| .encode( | |
| x=alt.X("col:O", title=None, axis=alt.Axis(labels=False)), | |
| y=alt.Y("row:O", title=None, axis=alt.Axis(labels=False)), | |
| color=alt.Color( | |
| "value:Q", scale=alt.Scale(scheme="blues", domain=[0, 0.6]), legend=None | |
| ), | |
| tooltip=["row", "col", alt.Tooltip("value:Q", format=".3f")], | |
| ) | |
| .properties(width=100, height=100) | |
| .facet( | |
| column=alt.Column("k:N", title="Sinkhorn Iterations", sort=iters_to_show) | |
| ) | |
| .properties( | |
| title="The Manifold Dial: Sinkhorn Iterations Transform Random to Doubly Stochastic" | |
| ) | |
| ) | |
| dial_chart | |
| return | |
| def _(): | |
| mo.md(r""" | |
| ## Why Does This Work? | |
| ### The Mathematics of Stability | |
| Doubly stochastic matrices have three key properties: | |
| 1. **Spectral norm <= 1**: The matrix doesn't amplify signals | |
| 2. **Closed under multiplication**: Products remain doubly stochastic | |
| 3. **Convex combinations of permutations**: Acts like a weighted average (Birkhoff-von Neumann theorem) | |
| When you multiply many doubly stochastic matrices together, the result stays bounded because each multiplication is **non-expansive**. | |
| In contrast, unconstrained matrices compound their gains exponentially: | |
| - If each matrix has gain 1.1, after 64 layers: $1.1^{64} \approx 300$ | |
| - If each matrix has gain 1.5, after 64 layers: $1.5^{64} \approx 10^{11}$! | |
| """) | |
| return | |
| def _(): | |
| mo.md(r""" | |
| ## Key Takeaways | |
| 1. **HC (Hyper-Connections)** use unconstrained residual mixing matrices | |
| - Each layer's matrix can have arbitrary row/column sums | |
| - Over many layers, these compound into **exponential explosion** | |
| 2. **mHC (Manifold-Constrained HC)** projects matrices onto the Birkhoff polytope | |
| - Uses Sinkhorn-Knopp to ensure doubly stochastic matrices | |
| - Spectral norm <= 1, so products stay bounded - **stable signals** | |
| 3. **The "Manifold Dial"** is the Sinkhorn iteration count (k) | |
| - k=0: Unconstrained (like HC) - unstable | |
| - k>=10: Well-projected - stable | |
| - Sweet spot around k=20 for most applications | |
| --- | |
| **Try it yourself!** Modify the sliders above and re-run to explore: | |
| - Different depths (try 100, 200) | |
| - Different matrix sizes (n=2, 8) | |
| - Different random seeds | |
| **Interactive Demo:** https://github.com/bassrehab/mhc-visualizer | |
| """) | |
| return | |
| def _(): | |
| mo.md(r""" | |
| --- | |
| ## References | |
| - **mHC Paper**: [DeepSeek-AI, arXiv:2512.24880](https://arxiv.org/abs/2512.24880) | |
| - **Sinkhorn-Knopp Algorithm**: Sinkhorn & Knopp (1967) | |
| - **This Notebook**: [github.com/bassrehab/mhc-visualizer](https://github.com/bassrehab/mhc-visualizer) | |
| **Author**: Subhadip Mitra (contact@subhadipmitra.com) | |
| """) | |
| return | |
| def _(depth_slider, results): | |
| # Build DataFrame from results | |
| data = [] | |
| for method in ["baseline", "hc", "mhc"]: | |
| for _i, _m in enumerate(results[method]["composite"]): | |
| data.append( | |
| { | |
| "layer": _i + 1, | |
| "gain": _m["forward_gain"], | |
| "method": method.upper() if method != "mhc" else "mHC", | |
| } | |
| ) | |
| df = pd.DataFrame(data) | |
| # Layer selector | |
| layer_selector = mo.ui.slider( | |
| start=1, | |
| stop=depth_slider.value, | |
| value=depth_slider.value, | |
| label="Inspect Layer", | |
| show_value=True, | |
| ) | |
| # Altair chart with log scale | |
| base_chart = ( | |
| alt.Chart(df) | |
| .mark_line(strokeWidth=2.5) | |
| .encode( | |
| x=alt.X( | |
| "layer:Q", | |
| title="Layer Depth", | |
| scale=alt.Scale(domain=[1, depth_slider.value]), | |
| ), | |
| y=alt.Y( | |
| "gain:Q", | |
| scale=alt.Scale(type="log"), | |
| title="Composite Forward Gain (log scale)", | |
| ), | |
| color=alt.Color( | |
| "method:N", | |
| scale=alt.Scale( | |
| domain=["BASELINE", "HC", "mHC"], | |
| range=["#10b981", "#ef4444", "#3b82f6"], | |
| ), | |
| legend=alt.Legend(title="Method"), | |
| ), | |
| strokeDash=alt.StrokeDash( | |
| "method:N", | |
| scale=alt.Scale( | |
| domain=["BASELINE", "HC", "mHC"], range=[[1, 0], [1, 0], [1, 0]] | |
| ), | |
| legend=None, | |
| ), | |
| ) | |
| .properties( | |
| width=700, | |
| height=400, | |
| title="The Manifold Dial: HC Explosion vs mHC Stability", | |
| ) | |
| ) | |
| return base_chart, layer_selector | |
| def _(n_dropdown, seed_input, sinkhorn_slider): | |
| n_val = int(n_dropdown.value) | |
| rng = np.random.default_rng(seed_input.value) | |
| # Generate matrices | |
| base_matrix = rng.standard_normal((n_val, n_val)) | |
| hc_sample = base_matrix # Unconstrained | |
| mhc_sample = sinkhorn_knopp(base_matrix, iterations=sinkhorn_slider.value) | |
| # Create heatmap data | |
| def matrix_to_df(mat, name): | |
| rows = [] | |
| for i in range(mat.shape[0]): | |
| for j in range(mat.shape[1]): | |
| rows.append( | |
| { | |
| "row": str(i), | |
| "col": str(j), | |
| "value": float(mat[i, j]), | |
| "type": name, | |
| } | |
| ) | |
| return pd.DataFrame(rows) | |
| hc_df = matrix_to_df(hc_sample, "HC") | |
| mhc_df = matrix_to_df(mhc_sample, "mHC") | |
| # HC heatmap - use diverging colorscale | |
| hc_heatmap = ( | |
| alt.Chart(hc_df) | |
| .mark_rect() | |
| .encode( | |
| x=alt.X("col:O", title="Column", axis=alt.Axis(labelAngle=0)), | |
| y=alt.Y("row:O", title="Row"), | |
| color=alt.Color( | |
| "value:Q", | |
| scale=alt.Scale(scheme="redblue", domain=[-2, 2]), | |
| legend=alt.Legend(title="Value"), | |
| ), | |
| tooltip=["row", "col", alt.Tooltip("value:Q", format=".3f")], | |
| ) | |
| .properties(width=180, height=180, title=f"HC (Random)") | |
| ) | |
| # mHC heatmap - use sequential colorscale | |
| mhc_heatmap = ( | |
| alt.Chart(mhc_df) | |
| .mark_rect() | |
| .encode( | |
| x=alt.X("col:O", title="Column", axis=alt.Axis(labelAngle=0)), | |
| y=alt.Y("row:O", title="Row"), | |
| color=alt.Color( | |
| "value:Q", | |
| scale=alt.Scale(scheme="blues", domain=[0, 0.6]), | |
| legend=alt.Legend(title="Value"), | |
| ), | |
| tooltip=["row", "col", alt.Tooltip("value:Q", format=".3f")], | |
| ) | |
| .properties(width=180, height=180, title=f"mHC (k={sinkhorn_slider.value})") | |
| ) | |
| heatmaps = hc_heatmap | mhc_heatmap | |
| return hc_sample, heatmaps, mhc_sample | |
| def _(preset_deep, preset_explosion, preset_minimal, randomize_btn): | |
| # Active config based on which preset button was clicked | |
| if preset_explosion.value: | |
| active_config = {"k": 0, "depth": 64, "n": "4", "seed": 42} | |
| elif preset_minimal.value: | |
| active_config = {"k": 5, "depth": 64, "n": "4", "seed": 42} | |
| elif preset_deep.value: | |
| active_config = {"k": 20, "depth": 200, "n": "4", "seed": 42} | |
| elif randomize_btn.value: | |
| active_config = { | |
| "k": 20, | |
| "depth": 64, | |
| "n": "4", | |
| "seed": int(np.random.randint(0, 10000)), | |
| } | |
| else: # default (including preset_default.value) | |
| active_config = {"k": 20, "depth": 64, "n": "4", "seed": 42} | |
| return (active_config,) | |
| if __name__ == "__main__": | |
| app.run() |