""" Fused Triton kernels for mHC (manifold-constrained Hyper-Connections). Fuses rearrange + einsum operations to reduce kernel launch overhead: - Baseline: 9 kernels per mHC module (2 rearranges + 3 einsums + 4 softmax/sinkhorn ops) - Fused: 3-4 kernels per mHC module (fused stream mixing + separate sinkhorn) Performance target: 3-8x speedup for 48-layer model (96 mHC modules). """ import torch import triton # type: ignore[import-untyped] import triton.language as tl # type: ignore[import-untyped] from typing import Optional @triton.jit def fused_stream_mixing_kernel( # type: ignore[no-untyped-def] # Input pointers x_ptr, # Residual input (batch, seq_len, num_streams, stream_dim) transformed_ptr, # Transformed input (batch, seq_len, num_streams, stream_dim) H_res_ptr, # Doubly stochastic matrix (num_streams, num_streams) H_pre_ptr, # Pre-mixing matrix (num_streams, num_streams) H_post_ptr, # Post-mixing matrix (num_streams, num_streams) # Output pointer output_ptr, # Output (batch, seq_len, num_streams, stream_dim) # Dimensions batch_size: tl.constexpr, seq_len: tl.constexpr, num_streams: tl.constexpr, stream_dim: tl.constexpr, # Block size BLOCK_SIZE: tl.constexpr, ) -> None: """ Fused kernel for width_connection stream mixing. Fuses the following operations: 1. residual_mixed = einsum(H_res, x_streams, "n m, b s n d -> b s m d") 2. pre_mixed = einsum(H_pre, transformed_streams, "n m, b s n d -> b s m d") 3. post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d") 4. output = residual_mixed + post_mixed Each thread block processes a chunk of (batch, seq_len) positions. """ # Get thread block ID pid = tl.program_id(0) # Calculate batch and sequence position for this block total_positions = batch_size * seq_len num_blocks = tl.cdiv(total_positions, BLOCK_SIZE) if pid >= num_blocks: return # Calculate batch and seq indices for this block block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) # Mask for valid positions mask = offsets < total_positions # Convert flat index to (batch, seq) coordinates batch_idx = offsets // seq_len seq_idx = offsets % seq_len # Process each stream dimension for s_out in range(num_streams): for d in range(stream_dim): # Initialize accumulator for this output position residual_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) post_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32) # Step 1: Compute residual_mixed[b, s, s_out, d] # = sum over s_in: H_res[s_out, s_in] * x[b, s, s_in, d] for s_in in range(num_streams): # Load H_res[s_out, s_in] h_res_val = tl.load( H_res_ptr + s_out * num_streams + s_in ) # Load x[batch_idx, seq_idx, s_in, d] x_offset = ( batch_idx * seq_len * num_streams * stream_dim + seq_idx * num_streams * stream_dim + s_in * stream_dim + d ) x_val = tl.load( x_ptr + x_offset, mask=mask, other=0.0 ) residual_acc += h_res_val * x_val # Step 2 & 3: Compute post_mixed[b, s, s_out, d] # First compute pre_mixed[b, s, s_mid, d] (intermediate) # Then multiply by H_post[s_mid, s_out] for s_mid in range(num_streams): # Compute pre_mixed[b, s, s_mid, d] # = sum over s_in: H_pre[s_mid, s_in] * transformed[b, s, s_in, d] pre_acc = 0.0 for s_in in range(num_streams): # Load H_pre[s_mid, s_in] h_pre_val = tl.load( H_pre_ptr + s_mid * num_streams + s_in ) # Load transformed[batch_idx, seq_idx, s_in, d] transformed_offset = ( batch_idx * seq_len * num_streams * stream_dim + seq_idx * num_streams * stream_dim + s_in * stream_dim + d ) transformed_val = tl.load( transformed_ptr + transformed_offset, mask=mask, other=0.0 ) pre_acc += h_pre_val * transformed_val # Now multiply by H_post[s_mid, s_out] and accumulate h_post_val = tl.load( H_post_ptr + s_mid * num_streams + s_out ) post_acc += h_post_val * pre_acc # Step 4: Combine residual and post-mixed output_val = residual_acc + post_acc # Store output[batch_idx, seq_idx, s_out, d] output_offset = ( batch_idx * seq_len * num_streams * stream_dim + seq_idx * num_streams * stream_dim + s_out * stream_dim + d ) tl.store( output_ptr + output_offset, output_val, mask=mask ) def fused_width_connection_triton( x: torch.Tensor, transformed: torch.Tensor, H_res: torch.Tensor, H_pre: torch.Tensor, H_post: torch.Tensor, ) -> torch.Tensor: """ Fused Triton implementation of width_connection. Args: x: Residual input (batch, seq_len, num_streams, stream_dim) transformed: Transformed input (batch, seq_len, num_streams, stream_dim) H_res: Doubly stochastic matrix (num_streams, num_streams) H_pre: Pre-mixing matrix (num_streams, num_streams) H_post: Post-mixing matrix (num_streams, num_streams) Returns: output: Mixed features (batch, seq_len, num_streams, stream_dim) """ batch_size, seq_len, num_streams, stream_dim = x.shape # Validate inputs assert x.is_cuda, "Input must be on GPU" assert x.is_contiguous(), "Input must be contiguous" assert transformed.shape == x.shape, "Shape mismatch" assert H_res.shape == (num_streams, num_streams), "H_res shape mismatch" assert H_pre.shape == (num_streams, num_streams), "H_pre shape mismatch" assert H_post.shape == (num_streams, num_streams), "H_post shape mismatch" # Allocate output output = torch.empty_like(x) # Launch kernel total_positions = batch_size * seq_len BLOCK_SIZE = 128 grid = (triton.cdiv(total_positions, BLOCK_SIZE),) fused_stream_mixing_kernel[grid]( x, transformed, H_res, H_pre, H_post, output, batch_size, seq_len, num_streams, stream_dim, BLOCK_SIZE=BLOCK_SIZE, ) return output def compare_with_einops_reference( x: torch.Tensor, transformed: torch.Tensor, H_res: torch.Tensor, H_pre: torch.Tensor, H_post: torch.Tensor, atol: float = 1e-5, rtol: float = 1e-4, ) -> tuple[bool, float, Optional[str]]: """ Compare Triton kernel output with einops reference implementation. Args: x: Residual input (batch, seq_len, num_streams, stream_dim) transformed: Transformed input (batch, seq_len, num_streams, stream_dim) H_res: Doubly stochastic matrix (num_streams, num_streams) H_pre: Pre-mixing matrix (num_streams, num_streams) H_post: Post-mixing matrix (num_streams, num_streams) atol: Absolute tolerance for comparison rtol: Relative tolerance for comparison Returns: matches: True if outputs match within tolerance max_diff: Maximum absolute difference error_msg: Error message if mismatch, None otherwise """ from einops import einsum # type: ignore[import-not-found] # Triton kernel output triton_output = fused_width_connection_triton( x, transformed, H_res, H_pre, H_post ) # Reference einops implementation (from hyper_connections.py) residual_mixed = einsum( H_res, x, "n m, b s n d -> b s m d" ) pre_mixed = einsum( H_pre, transformed, "n m, b s n d -> b s m d" ) post_mixed = einsum( H_post, pre_mixed, "m n, b s m d -> b s n d" ) einops_output = residual_mixed + post_mixed # Compare matches = torch.allclose(triton_output, einops_output, atol=atol, rtol=rtol) max_diff = (triton_output - einops_output).abs().max().item() error_msg = None if not matches: mean_diff = (triton_output - einops_output).abs().mean().item() error_msg = ( f"Triton kernel output does not match einops reference!\n" f"Max diff: {max_diff:.6e} (atol={atol}, rtol={rtol})\n" f"Mean diff: {mean_diff:.6e}\n" f"Triton output range: [{triton_output.min().item():.4f}, " f"{triton_output.max().item():.4f}]\n" f"Einops output range: [{einops_output.min().item():.4f}, " f"{einops_output.max().item():.4f}]" ) return matches, max_diff, error_msg def benchmark_kernel_speedup( batch_size: int = 16, seq_len: int = 128, dim: int = 512, num_streams: int = 8, num_warmup: int = 10, num_iters: int = 100, ) -> tuple[float, float, float]: """ Benchmark Triton kernel vs einops reference. Args: batch_size: Batch size seq_len: Sequence length dim: Hidden dimension num_streams: Number of streams num_warmup: Warmup iterations num_iters: Benchmark iterations Returns: einops_time_ms: Average einops time (milliseconds) triton_time_ms: Average Triton time (milliseconds) speedup: Speedup factor (einops_time / triton_time) """ from einops import einsum # type: ignore[import-not-found] import time assert torch.cuda.is_available(), "GPU required for benchmarking" device = torch.device('cuda') stream_dim = dim // num_streams # Generate random inputs x = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device) transformed = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device) H_res = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) H_pre = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) H_post = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) # Warmup for _ in range(num_warmup): _ = fused_width_connection_triton(x, transformed, H_res, H_pre, H_post) residual_mixed = einsum(H_res, x, "n m, b s n d -> b s m d") pre_mixed = einsum(H_pre, transformed, "n m, b s n d -> b s m d") post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d") _ = residual_mixed + post_mixed torch.cuda.synchronize() # Benchmark einops start = time.perf_counter() for _ in range(num_iters): residual_mixed = einsum(H_res, x, "n m, b s n d -> b s m d") pre_mixed = einsum(H_pre, transformed, "n m, b s n d -> b s m d") post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d") output_einops = residual_mixed + post_mixed torch.cuda.synchronize() einops_time = (time.perf_counter() - start) / num_iters * 1000 # ms # Benchmark Triton start = time.perf_counter() for _ in range(num_iters): output_triton = fused_width_connection_triton( x, transformed, H_res, H_pre, H_post ) torch.cuda.synchronize() triton_time = (time.perf_counter() - start) / num_iters * 1000 # ms speedup = einops_time / triton_time return einops_time, triton_time, speedup if __name__ == "__main__": """Quick test of fused kernel correctness.""" print("Testing fused mHC Triton kernel...") if not torch.cuda.is_available(): print("CUDA not available, skipping test") exit(0) device = torch.device('cuda') # Test configuration batch_size = 4 seq_len = 32 num_streams = 8 stream_dim = 64 # Generate test inputs x = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device) transformed = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device) H_res = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) H_pre = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) H_post = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1) # Compare with reference matches, max_diff, error_msg = compare_with_einops_reference( x, transformed, H_res, H_pre, H_post ) if matches: print(f"āœ“ Correctness test PASSED (max diff: {max_diff:.6e})") else: print(f"āœ— Correctness test FAILED") print(error_msg) exit(1) # Benchmark print("\nBenchmarking...") einops_time, triton_time, speedup = benchmark_kernel_speedup() print(f"Einops time: {einops_time:.3f} ms") print(f"Triton time: {triton_time:.3f} ms") print(f"Speedup: {speedup:.2f}x") print("\nāœ“ All tests passed!")