File size: 5,060 Bytes
93d11ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from stu import STU
from modules import Attention
from utils import get_spectral_filters, nearest_power_of_two
from flash_stu.config import FlashSTUConfig
try:
from flashfftconv import FlashFFTConv
flash_fft_available = True
except ImportError as e:
print(f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation.")
flash_fft_available = False
try:
from flash_attn.modules.mlp import GatedMlp as MLP
triton_mlp = True
except ImportError as e:
print(f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead.")
from modules import MLP
triton_mlp = False
try:
from flash_attn.ops.triton.layer_norm import RMSNorm
except ImportError as e:
print(f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation.")
from torch.nn import RMSNorm
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
except ImportError as e:
print(f"Unable to import Triton-based cross entropy loss: {e}. Falling back to PyTorch implementation.")
from torch.nn import CrossEntropyLoss
class Block(nn.Module):
def __init__(self, config, phi, n, flash_fft) -> None:
super(Block, self).__init__()
# For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
self.rn_1 = RMSNorm(config.n_embd)
self.stu = STU(config, phi, n, flash_fft)
self.rn_2 = RMSNorm(config.n_embd)
self.attn = Attention(config)
self.rn_3 = RMSNorm(config.n_embd)
self.mlp = MLP(
config.n_embd,
config.n_embd * config.mlp_scale,
activation=F.silu, # Use SwiGLU
bias1=config.bias,
bias2=config.bias,
) if triton_mlp else MLP(config)
self.rn_4 = RMSNorm(config.n_embd)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.stu(self.rn_1(x))
x = x + self.mlp(self.rn_2(x))
x = x + self.attn(self.rn_3(x))
x = x + self.mlp(self.rn_4(x))
return x
class FlashSTU(PreTrainedModel):
config_class = FlashSTUConfig
def __init__(self, config) -> None:
super(FlashSTU, self).__init__(config)
self.config = config
self.n_layers = config.n_layers
self.n_embd = config.n_embd
self.mlp_scale = config.mlp_scale
self.seq_len = config.seq_len
self.n = nearest_power_of_two(self.seq_len * 2 - 1, round_up=True)
self.vocab_size = config.vocab_size
self.K = config.num_eigh
self.use_hankel_L = config.use_hankel_L
self.phi = get_spectral_filters(self.seq_len, self.K, self.use_hankel_L)
self.use_approx = config.use_approx
self.flash_fft = (
FlashFFTConv(self.n, dtype=torch.bfloat16)
if config.use_flash_fft and flash_fft_available
else None
)
self.dropout = config.dropout
self.bias = config.bias
self.loss_fn = CrossEntropyLoss()
self.flash_stu = nn.ModuleDict(
dict(
tok_emb=nn.Embedding(self.vocab_size, self.n_embd),
dropout=nn.Dropout(self.dropout),
hidden=nn.ModuleList(
[
Block(self.config, self.phi, self.n, self.flash_fft)
for _ in range(self.n_layers)
]
),
rn_f=RMSNorm(config.n_embd)
)
)
self.lm_head = nn.Linear(self.n_embd, self.vocab_size, bias=self.bias)
self.std = (self.n_embd) ** -0.5
self.apply(self._init_weights)
print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
def forward(self, x: torch.Tensor) -> torch.tensor:
tok_emb = self.flash_stu.tok_emb(x)
x = self.flash_stu.dropout(tok_emb)
for block in self.flash_stu.hidden:
x = block(x)
x = self.flash_stu.rn_f(x)
y_hat = self.lm_head(x)
return y_hat
def _get_num_params(self):
n_params = sum(p.numel() for p in self.parameters())
return n_params
def _init_weights(self, module):
if isinstance(module, nn.Linear):
if hasattr(module, "SCALE_INIT"):
self.std *= (2 * self.n_layers) ** -0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
elif isinstance(module, STU):
if self.use_approx:
torch.nn.init.xavier_normal_(module.M_inputs)
torch.nn.init.xavier_normal_(module.M_filters)
else:
torch.nn.init.xavier_normal_(module.M_phi_plus)
torch.nn.init.xavier_normal_(module.M_phi_minus)
|