import marimo __generated_with = "0.18.4" app = marimo.App(width="medium", app_title="Mhc Exploration") with app.setup(hide_code=True): import sys import subprocess import tempfile from pathlib import Path import numpy as np import altair as alt import pandas as pd import marimo as mo from mhc import run_comparison, sinkhorn_knopp, compute_all_metrics @app.cell(hide_code=True) def _(): mo.md(r""" # Exploring Manifold-Constrained Hyper-Connections (mHC) [![Open in molab](https://marimo.io/molab-shield.svg)](https://molab.marimo.io/github/bassrehab/mhc-visualizer/blob/main/notebook/mhc_exploration.py) This interactive notebook demonstrates the key insight from DeepSeek's mHC paper: **unconstrained residual mixing matrices cause signal explosion in deep networks, while doubly stochastic constraints keep signals bounded.** Use the controls below to explore how Sinkhorn iterations affect stability! **Paper:** https://arxiv.org/abs/2512.24880 **Viz Author (under MIT):** Subhadip Mitra **Implementation/ Viz Repository:** https://github.com/bassrehab/mhc-visualizer """) return @app.cell(hide_code=True) def _(): mo.md(r""" ## The Sinkhorn-Knopp Algorithm The Sinkhorn-Knopp algorithm projects any positive matrix onto the set of **doubly stochastic matrices** - matrices where all rows and columns sum to 1. ### Why Doubly Stochastic? Doubly stochastic matrices have a crucial property: **they are closed under multiplication**. This means: - If A and B are doubly stochastic, then A @ B is also doubly stochastic - The spectral norm is bounded: ||A|| <= 1 - Signal propagation stays bounded even through many layers ### The Algorithm Starting from any positive matrix, we alternate between: 1. Normalizing rows to sum to 1 2. Normalizing columns to sum to 1 This converges to a doubly stochastic matrix! """) return @app.cell(hide_code=True) def _(): mo.md(r""" ## Interactive Controls Adjust the parameters below to see how they affect signal propagation stability. """) return @app.cell(hide_code=True) def _(active_config): # Sliders initialized from active config (re-created when preset clicked) sinkhorn_slider = mo.ui.slider( start=0, stop=30, value=active_config["k"], step=1, label="Sinkhorn Iterations (k)", show_value=True, ) depth_slider = mo.ui.slider( start=10, stop=200, value=active_config["depth"], step=10, label="Network Depth", show_value=True, ) n_dropdown = mo.ui.dropdown( options={"2": 2, "4": 4, "8": 8}, value=active_config["n"], label="Number of Streams (n)", ) seed_input = mo.ui.number( value=active_config["seed"], start=0, stop=10000, label="Random Seed" ) controls = mo.hstack( [sinkhorn_slider, depth_slider, n_dropdown, seed_input], justify="start", gap=2 ) controls return depth_slider, n_dropdown, seed_input, sinkhorn_slider @app.cell(hide_code=True) def _(): # Preset buttons - clicking triggers re-run with new defaults preset_default = mo.ui.run_button(label="Default") preset_explosion = mo.ui.run_button(label="HC Explosion (k=0)") preset_minimal = mo.ui.run_button(label="Minimal Projection (k=5)") preset_deep = mo.ui.run_button(label="Deep Network (200)") randomize_btn = mo.ui.run_button(label="Randomize Seed") presets = mo.hstack( [preset_default, preset_explosion, preset_minimal, preset_deep, randomize_btn], justify="start", gap=1, ) presets return preset_deep, preset_explosion, preset_minimal, randomize_btn @app.cell def _(depth_slider, n_dropdown, seed_input, sinkhorn_slider): # Run the simulation with current parameters results = run_comparison( depth=depth_slider.value, n=int(n_dropdown.value), sinkhorn_iters=sinkhorn_slider.value, seed=seed_input.value, ) return (results,) @app.cell(hide_code=True) def _(): mo.md(r""" ## Signal Propagation: The Explosion Problem The real issue isn't single-layer behavior - it's what happens when we **compose many layers**. In a deep network, the effective transformation is: $$H_{composite} = H_L \cdot H_{L-1} \cdot ... \cdot H_1$$ Watch the chart below: **HC (red) explodes exponentially, while mHC (blue) stays bounded!** """) return @app.cell def _(base_chart): mo.ui.altair_chart(base_chart) return @app.cell(hide_code=True) def _(): mo.md(r""" ## Stability Metrics The table below shows metrics at the selected layer: - **Forward Gain**: Maximum row sum (worst-case signal amplification) - **Backward Gain**: Maximum column sum (gradient flow) - **Spectral Norm**: Largest singular value (operator norm) For doubly stochastic matrices (mHC), all these should be close to 1! """) return @app.cell(hide_code=True) def _(layer_selector, results): layer_idx = layer_selector.value - 1 baseline_m = results["baseline"]["composite"][layer_idx] hc_m = results["hc"]["composite"][layer_idx] mhc_m = results["mhc"]["composite"][layer_idx] def _format_gain(g): if g > 1000: return f"{g:.2e}" return f"{g:.4f}" def _status_badge(g): if g < 2: return mo.md(f"**{_format_gain(g)}** :green_circle:") elif g < 10: return mo.md(f"**{_format_gain(g)}** :yellow_circle:") else: return mo.md(f"**{_format_gain(g)}** :red_circle:") metrics_md = mo.md(f""" ### Metrics at Layer {layer_selector.value} | Metric | Baseline | HC (Unconstrained) | mHC (Sinkhorn) | |--------|----------|-------------------|----------------| | Forward Gain | {_format_gain(baseline_m["forward_gain"])} | {_format_gain(hc_m["forward_gain"])} | {_format_gain(mhc_m["forward_gain"])} | | Backward Gain | {_format_gain(baseline_m["backward_gain"])} | {_format_gain(hc_m["backward_gain"])} | {_format_gain(mhc_m["backward_gain"])} | | Spectral Norm | {_format_gain(baseline_m["spectral_norm"])} | {_format_gain(hc_m["spectral_norm"])} | {_format_gain(mhc_m["spectral_norm"])} | | Row Sum Dev | {baseline_m["row_sum_max_dev"]:.2e} | {hc_m["row_sum_max_dev"]:.2e} | {mhc_m["row_sum_max_dev"]:.2e} | | Col Sum Dev | {baseline_m["col_sum_max_dev"]:.2e} | {hc_m["col_sum_max_dev"]:.2e} | {mhc_m["col_sum_max_dev"]:.2e} | """) metrics_md return @app.cell(hide_code=True) def _(): mo.md(r""" ## Matrix Visualization Compare a sample residual mixing matrix before and after Sinkhorn projection. - **HC (left)**: Random matrix with arbitrary row/column sums - **mHC (right)**: Sinkhorn-projected doubly stochastic matrix (all rows and columns sum to 1) """) return @app.cell def _(heatmaps): heatmaps return @app.cell(hide_code=True) def _(hc_sample, mhc_sample): # Show row and column sums hc_row_sums = hc_sample.sum(axis=1) hc_col_sums = hc_sample.sum(axis=0) mhc_row_sums = mhc_sample.sum(axis=1) mhc_col_sums = mhc_sample.sum(axis=0) sums_md = mo.md(f""" **Row/Column Sums:** | | HC | mHC | |---|---|---| | Row Sums | {np.array2string(hc_row_sums, precision=2)} | {np.array2string(mhc_row_sums, precision=3)} | | Col Sums | {np.array2string(hc_col_sums, precision=2)} | {np.array2string(mhc_col_sums, precision=3)} | | Max Row Dev from 1 | {np.abs(hc_row_sums - 1).max():.4f} | {np.abs(mhc_row_sums - 1).max():.2e} | | Max Col Dev from 1 | {np.abs(hc_col_sums - 1).max():.4f} | {np.abs(mhc_col_sums - 1).max():.2e} | Notice how mHC row/column sums are all ~1.0 (doubly stochastic)! """) sums_md return @app.cell(hide_code=True) def _(): mo.md(r""" ## The Manifold Dial: Varying Sinkhorn Iterations Watch how the matrix transforms as we increase iterations: - **k=0**: Raw random matrix (same as HC) - unconstrained - **k=1-5**: Partial projection, rapid stabilization - **k=10-20**: Fully doubly stochastic """) return @app.cell(hide_code=True) def _(seed_input): # Show progression of Sinkhorn iterations iters_to_show = [0, 1, 2, 5, 10, 20] rng_dial = np.random.default_rng(seed_input.value) dial_base = rng_dial.standard_normal((4, 4)) dial_data = [] for k in iters_to_show: if k == 0: mat = dial_base # Raw random matrix (same as HC) else: mat = sinkhorn_knopp(dial_base, iterations=k) for _i in range(4): for _j in range(4): dial_data.append( { "row": str(_i), "col": str(_j), "value": float(mat[_i, _j]), "k": f"k={k}", } ) dial_df = pd.DataFrame(dial_data) dial_chart = ( alt.Chart(dial_df) .mark_rect() .encode( x=alt.X("col:O", title=None, axis=alt.Axis(labels=False)), y=alt.Y("row:O", title=None, axis=alt.Axis(labels=False)), color=alt.Color( "value:Q", scale=alt.Scale(scheme="blues", domain=[0, 0.6]), legend=None ), tooltip=["row", "col", alt.Tooltip("value:Q", format=".3f")], ) .properties(width=100, height=100) .facet( column=alt.Column("k:N", title="Sinkhorn Iterations", sort=iters_to_show) ) .properties( title="The Manifold Dial: Sinkhorn Iterations Transform Random to Doubly Stochastic" ) ) dial_chart return @app.cell(hide_code=True) def _(): mo.md(r""" ## Why Does This Work? ### The Mathematics of Stability Doubly stochastic matrices have three key properties: 1. **Spectral norm <= 1**: The matrix doesn't amplify signals 2. **Closed under multiplication**: Products remain doubly stochastic 3. **Convex combinations of permutations**: Acts like a weighted average (Birkhoff-von Neumann theorem) When you multiply many doubly stochastic matrices together, the result stays bounded because each multiplication is **non-expansive**. In contrast, unconstrained matrices compound their gains exponentially: - If each matrix has gain 1.1, after 64 layers: $1.1^{64} \approx 300$ - If each matrix has gain 1.5, after 64 layers: $1.5^{64} \approx 10^{11}$! """) return @app.cell(hide_code=True) def _(): mo.md(r""" ## Key Takeaways 1. **HC (Hyper-Connections)** use unconstrained residual mixing matrices - Each layer's matrix can have arbitrary row/column sums - Over many layers, these compound into **exponential explosion** 2. **mHC (Manifold-Constrained HC)** projects matrices onto the Birkhoff polytope - Uses Sinkhorn-Knopp to ensure doubly stochastic matrices - Spectral norm <= 1, so products stay bounded - **stable signals** 3. **The "Manifold Dial"** is the Sinkhorn iteration count (k) - k=0: Unconstrained (like HC) - unstable - k>=10: Well-projected - stable - Sweet spot around k=20 for most applications --- **Try it yourself!** Modify the sliders above and re-run to explore: - Different depths (try 100, 200) - Different matrix sizes (n=2, 8) - Different random seeds **Interactive Demo:** https://github.com/bassrehab/mhc-visualizer """) return @app.cell(hide_code=True) def _(): mo.md(r""" --- ## References - **mHC Paper**: [DeepSeek-AI, arXiv:2512.24880](https://arxiv.org/abs/2512.24880) - **Sinkhorn-Knopp Algorithm**: Sinkhorn & Knopp (1967) - **This Notebook**: [github.com/bassrehab/mhc-visualizer](https://github.com/bassrehab/mhc-visualizer) **Author**: Subhadip Mitra (contact@subhadipmitra.com) """) return @app.cell def _(depth_slider, results): # Build DataFrame from results data = [] for method in ["baseline", "hc", "mhc"]: for _i, _m in enumerate(results[method]["composite"]): data.append( { "layer": _i + 1, "gain": _m["forward_gain"], "method": method.upper() if method != "mhc" else "mHC", } ) df = pd.DataFrame(data) # Layer selector layer_selector = mo.ui.slider( start=1, stop=depth_slider.value, value=depth_slider.value, label="Inspect Layer", show_value=True, ) # Altair chart with log scale base_chart = ( alt.Chart(df) .mark_line(strokeWidth=2.5) .encode( x=alt.X( "layer:Q", title="Layer Depth", scale=alt.Scale(domain=[1, depth_slider.value]), ), y=alt.Y( "gain:Q", scale=alt.Scale(type="log"), title="Composite Forward Gain (log scale)", ), color=alt.Color( "method:N", scale=alt.Scale( domain=["BASELINE", "HC", "mHC"], range=["#10b981", "#ef4444", "#3b82f6"], ), legend=alt.Legend(title="Method"), ), strokeDash=alt.StrokeDash( "method:N", scale=alt.Scale( domain=["BASELINE", "HC", "mHC"], range=[[1, 0], [1, 0], [1, 0]] ), legend=None, ), ) .properties( width=700, height=400, title="The Manifold Dial: HC Explosion vs mHC Stability", ) ) return base_chart, layer_selector @app.cell def _(n_dropdown, seed_input, sinkhorn_slider): n_val = int(n_dropdown.value) rng = np.random.default_rng(seed_input.value) # Generate matrices base_matrix = rng.standard_normal((n_val, n_val)) hc_sample = base_matrix # Unconstrained mhc_sample = sinkhorn_knopp(base_matrix, iterations=sinkhorn_slider.value) # Create heatmap data def matrix_to_df(mat, name): rows = [] for i in range(mat.shape[0]): for j in range(mat.shape[1]): rows.append( { "row": str(i), "col": str(j), "value": float(mat[i, j]), "type": name, } ) return pd.DataFrame(rows) hc_df = matrix_to_df(hc_sample, "HC") mhc_df = matrix_to_df(mhc_sample, "mHC") # HC heatmap - use diverging colorscale hc_heatmap = ( alt.Chart(hc_df) .mark_rect() .encode( x=alt.X("col:O", title="Column", axis=alt.Axis(labelAngle=0)), y=alt.Y("row:O", title="Row"), color=alt.Color( "value:Q", scale=alt.Scale(scheme="redblue", domain=[-2, 2]), legend=alt.Legend(title="Value"), ), tooltip=["row", "col", alt.Tooltip("value:Q", format=".3f")], ) .properties(width=180, height=180, title=f"HC (Random)") ) # mHC heatmap - use sequential colorscale mhc_heatmap = ( alt.Chart(mhc_df) .mark_rect() .encode( x=alt.X("col:O", title="Column", axis=alt.Axis(labelAngle=0)), y=alt.Y("row:O", title="Row"), color=alt.Color( "value:Q", scale=alt.Scale(scheme="blues", domain=[0, 0.6]), legend=alt.Legend(title="Value"), ), tooltip=["row", "col", alt.Tooltip("value:Q", format=".3f")], ) .properties(width=180, height=180, title=f"mHC (k={sinkhorn_slider.value})") ) heatmaps = hc_heatmap | mhc_heatmap return hc_sample, heatmaps, mhc_sample @app.cell(hide_code=True) def _(preset_deep, preset_explosion, preset_minimal, randomize_btn): # Active config based on which preset button was clicked if preset_explosion.value: active_config = {"k": 0, "depth": 64, "n": "4", "seed": 42} elif preset_minimal.value: active_config = {"k": 5, "depth": 64, "n": "4", "seed": 42} elif preset_deep.value: active_config = {"k": 20, "depth": 200, "n": "4", "seed": 42} elif randomize_btn.value: active_config = { "k": 20, "depth": 64, "n": "4", "seed": int(np.random.randint(0, 10000)), } else: # default (including preset_default.value) active_config = {"k": 20, "depth": 64, "n": "4", "seed": 42} return (active_config,) if __name__ == "__main__": app.run()