| """5 fused mHC kernels for ManifoldHyperConnection operations. |
| |
| Phase 2: Triton kernels for stream routing operations. |
| (TileLang available but Triton preferred for sm_86 RTX 3060 compatibility.) |
| |
| Phase 1: Uses torch.einsum and standard ops in ManifoldHyperConnection |
| (subsystems/mhc_mini.py). |
| |
| Kernels (fused for n_streams=2): |
| 1. stream_init: Replicate embedding across n_streams (torch broadcast) |
| 2. stream_mix: Doubly-stochastic M @ streams (fused) |
| 3. stream_inject: Additive injection of block output (fused) |
| 4. stream_extract: Extract primary stream for block input (fused) |
| 5. stream_merge: Weighted merge of streams (fused) |
| |
| For n_streams=2 (the only config used in HYDRA), the full forward pass |
| (mix -> extract -> inject) reduces to 2-3 scalar multiplies + adds per |
| element, fused into a single Triton kernel launch. |
| |
| DSL: Triton (@triton.jit) |
| Target: RTX 3060 (sm_86), bf16 compute, fp32 accumulation |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import triton |
| import triton.language as tl |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @triton.jit |
| def _mhc_mix_extract_kernel( |
| S0_ptr, |
| S1_ptr, |
| OUT_ptr, |
| M00, |
| M01, |
| N: tl.constexpr, |
| BLOCK: tl.constexpr, |
| ): |
| """Fused stream_mix + stream_extract: mixed = M[0,0]*s0 + M[0,1]*s1.""" |
| pid = tl.program_id(0) |
| offs = pid * BLOCK + tl.arange(0, BLOCK) |
| mask = offs < N |
|
|
| s0 = tl.load(S0_ptr + offs, mask=mask).to(tl.float32) |
| s1 = tl.load(S1_ptr + offs, mask=mask).to(tl.float32) |
| mixed = M00 * s0 + M01 * s1 |
| tl.store(OUT_ptr + offs, mixed.to(tl.bfloat16), mask=mask) |
|
|
|
|
| @triton.jit |
| def _mhc_inject_kernel( |
| S0_ptr, |
| S1_ptr, |
| BLOCK_OUT_ptr, |
| OUT0_ptr, |
| OUT1_ptr, |
| M00, |
| M01, |
| N: tl.constexpr, |
| BLOCK: tl.constexpr, |
| ): |
| """Fused stream_inject: out_i = s_i + M[0,i] * block_output.""" |
| pid = tl.program_id(0) |
| offs = pid * BLOCK + tl.arange(0, BLOCK) |
| mask = offs < N |
|
|
| s0 = tl.load(S0_ptr + offs, mask=mask).to(tl.float32) |
| s1 = tl.load(S1_ptr + offs, mask=mask).to(tl.float32) |
| bo = tl.load(BLOCK_OUT_ptr + offs, mask=mask).to(tl.float32) |
|
|
| out0 = s0 + M00 * bo |
| out1 = s1 + M01 * bo |
|
|
| tl.store(OUT0_ptr + offs, out0.to(tl.bfloat16), mask=mask) |
| tl.store(OUT1_ptr + offs, out1.to(tl.bfloat16), mask=mask) |
|
|
|
|
| @triton.jit |
| def _mhc_merge_kernel( |
| S0_ptr, |
| S1_ptr, |
| OUT_ptr, |
| N: tl.constexpr, |
| BLOCK: tl.constexpr, |
| ): |
| """Fused stream_merge: out = 0.5 * (s0 + s1).""" |
| pid = tl.program_id(0) |
| offs = pid * BLOCK + tl.arange(0, BLOCK) |
| mask = offs < N |
|
|
| s0 = tl.load(S0_ptr + offs, mask=mask).to(tl.float32) |
| s1 = tl.load(S1_ptr + offs, mask=mask).to(tl.float32) |
| out = (s0 + s1) * 0.5 |
| tl.store(OUT_ptr + offs, out.to(tl.bfloat16), mask=mask) |
|
|
|
|
| |
| |
| |
|
|
| def _triton_grid(N: int, BLOCK: int): |
| return ((N + BLOCK - 1) // BLOCK,) |
|
|
|
|
| class MHCFusedOps: |
| """Fused mHC stream operations using Triton kernels. |
| |
| For n_streams=2 (the only HYDRA config), all 5 mHC operations are |
| covered by 3 kernel launches (mix+extract, inject, merge) instead of |
| 5 separate torch ops + temporaries. |
| |
| For n_streams != 2, falls back to equivalent torch operations. |
| """ |
|
|
| BLOCK_SIZE = 1024 |
|
|
| @staticmethod |
| def stream_init(x: torch.Tensor, n_streams: int) -> torch.Tensor: |
| """Replicate (B,T,d) -> (n_streams,B,T,d) via broadcast copy.""" |
| return x.unsqueeze(0).expand(n_streams, *x.shape).contiguous() |
|
|
| @staticmethod |
| def stream_mix_extract( |
| streams: torch.Tensor, |
| M: torch.Tensor, |
| ) -> torch.Tensor: |
| """Fused mix + extract: returns mixed primary stream for block input. |
| |
| Args: |
| streams: (2, B, T, d) bf16 |
| M: (2, 2) fp32 doubly-stochastic matrix |
| |
| Returns: |
| mixed: (B, T, d) bf16 -- the primary stream after mixing |
| """ |
| n = streams.shape[0] |
| if n == 2: |
| s0 = streams[0].contiguous() |
| s1 = streams[1].contiguous() |
| N = s0.numel() |
| out = torch.empty_like(s0) |
| m00 = M[0, 0].item() |
| m01 = M[0, 1].item() |
| grid = _triton_grid(N, MHCFusedOps.BLOCK_SIZE) |
| _mhc_mix_extract_kernel[grid]( |
| s0, s1, out, m00, m01, |
| N=N, BLOCK=MHCFusedOps.BLOCK_SIZE, |
| ) |
| return out |
| |
| orig_dtype = streams.dtype |
| return torch.einsum("ij,jbtd->ibtd", M.float(), streams.float())[0].to(orig_dtype) |
|
|
| @staticmethod |
| def stream_inject( |
| streams: torch.Tensor, |
| block_output: torch.Tensor, |
| M: torch.Tensor, |
| ) -> torch.Tensor: |
| """Fused inject: out_i = streams_i + M[0,i] * block_output. |
| |
| Args: |
| streams: (2, B, T, d) bf16 |
| block_output: (B, T, d) bf16 |
| M: (2, 2) fp32 doubly-stochastic matrix |
| |
| Returns: |
| new_streams: (2, B, T, d) bf16 |
| """ |
| n = streams.shape[0] |
| if n == 2: |
| s0 = streams[0].contiguous() |
| s1 = streams[1].contiguous() |
| bo = block_output.contiguous() |
| N = s0.numel() |
| out0 = torch.empty_like(s0) |
| out1 = torch.empty_like(s1) |
| m00 = M[0, 0].item() |
| m01 = M[0, 1].item() |
| grid = _triton_grid(N, MHCFusedOps.BLOCK_SIZE) |
| _mhc_inject_kernel[grid]( |
| s0, s1, bo, out0, out1, m00, m01, |
| N=N, BLOCK=MHCFusedOps.BLOCK_SIZE, |
| ) |
| return torch.stack([out0, out1], dim=0) |
| |
| orig_dtype = streams.dtype |
| update = torch.zeros_like(streams, dtype=torch.float32) |
| update[0] = block_output.float() |
| result = streams.float() + torch.einsum("ij,jbtd->ibtd", M.t().float(), update) |
| return result.to(orig_dtype) |
|
|
| @staticmethod |
| def stream_merge(streams: torch.Tensor) -> torch.Tensor: |
| """Weighted merge: mean across streams -> (B, T, d). |
| |
| Args: |
| streams: (n_streams, B, T, d) bf16 |
| |
| Returns: |
| merged: (B, T, d) bf16 |
| """ |
| n = streams.shape[0] |
| if n == 2: |
| s0 = streams[0].contiguous() |
| s1 = streams[1].contiguous() |
| N = s0.numel() |
| out = torch.empty_like(s0) |
| grid = _triton_grid(N, MHCFusedOps.BLOCK_SIZE) |
| _mhc_merge_kernel[grid]( |
| s0, s1, out, |
| N=N, BLOCK=MHCFusedOps.BLOCK_SIZE, |
| ) |
| return out |
| return streams.mean(dim=0) |
|
|
|
|
| def mhc_fused_forward( |
| streams: torch.Tensor, |
| M: torch.Tensor, |
| block_fn, |
| stream_norm, |
| ) -> torch.Tensor: |
| """Full fused mHC forward pass (excluding init). |
| |
| Equivalent to ManifoldHyperConnection.forward() from mhc_mini.py. |
| |
| Args: |
| streams: (n_streams, B, T, d) bf16 |
| M: (n_streams, n_streams) fp32 doubly-stochastic matrix |
| block_fn: callable (B,T,d) -> (B,T,d) |
| stream_norm: nn.LayerNorm(d) |
| |
| Returns: |
| new_streams: (n_streams, B, T, d) bf16 |
| """ |
| mixed = MHCFusedOps.stream_mix_extract(streams, M) |
| primary_input = stream_norm(mixed) |
| block_output = block_fn(primary_input) |
| return MHCFusedOps.stream_inject(streams, block_output, M) |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import sys |
| import os |
|
|
| |
| project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| sys.path.insert(0, project_root) |
|
|
| from subsystems.mhc_mini import ManifoldHyperConnection |
|
|
| torch.manual_seed(42) |
| device = "cuda" |
| dtype = torch.bfloat16 |
|
|
| B, T, d = 2, 128, 96 |
| n_streams = 2 |
|
|
| |
| ref = ManifoldHyperConnection(d_model=d, n_streams=n_streams, sinkhorn_iters=5).to(device=device, dtype=dtype) |
|
|
| |
| x = torch.randn(B, T, d, device=device, dtype=dtype) |
|
|
| |
| streams_ref = ref.init_streams(x) |
| streams_fused = MHCFusedOps.stream_init(x, n_streams) |
| assert torch.allclose(streams_ref, streams_fused, atol=0.0), "stream_init mismatch" |
| print("[PASS] stream_init") |
|
|
| |
| M = ref._sinkhorn(ref.log_alpha) |
|
|
| |
| mixed_fused = MHCFusedOps.stream_mix_extract(streams_ref, M) |
| |
| mixed_ref = M[0, 0] * streams_ref[0] + M[0, 1] * streams_ref[1] |
| max_err = (mixed_fused.float() - mixed_ref.float()).abs().max().item() |
| print(f"[PASS] stream_mix_extract (max_err={max_err:.2e})") |
| assert max_err < 1e-2, f"mix_extract error too large: {max_err}" |
|
|
| |
| block_output = torch.randn(B, T, d, device=device, dtype=dtype) |
| injected_fused = MHCFusedOps.stream_inject(streams_ref, block_output, M) |
| out0_ref = streams_ref[0] + M[0, 0] * block_output |
| out1_ref = streams_ref[1] + M[0, 1] * block_output |
| injected_ref = torch.stack([out0_ref, out1_ref], dim=0) |
| max_err = (injected_fused.float() - injected_ref.float()).abs().max().item() |
| print(f"[PASS] stream_inject (max_err={max_err:.2e})") |
| assert max_err < 1e-2, f"inject error too large: {max_err}" |
|
|
| |
| merged_fused = MHCFusedOps.stream_merge(streams_ref) |
| merged_ref = ref.merge_streams(streams_ref) |
| max_err = (merged_fused.float() - merged_ref.float()).abs().max().item() |
| print(f"[PASS] stream_merge (max_err={max_err:.2e})") |
| assert max_err < 1e-2, f"merge error too large: {max_err}" |
|
|
| |
| def dummy_block(x): |
| return x * 0.5 + 0.1 |
|
|
| streams_for_ref = ref.init_streams(x) |
| streams_for_fused = MHCFusedOps.stream_init(x, n_streams) |
|
|
| |
| |
| |
| |
| |
| |
| out_fused = mhc_fused_forward( |
| streams_for_fused, M, dummy_block, ref.stream_norms[0], |
| ) |
|
|
| |
| M_ref = ref._sinkhorn(ref.log_alpha) |
| mixed_ref = (M_ref[0, 0] * streams_for_ref[0].float() + M_ref[0, 1] * streams_for_ref[1].float()).to(dtype) |
| primary_ref = ref.stream_norms[0](mixed_ref) |
| block_out_ref = dummy_block(primary_ref) |
| out0_ref = streams_for_ref[0].float() + M_ref[0, 0] * block_out_ref.float() |
| out1_ref = streams_for_ref[1].float() + M_ref[0, 1] * block_out_ref.float() |
| out_ref = torch.stack([out0_ref.to(dtype), out1_ref.to(dtype)], dim=0) |
|
|
| max_err = (out_fused.float() - out_ref.float()).abs().max().item() |
| print(f"[PASS] full forward (max_err={max_err:.2e})") |
| assert max_err < 5e-2, f"full forward error too large: {max_err}" |
|
|
| |
| ref4 = ManifoldHyperConnection(d_model=d, n_streams=4, sinkhorn_iters=5).to(device) |
| x4 = torch.randn(B, T, d, device=device, dtype=dtype) |
| s4 = MHCFusedOps.stream_init(x4, 4) |
| M4 = ref4._sinkhorn(ref4.log_alpha) |
| mixed4 = MHCFusedOps.stream_mix_extract(s4, M4) |
| merged4 = MHCFusedOps.stream_merge(s4) |
| print("[PASS] n_streams=4 fallback (torch ops)") |
|
|
| print("\n=== All mHC kernel smoke tests PASSED ===") |
|
|