File size: 4,143 Bytes
72c0672 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | # Copyright (c) Meta Platforms, Inc. and affiliates.
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from lingua.transformer import RMSNorm, TiedLinear, cross_entropy
from apps.mamba.core_mamba import BaseMambaArgs, BaseMamba
@dataclass
class LMMambaArgs(BaseMambaArgs):
seed: int = 42
vocab_size: int = -1
weight_tying: bool = False
loss_reduction: str = "mean"
def attention_flops_per_token(n_layers, seq_len, dim, causal):
# Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30
return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1))
def get_num_flop_per_token(
num_non_embed_params: int, n_layers: int, dim: int, seq_len: int
) -> int:
return 6 * num_non_embed_params + attention_flops_per_token(
n_layers, seq_len, dim, True
)
class StateCache(nn.Module):
def __init__(
self, bsz, n_heads, head_dim, state_dim, conv_size, conv_dim, dtype, device
):
super().__init__()
state_shape = (bsz, n_heads, head_dim, state_dim)
if conv_size is None:
conv_shape = (0,)
else:
conv_shape = (bsz, conv_dim, conv_size)
self.register_buffer(
"conv_cache",
torch.zeros(conv_shape, dtype=dtype, device=device),
persistent=False,
)
self.register_buffer(
"state_cache",
torch.zeros(state_shape, dtype=dtype, device=device),
persistent=False,
)
def reset(self):
self.conv_cache.zero_()
self.state_cache.zero_()
class LMMamba(BaseMamba):
def __init__(self, args: LMMambaArgs) -> None:
super().__init__(args)
self.weight_tying = args.weight_tying
self.loss_reduction = args.loss_reduction
assert args.vocab_size > 0
self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim)
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
if args.weight_tying:
self.output = TiedLinear(self.tok_embeddings)
else:
self.output = nn.Linear(
args.dim,
args.vocab_size,
bias=False,
)
def forward(
self,
token_values: torch.Tensor,
target: Optional[torch.Tensor] = None,
tok_idx: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
ssm_impl: str = "ssm",
) -> torch.Tensor:
h = self.tok_embeddings(token_values)
h = super().forward(
h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl
)
logits = self.output(self.norm(h))
if target is not None:
return cross_entropy(
logits.flatten(0, 1),
target.flatten(0, 1),
reduction=self.loss_reduction,
)
else:
return logits
def reset_parameters(self, init_std=None):
# Either use fixed base std or sqrt model dim
super().reset_parameters()
init_std = init_std or (self.model_dim ** (-0.5))
self.norm.reset_parameters()
nn.init.trunc_normal_(
self.tok_embeddings.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
if not self.weight_tying:
nn.init.trunc_normal_(
self.output.weight,
mean=0.0,
std=init_std,
a=-3 * init_std,
b=3 * init_std,
)
@torch.inference_mode()
def init_weights(self):
super().init_weights()
# Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops)
def get_no_recompute_ops():
return {
torch.ops.aten.mm.default,
torch.ops.aten._scaled_mm.default,
torch.ops.c10d_functional.reduce_scatter_tensor.default,
torch.ops.mamba_ssm.ssm_chunk_scan_combined_fwd.default,
}
|