# Copied verbatim from vortex # Copyright (c) 2024, Michael Poli. import math import torch import torch.nn as nn import torch.nn.functional as F from .cache import ( InferenceParams, HyenaCascadeFIRInferenceParams, HyenaCascadeIIRInferenceParams, ) from .engine import HyenaInferenceEngine from .layers import ( ParallelGatedMLP, RMSNorm, VocabParallelEmbedding, VocabParallelUnembedding, TELinear, ) from .utils import ( Lambda, column_split, interleave, print_rank_0, move_to_device, fixup_fp8_extra_states, fixup_te_workspace, ) from .rich_logging import activations_logger, enable_activations_logging import logging from tqdm import tqdm from .attention import MHA try: from .positional_embeddings import swap_mha_rope except ImportError: "could not import swap_mha_rope from src.positional_embeddings" class AttentionBlock(nn.Module): def __init__(self, config, layer_idx) -> None: super().__init__() self.config = config self.pre_norm, self.post_norm = RMSNorm(config), RMSNorm(config) self.layer_idx = layer_idx self.print_activations = config.get("print_activations", False) self.proj_groups = config.get("proj_groups", 1) dtype = config.get("attn_block_dtype", torch.bfloat16) mlp_dtype = config.get("mlp_dtype", torch.bfloat16) self.num_attention_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.hidden_size_per_attention_head = config.hidden_size // config.num_attention_heads self.counter = 0 self.inner_mha_cls = MHA( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, num_heads_kv=config.num_attention_heads // self.proj_groups, rotary_emb_dim=config.hidden_size // config.num_attention_heads, qkv_proj_bias=config.get("qkv_proj_bias", True), rotary_emb_base=config.get("rotary_emb_base", 1000000), causal=True, layer_idx=layer_idx, out_proj_bias=config.get("mha_out_proj_bias", True), use_flash_attn=self.config.use_flash_attn, ).to(dtype=dtype) # check if using interpolated rotary pos emb from config, and swap the rope emb if config.get("use_interpolated_rotary_pos_emb", False): swap_mha_rope( mha=self.inner_mha_cls, kwargs_new_rope={"scaling_factor": config.get("rotary_emb_scaling_factor", 1.0)}, ) if self.config.get("smeared_gqa", False): self.inner_mha_cls.num_heads_kv = self.inner_mha_cls.num_heads self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq) self.mlp = ParallelGatedMLP(config, layer_idx).to(dtype=mlp_dtype) def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs): if ( type(padding_mask) == torch.Tensor ): # workaround for masking bug in FA. This works because Wqkv does not have bias # and attention scores will be also automatically zeroed. u = u * padding_mask[..., None] if self.print_activations: activations_logger.info(f"pre mha: {u}") u = ( self.inner_mha_cls( self.pre_norm(u), inference_params=inference_params, ) + u ) if self.print_activations: activations_logger.info(f"post mha: {u}") if type(padding_mask) == torch.Tensor: # guard against bias u = u * padding_mask[..., None] if self.print_activations: activations_logger.info(f"pre mlp: {u} {u.min()} {u.max()} {self.mlp.__class__}") activations_logger.info( f"post mlp norm: {self.post_norm(u)} {self.post_norm(u).min()} {self.post_norm(u).max()}" ) activations_logger.info( f"post mlp: {self.mlp(self.post_norm(u))} {self.mlp(self.post_norm(u)).min()} {self.mlp(self.post_norm(u)).max()}" ) u = self.mlp(self.post_norm(u)) + u return u, None class HyenaCascade(nn.Module): def __init__(self, config, layer_idx, hyena_filter_groups=None, fir_inner_filter_length=None) -> None: super().__init__() self.config = config self.layer_idx = layer_idx self.hyena_filter_groups = hyena_filter_groups self.print_activations = config.get("print_activations", False) self.ground_truth_activations_path = config.get("ground_truth_activations_path", None) self.use_flashfft = config.get("use_flashfft", False) self.state_size = config.state_size self.hidden_size = config.hidden_size self.num_filters = config.num_filters self.inference_mode = config.get("inference_mode", True) self.counter = 0 self.column_split_hyena = config.get("column_split_hyena", True) self.hyena_flip_x1x2 = config.get("hyena_flip_x1x2", False) assert self.hidden_size % self.num_filters == 0 and self.num_filters <= self.hidden_size # attention heads are not used except to split post short_filter # projections in the same way as the checkpoint self.num_attention_heads = config.num_attention_heads self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads self.fir_inner_filter_length = fir_inner_filter_length self.short_filter_length = config.short_filter_length self.short_filter_weight = nn.Parameter(torch.randn(3 * config.hidden_size, 1, config.short_filter_length)) self.short_filter_bias = nn.Parameter(torch.randn(3 * config.hidden_size)) if config.short_filter_bias else None self.engine = HyenaInferenceEngine( layer_idx=layer_idx, ground_truth_activations_path=self.ground_truth_activations_path, print_activations=self.print_activations, hyena_flip_x1x2=config.get("hyena_flip_x1x2", False), ) self.use_flash_depthwise = config.get("use_flash_depthwise", False) self.data_dtype = None if self.use_flash_depthwise: try: from flashfftconv import FlashDepthwiseConv1d self.fir_fn = FlashDepthwiseConv1d( channels=3 * self.hidden_size, kernel_size=self.short_filter_length, padding=self.short_filter_length - 1, weights=self.short_filter_weight, bias=self.short_filter_bias, device=None, dtype=self.config.get("depthwise_dtype", torch.bfloat16), ) except ImportError: "flashfftconv not installed" else: self.fir_fn = F.conv1d self.fir_inner_fn = F.conv1d self.fftconv_fn = None self.long_fir_threshold = config.get("long_fir_threshold", None) if self.long_fir_threshold is not None: assert self.use_flashfft is False, "long_fir_threshold not compatible with fused flashfft" self.num_systems = self.hyena_filter_groups self.channels_per_group = self.hidden_size // self.hyena_filter_groups if self.fir_inner_filter_length: self.h = nn.Parameter(torch.randn(self.hyena_filter_groups, 1, fir_inner_filter_length)) if fir_inner_filter_length >= 128: self.D = nn.Parameter(torch.zeros(self.hidden_size)) if fir_inner_filter_length < 128: self.D = None else: log_poles = torch.randn(self.num_systems, self.state_size, 1, dtype=torch.float32) # TODO: bring over init from internals # poles[..., 0] = 1e-2 * torch.randn(self.num_systems, self.state_size, 1) # poles[..., 1] = 1e-3 * torch.randn(self.num_systems, self.state_size, 1) self.log_poles = nn.Parameter(log_poles) self.residues = nn.Parameter(torch.randn(self.num_systems, self.state_size, dtype=torch.float32)) self.D = nn.Parameter(torch.zeros(self.hidden_size)) self.h = None self.t = None def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs): if inference_params is not None and self.layer_idx in inference_params.fir_state_dict.keys(): return self.sequential_forward(u, inference_params) else: return self.parallel_forward(u, inference_params, padding_mask) def parallel_forward(self, u, inference_params=None, padding_mask=None): L = u.shape[1] dims = ( self.hidden_size, self.num_attention_heads, self.hidden_size_per_attention_head, self.state_size, self.hyena_filter_groups, ) if self.print_activations: activations_logger.info(f"pre 1 parallel fir: {u}, {u.min()}, {u.max()}") z_pre, fir_state = self.engine.parallel_fir( self.fir_fn, u, self.short_filter_weight, self.short_filter_bias, L, dims=dims, gate=False, column_split_hyena=self.column_split_hyena, fir_length=self.short_filter_length, inference_params=inference_params, padding_mask=padding_mask, dim_last=True, ) if inference_params: inference_params.fir_state_dict[self.layer_idx] = fir_state if self.config.interleave: z_pre = interleave(z_pre) if self.h is None: h, _, _, _ = self.compute_filter(L, u.device) else: h = self.h D = self.D if self.hyena_filter_groups > 1: h = h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 0) # if inference_params is not None, we plan to perform generation: # prefilling is handled by the engine. if self.fir_inner_filter_length is not None: if self.print_activations: activations_logger.info( f"pre 2 parallel fir: {z_pre}, {z_pre.min()}, {z_pre.max()}, {self.fir_inner_filter_length}" ) y, fir_inner_state = self.engine.parallel_fir( self.fir_inner_fn, z_pre, h, D, L, dims=dims, gate=True, gated_bias=self.fir_inner_filter_length >= 128, dim_last=False, column_split_hyena=self.column_split_hyena, fir_length=self.fir_inner_filter_length, inference_params=inference_params, padding_mask=padding_mask, groups=self.hyena_filter_groups, ) if self.print_activations: activations_logger.info(f"post 2 parallel fir: {y}, {y.min()}, {y.max()}") y = y.permute(0, 2, 1) if inference_params: inference_params.fir_inner_state_dict[self.layer_idx] = fir_inner_state else: if self.print_activations: activations_logger.info(f"pre 2 parallel iir: {z_pre}, {z_pre.min()}, {z_pre.max()}") y = self.engine.parallel_iir( z_pre, h, D, L, t=self.t, poles=self.log_poles, residues=self.residues, dims=dims, inference_params=inference_params, layer_idx=self.layer_idx, prefill_style=self.config.get("prefill_style", "fft"), use_flashfft=self.use_flashfft, fftconv_fn=self.fftconv_fn, column_split_hyena=self.column_split_hyena, long_fir_threshold=self.long_fir_threshold, padding_mask=padding_mask, ) if self.print_activations: activations_logger.info(f"post 2 parallel iir: {y}, {y.min()}, {y.max()}") return y, inference_params def sequential_forward(self, u, inference_params): if self.data_dtype is None: self.data_dtype = u.dtype if len(u.shape) > 2: u = u[:, -1] z_pre, fir_state = self.engine.step_fir( u, inference_params.fir_state_dict[self.layer_idx], weight=self.short_filter_weight, bias=self.short_filter_bias, ) inference_params.fir_state_dict[self.layer_idx] = fir_state if self.config.interleave: z_pre = interleave(z_pre) x2, x1, v = ( column_split(z_pre, self.num_attention_heads, self.hidden_size_per_attention_head) if self.column_split_hyena else z_pre.split([self.hidden_size, self.hidden_size, self.hidden_size], dim=1) ) if self.hyena_flip_x1x2: x1, x2 = x2, x1 if self.fir_inner_filter_length is not None: if self.hyena_filter_groups > 1: h = self.h.repeat_interleave(self.hidden_size // self.hyena_filter_groups, 0) else: h = self.h y, fir_inner_state = self.engine.step_fir( x1 * v, inference_params.fir_inner_state_dict[self.layer_idx], weight=h, bias=self.D, flip_filter=self.fir_inner_filter_length >= 128, gated_bias=self.fir_inner_filter_length >= 128, ) y = y * x2 inference_params.fir_inner_state_dict[self.layer_idx] = fir_inner_state else: y, iir_state = self.engine.step_iir( x2, x1, v, self.D, self.residues, self.log_poles, inference_params.state_dict[self.layer_idx], iir_groups=1, ) inference_params.state_dict[self.layer_idx] = iir_state y = y.to(dtype=self.data_dtype) return y[:, None], inference_params def update_time(self, L, device): """ Set [0, 1, ..., L-1] where L is the length of the current batch of inputs. If L is greater than the length of the previous batch, then the time vector is reinitialized. Otherwise, the time vector is truncated from cache. """ if self.t is None: self.t = torch.arange(L, device=device)[None, None] elif self.t.shape[-1] < L: self.t = torch.arange(L, device=device)[None, None] else: self.t = self.t[..., :L] def compute_filter(self, L, device): self.update_time(L, device) filter_dtype = torch.float32 residues, log_poles = ( self.residues.to(filter_dtype), self.log_poles.to(filter_dtype), ) h = (residues[..., None] * (log_poles * self.t).exp()).sum(1)[None] # B, D, L return h, filter_dtype, log_poles, residues class ParallelGatedConvBlock(nn.Module): def __init__(self, config, layer_idx, hyena_filter_groups=None, fir_inner_filter_length=None) -> None: super().__init__() self.config = config self.layer_idx = layer_idx self.print_activations = config.get("print_activations", False) self.ground_truth_activations_path = config.get("ground_truth_activations_path", None) self.low_mem_mode = config.get("low_mem_mode", False) self.fir_inner_filter_length = fir_inner_filter_length self.hyena_filter_groups = hyena_filter_groups if hyena_filter_groups is not None else config.hidden_size dtype = config.get("hyena_block_dtype", torch.bfloat16) mlp_dtype = config.get("mlp_dtype", torch.bfloat16) self.pre_norm, self.post_norm = ( RMSNorm(config).to(dtype=dtype), RMSNorm(config).to(dtype=dtype), ) self.filter = HyenaCascade( config, layer_idx, hyena_filter_groups=self.hyena_filter_groups, fir_inner_filter_length=fir_inner_filter_length, ).to(dtype=dtype) # For posterity/debugging: TELinear can be easily replaced by # nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.qkv_proj_bias).to(dtype=dtype) # which sometimes is very useful when debugging FP8. # Ishan: replacing TELinear with nn.Linear to get meta tensor loading to # behave. # self.projections = TELinear( # config.hidden_size, # 3 * config.hidden_size, # bias=config.qkv_proj_bias, # init_method=torch.nn.init.xavier_uniform_, # use_fp8=config.get("use_fp8_input_projections", False), # ) self.projections = nn.Linear( config.hidden_size, 3 * config.hidden_size, bias=config.qkv_proj_bias, ).to(dtype=dtype) self.out_filter_dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.hyena_out_proj_bias).to( dtype ) self.mlp = ParallelGatedMLP(config, layer_idx).to(dtype=mlp_dtype) # self.proj_norm_fn = self.proj_norm # self.res_mlp_norm_fn = self.res_mlp_norm if self.config.get("compile", False): self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead") self.res_mlp_norm_fn = torch.compile( self.res_mlp_norm, fullgraph=True, dynamic=False, mode="reduce-overhead" ) def pad_to_multiple(self, x, multiple=16): """Pad input tensor to multiple of 16 only when FP8 is enabled""" if not self.config.get("use_fp8_input_projections", False): return x batch_size, seq_len, hidden_dim = x.size() pad_len = (multiple - (seq_len % multiple)) % multiple if pad_len == 0: return x return F.pad(x, (0, 0, 0, pad_len)) def proj_norm(self, x): if self.print_activations: activations_logger.info(f"pre mixer norm: {x} {x.min()} {x.max()} {self.projections.__class__}") activations_logger.info( f"post mixer norm: {self.pre_norm(x)} {self.pre_norm(x).min()} {self.pre_norm(x).max()}" ) if self.ground_truth_activations_path: pre_norm_savanna = torch.load( f"{self.ground_truth_activations_path}/pre_mixer_norm_{self.layer_idx}.pt" ) post_norm_savanna = torch.load( f"{self.ground_truth_activations_path}/post_mixer_norm_{self.layer_idx}.pt" ) activation_diff = (x.squeeze() - pre_norm_savanna.squeeze()).abs() activations_logger.info( f"pre mixer norm activation_diff: {activation_diff.max()}, {activation_diff.mean()}" ) activation_diff = (self.pre_norm(x).squeeze() - post_norm_savanna.squeeze()).abs() activations_logger.info( f"post mixer norm activation_diff: {activation_diff.max()}, {activation_diff.mean()}" ) activations_logger.info( f"pre norm scale: {self.pre_norm.scale}, {self.pre_norm.scale.min()}, {self.pre_norm.scale.max()}" ) normalized = self.pre_norm(x) normalized = self.pad_to_multiple(normalized) # Ishan: comment out this vestige of manual device management # with torch.cuda.device(x.device): # projected = self.projections(normalized) projected = self.projections(normalized) if isinstance(projected, tuple): projected = projected[0] original_seq_len = x.size(1) # Slice back to original sequence length if padding was added if projected.size(1) > original_seq_len: projected = projected[:, :original_seq_len, :] return projected def res_mlp_norm(self, x): if self.print_activations: activations_logger.info(f"pre mlp: {x} {x.min()} {x.max()} {self.mlp.__class__}") activations_logger.info( f"post mlp norm: {self.post_norm(x)} {self.post_norm(x).min()} {self.post_norm(x).max()}" ) activations_logger.info( f"post mlp: {self.mlp(self.post_norm(x))} {self.mlp(self.post_norm(x)).min()} {self.mlp(self.post_norm(x)).max()}" ) if self.ground_truth_activations_path: pre_mlp_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_mlp_{self.layer_idx}.pt") post_mlp_savanna = torch.load(f"{self.ground_truth_activations_path}/post_mlp_norm_{self.layer_idx}.pt") activation_diff = (x.squeeze() - pre_mlp_savanna.squeeze()).abs() activations_logger.info(f"pre mlp activation_diff: {activation_diff.max()}, {activation_diff.mean()}") activation_diff = (self.post_norm(x).squeeze() - post_mlp_savanna.squeeze()).abs() activations_logger.info( f"post mlp norm activation_diff: {activation_diff.max()}, {activation_diff.mean()}" ) return self.mlp(self.post_norm(x)) + x def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs): z = self.proj_norm(u) if type(padding_mask) == torch.Tensor: # guard against bias z = z * padding_mask[..., None] if self.print_activations: activations_logger.info(f"pre filter: {z} {z.min()} {z.max()} {self.filter.__class__}") if self.ground_truth_activations_path: z_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_filter_{self.layer_idx}.pt") activation_diff = (z - z_savanna.squeeze()).abs() activations_logger.info( f"pre filter activation_diff: {activation_diff.max()}, {activation_diff.mean()}" ) z, inference_params = self.filter(z, inference_params=inference_params, padding_mask=padding_mask) if self.print_activations: activations_logger.info(f"post postgate: {z} {z.min()} {z.max()} {self.filter.__class__}") activations_logger.info( f"post out proj: {self.out_filter_dense(z)} {self.out_filter_dense(z).min()} {self.out_filter_dense(z).max()} {self.out_filter_dense.__class__}" ) activations_logger.info( f"post mixer dense and residual: {self.out_filter_dense(z) + u} {(self.out_filter_dense(z) + u).min()} {(self.out_filter_dense(z) + u).max()}" ) activations_logger.info( f"post mixer dense: {self.out_filter_dense(z)} {self.out_filter_dense(z).min()} {self.out_filter_dense(z).max()}" ) activations_logger.info(f"post mixer: {z} {z.min()} {z.max()}") if self.ground_truth_activations_path: z_savanna = torch.load(f"{self.ground_truth_activations_path}/post_filter_{self.layer_idx}.pt") activation_diff = (z - z_savanna.squeeze()).abs() activations_logger.info( f"post filter activation_diff: {activation_diff.max()}, {activation_diff.mean()}" ) z_savanna = torch.load(f"{self.ground_truth_activations_path}/post_out_proj_{self.layer_idx}.pt") z_ = F.linear(z, self.out_filter_dense.weight) activation_diff = (z_ - z_savanna.squeeze()).abs() activations_logger.info( f"post out proj activation_diff: {activation_diff.max()}, {activation_diff.mean()}" ) z_in = self.out_filter_dense(z) + u # if self.layer_idx == 0: # z_in = z_savanna.squeeze() + u + self.out_filter_dense.bias if type(padding_mask) == torch.Tensor: # guard against bias z_in = z_in * padding_mask[..., None] y = self.res_mlp_norm(z_in) return y, inference_params def get_block(config, layer_idx, flash_fft=None): if layer_idx in config.attn_layer_idxs: return AttentionBlock(config, layer_idx) elif layer_idx in config.hcl_layer_idxs: block = ParallelGatedConvBlock(config, layer_idx) if config.get("use_flashfft", "False"): block.filter.fftconv_fn = flash_fft return block elif layer_idx in config.hcm_layer_idxs: block = ParallelGatedConvBlock( config, layer_idx, hyena_filter_groups=config.hcm_filter_groups, fir_inner_filter_length=config.hcm_filter_length, ) return block elif layer_idx in config.hcs_layer_idxs: block = ParallelGatedConvBlock( config, layer_idx, hyena_filter_groups=config.hcs_filter_groups, fir_inner_filter_length=config.hcs_filter_length, ) return block else: raise NotImplementedError class StripedHyena(nn.Module): def __init__(self, config): super().__init__() fixup_te_workspace() # Workaround global cublas workspaces in TE self.config = config self.print_activations = config.get("print_activations", False) if self.print_activations: enable_activations_logging() self.logger = logging.getLogger(self.__class__.__name__) self.ground_truth_activations_path = config.get("ground_truth_activations_path", None) self.logger.info(f"Initializing StripedHyena with config: {config}") with torch.device("cuda:0" if torch.cuda.is_available() else "cpu"): self.embedding_layer = VocabParallelEmbedding(config) if config.get("use_flashfft", "True"): try: from flashfftconv import FlashFFTConv self.flash_fft = FlashFFTConv(config.seqlen, dtype=torch.bfloat16) except ImportError: "flashfftconv not installed" else: self.flash_fft = None if not self.config.get('evo2_style_activations', False): self.logger.warning( "⚠️ Not using Evo2 style activations ⚠️\n" "⚠️ Set 'evo2_style_activations: True' in config if you are using Evo 2 checkpoints ⚠️" ) self.logger.info(f"Initializing {config.num_layers} blocks...") self.blocks = nn.ModuleList() self.block_idx_to_device = {} # Calculate layers per GPU # num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 1 # layers_per_gpu = math.ceil(config.num_layers / num_gpus) # self.logger.info(f"Distributing across {num_gpus} GPUs, approximately {layers_per_gpu} layers per GPU") for layer_idx in tqdm(range(config.num_layers)): # Determine which GPU should handle this layer # device_idx = min(layer_idx // layers_per_gpu, num_gpus - 1) # device = f"cuda:{device_idx}" if torch.cuda.is_available() else "cpu" # with torch.device(device): # TELinear uses `device="cuda"` device to allocate empty bias # tensor. This makes sure that the empty tensor is allocated on the # correct device. (torch.device(), unlike torch.cuda.device(), # doesn't override current CUDA device.) # with torch.cuda.device(device): block = get_block(config, layer_idx, flash_fft=self.flash_fft) # move_to_device(block, device) self.blocks.append(block) # self.block_idx_to_device[layer_idx] = device # self.logger.info(f"Assigned {layer_idx=} to {device=}") # self.logger.info( # f"Parameter count for block {layer_idx}: {sum(p.numel() for p in self.blocks[-1].parameters())}" # ) # with torch.device(self.block_idx_to_device[0]): # with torch.cuda.device(self.block_idx_to_device[0]): self.norm = RMSNorm(config) if config.get("final_norm", True) else None if config.tie_embeddings: # Lambda usage is to be able to use forward() on caller side, which in # turn is needed for PyTorch hooks to work properly. self.unembed = Lambda(self.embedding_layer.unembed) else: if config.tie_embeddings: # Technically we can support this mode, just need to # copy tensors across GPUs then. But let's implement it # once/if needed. self.logger.info("Ignoring tie_embeddings for now.") self.unembed = VocabParallelUnembedding(config) self.logger.info("Initialized model") def forward(self, x, inference_params_dict=None, padding_mask=None): L = x.shape[1] if self.print_activations: activations_logger.info(f"pre embedding: {x}, {x.min()}, {x.max()}") x = self.embedding_layer(x) if self.print_activations: activations_logger.info(f"post embedding: {x}, {x.min()}, {x.max()}") if inference_params_dict is not None: x, inference_params_dict_out = self.stateful_forward( x, inference_params_dict=inference_params_dict, ) else: x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask) if self.print_activations: activations_logger.info(f"pre norm: {x}, {x.min()}, {x.max()}") # By convention, this line used to return results on the first device. # Since we're systematically ridding this code of custom device # management, it's no longer needed. # x = x.to(self.block_idx_to_device[0]) x = self.norm(x) if self.print_activations: activations_logger.info(f"post norm: {x}, {x.min()}, {x.max(), {self.norm.scale}}") x = self.unembed(x) return x, inference_params_dict_out def block_idx_to_name(self, block_idx): if block_idx in self.config.attn_layer_idxs: return "mha" elif block_idx in self.config.hcl_layer_idxs: return "hcl" elif block_idx in self.config.hcm_layer_idxs: return "hcm" elif block_idx in self.config.hcs_layer_idxs: return "hcs" else: raise ValueError(f"Block index {block_idx} not found") def cross_device_transfer(self, x, block_idx): if self.block_idx_to_device[max(block_idx - 1, 0)] != self.block_idx_to_device[block_idx]: x = x.to(self.block_idx_to_device[block_idx]) return x def stateful_forward(self, x, inference_params_dict=None): for block_idx, block in enumerate(self.blocks): inference_params = inference_params_dict[self.block_idx_to_name(block_idx)] if self.print_activations: activations_logger.info(f"pre block {block_idx}: {x}, {x.min()}, {x.max()} {block.__class__}") if self.ground_truth_activations_path: x_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_block_{block_idx}.pt") activation_diff = (x - x_savanna.squeeze()).abs() activations_logger.info( f"pre block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}" ) # Ishan: commenting out now-redundant manual device management # x = self.cross_device_transfer(x, block_idx) x, _ = block(x, inference_params=inference_params) if self.print_activations: activations_logger.info(f"post block {block_idx}: {x}, {x.min()}, {x.max()}") if self.ground_truth_activations_path: x_savanna = torch.load(f"{self.ground_truth_activations_path}/post_block_{block_idx}.pt") activation_diff = (x - x_savanna.squeeze()).abs() activations_logger.info( f"post block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}" ) return x, inference_params_dict def stateless_forward(self, x, padding_mask=None): if type(padding_mask) == torch.Tensor: x = x * padding_mask[..., None] for block_idx, block in enumerate(self.blocks): if self.print_activations: activations_logger.info(f"pre block {block_idx}: {x}, {x.min()}, {x.max()} {block.__class__}") if self.ground_truth_activations_path: x_savanna = torch.load(f"{self.ground_truth_activations_path}/pre_block_{block_idx}.pt") activation_diff = (x - x_savanna.squeeze()).abs() activations_logger.info( f"pre block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}" ) # Ishan: commenting out now-redundant manual device management # x = self.cross_device_transfer(x, block_idx) x, _ = block(x, inference_params=None, padding_mask=padding_mask) if self.print_activations: activations_logger.info(f"post block {block_idx}: {x}, {x.min()}, {x.max()}") if self.ground_truth_activations_path: x_savanna = torch.load(f"{self.ground_truth_activations_path}/post_block_{block_idx}.pt") activation_diff = (x - x_savanna.squeeze()).abs() activations_logger.info( f"post block {block_idx} activation_diff: {activation_diff.max()}, {activation_diff.mean()}" ) return x, None def initialize_inference_params(self, max_seqlen=None): ## Input seqlen takes priority over config! ## WARNING: This avoids potential errors but means the model can be used beyond length it was trained at config_seqlen = self.config.get("max_seqlen", None) if config_seqlen is None: print("No max_seqlen found in config!!! using default value of 8192") config_seqlen = 8192 new_max_seqlen = max_seqlen if max_seqlen != None else config_seqlen # self.config["max_seqlen"] = new_max_seqlen ## Note: changing the stored config max_seqlen will change the max_seqlen used in flash attention, leading to minor logit differences print(f"Initializing inference params with max_seqlen={new_max_seqlen}") inference_params_dict = { "mha": InferenceParams( max_seqlen=new_max_seqlen, max_batch_size=self.config.get("max_batch_size", 1), seqlen_offset=0, ), "hcl": HyenaCascadeIIRInferenceParams( fir_filter_length=self.config.short_filter_length, state_dim=self.config.state_size, seqlen_offset=0, ), "hcm": HyenaCascadeFIRInferenceParams( fir_filter_length=self.config.short_filter_length, fir_inner_filter_length=self.config.hcm_filter_length, seqlen_offset=0, ), "hcs": HyenaCascadeFIRInferenceParams( fir_filter_length=self.config.short_filter_length, fir_inner_filter_length=self.config.hcs_filter_length, seqlen_offset=0, ), } return inference_params_dict def precompute_filters(self, L, device): for block_idx, block in enumerate(self.blocks): if type(block) == ParallelGatedConvBlock: if type(block.filter) == HyenaCascade: L = block.filter.long_fir_threshold or L print_rank_0(f"Precomputing filters, L={L}...") filter_dtype = torch.float16 if L >= 2048 else torch.float32 block.filter._set_time(L, device) residues, poles = ( block.filter.residues.to(torch.float16), block.filter.poles.to(torch.float16), ) block.filter.h = (residues * poles**block.filter.t).real.sum(1)[None] block.filter.h = block.filter.h.to(dtype=filter_dtype) def load_poles_residues(self, path): "Load different poles and residues for each layer." for block_idx, block in enumerate(self.blocks): if type(block) == ParallelGatedConvBlock: if type(block.filter) == HyenaCascade: self.logger.info(f"Loading approximatepoles and residues for block {block_idx}") poles = torch.load(path + f"/approx_poles_{block_idx+1}.pt", map_location="cpu") poles = torch.view_as_real(poles) residues = torch.load(path + f"/approx_residues_{block_idx+1}.pt", map_location="cpu") residues = torch.view_as_real(residues) poles = poles.permute(1, 0, 2).unsqueeze(-2) residues = residues.permute(1, 0, 2).unsqueeze(-2) block.filter.poles = nn.Parameter(poles) block.filter.residues = nn.Parameter(residues) def custom_load_state_dict(self, state_dict, strict=True): """ Post-processes the state_dict to convert savanna checkpoints to vortex checkpoints. """ self.logger.debug(f"Loading state dict: {state_dict}, (ignoring extra keys) with strict: {strict}") model_dict = self.state_dict() # Find keys that are in model_dict but not in state_dict missing_in_state_dict = model_dict.keys() - state_dict.keys() # Find keys that are in state_dict but not in model_dict extra_in_state_dict = state_dict.keys() - model_dict.keys() if missing_in_state_dict: print(f"Keys missing in state_dict: {missing_in_state_dict}") if extra_in_state_dict: print(f"Extra keys in state_dict: {extra_in_state_dict}") filtered_dict = {k: v for k, v in state_dict.items() if k in model_dict} if all("._extra_state" in k for k in missing_in_state_dict): self.logger.info("Checkpoint has no FP8 extra state, will be using initial state.") for k in missing_in_state_dict: filtered_dict[k] = None self.load_state_dict(filtered_dict, strict=strict) fixup_fp8_extra_states(self) if self.config.get("column_split", True): self.logger.info("Adjusting Wqkv for column split (permuting rows)") for layer_idx, block in enumerate(self.blocks): if type(block) == AttentionBlock: target_device = block.inner_mha_cls.Wqkv.weight.device Wqkv = state_dict[f"blocks.{layer_idx}.inner_mha_cls.Wqkv.weight"] try: bias = state_dict[f"blocks.{layer_idx}.inner_mha_cls.Wqkv.bias"] except: bias = None size_att_head = block.hidden_size_per_attention_head Wqkv = Wqkv.permute(1, 0) Wqkv = Wqkv.reshape(block.hidden_size, block.num_attention_heads, 3, size_att_head) Wq, Wk, Wv = Wqkv.unbind(dim=-2) Wq = Wq.reshape(block.hidden_size, -1) Wk = Wk.reshape(block.hidden_size, -1) Wv = Wv.reshape(block.hidden_size, -1) Wqkv = torch.cat([Wq, Wk, Wv], dim=-1) Wqkv = Wqkv.permute(1, 0) # Single device transfer at the end block.inner_mha_cls.Wqkv.weight.data = Wqkv.to(target_device) if bias is not None: bias = bias.cpu() # Process on CPU bias = bias.reshape(block.num_attention_heads, 3, size_att_head) bias_q, bias_k, bias_v = bias.unbind(dim=-2) bias_q = bias_q.reshape(block.hidden_size) bias_k = bias_k.reshape(block.hidden_size) bias_v = bias_v.reshape(block.hidden_size) bias = torch.cat([bias_q, bias_k, bias_v], dim=0) try: block.inner_mha_cls.Wqkv.bias.data = bias.to(target_device) except: pass def to_bfloat16_except_pr_lc(self, to_float32=False): """Convert all parameters to bfloat16 except for the poles and residues. Particularly important for longer prompts. """ excluded_shapes = [(4096, 1, 128)] for k, p in self.named_parameters(): if "projections" not in k: # avoid TE linears if "log_poles" not in k and "residues" not in k and p.shape not in excluded_shapes: p.data = p.data.to(torch.bfloat16) else: if to_float32: p.data = p.data.to(torch.float32) for k, b in self.named_buffers(): if "inv_freq" in k: if to_float32: b.data = b.data.to(torch.float32)