StyleForge / kernels /attention_v3_wrapper.py
Olivia
Add CUDA kernels and backend comparison
3386f25
"""
StyleForge - Fused Attention V3 Python Wrapper
V3 uses register-based accumulation (no shared memory for V).
Educational kernel - still slower than Flash Attention 2 due to
fundamental limitations (element-wise matmul vs tensor cores).
"""
import torch
import torch.nn as nn
from pathlib import Path
from typing import Optional
from utils import compile_inline
_attention_v3_module = None
def get_attention_v3_module():
global _attention_v3_module
if _attention_v3_module is not None:
return _attention_v3_module
kernel_path = Path(__file__).parent / "attention_v3.cu"
if not kernel_path.exists():
raise FileNotFoundError(f"V3 kernel not found at {kernel_path}")
cuda_source = kernel_path.read_text()
print("Compiling fused attention V3 kernel (register-based)...")
_attention_v3_module = compile_inline(
name='fused_attention_v3',
cuda_source=cuda_source,
functions=['fused_attention_v3'],
build_directory=Path('build_v3'),
verbose=False
)
print("V3 Compilation complete!")
return _attention_v3_module
class FusedAttentionV3Function(torch.autograd.Function):
MAX_SEQ_LEN = 4096 # Conservative limit
MAX_HEAD_DIM = 128
@staticmethod
def forward(
ctx,
x: torch.Tensor,
w_qkv: torch.Tensor,
w_out: torch.Tensor,
bias_qkv: Optional[torch.Tensor],
bias_out: Optional[torch.Tensor],
num_heads: int,
scale: float
) -> torch.Tensor:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
batch_size = x.size(0)
seq_len = x.size(1)
embed_dim = x.size(2)
head_dim = embed_dim // num_heads
if seq_len > FusedAttentionV3Function.MAX_SEQ_LEN:
raise ValueError(f"seq_len {seq_len} exceeds MAX_SEQ_LEN {FusedAttentionV3Function.MAX_SEQ_LEN}")
module = get_attention_v3_module()
ctx.save_for_backward(x, w_qkv, w_out, bias_qkv, bias_out)
ctx.num_heads = num_heads
ctx.scale = scale
ctx.embed_dim = embed_dim
output = module.fused_attention_v3(
x.contiguous(),
w_qkv.contiguous(),
w_out.contiguous(),
bias_qkv,
bias_out,
scale,
num_heads
)
return output
@staticmethod
def backward(ctx, grad_output):
# No autograd support
return None, None, None, None, None, None, None
class FusedAttentionV3(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int = 4,
dropout: float = 0.0,
bias: bool = True
):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.w_qkv = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
self.bias_qkv = nn.Parameter(torch.empty(3 * embed_dim)) if bias else None
self.w_out = nn.Parameter(torch.empty(embed_dim, embed_dim))
self.bias_out = nn.Parameter(torch.empty(embed_dim)) if bias else None
self._reset_parameters()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.w_qkv)
nn.init.xavier_uniform_(self.w_out)
if self.bias_qkv is not None:
nn.init.zeros_(self.bias_qkv)
if self.bias_out is not None:
nn.init.zeros_(self.bias_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return FusedAttentionV3Function.apply(
x,
self.w_qkv,
self.w_out,
self.bias_qkv,
self.bias_out,
self.num_heads,
self.scale
)