sem-v6-training / src /sem_v6 /kernels /fused_mhc_kernel.py
icarus112's picture
Upload folder using huggingface_hub
518db7a verified
"""
Fused Triton kernels for mHC (manifold-constrained Hyper-Connections).
Fuses rearrange + einsum operations to reduce kernel launch overhead:
- Baseline: 9 kernels per mHC module (2 rearranges + 3 einsums + 4 softmax/sinkhorn ops)
- Fused: 3-4 kernels per mHC module (fused stream mixing + separate sinkhorn)
Performance target: 3-8x speedup for 48-layer model (96 mHC modules).
"""
import torch
import triton # type: ignore[import-untyped]
import triton.language as tl # type: ignore[import-untyped]
from typing import Optional
@triton.jit
def fused_stream_mixing_kernel( # type: ignore[no-untyped-def]
# Input pointers
x_ptr, # Residual input (batch, seq_len, num_streams, stream_dim)
transformed_ptr, # Transformed input (batch, seq_len, num_streams, stream_dim)
H_res_ptr, # Doubly stochastic matrix (num_streams, num_streams)
H_pre_ptr, # Pre-mixing matrix (num_streams, num_streams)
H_post_ptr, # Post-mixing matrix (num_streams, num_streams)
# Output pointer
output_ptr, # Output (batch, seq_len, num_streams, stream_dim)
# Dimensions
batch_size: tl.constexpr,
seq_len: tl.constexpr,
num_streams: tl.constexpr,
stream_dim: tl.constexpr,
# Block size
BLOCK_SIZE: tl.constexpr,
) -> None:
"""
Fused kernel for width_connection stream mixing.
Fuses the following operations:
1. residual_mixed = einsum(H_res, x_streams, "n m, b s n d -> b s m d")
2. pre_mixed = einsum(H_pre, transformed_streams, "n m, b s n d -> b s m d")
3. post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d")
4. output = residual_mixed + post_mixed
Each thread block processes a chunk of (batch, seq_len) positions.
"""
# Get thread block ID
pid = tl.program_id(0)
# Calculate batch and sequence position for this block
total_positions = batch_size * seq_len
num_blocks = tl.cdiv(total_positions, BLOCK_SIZE)
if pid >= num_blocks:
return
# Calculate batch and seq indices for this block
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Mask for valid positions
mask = offsets < total_positions
# Convert flat index to (batch, seq) coordinates
batch_idx = offsets // seq_len
seq_idx = offsets % seq_len
# Process each stream dimension
for s_out in range(num_streams):
for d in range(stream_dim):
# Initialize accumulator for this output position
residual_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
post_acc = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
# Step 1: Compute residual_mixed[b, s, s_out, d]
# = sum over s_in: H_res[s_out, s_in] * x[b, s, s_in, d]
for s_in in range(num_streams):
# Load H_res[s_out, s_in]
h_res_val = tl.load(
H_res_ptr + s_out * num_streams + s_in
)
# Load x[batch_idx, seq_idx, s_in, d]
x_offset = (
batch_idx * seq_len * num_streams * stream_dim
+ seq_idx * num_streams * stream_dim
+ s_in * stream_dim
+ d
)
x_val = tl.load(
x_ptr + x_offset,
mask=mask,
other=0.0
)
residual_acc += h_res_val * x_val
# Step 2 & 3: Compute post_mixed[b, s, s_out, d]
# First compute pre_mixed[b, s, s_mid, d] (intermediate)
# Then multiply by H_post[s_mid, s_out]
for s_mid in range(num_streams):
# Compute pre_mixed[b, s, s_mid, d]
# = sum over s_in: H_pre[s_mid, s_in] * transformed[b, s, s_in, d]
pre_acc = 0.0
for s_in in range(num_streams):
# Load H_pre[s_mid, s_in]
h_pre_val = tl.load(
H_pre_ptr + s_mid * num_streams + s_in
)
# Load transformed[batch_idx, seq_idx, s_in, d]
transformed_offset = (
batch_idx * seq_len * num_streams * stream_dim
+ seq_idx * num_streams * stream_dim
+ s_in * stream_dim
+ d
)
transformed_val = tl.load(
transformed_ptr + transformed_offset,
mask=mask,
other=0.0
)
pre_acc += h_pre_val * transformed_val
# Now multiply by H_post[s_mid, s_out] and accumulate
h_post_val = tl.load(
H_post_ptr + s_mid * num_streams + s_out
)
post_acc += h_post_val * pre_acc
# Step 4: Combine residual and post-mixed
output_val = residual_acc + post_acc
# Store output[batch_idx, seq_idx, s_out, d]
output_offset = (
batch_idx * seq_len * num_streams * stream_dim
+ seq_idx * num_streams * stream_dim
+ s_out * stream_dim
+ d
)
tl.store(
output_ptr + output_offset,
output_val,
mask=mask
)
def fused_width_connection_triton(
x: torch.Tensor,
transformed: torch.Tensor,
H_res: torch.Tensor,
H_pre: torch.Tensor,
H_post: torch.Tensor,
) -> torch.Tensor:
"""
Fused Triton implementation of width_connection.
Args:
x: Residual input (batch, seq_len, num_streams, stream_dim)
transformed: Transformed input (batch, seq_len, num_streams, stream_dim)
H_res: Doubly stochastic matrix (num_streams, num_streams)
H_pre: Pre-mixing matrix (num_streams, num_streams)
H_post: Post-mixing matrix (num_streams, num_streams)
Returns:
output: Mixed features (batch, seq_len, num_streams, stream_dim)
"""
batch_size, seq_len, num_streams, stream_dim = x.shape
# Validate inputs
assert x.is_cuda, "Input must be on GPU"
assert x.is_contiguous(), "Input must be contiguous"
assert transformed.shape == x.shape, "Shape mismatch"
assert H_res.shape == (num_streams, num_streams), "H_res shape mismatch"
assert H_pre.shape == (num_streams, num_streams), "H_pre shape mismatch"
assert H_post.shape == (num_streams, num_streams), "H_post shape mismatch"
# Allocate output
output = torch.empty_like(x)
# Launch kernel
total_positions = batch_size * seq_len
BLOCK_SIZE = 128
grid = (triton.cdiv(total_positions, BLOCK_SIZE),)
fused_stream_mixing_kernel[grid](
x, transformed, H_res, H_pre, H_post, output,
batch_size, seq_len, num_streams, stream_dim,
BLOCK_SIZE=BLOCK_SIZE,
)
return output
def compare_with_einops_reference(
x: torch.Tensor,
transformed: torch.Tensor,
H_res: torch.Tensor,
H_pre: torch.Tensor,
H_post: torch.Tensor,
atol: float = 1e-5,
rtol: float = 1e-4,
) -> tuple[bool, float, Optional[str]]:
"""
Compare Triton kernel output with einops reference implementation.
Args:
x: Residual input (batch, seq_len, num_streams, stream_dim)
transformed: Transformed input (batch, seq_len, num_streams, stream_dim)
H_res: Doubly stochastic matrix (num_streams, num_streams)
H_pre: Pre-mixing matrix (num_streams, num_streams)
H_post: Post-mixing matrix (num_streams, num_streams)
atol: Absolute tolerance for comparison
rtol: Relative tolerance for comparison
Returns:
matches: True if outputs match within tolerance
max_diff: Maximum absolute difference
error_msg: Error message if mismatch, None otherwise
"""
from einops import einsum # type: ignore[import-not-found]
# Triton kernel output
triton_output = fused_width_connection_triton(
x, transformed, H_res, H_pre, H_post
)
# Reference einops implementation (from hyper_connections.py)
residual_mixed = einsum(
H_res, x, "n m, b s n d -> b s m d"
)
pre_mixed = einsum(
H_pre, transformed, "n m, b s n d -> b s m d"
)
post_mixed = einsum(
H_post, pre_mixed, "m n, b s m d -> b s n d"
)
einops_output = residual_mixed + post_mixed
# Compare
matches = torch.allclose(triton_output, einops_output, atol=atol, rtol=rtol)
max_diff = (triton_output - einops_output).abs().max().item()
error_msg = None
if not matches:
mean_diff = (triton_output - einops_output).abs().mean().item()
error_msg = (
f"Triton kernel output does not match einops reference!\n"
f"Max diff: {max_diff:.6e} (atol={atol}, rtol={rtol})\n"
f"Mean diff: {mean_diff:.6e}\n"
f"Triton output range: [{triton_output.min().item():.4f}, "
f"{triton_output.max().item():.4f}]\n"
f"Einops output range: [{einops_output.min().item():.4f}, "
f"{einops_output.max().item():.4f}]"
)
return matches, max_diff, error_msg
def benchmark_kernel_speedup(
batch_size: int = 16,
seq_len: int = 128,
dim: int = 512,
num_streams: int = 8,
num_warmup: int = 10,
num_iters: int = 100,
) -> tuple[float, float, float]:
"""
Benchmark Triton kernel vs einops reference.
Args:
batch_size: Batch size
seq_len: Sequence length
dim: Hidden dimension
num_streams: Number of streams
num_warmup: Warmup iterations
num_iters: Benchmark iterations
Returns:
einops_time_ms: Average einops time (milliseconds)
triton_time_ms: Average Triton time (milliseconds)
speedup: Speedup factor (einops_time / triton_time)
"""
from einops import einsum # type: ignore[import-not-found]
import time
assert torch.cuda.is_available(), "GPU required for benchmarking"
device = torch.device('cuda')
stream_dim = dim // num_streams
# Generate random inputs
x = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device)
transformed = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device)
H_res = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)
H_pre = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)
H_post = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)
# Warmup
for _ in range(num_warmup):
_ = fused_width_connection_triton(x, transformed, H_res, H_pre, H_post)
residual_mixed = einsum(H_res, x, "n m, b s n d -> b s m d")
pre_mixed = einsum(H_pre, transformed, "n m, b s n d -> b s m d")
post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d")
_ = residual_mixed + post_mixed
torch.cuda.synchronize()
# Benchmark einops
start = time.perf_counter()
for _ in range(num_iters):
residual_mixed = einsum(H_res, x, "n m, b s n d -> b s m d")
pre_mixed = einsum(H_pre, transformed, "n m, b s n d -> b s m d")
post_mixed = einsum(H_post, pre_mixed, "m n, b s m d -> b s n d")
output_einops = residual_mixed + post_mixed
torch.cuda.synchronize()
einops_time = (time.perf_counter() - start) / num_iters * 1000 # ms
# Benchmark Triton
start = time.perf_counter()
for _ in range(num_iters):
output_triton = fused_width_connection_triton(
x, transformed, H_res, H_pre, H_post
)
torch.cuda.synchronize()
triton_time = (time.perf_counter() - start) / num_iters * 1000 # ms
speedup = einops_time / triton_time
return einops_time, triton_time, speedup
if __name__ == "__main__":
"""Quick test of fused kernel correctness."""
print("Testing fused mHC Triton kernel...")
if not torch.cuda.is_available():
print("CUDA not available, skipping test")
exit(0)
device = torch.device('cuda')
# Test configuration
batch_size = 4
seq_len = 32
num_streams = 8
stream_dim = 64
# Generate test inputs
x = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device)
transformed = torch.randn(batch_size, seq_len, num_streams, stream_dim, device=device)
H_res = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)
H_pre = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)
H_post = torch.randn(num_streams, num_streams, device=device).softmax(dim=-1)
# Compare with reference
matches, max_diff, error_msg = compare_with_einops_reference(
x, transformed, H_res, H_pre, H_post
)
if matches:
print(f"✓ Correctness test PASSED (max diff: {max_diff:.6e})")
else:
print(f"✗ Correctness test FAILED")
print(error_msg)
exit(1)
# Benchmark
print("\nBenchmarking...")
einops_time, triton_time, speedup = benchmark_kernel_speedup()
print(f"Einops time: {einops_time:.3f} ms")
print(f"Triton time: {triton_time:.3f} ms")
print(f"Speedup: {speedup:.2f}x")
print("\n✓ All tests passed!")