File size: 4,988 Bytes
72c0672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | # Copyright (c) Meta Platforms, Inc. and affiliates.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask, BlockMask
import torch.utils.checkpoint
from xformers.ops import fmha, AttentionBias
from lingua.transformer import (
BaseTransformer,
BaseTransformerArgs,
RMSNorm,
cross_entropy,
)
def create_causal_mask(seqlen, attn_impl, sliding_window):
if sliding_window is not None and attn_impl == "xformers":
return fmha.attn_bias.LocalAttentionFromBottomRightMask(
window_left=sliding_window - 1, window_right=0
)
elif attn_impl == "xformers":
return fmha.attn_bias.LowerTriangularMask()
elif attn_impl == "sdpa":
return "causal"
elif attn_impl == "flex_attention":
return create_block_mask(causal_mask, None, None, seqlen, seqlen)
else:
raise NotImplementedError(
f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
)
def attention_flops_per_token(n_layers, seq_len, dim, causal):
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1))
def get_num_flop_per_token(
num_non_embed_params: int, n_layers: int, dim: int, seq_len: int
) -> int:
return 6 * num_non_embed_params + attention_flops_per_token(
n_layers, seq_len, dim, True
)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
@dataclass
class LMMTPArgs(BaseTransformerArgs):
seed: int = 42
n_future_head: int = 1
vocab_size: int = -1
attn_impl: str = "sdpa"
mask: str = "causal"
sliding_window: Optional[int] = None
class LMTransformer(BaseTransformer):
def __init__(self, args: LMMTPArgs):
super().__init__(args)
self.sliding_window = args.sliding_window
self.mask = args.mask
self.attn_impl = args.attn_impl
self.n_future_head = args.n_future_head
assert self.n_future_head >= 1
assert args.vocab_size > 0
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.heads = nn.ModuleList()
for _ in range(self.n_future_head):
self.heads.append(
nn.Linear(
args.dim,
args.vocab_size,
bias=False,
)
)
def forward(
self,
token_values: torch.Tensor,
target: Optional[List[torch.Tensor]] = None,
tok_idx: Optional[torch.Tensor] = None,
mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
attn_impl: str = "sdpa",
):
bsz, seqlen = token_values.shape
h = self.tok_embeddings(token_values)
mask = (
mask
if mask is not None
else create_causal_mask(seqlen, self.attn_impl, self.sliding_window)
)
h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
norm_h = self.norm(h)
if target is not None:
if self.training:
ce = []
for i, head in enumerate(self.heads):
logits = torch.utils.checkpoint.checkpoint(
head,
norm_h,
use_reentrant=False,
preserve_rng_state=False,
)
ce.append(cross_entropy(logits, target[..., i]))
else:
head = self.heads[0]
logits = head(norm_h)
ce = cross_entropy(logits, target)
return ce
else:
return self.heads[0](norm_h)
def reset_parameters(self, init_std=None):
# Either use fixed base std or sqrt model dim
super().reset_parameters()
init_std = init_std or (self.dim ** (-0.5))
self.norm.reset_parameters()
nn.init.trunc_normal_(
self.tok_embeddings.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
for head in self.heads:
nn.init.trunc_normal_(
head.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
def init_weights(self):
super().init_weights()
def build_fsdp_grouping_plan(model_args: LMMTPArgs) -> List[Tuple[str, bool]]:
group_plan: Tuple[int, bool] = []
# Grouping and output seperately
group_plan.append(("tok_embeddings", False))
# Grouping by layers
for i in range(model_args.n_layers):
group_plan.append((f"layers.{i}", False))
return group_plan
|