Spaces:
Sleeping
Sleeping
File size: 3,844 Bytes
3386f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
"""
StyleForge - Fused Attention V3 Python Wrapper
V3 uses register-based accumulation (no shared memory for V).
Educational kernel - still slower than Flash Attention 2 due to
fundamental limitations (element-wise matmul vs tensor cores).
"""
import torch
import torch.nn as nn
from pathlib import Path
from typing import Optional
from utils import compile_inline
_attention_v3_module = None
def get_attention_v3_module():
global _attention_v3_module
if _attention_v3_module is not None:
return _attention_v3_module
kernel_path = Path(__file__).parent / "attention_v3.cu"
if not kernel_path.exists():
raise FileNotFoundError(f"V3 kernel not found at {kernel_path}")
cuda_source = kernel_path.read_text()
print("Compiling fused attention V3 kernel (register-based)...")
_attention_v3_module = compile_inline(
name='fused_attention_v3',
cuda_source=cuda_source,
functions=['fused_attention_v3'],
build_directory=Path('build_v3'),
verbose=False
)
print("V3 Compilation complete!")
return _attention_v3_module
class FusedAttentionV3Function(torch.autograd.Function):
MAX_SEQ_LEN = 4096 # Conservative limit
MAX_HEAD_DIM = 128
@staticmethod
def forward(
ctx,
x: torch.Tensor,
w_qkv: torch.Tensor,
w_out: torch.Tensor,
bias_qkv: Optional[torch.Tensor],
bias_out: Optional[torch.Tensor],
num_heads: int,
scale: float
) -> torch.Tensor:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
batch_size = x.size(0)
seq_len = x.size(1)
embed_dim = x.size(2)
head_dim = embed_dim // num_heads
if seq_len > FusedAttentionV3Function.MAX_SEQ_LEN:
raise ValueError(f"seq_len {seq_len} exceeds MAX_SEQ_LEN {FusedAttentionV3Function.MAX_SEQ_LEN}")
module = get_attention_v3_module()
ctx.save_for_backward(x, w_qkv, w_out, bias_qkv, bias_out)
ctx.num_heads = num_heads
ctx.scale = scale
ctx.embed_dim = embed_dim
output = module.fused_attention_v3(
x.contiguous(),
w_qkv.contiguous(),
w_out.contiguous(),
bias_qkv,
bias_out,
scale,
num_heads
)
return output
@staticmethod
def backward(ctx, grad_output):
# No autograd support
return None, None, None, None, None, None, None
class FusedAttentionV3(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int = 4,
dropout: float = 0.0,
bias: bool = True
):
super().__init__()
assert embed_dim % num_heads == 0
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.w_qkv = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
self.bias_qkv = nn.Parameter(torch.empty(3 * embed_dim)) if bias else None
self.w_out = nn.Parameter(torch.empty(embed_dim, embed_dim))
self.bias_out = nn.Parameter(torch.empty(embed_dim)) if bias else None
self._reset_parameters()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.w_qkv)
nn.init.xavier_uniform_(self.w_out)
if self.bias_qkv is not None:
nn.init.zeros_(self.bias_qkv)
if self.bias_out is not None:
nn.init.zeros_(self.bias_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return FusedAttentionV3Function.apply(
x,
self.w_qkv,
self.w_out,
self.bias_qkv,
self.bias_out,
self.num_heads,
self.scale
)
|