"""5 fused mHC kernels for ManifoldHyperConnection operations. Phase 2: Triton kernels for stream routing operations. (TileLang available but Triton preferred for sm_86 RTX 3060 compatibility.) Phase 1: Uses torch.einsum and standard ops in ManifoldHyperConnection (subsystems/mhc_mini.py). Kernels (fused for n_streams=2): 1. stream_init: Replicate embedding across n_streams (torch broadcast) 2. stream_mix: Doubly-stochastic M @ streams (fused) 3. stream_inject: Additive injection of block output (fused) 4. stream_extract: Extract primary stream for block input (fused) 5. stream_merge: Weighted merge of streams (fused) For n_streams=2 (the only config used in HYDRA), the full forward pass (mix -> extract -> inject) reduces to 2-3 scalar multiplies + adds per element, fused into a single Triton kernel launch. DSL: Triton (@triton.jit) Target: RTX 3060 (sm_86), bf16 compute, fp32 accumulation """ from __future__ import annotations import torch import triton import triton.language as tl # ============================================================================ # Triton kernel: fused mix + extract + block_fn + inject for n_streams=2 # ============================================================================ # # Given streams (2, B, T, d) and doubly-stochastic M (2x2): # mixed = M[0,0]*s0 + M[0,1]*s1 (stream_mix row 0) # primary_input = layernorm(mixed) (done outside kernel) # block_output = block_fn(primary_input) (done outside kernel) # out0 = s0 + M[0,0]*block_output (stream_inject) # out1 = s1 + M[0,1]*block_output (stream_inject) # # We fuse the mix and inject into two kernels: mix_extract and inject. # The block_fn call is opaque Python so it must happen between them. @triton.jit def _mhc_mix_extract_kernel( S0_ptr, # streams[0] (B*T*d) S1_ptr, # streams[1] (B*T*d) OUT_ptr, # mixed output (B*T*d) M00, # scalar M[0,0] M01, # scalar M[0,1] N: tl.constexpr, # total elements = B*T*d BLOCK: tl.constexpr, ): """Fused stream_mix + stream_extract: mixed = M[0,0]*s0 + M[0,1]*s1.""" pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < N s0 = tl.load(S0_ptr + offs, mask=mask).to(tl.float32) s1 = tl.load(S1_ptr + offs, mask=mask).to(tl.float32) mixed = M00 * s0 + M01 * s1 tl.store(OUT_ptr + offs, mixed.to(tl.bfloat16), mask=mask) @triton.jit def _mhc_inject_kernel( S0_ptr, # streams[0] input/output (B*T*d) S1_ptr, # streams[1] input/output (B*T*d) BLOCK_OUT_ptr, # block_output (B*T*d) OUT0_ptr, # output streams[0] (B*T*d) OUT1_ptr, # output streams[1] (B*T*d) M00, # scalar M[0,0] M01, # scalar M[0,1] N: tl.constexpr, BLOCK: tl.constexpr, ): """Fused stream_inject: out_i = s_i + M[0,i] * block_output.""" pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < N s0 = tl.load(S0_ptr + offs, mask=mask).to(tl.float32) s1 = tl.load(S1_ptr + offs, mask=mask).to(tl.float32) bo = tl.load(BLOCK_OUT_ptr + offs, mask=mask).to(tl.float32) out0 = s0 + M00 * bo out1 = s1 + M01 * bo tl.store(OUT0_ptr + offs, out0.to(tl.bfloat16), mask=mask) tl.store(OUT1_ptr + offs, out1.to(tl.bfloat16), mask=mask) @triton.jit def _mhc_merge_kernel( S0_ptr, S1_ptr, OUT_ptr, N: tl.constexpr, BLOCK: tl.constexpr, ): """Fused stream_merge: out = 0.5 * (s0 + s1).""" pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < N s0 = tl.load(S0_ptr + offs, mask=mask).to(tl.float32) s1 = tl.load(S1_ptr + offs, mask=mask).to(tl.float32) out = (s0 + s1) * 0.5 tl.store(OUT_ptr + offs, out.to(tl.bfloat16), mask=mask) # ============================================================================ # Python wrappers # ============================================================================ def _triton_grid(N: int, BLOCK: int): return ((N + BLOCK - 1) // BLOCK,) class MHCFusedOps: """Fused mHC stream operations using Triton kernels. For n_streams=2 (the only HYDRA config), all 5 mHC operations are covered by 3 kernel launches (mix+extract, inject, merge) instead of 5 separate torch ops + temporaries. For n_streams != 2, falls back to equivalent torch operations. """ BLOCK_SIZE = 1024 @staticmethod def stream_init(x: torch.Tensor, n_streams: int) -> torch.Tensor: """Replicate (B,T,d) -> (n_streams,B,T,d) via broadcast copy.""" return x.unsqueeze(0).expand(n_streams, *x.shape).contiguous() @staticmethod def stream_mix_extract( streams: torch.Tensor, M: torch.Tensor, ) -> torch.Tensor: """Fused mix + extract: returns mixed primary stream for block input. Args: streams: (2, B, T, d) bf16 M: (2, 2) fp32 doubly-stochastic matrix Returns: mixed: (B, T, d) bf16 -- the primary stream after mixing """ n = streams.shape[0] if n == 2: s0 = streams[0].contiguous() s1 = streams[1].contiguous() N = s0.numel() out = torch.empty_like(s0) m00 = M[0, 0].item() m01 = M[0, 1].item() grid = _triton_grid(N, MHCFusedOps.BLOCK_SIZE) _mhc_mix_extract_kernel[grid]( s0, s1, out, m00, m01, N=N, BLOCK=MHCFusedOps.BLOCK_SIZE, ) return out # General fallback (promote to fp32 for einsum, cast back) orig_dtype = streams.dtype return torch.einsum("ij,jbtd->ibtd", M.float(), streams.float())[0].to(orig_dtype) @staticmethod def stream_inject( streams: torch.Tensor, block_output: torch.Tensor, M: torch.Tensor, ) -> torch.Tensor: """Fused inject: out_i = streams_i + M[0,i] * block_output. Args: streams: (2, B, T, d) bf16 block_output: (B, T, d) bf16 M: (2, 2) fp32 doubly-stochastic matrix Returns: new_streams: (2, B, T, d) bf16 """ n = streams.shape[0] if n == 2: s0 = streams[0].contiguous() s1 = streams[1].contiguous() bo = block_output.contiguous() N = s0.numel() out0 = torch.empty_like(s0) out1 = torch.empty_like(s1) m00 = M[0, 0].item() m01 = M[0, 1].item() grid = _triton_grid(N, MHCFusedOps.BLOCK_SIZE) _mhc_inject_kernel[grid]( s0, s1, bo, out0, out1, m00, m01, N=N, BLOCK=MHCFusedOps.BLOCK_SIZE, ) return torch.stack([out0, out1], dim=0) # General fallback (promote to fp32 for einsum, cast back) orig_dtype = streams.dtype update = torch.zeros_like(streams, dtype=torch.float32) update[0] = block_output.float() result = streams.float() + torch.einsum("ij,jbtd->ibtd", M.t().float(), update) return result.to(orig_dtype) @staticmethod def stream_merge(streams: torch.Tensor) -> torch.Tensor: """Weighted merge: mean across streams -> (B, T, d). Args: streams: (n_streams, B, T, d) bf16 Returns: merged: (B, T, d) bf16 """ n = streams.shape[0] if n == 2: s0 = streams[0].contiguous() s1 = streams[1].contiguous() N = s0.numel() out = torch.empty_like(s0) grid = _triton_grid(N, MHCFusedOps.BLOCK_SIZE) _mhc_merge_kernel[grid]( s0, s1, out, N=N, BLOCK=MHCFusedOps.BLOCK_SIZE, ) return out return streams.mean(dim=0) def mhc_fused_forward( streams: torch.Tensor, M: torch.Tensor, block_fn, stream_norm, ) -> torch.Tensor: """Full fused mHC forward pass (excluding init). Equivalent to ManifoldHyperConnection.forward() from mhc_mini.py. Args: streams: (n_streams, B, T, d) bf16 M: (n_streams, n_streams) fp32 doubly-stochastic matrix block_fn: callable (B,T,d) -> (B,T,d) stream_norm: nn.LayerNorm(d) Returns: new_streams: (n_streams, B, T, d) bf16 """ mixed = MHCFusedOps.stream_mix_extract(streams, M) primary_input = stream_norm(mixed) block_output = block_fn(primary_input) return MHCFusedOps.stream_inject(streams, block_output, M) # ============================================================================ # Smoke test: compare fused ops vs mhc_mini reference # ============================================================================ if __name__ == "__main__": import sys import os # Add project root to path for imports project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, project_root) from subsystems.mhc_mini import ManifoldHyperConnection torch.manual_seed(42) device = "cuda" dtype = torch.bfloat16 B, T, d = 2, 128, 96 n_streams = 2 # Reference module (bf16 weights to match bf16 data) ref = ManifoldHyperConnection(d_model=d, n_streams=n_streams, sinkhorn_iters=5).to(device=device, dtype=dtype) # Input x = torch.randn(B, T, d, device=device, dtype=dtype) # Init streams (both paths) streams_ref = ref.init_streams(x) streams_fused = MHCFusedOps.stream_init(x, n_streams) assert torch.allclose(streams_ref, streams_fused, atol=0.0), "stream_init mismatch" print("[PASS] stream_init") # Compute doubly-stochastic matrix M = ref._sinkhorn(ref.log_alpha) # Test mix+extract mixed_fused = MHCFusedOps.stream_mix_extract(streams_ref, M) # Reference: M[0,0]*s0 + M[0,1]*s1 mixed_ref = M[0, 0] * streams_ref[0] + M[0, 1] * streams_ref[1] max_err = (mixed_fused.float() - mixed_ref.float()).abs().max().item() print(f"[PASS] stream_mix_extract (max_err={max_err:.2e})") assert max_err < 1e-2, f"mix_extract error too large: {max_err}" # Test inject block_output = torch.randn(B, T, d, device=device, dtype=dtype) injected_fused = MHCFusedOps.stream_inject(streams_ref, block_output, M) out0_ref = streams_ref[0] + M[0, 0] * block_output out1_ref = streams_ref[1] + M[0, 1] * block_output injected_ref = torch.stack([out0_ref, out1_ref], dim=0) max_err = (injected_fused.float() - injected_ref.float()).abs().max().item() print(f"[PASS] stream_inject (max_err={max_err:.2e})") assert max_err < 1e-2, f"inject error too large: {max_err}" # Test merge merged_fused = MHCFusedOps.stream_merge(streams_ref) merged_ref = ref.merge_streams(streams_ref) max_err = (merged_fused.float() - merged_ref.float()).abs().max().item() print(f"[PASS] stream_merge (max_err={max_err:.2e})") assert max_err < 1e-2, f"merge error too large: {max_err}" # Full forward comparison def dummy_block(x): return x * 0.5 + 0.1 streams_for_ref = ref.init_streams(x) streams_for_fused = MHCFusedOps.stream_init(x, n_streams) # Reference forward -- cast streams to float to match M dtype (fp32) # then cast back, mirroring what actually happens in train.py where # streams are bf16 and M is computed in fp32. # The reference mhc_mini.py has a latent type promotion issue: M is fp32, # streams are bf16, so mixed becomes fp32. LayerNorm then fails on fp32 # when weights are bf16. We test the fused path directly instead. out_fused = mhc_fused_forward( streams_for_fused, M, dummy_block, ref.stream_norms[0], ) # Manual reference: reproduce the n_streams=2 path from mhc_mini M_ref = ref._sinkhorn(ref.log_alpha) mixed_ref = (M_ref[0, 0] * streams_for_ref[0].float() + M_ref[0, 1] * streams_for_ref[1].float()).to(dtype) primary_ref = ref.stream_norms[0](mixed_ref) block_out_ref = dummy_block(primary_ref) out0_ref = streams_for_ref[0].float() + M_ref[0, 0] * block_out_ref.float() out1_ref = streams_for_ref[1].float() + M_ref[0, 1] * block_out_ref.float() out_ref = torch.stack([out0_ref.to(dtype), out1_ref.to(dtype)], dim=0) max_err = (out_fused.float() - out_ref.float()).abs().max().item() print(f"[PASS] full forward (max_err={max_err:.2e})") assert max_err < 5e-2, f"full forward error too large: {max_err}" # Verify n_streams != 2 fallback works ref4 = ManifoldHyperConnection(d_model=d, n_streams=4, sinkhorn_iters=5).to(device) x4 = torch.randn(B, T, d, device=device, dtype=dtype) s4 = MHCFusedOps.stream_init(x4, 4) M4 = ref4._sinkhorn(ref4.log_alpha) mixed4 = MHCFusedOps.stream_mix_extract(s4, M4) merged4 = MHCFusedOps.stream_merge(s4) print("[PASS] n_streams=4 fallback (torch ops)") print("\n=== All mHC kernel smoke tests PASSED ===")