| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.nn import Module, ModuleList |
| | import torchaudio |
| | from einops import rearrange |
| | import numpy as np |
| | |
| |
|
| | from torchtune.modules import RotaryPositionalEmbeddings |
| | |
| |
|
| | |
| | class RMSNorm(torch.nn.Module): |
| | def __init__(self, dim: int, eps: float = 1e-6): |
| | r"""https://github.com/meta-llama/llama/blob/main/llama/model.py""" |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(dim)) |
| |
|
| | def forward(self, x): |
| | norm_x = torch.mean(x ** 2, dim=-1, keepdim=True) |
| | output = x * torch.rsqrt(norm_x + self.eps) * self.weight |
| | return output |
| |
|
| |
|
| | |
| | class MLP(nn.Module): |
| | def __init__(self, dim: int) -> None: |
| | super().__init__() |
| |
|
| | self.fc1 = nn.Linear(dim, 4 * dim, bias=False) |
| | self.silu = nn.SiLU() |
| | self.fc2 = nn.Linear(4 * dim, dim, bias=False) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x = self.silu(x) |
| | x = self.fc2(x) |
| | return x |
| |
|
| |
|
| | class Attention(nn.Module): |
| |
|
| | def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): |
| | super().__init__() |
| | |
| | assert dim % n_heads == 0 |
| |
|
| | self.n_heads = n_heads |
| | self.dim = dim |
| | self.rotary_embed = rotary_embed |
| |
|
| | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
| | assert self.flash, "Must have flash attention." |
| | |
| | self.c_attn = nn.Linear(dim, 3 * dim, bias=False) |
| | self.c_proj = nn.Linear(dim, dim, bias=False) |
| | |
| | def forward(self, x): |
| | r""" |
| | Args: |
| | x: (b, t, h*d) |
| | |
| | Constants: |
| | b: batch_size |
| | t: time steps |
| | r: 3 |
| | h: heads_num |
| | d: heads_dim |
| | """ |
| | B, T, C = x.size() |
| |
|
| | q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads) |
| | |
| |
|
| | q = self.rotary_embed(q) |
| | k = self.rotary_embed(k) |
| |
|
| | if self.flash: |
| | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False) |
| | |
| | y = rearrange(y, 'b h t d -> b t (h d)') |
| |
|
| | y = self.c_proj(y) |
| | |
| |
|
| | return y |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): |
| | |
| | super().__init__() |
| | self.dim = dim |
| | self.n_heads = n_heads |
| | |
| | self.att_norm = RMSNorm(dim) |
| | self.ffn_norm = RMSNorm(dim) |
| | self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed) |
| | self.mlp = MLP(dim=dim) |
| | |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | ): |
| | x = x + self.att(self.att_norm(x)) |
| | x = x + self.mlp(self.ffn_norm(x)) |
| | return x |
| | |
| |
|
| | if __name__ == '__main__': |
| | rotary_embed_128 = RotaryPositionalEmbeddings(dim=128) |
| | transformer_block = TransformerBlock( |
| | dim=1024, |
| | n_heads=8, |
| | rotary_embed=rotary_embed_128 |
| | ) |
| | x = torch.randn(2, 128, 1024) |
| | y = transformer_block(x) |
| | print(y.shape) |
| | c=1 |