# Copied verbatim from vortex (minus the commented out code) # Copyright (c) 2024, Michael Poli. import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import Callable from .utils import grab_first_if_tuple from transformer_engine.pytorch import Linear from transformer_engine.common.recipe import Format, DelayedScaling import transformer_engine.pytorch as te # Not bothering with ops right now (which is an interface with custom Triton # kernels) # try: # from hyena_ops import hyena_se_fwd, hyena_mr_fwd, hyena_li_fwd # except ImportError: # hyena_se_fwd, hyena_mr_fwd, hyena_li_fwd = None, None, None hyena_se_fwd, hyena_mr_fwd, hyena_li_fwd = None, None, None def set_format_recipe(): fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") return fp8_format, fp8_recipe class TELinear(Linear): """ Wrapper for Transformer-Engine's `Linear` layer. Note that if Megatron's parallel_state has not been initialized yet, the tp_group passed to TE will be None and must be set later via set_tensor_parallel_group(). """ def __init__( self, input_size: int, output_size: int, init_method: Callable, bias: bool = True, skip_bias_add: bool = False, use_fp8: bool = False, **kwargs, ): # Parameters are initialized at higher precision even if fp8 # is used params_dtype = torch.bfloat16 # TE returns a zero length Tensor when bias=False and # return_bias=True, but we prefer None. So in that case we # tell TE to not return the bias, and return None # ourselves. This way our forward always returns two values # and we don't have to deal with the zero length Tensor. self.te_return_bias = skip_bias_add and bias self.use_fp8_input_projections = use_fp8 if use_fp8: self.fp8_format, self.fp8_recipe = set_format_recipe() super().__init__( in_features=input_size, out_features=output_size, sequence_parallel=False, fuse_wgrad_accumulation=False, tp_group=None, tp_size=1, init_method=init_method, params_dtype=params_dtype, parallel_mode=None, bias=bias, return_bias=self.te_return_bias, **kwargs, ) def forward(self, x): if self.use_fp8_input_projections: with te.fp8_autocast(enabled=True, fp8_recipe=self.fp8_recipe): out = super().forward(x) else: out = super().forward(x) # TE only returns a tuple when return_bias is True, otherwise # it returns a single Tensor, we always want to return two # values regardless of the arguments. if self.te_return_bias: return out return out, None class FlexLinear: """ Megatron and Transformer Engine linear layer compatible with fp8, bf16, fp16 and fp32 """ def __new__( self, input_size, output_size, config, parallel_mode: str, bias: bool = False, skip_bias_add: bool = True, use_fp8: bool = False, input_is_parallel=False, # for row parallel gather_output: bool = True, # for column parallel parallel_output: bool = False, # for row parallel **kwargs, ): # use_fp8 = config.use_fp8_linears self.config = config instance = None if use_fp8: instance = TELinear( input_size=input_size, output_size=output_size, config=self.config, parallel_mode=parallel_mode, bias=bias, skip_bias_add=skip_bias_add, **kwargs, ) return instance class RMSNorm(torch.nn.Module): def __init__(self, config): super(RMSNorm, self).__init__() self.eps, self.hidden_size = config.eps, config.hidden_size self.scale = torch.nn.Parameter(torch.ones(self.hidden_size, dtype=config.params_dtype)) self.register_parameter("scale", self.scale) self.use_flash_rmsnorm = config.get("use_flash_rmsnorm", False) if self.use_flash_rmsnorm: from flash_attn.ops.rms_norm import rms_norm as rmsnorm_func self.rmsnorm_func = rmsnorm_func def forward(self, x): if self.use_flash_rmsnorm: return self.rmsnorm_func(x, self.scale, self.eps) else: y = x / (x.norm(2, dim=-1, keepdim=True) * self.hidden_size ** (-1.0 / 2) + self.eps) return self.scale * y class ParallelGatedMLP(nn.Module): def __init__( self, config, layer_idx, ): super().__init__() self.layer_idx = layer_idx multiple_of = config.get("inner_size_multiple_of", 64) self.act_type = config.get("mlp_activation", "gelu") if self.act_type == "gelu": self.act = F.gelu elif self.act_type == "silu": self.act = F.silu else: raise NotImplementedError if self.layer_idx > 0 and config.get("evo2_style_activations", False): self.act = nn.Identity() self.multiple_of = multiple_of * config.model_parallel_size inner_size = int(2 * config.hidden_size * 4 / 3) inner_size = self.multiple_of * ((inner_size + self.multiple_of - 1) // self.multiple_of) inner_size = config.get("inner_mlp_size", inner_size) self.l1 = nn.Linear( in_features=config.hidden_size, out_features=inner_size, bias=False, ) self.l2 = nn.Linear( in_features=config.hidden_size, out_features=inner_size, bias=False, ) self.l3 = nn.Linear( in_features=inner_size, out_features=config.hidden_size, bias=False, ) def forward(self, z): z1, z2 = self.l1(z), self.l2(z) z1, z2 = grab_first_if_tuple(z1), grab_first_if_tuple(z2) y = self.l3(self.act(z1) * z2) return grab_first_if_tuple(y) class Embedding(nn.Module): _train_dtype = "bf16" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) def embed(self, input_ids, position_ids=None, tokentype_ids=None): embeddings = self.word_embeddings(input_ids) return embeddings def unembed(self, u): weight = self.word_embeddings.weight return torch.matmul(u, weight) class VocabParallelEmbedding(nn.Embedding): "Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py" def __init__(self, config): vocab_size, process_group, padding_idx = ( config.vocab_size, config.get("process_group", None), config.get("padding_idx", None), ) self.process_group = process_group if process_group is not None: world_size = torch.distributed.get_world_size(process_group) if vocab_size % world_size != 0: raise ValueError(f"vocab_size ({vocab_size}) must be divisible by " f"world_size ({world_size})") if world_size > 1 and padding_idx is not None: raise RuntimeError("ParallelEmbedding does not support padding_idx") else: world_size = 1 super().__init__( vocab_size // world_size, embedding_dim=config.hidden_size, padding_idx=padding_idx, ) def forward(self, input: Tensor) -> Tensor: if self.process_group is None: return super().forward(input) else: rank = torch.distributed.get_rank(self.process_group) vocab_size = self.num_embeddings vocab_start_index, vocab_end_index = ( rank * vocab_size, (rank + 1) * vocab_size, ) # Create a mask of valid vocab ids (1 means it needs to be masked). input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index) input = input - vocab_start_index input[input_ids_mask] = 0 embeddings = self.forward(input) embeddings[input_ids_mask] = 0.0 # Reduce to the global process group torch.distributed.all_reduce(embeddings, group=self.process_group) return embeddings def unembed(self, u: Tensor) -> Tensor: if self.process_group is None: return u @ self.weight.T else: raise NotImplementedError class VocabParallelUnembedding(VocabParallelEmbedding): def forward(self, input: Tensor) -> Tensor: return self.unembed(input)