File size: 3,844 Bytes
3386f25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
StyleForge - Fused Attention V3 Python Wrapper

V3 uses register-based accumulation (no shared memory for V).
Educational kernel - still slower than Flash Attention 2 due to
fundamental limitations (element-wise matmul vs tensor cores).
"""

import torch
import torch.nn as nn
from pathlib import Path
from typing import Optional

from utils import compile_inline

_attention_v3_module = None

def get_attention_v3_module():
    global _attention_v3_module

    if _attention_v3_module is not None:
        return _attention_v3_module

    kernel_path = Path(__file__).parent / "attention_v3.cu"

    if not kernel_path.exists():
        raise FileNotFoundError(f"V3 kernel not found at {kernel_path}")

    cuda_source = kernel_path.read_text()

    print("Compiling fused attention V3 kernel (register-based)...")
    _attention_v3_module = compile_inline(
        name='fused_attention_v3',
        cuda_source=cuda_source,
        functions=['fused_attention_v3'],
        build_directory=Path('build_v3'),
        verbose=False
    )
    print("V3 Compilation complete!")

    return _attention_v3_module

class FusedAttentionV3Function(torch.autograd.Function):
    MAX_SEQ_LEN = 4096  # Conservative limit
    MAX_HEAD_DIM = 128

    @staticmethod
    def forward(
        ctx,
        x: torch.Tensor,
        w_qkv: torch.Tensor,
        w_out: torch.Tensor,
        bias_qkv: Optional[torch.Tensor],
        bias_out: Optional[torch.Tensor],
        num_heads: int,
        scale: float
    ) -> torch.Tensor:
        if not torch.cuda.is_available():
            raise RuntimeError("CUDA is not available")

        batch_size = x.size(0)
        seq_len = x.size(1)
        embed_dim = x.size(2)
        head_dim = embed_dim // num_heads

        if seq_len > FusedAttentionV3Function.MAX_SEQ_LEN:
            raise ValueError(f"seq_len {seq_len} exceeds MAX_SEQ_LEN {FusedAttentionV3Function.MAX_SEQ_LEN}")

        module = get_attention_v3_module()

        ctx.save_for_backward(x, w_qkv, w_out, bias_qkv, bias_out)
        ctx.num_heads = num_heads
        ctx.scale = scale
        ctx.embed_dim = embed_dim

        output = module.fused_attention_v3(
            x.contiguous(),
            w_qkv.contiguous(),
            w_out.contiguous(),
            bias_qkv,
            bias_out,
            scale,
            num_heads
        )

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # No autograd support
        return None, None, None, None, None, None, None

class FusedAttentionV3(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int = 4,
        dropout: float = 0.0,
        bias: bool = True
    ):
        super().__init__()

        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.w_qkv = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
        self.bias_qkv = nn.Parameter(torch.empty(3 * embed_dim)) if bias else None

        self.w_out = nn.Parameter(torch.empty(embed_dim, embed_dim))
        self.bias_out = nn.Parameter(torch.empty(embed_dim)) if bias else None

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.w_qkv)
        nn.init.xavier_uniform_(self.w_out)
        if self.bias_qkv is not None:
            nn.init.zeros_(self.bias_qkv)
        if self.bias_out is not None:
            nn.init.zeros_(self.bias_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return FusedAttentionV3Function.apply(
            x,
            self.w_qkv,
            self.w_out,
            self.bias_qkv,
            self.bias_out,
            self.num_heads,
            self.scale
        )