Buckets:
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| from torch import nn | |
| from lingua.transformer import RMSNorm, TiedLinear, cross_entropy | |
| from apps.mamba.core_mamba import BaseMambaArgs, BaseMamba | |
| class LMMambaArgs(BaseMambaArgs): | |
| seed: int = 42 | |
| vocab_size: int = -1 | |
| weight_tying: bool = False | |
| loss_reduction: str = "mean" | |
| def attention_flops_per_token(n_layers, seq_len, dim, causal): | |
| # Formula from https://github.com/Dao-AILab/flash-attention/blob/main/benchmarks/benchmark_flash_attention.py#L27-L30 | |
| return 3.5 * (4 * n_layers * seq_len * dim // (2 if causal else 1)) | |
| def get_num_flop_per_token( | |
| num_non_embed_params: int, n_layers: int, dim: int, seq_len: int | |
| ) -> int: | |
| return 6 * num_non_embed_params + attention_flops_per_token( | |
| n_layers, seq_len, dim, True | |
| ) | |
| class StateCache(nn.Module): | |
| def __init__( | |
| self, bsz, n_heads, head_dim, state_dim, conv_size, conv_dim, dtype, device | |
| ): | |
| super().__init__() | |
| state_shape = (bsz, n_heads, head_dim, state_dim) | |
| if conv_size is None: | |
| conv_shape = (0,) | |
| else: | |
| conv_shape = (bsz, conv_dim, conv_size) | |
| self.register_buffer( | |
| "conv_cache", | |
| torch.zeros(conv_shape, dtype=dtype, device=device), | |
| persistent=False, | |
| ) | |
| self.register_buffer( | |
| "state_cache", | |
| torch.zeros(state_shape, dtype=dtype, device=device), | |
| persistent=False, | |
| ) | |
| def reset(self): | |
| self.conv_cache.zero_() | |
| self.state_cache.zero_() | |
| class LMMamba(BaseMamba): | |
| def __init__(self, args: LMMambaArgs) -> None: | |
| super().__init__(args) | |
| self.weight_tying = args.weight_tying | |
| self.loss_reduction = args.loss_reduction | |
| assert args.vocab_size > 0 | |
| self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) | |
| self.norm = RMSNorm(args.dim, eps=args.norm_eps) | |
| if args.weight_tying: | |
| self.output = TiedLinear(self.tok_embeddings) | |
| else: | |
| self.output = nn.Linear( | |
| args.dim, | |
| args.vocab_size, | |
| bias=False, | |
| ) | |
| def forward( | |
| self, | |
| token_values: torch.Tensor, | |
| target: Optional[torch.Tensor] = None, | |
| tok_idx: Optional[torch.Tensor] = None, | |
| cu_seqlens: Optional[torch.Tensor] = None, | |
| ssm_impl: str = "ssm", | |
| ) -> torch.Tensor: | |
| h = self.tok_embeddings(token_values) | |
| h = super().forward( | |
| h, tok_idx=tok_idx, cu_seqlens=cu_seqlens, ssm_impl=ssm_impl | |
| ) | |
| logits = self.output(self.norm(h)) | |
| if target is not None: | |
| return cross_entropy( | |
| logits.flatten(0, 1), | |
| target.flatten(0, 1), | |
| reduction=self.loss_reduction, | |
| ) | |
| else: | |
| return logits | |
| def reset_parameters(self, init_std=None): | |
| # Either use fixed base std or sqrt model dim | |
| super().reset_parameters() | |
| init_std = init_std or (self.model_dim ** (-0.5)) | |
| self.norm.reset_parameters() | |
| nn.init.trunc_normal_( | |
| self.tok_embeddings.weight, | |
| mean=0.0, | |
| std=init_std, | |
| a=-3 * init_std, | |
| b=3 * init_std, | |
| ) | |
| if not self.weight_tying: | |
| nn.init.trunc_normal_( | |
| self.output.weight, | |
| mean=0.0, | |
| std=init_std, | |
| a=-3 * init_std, | |
| b=3 * init_std, | |
| ) | |
| def init_weights(self): | |
| super().init_weights() | |
| # Optional policy for activation checkpointing. With None, we stick to the default (defined distributed.py: default_no_recompute_ops) | |
| def get_no_recompute_ops(): | |
| return { | |
| torch.ops.aten.mm.default, | |
| torch.ops.aten._scaled_mm.default, | |
| torch.ops.c10d_functional.reduce_scatter_tensor.default, | |
| torch.ops.mamba_ssm.ssm_chunk_scan_combined_fwd.default, | |
| } | |
Xet Storage Details
- Size:
- 4.14 kB
- Xet hash:
- a9ddcc6173184718fa16ec0655ec790a3e5743ddb3d91ebfdeb148eee046f302
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.