icarus112's picture
Update Feather a10g-large training runtime image
c475135 verified
"""5 fused mHC kernels for ManifoldHyperConnection operations.
Phase 2: Triton kernels for stream routing operations.
(TileLang available but Triton preferred for sm_86 RTX 3060 compatibility.)
Phase 1: Uses torch.einsum and standard ops in ManifoldHyperConnection
(subsystems/mhc_mini.py).
Kernels (fused for n_streams=2):
1. stream_init: Replicate embedding across n_streams (torch broadcast)
2. stream_mix: Doubly-stochastic M @ streams (fused)
3. stream_inject: Additive injection of block output (fused)
4. stream_extract: Extract primary stream for block input (fused)
5. stream_merge: Weighted merge of streams (fused)
For n_streams=2 (the only config used in HYDRA), the full forward pass
(mix -> extract -> inject) reduces to 2-3 scalar multiplies + adds per
element, fused into a single Triton kernel launch.
DSL: Triton (@triton.jit)
Target: RTX 3060 (sm_86), bf16 compute, fp32 accumulation
"""
from __future__ import annotations
import torch
import triton
import triton.language as tl
# ============================================================================
# Triton kernel: fused mix + extract + block_fn + inject for n_streams=2
# ============================================================================
#
# Given streams (2, B, T, d) and doubly-stochastic M (2x2):
# mixed = M[0,0]*s0 + M[0,1]*s1 (stream_mix row 0)
# primary_input = layernorm(mixed) (done outside kernel)
# block_output = block_fn(primary_input) (done outside kernel)
# out0 = s0 + M[0,0]*block_output (stream_inject)
# out1 = s1 + M[0,1]*block_output (stream_inject)
#
# We fuse the mix and inject into two kernels: mix_extract and inject.
# The block_fn call is opaque Python so it must happen between them.
@triton.jit
def _mhc_mix_extract_kernel(
S0_ptr, # streams[0] (B*T*d)
S1_ptr, # streams[1] (B*T*d)
OUT_ptr, # mixed output (B*T*d)
M00, # scalar M[0,0]
M01, # scalar M[0,1]
N: tl.constexpr, # total elements = B*T*d
BLOCK: tl.constexpr,
):
"""Fused stream_mix + stream_extract: mixed = M[0,0]*s0 + M[0,1]*s1."""
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
s0 = tl.load(S0_ptr + offs, mask=mask).to(tl.float32)
s1 = tl.load(S1_ptr + offs, mask=mask).to(tl.float32)
mixed = M00 * s0 + M01 * s1
tl.store(OUT_ptr + offs, mixed.to(tl.bfloat16), mask=mask)
@triton.jit
def _mhc_inject_kernel(
S0_ptr, # streams[0] input/output (B*T*d)
S1_ptr, # streams[1] input/output (B*T*d)
BLOCK_OUT_ptr, # block_output (B*T*d)
OUT0_ptr, # output streams[0] (B*T*d)
OUT1_ptr, # output streams[1] (B*T*d)
M00, # scalar M[0,0]
M01, # scalar M[0,1]
N: tl.constexpr,
BLOCK: tl.constexpr,
):
"""Fused stream_inject: out_i = s_i + M[0,i] * block_output."""
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
s0 = tl.load(S0_ptr + offs, mask=mask).to(tl.float32)
s1 = tl.load(S1_ptr + offs, mask=mask).to(tl.float32)
bo = tl.load(BLOCK_OUT_ptr + offs, mask=mask).to(tl.float32)
out0 = s0 + M00 * bo
out1 = s1 + M01 * bo
tl.store(OUT0_ptr + offs, out0.to(tl.bfloat16), mask=mask)
tl.store(OUT1_ptr + offs, out1.to(tl.bfloat16), mask=mask)
@triton.jit
def _mhc_merge_kernel(
S0_ptr,
S1_ptr,
OUT_ptr,
N: tl.constexpr,
BLOCK: tl.constexpr,
):
"""Fused stream_merge: out = 0.5 * (s0 + s1)."""
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
s0 = tl.load(S0_ptr + offs, mask=mask).to(tl.float32)
s1 = tl.load(S1_ptr + offs, mask=mask).to(tl.float32)
out = (s0 + s1) * 0.5
tl.store(OUT_ptr + offs, out.to(tl.bfloat16), mask=mask)
# ============================================================================
# Python wrappers
# ============================================================================
def _triton_grid(N: int, BLOCK: int):
return ((N + BLOCK - 1) // BLOCK,)
class MHCFusedOps:
"""Fused mHC stream operations using Triton kernels.
For n_streams=2 (the only HYDRA config), all 5 mHC operations are
covered by 3 kernel launches (mix+extract, inject, merge) instead of
5 separate torch ops + temporaries.
For n_streams != 2, falls back to equivalent torch operations.
"""
BLOCK_SIZE = 1024
@staticmethod
def stream_init(x: torch.Tensor, n_streams: int) -> torch.Tensor:
"""Replicate (B,T,d) -> (n_streams,B,T,d) via broadcast copy."""
return x.unsqueeze(0).expand(n_streams, *x.shape).contiguous()
@staticmethod
def stream_mix_extract(
streams: torch.Tensor,
M: torch.Tensor,
) -> torch.Tensor:
"""Fused mix + extract: returns mixed primary stream for block input.
Args:
streams: (2, B, T, d) bf16
M: (2, 2) fp32 doubly-stochastic matrix
Returns:
mixed: (B, T, d) bf16 -- the primary stream after mixing
"""
n = streams.shape[0]
if n == 2:
s0 = streams[0].contiguous()
s1 = streams[1].contiguous()
N = s0.numel()
out = torch.empty_like(s0)
m00 = M[0, 0].item()
m01 = M[0, 1].item()
grid = _triton_grid(N, MHCFusedOps.BLOCK_SIZE)
_mhc_mix_extract_kernel[grid](
s0, s1, out, m00, m01,
N=N, BLOCK=MHCFusedOps.BLOCK_SIZE,
)
return out
# General fallback (promote to fp32 for einsum, cast back)
orig_dtype = streams.dtype
return torch.einsum("ij,jbtd->ibtd", M.float(), streams.float())[0].to(orig_dtype)
@staticmethod
def stream_inject(
streams: torch.Tensor,
block_output: torch.Tensor,
M: torch.Tensor,
) -> torch.Tensor:
"""Fused inject: out_i = streams_i + M[0,i] * block_output.
Args:
streams: (2, B, T, d) bf16
block_output: (B, T, d) bf16
M: (2, 2) fp32 doubly-stochastic matrix
Returns:
new_streams: (2, B, T, d) bf16
"""
n = streams.shape[0]
if n == 2:
s0 = streams[0].contiguous()
s1 = streams[1].contiguous()
bo = block_output.contiguous()
N = s0.numel()
out0 = torch.empty_like(s0)
out1 = torch.empty_like(s1)
m00 = M[0, 0].item()
m01 = M[0, 1].item()
grid = _triton_grid(N, MHCFusedOps.BLOCK_SIZE)
_mhc_inject_kernel[grid](
s0, s1, bo, out0, out1, m00, m01,
N=N, BLOCK=MHCFusedOps.BLOCK_SIZE,
)
return torch.stack([out0, out1], dim=0)
# General fallback (promote to fp32 for einsum, cast back)
orig_dtype = streams.dtype
update = torch.zeros_like(streams, dtype=torch.float32)
update[0] = block_output.float()
result = streams.float() + torch.einsum("ij,jbtd->ibtd", M.t().float(), update)
return result.to(orig_dtype)
@staticmethod
def stream_merge(streams: torch.Tensor) -> torch.Tensor:
"""Weighted merge: mean across streams -> (B, T, d).
Args:
streams: (n_streams, B, T, d) bf16
Returns:
merged: (B, T, d) bf16
"""
n = streams.shape[0]
if n == 2:
s0 = streams[0].contiguous()
s1 = streams[1].contiguous()
N = s0.numel()
out = torch.empty_like(s0)
grid = _triton_grid(N, MHCFusedOps.BLOCK_SIZE)
_mhc_merge_kernel[grid](
s0, s1, out,
N=N, BLOCK=MHCFusedOps.BLOCK_SIZE,
)
return out
return streams.mean(dim=0)
def mhc_fused_forward(
streams: torch.Tensor,
M: torch.Tensor,
block_fn,
stream_norm,
) -> torch.Tensor:
"""Full fused mHC forward pass (excluding init).
Equivalent to ManifoldHyperConnection.forward() from mhc_mini.py.
Args:
streams: (n_streams, B, T, d) bf16
M: (n_streams, n_streams) fp32 doubly-stochastic matrix
block_fn: callable (B,T,d) -> (B,T,d)
stream_norm: nn.LayerNorm(d)
Returns:
new_streams: (n_streams, B, T, d) bf16
"""
mixed = MHCFusedOps.stream_mix_extract(streams, M)
primary_input = stream_norm(mixed)
block_output = block_fn(primary_input)
return MHCFusedOps.stream_inject(streams, block_output, M)
# ============================================================================
# Smoke test: compare fused ops vs mhc_mini reference
# ============================================================================
if __name__ == "__main__":
import sys
import os
# Add project root to path for imports
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)
from subsystems.mhc_mini import ManifoldHyperConnection
torch.manual_seed(42)
device = "cuda"
dtype = torch.bfloat16
B, T, d = 2, 128, 96
n_streams = 2
# Reference module (bf16 weights to match bf16 data)
ref = ManifoldHyperConnection(d_model=d, n_streams=n_streams, sinkhorn_iters=5).to(device=device, dtype=dtype)
# Input
x = torch.randn(B, T, d, device=device, dtype=dtype)
# Init streams (both paths)
streams_ref = ref.init_streams(x)
streams_fused = MHCFusedOps.stream_init(x, n_streams)
assert torch.allclose(streams_ref, streams_fused, atol=0.0), "stream_init mismatch"
print("[PASS] stream_init")
# Compute doubly-stochastic matrix
M = ref._sinkhorn(ref.log_alpha)
# Test mix+extract
mixed_fused = MHCFusedOps.stream_mix_extract(streams_ref, M)
# Reference: M[0,0]*s0 + M[0,1]*s1
mixed_ref = M[0, 0] * streams_ref[0] + M[0, 1] * streams_ref[1]
max_err = (mixed_fused.float() - mixed_ref.float()).abs().max().item()
print(f"[PASS] stream_mix_extract (max_err={max_err:.2e})")
assert max_err < 1e-2, f"mix_extract error too large: {max_err}"
# Test inject
block_output = torch.randn(B, T, d, device=device, dtype=dtype)
injected_fused = MHCFusedOps.stream_inject(streams_ref, block_output, M)
out0_ref = streams_ref[0] + M[0, 0] * block_output
out1_ref = streams_ref[1] + M[0, 1] * block_output
injected_ref = torch.stack([out0_ref, out1_ref], dim=0)
max_err = (injected_fused.float() - injected_ref.float()).abs().max().item()
print(f"[PASS] stream_inject (max_err={max_err:.2e})")
assert max_err < 1e-2, f"inject error too large: {max_err}"
# Test merge
merged_fused = MHCFusedOps.stream_merge(streams_ref)
merged_ref = ref.merge_streams(streams_ref)
max_err = (merged_fused.float() - merged_ref.float()).abs().max().item()
print(f"[PASS] stream_merge (max_err={max_err:.2e})")
assert max_err < 1e-2, f"merge error too large: {max_err}"
# Full forward comparison
def dummy_block(x):
return x * 0.5 + 0.1
streams_for_ref = ref.init_streams(x)
streams_for_fused = MHCFusedOps.stream_init(x, n_streams)
# Reference forward -- cast streams to float to match M dtype (fp32)
# then cast back, mirroring what actually happens in train.py where
# streams are bf16 and M is computed in fp32.
# The reference mhc_mini.py has a latent type promotion issue: M is fp32,
# streams are bf16, so mixed becomes fp32. LayerNorm then fails on fp32
# when weights are bf16. We test the fused path directly instead.
out_fused = mhc_fused_forward(
streams_for_fused, M, dummy_block, ref.stream_norms[0],
)
# Manual reference: reproduce the n_streams=2 path from mhc_mini
M_ref = ref._sinkhorn(ref.log_alpha)
mixed_ref = (M_ref[0, 0] * streams_for_ref[0].float() + M_ref[0, 1] * streams_for_ref[1].float()).to(dtype)
primary_ref = ref.stream_norms[0](mixed_ref)
block_out_ref = dummy_block(primary_ref)
out0_ref = streams_for_ref[0].float() + M_ref[0, 0] * block_out_ref.float()
out1_ref = streams_for_ref[1].float() + M_ref[0, 1] * block_out_ref.float()
out_ref = torch.stack([out0_ref.to(dtype), out1_ref.to(dtype)], dim=0)
max_err = (out_fused.float() - out_ref.float()).abs().max().item()
print(f"[PASS] full forward (max_err={max_err:.2e})")
assert max_err < 5e-2, f"full forward error too large: {max_err}"
# Verify n_streams != 2 fallback works
ref4 = ManifoldHyperConnection(d_model=d, n_streams=4, sinkhorn_iters=5).to(device)
x4 = torch.randn(B, T, d, device=device, dtype=dtype)
s4 = MHCFusedOps.stream_init(x4, 4)
M4 = ref4._sinkhorn(ref4.log_alpha)
mixed4 = MHCFusedOps.stream_mix_extract(s4, M4)
merged4 = MHCFusedOps.stream_merge(s4)
print("[PASS] n_streams=4 fallback (torch ops)")
print("\n=== All mHC kernel smoke tests PASSED ===")