File size: 4,819 Bytes
5a1cdf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils import nearest_power_of_two

try:
    from flash_attn import flash_attn_func as fa2
except ImportError as e:
    print(
        f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
    )
    # TODO: Add FlexAttention + local attention mask when it's in stable release

class Attention(nn.Module):
    def __init__(self, config):
        super(Attention, self).__init__()
        if isinstance(config.torch_dtype, str):
            torch_dtype = getattr(torch, config.torch_dtype)
        else:
            torch_dtype = config.torch_dtype
        assert torch.cuda.is_available(), "CUDA is required."
        assert config.n_embd % config.n_heads == 0
        self.n_heads = config.n_heads

        self.device = torch.device("cuda")
        self.bsz = config.bsz
        self.c_attn = nn.Linear(
            config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype
        )
        self.c_proj = nn.Linear(
            config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype
        )
        self.c_proj.SCALE_INIT = 1
        self.dropout = config.dropout
        self.resid_dropout = nn.Dropout(self.dropout)
        self.alibi_slopes = self._get_alibi_slopes(self.n_heads)
        self.window_size = config.window_size
        self.softcap = config.softcap

    def _generate_slopes(self, n: int):
        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
        return [start * (start**i) for i in range(n)]

    def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
        # If n_heads is a power of 2, generate slopes directly
        if math.log2(n_heads).is_integer():
            slopes = self._generate_slopes(n_heads)
        else:
            # Get slopes for the nearest power of two
            n = nearest_power_of_two(n_heads, round_up=False)
            slopes_power_of_two = self._generate_slopes(n)
    
            # Generate extra slopes
            extra_slopes = self._generate_slopes(2 * n)
            extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
            slopes = slopes_power_of_two + extra_slopes_trunc
        slopes = torch.tensor(slopes, device=self.device)
        slopes = slopes * interpolation_factor  # https://arxiv.org/pdf/2310.13017
        return slopes.to(torch.float32)  # Ensure slopes are in float32


    def forward(self, x):
        bsz, seq_len, d_in = x.size()

        qkv = self.c_attn(x)
        q, k, v = torch.chunk(qkv, 3, dim=2)

        q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
        k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
        v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads)
        y = fa2(  # https://arxiv.org/pdf/2307.08691
            q,
            k,
            v,
            dropout_p=self.dropout if self.training else 0.0,
            causal=True,
            window_size=(self.window_size, 0),
            alibi_slopes=self.alibi_slopes,  # https://arxiv.org/pdf/2108.12409
            softcap=self.softcap,  # https://arxiv.org/pdf/2408.00118
        )
        y = y.contiguous().view(bsz, seq_len, d_in)
        y = self.resid_dropout(self.c_proj(y))
        return y

class AttentionSDPA(nn.Module):
    def __init__(self, config):
        super(Attention, self).__init__()
        if isinstance(config.torch_dtype, str):
            torch_dtype = getattr(torch, config.torch_dtype)
        else:
            torch_dtype = config.torch_dtype
        assert torch.cuda.is_available(), "CUDA is required."
        assert config.n_embd % config.n_heads == 0
        self.n_heads = config.n_heads

        self.device = torch.device("cuda") # Technically don't need CUDA for SDPA
        self.bsz = config.bsz
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias, dtype=torch_dtype)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias, dtype=torch_dtype)
        self.dropout = config.dropout
        self.resid_dropout = nn.Dropout(self.dropout)

    def forward(self, x):
        bsz, seq_len, d_in = x.size()

        qkv = self.c_attn(x)
        q, k, v = torch.chunk(qkv, 3, dim=2)

        q = q.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
        k = k.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)
        v = v.view(bsz, seq_len, self.n_heads, d_in // self.n_heads).transpose(1, 2)

        y = F.scaled_dot_product_attention(
            q, k, v,
            is_causal=True,
            dropout_p=self.dropout if self.training else 0.0
        )

        y = y.transpose(1, 2).contiguous().view(bsz, seq_len, d_in)

        y = self.resid_dropout(self.c_proj(y))
        return y