| """ |
| Fused Triton kernels for mHC (manifold-constrained Hyper-Connections). |
| |
| Fuses rearrange + einsum operations to reduce kernel launch overhead: |
| - Baseline: 9 kernels per mHC module (2 rearranges + 3 einsums + 4 softmax/sinkhorn ops) |
| - Fused: 3-4 kernels per mHC module (fused stream mixing + separate sinkhorn) |
| |
| Performance target: 3-8x speedup for 48-layer model (96 mHC modules). |
| """ |
|
|
| import torch |
| import triton |
| import triton.language as tl |
| from typing import Optional |
|
|
| @triton.jit |
| def fused_stream_mixing_kernel( |
| |
| x_ptr, |
| transformed_ptr, |
| H_res_ptr, |
| H_pre_ptr, |
| H_post_ptr, |
| |
| output_ptr, |
| |
| batch_size: tl.constexpr, |
| seq_len: tl.constexpr, |
| num_streams: tl.constexpr, |
| stream_dim: tl.constexpr, |
| |
| BLOCK_SIZE: tl.constexpr, |
| ) -> None: |
| """ |
| Fused kernel for width_connection stream mixing. |
| |
| Fuses the following operations: |
| 1. residual_mixed = einsum(H_res, x_streams, "n m, b s n d -> b s m d") |
| 2. pre_mixed = einsum(H_pre, transformed_streams, "n m, b s n d -> b s m d") |
| 3. post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d") |
| 4. output = residual_mixed + post_mixed |
| |
| Each thread block processes a chunk of (batch, seq_len) positions. |
| """ |
| |
| pid = tl.program_id(0) |
|
|
| |
| total_positions = batch_size * seq_len |
| num_blocks = tl.cdiv(total_positions, BLOCK_SIZE) |
|
|
| if pid >= num_blocks: |
| return |
|
|
| |
| block_start = pid * BLOCK_SIZE |
| offsets = block_start + tl.arange(0, BLOCK_SIZE) |
|
|
| |
| mask = offsets < total_positions |
|
|
| |
| batch_idx = offsets // seq_len |
| seq_idx = offsets % seq_len |
|
|
| |
| for s_out in range(num_streams): |
| for d in range(stream_dim): |
| |
| residual_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) |
| post_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) |
|
|
| |
| |
| for s_in in range(num_streams): |
| |
| h_res_val = tl.load( |
| H_res_ptr + s_out * num_streams + s_in |
| ) |
|
|
| |
| x_offset = ( |
| batch_idx * seq_len * num_streams * stream_dim |
| + seq_idx * num_streams * stream_dim |
| + s_in * stream_dim |
| + d |
| ) |
| x_val = tl.load( |
| x_ptr + x_offset, |
| mask=mask, |
| other=0.0 |
| ) |
|
|
| residual_acc += h_res_val * x_val |
|
|
| |
| |
| |
|
|
| for s_mid in range(num_streams): |
| |
| |
| pre_acc = 0.0 |
| for s_in in range(num_streams): |
| |
| h_pre_val = tl.load( |
| H_pre_ptr + s_mid * num_streams + s_in |
| ) |
|
|
| |
| transformed_offset = ( |
| batch_idx * seq_len * num_streams * stream_dim |
| + seq_idx * num_streams * stream_dim |
| + s_in * stream_dim |
| + d |
| ) |
| transformed_val = tl.load( |
| transformed_ptr + transformed_offset, |
| mask=mask, |
| other=0.0 |
| ) |
|
|
| pre_acc += h_pre_val * transformed_val |
|
|
| |
| h_post_val = tl.load( |
| H_post_ptr + s_mid * num_streams + s_out |
| ) |
| post_acc += h_post_val * pre_acc |
|
|
| |
| output_val = residual_acc + post_acc |
|
|
| |
| output_offset = ( |
| batch_idx * seq_len * num_streams * stream_dim |
| + seq_idx * num_streams * stream_dim |
| + s_out * stream_dim |
| + d |
| ) |
| tl.store( |
| output_ptr + output_offset, |
| output_val, |
| mask=mask |
| ) |
|
|
| def fused_width_connection_triton( |
| x: torch.Tensor, |
| transformed: torch.Tensor, |
| H_res: torch.Tensor, |
| H_pre: torch.Tensor, |
| H_post: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Fused Triton implementation of width_connection. |
| |
| Args: |
| x: Residual input (batch, seq_len, num_streams, stream_dim) |
| transformed: Transformed input (batch, seq_len, num_streams, stream_dim) |
| H_res: Doubly stochastic matrix (num_streams, num_streams) |
| H_pre: Pre-mixing matrix (num_streams, num_streams) |
| H_post: Post-mixing matrix (num_streams, num_streams) |
| |
| Returns: |
| output: Mixed features (batch, seq_len, num_streams, stream_dim) |
| """ |
| batch_size, seq_len, num_streams, stream_dim = x.shape |
|
|
| |
| assert x.is_cuda, "Input must be on GPU" |
| assert x.is_contiguous(), "Input must be contiguous" |
| assert transformed.shape == x.shape, "Shape mismatch" |
| assert H_res.shape == (num_streams, num_streams), "H_res shape mismatch" |
| assert H_pre.shape == (num_streams, num_streams), "H_pre shape mismatch" |
| assert H_post.shape == (num_streams, num_streams), "H_post shape mismatch" |
|
|
| |
| output = torch.empty_like(x) |
|
|
| |
| total_positions = batch_size * seq_len |
| BLOCK_SIZE = 128 |
| grid = (triton.cdiv(total_positions, BLOCK_SIZE),) |
|
|
| fused_stream_mixing_kernel[grid]( |
| x, transformed, H_res, H_pre, H_post, output, |
| batch_size, seq_len, num_streams, stream_dim, |
| BLOCK_SIZE=BLOCK_SIZE, |
| ) |
|
|
| return output |
|
|
| def compare_with_einops_reference( |
| x: torch.Tensor, |
| transformed: torch.Tensor, |
| H_res: torch.Tensor, |
| H_pre: torch.Tensor, |
| H_post: torch.Tensor, |
| atol: float = 1e-5, |
| rtol: float = 1e-4, |
| ) -> tuple[bool, float, Optional[str]]: |
| """ |
| Compare Triton kernel output with einops reference implementation. |
| |
| Args: |
| x: Residual input (batch, seq_len, num_streams, stream_dim) |
| transformed: Transformed input (batch, seq_len, num_streams, stream_dim) |
| H_res: Doubly stochastic matrix (num_streams, num_streams) |
| H_pre: Pre-mixing matrix (num_streams, num_streams) |
| H_post: Post-mixing matrix (num_streams, num_streams) |
| atol: Absolute tolerance for comparison |
| rtol: Relative tolerance for comparison |
| |
| Returns: |
| matches: True if outputs match within tolerance |
| max_diff: Maximum absolute difference |
| error_msg: Error message if mismatch, None otherwise |
| """ |
| from einops import einsum |
|
|
| |
| triton_output = fused_width_connection_triton( |
| x, transformed, H_res, H_pre, H_post |
| ) |
|
|
| |
| residual_mixed = einsum( |
| H_res, x, "n m, b s n d -> b s m d" |
| ) |
| pre_mixed = einsum( |
| H_pre, transformed, "n m, b s n d -> b s m d" |
| ) |
| post_mixed = einsum( |
| H_post, pre_mixed, "m n, b s m d -> b s n d" |
| ) |
| einops_output = residual_mixed + post_mixed |
|
|
| |
| matches = torch.allclose(triton_output, einops_output, atol=atol, rtol=rtol) |
| max_diff = (triton_output - einops_output).abs().max().item() |
|
|
| error_msg = None |
| if not matches: |
| mean_diff = (triton_output - einops_output).abs().mean().item() |
| error_msg = ( |
| f"Triton kernel output does not match einops reference!\n" |
| f"Max diff: {max_diff:.6e} (atol={atol}, rtol={rtol})\n" |
| f"Mean diff: {mean_diff:.6e}\n" |
| f"Triton output range: [{triton_output.min().item():.4f}, " |
| f"{triton_output.max().item():.4f}]\n" |
| f"Einops output range: [{einops_output.min().item():.4f}, " |
| f"{einops_output.max().item():.4f}]" |
| ) |
|
|
| return matches, max_diff, error_msg |
|
|
| def benchmark_kernel_speedup( |
| batch_size: int = 16, |
| seq_len: int = 128, |
| dim: int = 512, |
| num_streams: int = 8, |
| num_warmup: int = 10, |
| num_iters: int = 100, |
| ) -> tuple[float, float, float]: |
| """ |
| Benchmark Triton kernel vs einops reference. |
| |
| Args: |
| batch_size: Batch size |
| seq_len: Sequence length |
| dim: Hidden dimension |
| num_streams: Number of streams |
| num_warmup: Warmup iterations |
| num_iters: Benchmark iterations |
| |
| Returns: |
| einops_time_ms: Average einops time (milliseconds) |
| triton_time_ms: Average Triton time (milliseconds) |
| speedup: Speedup factor (einops_time / triton_time) |
| """ |
| from einops import einsum |
| import time |
|
|
| assert torch.cuda.is_available(), "GPU required for benchmarking" |
| device = torch.device('cuda') |
|
|
| stream_dim = dim // num_streams |
|
|
| |
| x = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device) |
| transformed = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device) |
| H_res = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) |
| H_pre = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) |
| H_post = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) |
|
|
| |
| for _ in range(num_warmup): |
| _ = fused_width_connection_triton(x, transformed, H_res, H_pre, H_post) |
| residual_mixed = einsum(H_res, x, "n m, b s n d -> b s m d") |
| pre_mixed = einsum(H_pre, transformed, "n m, b s n d -> b s m d") |
| post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d") |
| _ = residual_mixed + post_mixed |
|
|
| torch.cuda.synchronize() |
|
|
| |
| start = time.perf_counter() |
| for _ in range(num_iters): |
| residual_mixed = einsum(H_res, x, "n m, b s n d -> b s m d") |
| pre_mixed = einsum(H_pre, transformed, "n m, b s n d -> b s m d") |
| post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d") |
| output_einops = residual_mixed + post_mixed |
| torch.cuda.synchronize() |
| einops_time = (time.perf_counter() - start) / num_iters * 1000 |
|
|
| |
| start = time.perf_counter() |
| for _ in range(num_iters): |
| output_triton = fused_width_connection_triton( |
| x, transformed, H_res, H_pre, H_post |
| ) |
| torch.cuda.synchronize() |
| triton_time = (time.perf_counter() - start) / num_iters * 1000 |
|
|
| speedup = einops_time / triton_time |
|
|
| return einops_time, triton_time, speedup |
|
|
| if __name__ == "__main__": |
| """Quick test of fused kernel correctness.""" |
| print("Testing fused mHC Triton kernel...") |
|
|
| if not torch.cuda.is_available(): |
| print("CUDA not available, skipping test") |
| exit(0) |
|
|
| device = torch.device('cuda') |
|
|
| |
| batch_size = 4 |
| seq_len = 32 |
| num_streams = 8 |
| stream_dim = 64 |
|
|
| |
| x = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device) |
| transformed = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device) |
| H_res = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) |
| H_pre = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) |
| H_post = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) |
|
|
| |
| matches, max_diff, error_msg = compare_with_einops_reference( |
| x, transformed, H_res, H_pre, H_post |
| ) |
|
|
| if matches: |
| print(f"✓ Correctness test PASSED (max diff: {max_diff:.6e})") |
| else: |
| print(f"✗ Correctness test FAILED") |
| print(error_msg) |
| exit(1) |
|
|
| |
| print("\nBenchmarking...") |
| einops_time, triton_time, speedup = benchmark_kernel_speedup() |
| print(f"Einops time: {einops_time:.3f} ms") |
| print(f"Triton time: {triton_time:.3f} ms") |
| print(f"Speedup: {speedup:.2f}x") |
|
|
| print("\n✓ All tests passed!") |
|
|