File size: 5,060 Bytes
93d11ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import PreTrainedModel

from stu import STU
from modules import Attention
from utils import get_spectral_filters, nearest_power_of_two
from flash_stu.config import FlashSTUConfig

try:
    from flashfftconv import FlashFFTConv
    flash_fft_available = True
except ImportError as e:
    print(f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation.")
    flash_fft_available = False

try:
    from flash_attn.modules.mlp import GatedMlp as MLP
    triton_mlp = True
except ImportError as e:
    print(f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead.")
    from modules import MLP
    triton_mlp = False

try:
    from flash_attn.ops.triton.layer_norm import RMSNorm
except ImportError as e:
    print(f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation.")
    from torch.nn import RMSNorm

try:
    from flash_attn.losses.cross_entropy import CrossEntropyLoss
except ImportError as e:
    print(f"Unable to import Triton-based cross entropy loss: {e}. Falling back to PyTorch implementation.")
    from torch.nn import CrossEntropyLoss

class Block(nn.Module):
    def __init__(self, config, phi, n, flash_fft) -> None:
        super(Block, self).__init__()
        # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
        self.rn_1 = RMSNorm(config.n_embd)
        self.stu = STU(config, phi, n, flash_fft)
        self.rn_2 = RMSNorm(config.n_embd)
        self.attn = Attention(config)
        self.rn_3 = RMSNorm(config.n_embd)
        self.mlp = MLP(
            config.n_embd, 
            config.n_embd * config.mlp_scale, 
            activation=F.silu, # Use SwiGLU
            bias1=config.bias,
            bias2=config.bias,
        ) if triton_mlp else MLP(config)
        self.rn_4 = RMSNorm(config.n_embd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.stu(self.rn_1(x))
        x = x + self.mlp(self.rn_2(x))
        x = x + self.attn(self.rn_3(x))
        x = x + self.mlp(self.rn_4(x))
        return x

class FlashSTU(PreTrainedModel):
    config_class = FlashSTUConfig

    def __init__(self, config) -> None:
        super(FlashSTU, self).__init__(config)
        self.config = config
        self.n_layers = config.n_layers
        self.n_embd = config.n_embd
        self.mlp_scale = config.mlp_scale
        self.seq_len = config.seq_len
        self.n = nearest_power_of_two(self.seq_len * 2 - 1, round_up=True)
        self.vocab_size = config.vocab_size
        self.K = config.num_eigh
        self.use_hankel_L = config.use_hankel_L
        self.phi = get_spectral_filters(self.seq_len, self.K, self.use_hankel_L)
        self.use_approx = config.use_approx
        self.flash_fft = (
            FlashFFTConv(self.n, dtype=torch.bfloat16) 
            if config.use_flash_fft and flash_fft_available 
            else None
        )
        self.dropout = config.dropout
        self.bias = config.bias
        self.loss_fn = CrossEntropyLoss()

        self.flash_stu = nn.ModuleDict(
            dict(
                tok_emb=nn.Embedding(self.vocab_size, self.n_embd),
                dropout=nn.Dropout(self.dropout),
                hidden=nn.ModuleList(
                    [
                        Block(self.config, self.phi, self.n, self.flash_fft)
                        for _ in range(self.n_layers)
                    ]
                ),
                rn_f=RMSNorm(config.n_embd)
            )
        )
        self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=self.bias)

        self.std = (self.n_embd) ** -0.5
        self.apply(self._init_weights)
        print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))

    def forward(self, x: torch.Tensor) -> torch.tensor:
        tok_emb = self.flash_stu.tok_emb(x)
        x = self.flash_stu.dropout(tok_emb)

        for block in self.flash_stu.hidden:
            x = block(x)
        x = self.flash_stu.rn_f(x)

        y_hat = self.lm_head(x)
        return y_hat

    def _get_num_params(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            if hasattr(module, "SCALE_INIT"):
                self.std *= (2 * self.n_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
        elif isinstance(module, STU):
            if self.use_approx:
                torch.nn.init.xavier_normal_(module.M_inputs)
                torch.nn.init.xavier_normal_(module.M_filters)
            else:
                torch.nn.init.xavier_normal_(module.M_phi_plus)
                torch.nn.init.xavier_normal_(module.M_phi_minus)