StyleForge / kernels /ffn_wrapper.py
Olivia
Add CUDA kernels and backend comparison
3386f25
"""
StyleForge - Fused Feed-Forward Network Wrapper
Python interface for the fused FFN CUDA kernel.
Fuses: Linear β†’ GELU β†’ Linear β†’ Bias β†’ Residual
Performance Target: 4-5x speedup over PyTorch sequential
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Optional
from utils import compile_inline
# Global module cache
_ffn_module = None
def get_ffn_module():
"""Lazy-load and compile the FFN kernel."""
global _ffn_module
if _ffn_module is not None:
return _ffn_module
kernel_path = Path(__file__).parent / "ffn.cu"
if not kernel_path.exists():
raise FileNotFoundError(f"FFN kernel not found at {kernel_path}")
cuda_source = kernel_path.read_text()
print("Compiling fused FFN kernel...")
_ffn_module = compile_inline(
name='fused_ffn',
cuda_source=cuda_source,
functions=['forward'],
build_directory=Path('build'),
verbose=False
)
print("FFN compilation complete!")
return _ffn_module
class FusedFFN(nn.Module):
"""
Fused Feed-Forward Network Module
Fuses the entire FFN block into a single kernel:
Linear(embed_dim, ffn_dim) β†’ GELU β†’ Linear(ffn_dim, embed_dim) + Residual
Args:
embed_dim: Input/output embedding dimension
ffn_dim: Hidden dimension of FFN (typically 4x embed_dim)
dropout: Dropout probability (not used in V1)
bias: Use bias in linear layers
Example:
>>> ffn = FusedFFN(embed_dim=128, ffn_dim=512).cuda()
>>> x = torch.randn(2, 256, 128).cuda()
>>> y = ffn(x)
>>> print(y.shape) # [2, 256, 128]
"""
def __init__(
self,
embed_dim: int = 128,
ffn_dim: int = 512,
dropout: float = 0.0,
bias: bool = True
):
super().__init__()
self.embed_dim = embed_dim
self.ffn_dim = ffn_dim
# FC1: embed_dim β†’ ffn_dim
self.fc1_weight = nn.Parameter(torch.empty(embed_dim, ffn_dim))
self.fc1_bias = nn.Parameter(torch.empty(ffn_dim)) if bias else None
# FC2: ffn_dim β†’ embed_dim
self.fc2_weight = nn.Parameter(torch.empty(ffn_dim, embed_dim))
self.fc2_bias = nn.Parameter(torch.empty(embed_dim)) if bias else None
self.dropout = nn.Dropout(dropout)
self._reset_parameters()
def _reset_parameters(self):
"""Initialize parameters using Xavier uniform"""
nn.init.xavier_uniform_(self.fc1_weight)
nn.init.xavier_uniform_(self.fc2_weight)
if self.fc1_bias is not None:
nn.init.zeros_(self.fc1_bias)
if self.fc2_bias is not None:
nn.init.zeros_(self.fc2_bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass with fused FFN kernel.
Args:
x: Input tensor [batch, seq_len, embed_dim]
Returns:
Output tensor [batch, seq_len, embed_dim]
"""
module = get_ffn_module()
# Transpose weights for kernel layout [out, in] β†’ [in, out]
w1_t = self.fc1_weight.T.contiguous()
w2_t = self.fc2_weight.T.contiguous()
# Create zero biases if not used
b1 = self.fc1_bias if self.fc1_bias is not None else torch.zeros(
self.ffn_dim, device=x.device
)
b2 = self.fc2_bias if self.fc2_bias is not None else torch.zeros(
self.embed_dim, device=x.device
)
with torch.cuda.nvtx.range("fused_ffn_forward"):
output = module.forward(
x.contiguous(),
w1_t,
b1,
w2_t,
b2,
False # use_vectorized - set to False for stability
)
# Apply dropout if training
if self.training and self.dropout.p > 0:
output = self.dropout(output)
return output
def extra_repr(self) -> str:
return f'embed_dim={self.embed_dim}, ffn_dim={self.ffn_dim}'
def benchmark_ffn_vs_pytorch(
batch_size: int = 2,
seq_len: int = 256,
embed_dim: int = 128,
ffn_dim: int = 512,
iterations: int = 100
):
"""
Benchmark fused FFN against PyTorch sequential.
Returns:
Dictionary with benchmark results
"""
import numpy as np
print(f"\nBenchmarking FFN ({batch_size}x{seq_len}x{embed_dim})...")
print("=" * 70)
x = torch.randn(batch_size, seq_len, embed_dim, device='cuda')
results = {}
# ----------------------------------------
# PyTorch Baseline
# ----------------------------------------
print("\n1. PyTorch Sequential FFN...")
ffn_pytorch = nn.Sequential(
nn.Linear(embed_dim, ffn_dim),
nn.GELU(),
nn.Linear(ffn_dim, embed_dim)
).cuda().eval()
times = []
for _ in range(10):
with torch.no_grad():
_ = ffn_pytorch(x)
torch.cuda.synchronize()
for _ in range(iterations):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
_ = ffn_pytorch(x)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
results['pytorch'] = {
'mean_ms': np.mean(times),
'std_ms': np.std(times),
'name': 'PyTorch Sequential'
}
print(f" {results['pytorch']['mean_ms']:.2f} Β± {results['pytorch']['std_ms']:.2f} ms")
# ----------------------------------------
# Fused FFN
# ----------------------------------------
print("\n2. Fused FFN Kernel...")
ffn_fused = FusedFFN(embed_dim, ffn_dim).cuda().eval()
times = []
for _ in range(10):
with torch.no_grad():
_ = ffn_fused(x)
torch.cuda.synchronize()
for _ in range(iterations):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
_ = ffn_fused(x)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
results['fused'] = {
'mean_ms': np.mean(times),
'std_ms': np.std(times),
'name': 'Fused FFN'
}
print(f" {results['fused']['mean_ms']:.2f} Β± {results['fused']['std_ms']:.2f} ms")
# ----------------------------------------
# Summary
# ----------------------------------------
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
baseline = results['pytorch']['mean_ms']
fused_time = results['fused']['mean_ms']
print(f"\nPyTorch: {baseline:.2f} ms")
print(f"Fused: {fused_time:.2f} ms")
print(f"\nπŸš€ Fused FFN is {baseline/fused_time:.2f}x faster than PyTorch!")
return results
if __name__ == "__main__":
# Run benchmark if executed directly
benchmark_ffn_vs_pytorch()