import math import os from dataclasses import dataclass from typing import Optional from huggingface_hub import hf_hub_download import lm_eval as evaluator import torch import torch.nn as nn import torch.nn.functional as F from safetensors.torch import load_file from torchtune.modules import RotaryPositionalEmbeddings from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, PreTrainedModel, PretrainedConfig, ) from transformers.modeling_outputs import CausalLMOutput try: from flashfftconv import FlashFFTConv flash_fft_available = True except ImportError as e: print(f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation.") flash_fft_available = False try: from flash_attn import flash_attn_func except ImportError as e: print(f"Unable to import Triton-based flash attention: {e}. No alternative currently available.") os.environ["HF_ALLOW_CODE_EVAL"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" loss_fn = nn.CrossEntropyLoss() def nearest_power_of_two(n: int, round_up: bool = False) -> int: if n <= 1: return 1 return 1 << ((n - 1).bit_length() if round_up else (n).bit_length() - 1) def find_multiple(n: int, k: int) -> int: if n % k == 0: return n return n + k - (n % k) def get_hankel(seq_len: int, use_hankel_L: bool = False) -> torch.Tensor: entries = torch.arange(1, seq_len + 1, dtype=torch.float64) i_plus_j = entries.reshape(-1, 1) + entries.reshape(1, -1) if use_hankel_L: sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0 denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0) Z = sgn * (8.0 / denom) elif not use_hankel_L: Z = 2.0 / (i_plus_j**3 - i_plus_j) else: raise ValueError("use_hankel_L must be a boolean") return Z def get_spectral_filters( seq_len: int, K: int, use_hankel_L: bool = False, device: torch.device = None, dtype: torch.dtype = torch.float64, ) -> torch.Tensor: Z = get_hankel(seq_len, use_hankel_L).to(device) sigma, phi = torch.linalg.eigh(Z) sigma_k, phi_k = sigma[-K:], phi[:, -K:] phi_k *= sigma_k**0.25 return phi_k.to(device=device, dtype=dtype) class BaseConfigForCausalLM(PretrainedConfig): """Base PretrainedConfig class to be decorated with dataclass""" model_type = "base_model" def __init__(self, **kwargs): super().__init__(**kwargs) @dataclass class FlashSTUConfig(BaseConfigForCausalLM): model_type = "FlashSTU" # Define fields with defaults (as before) bsz: int = 1 dim: int = 1024 r: int = 1024 num_heads: int = 12 num_local_heads: Optional[int] = -1 num_layers: int = 12 seq_len: int = 4096 n: int = 8191 window_size: int = 2048 vocab_size: int = 200064 inter_dim: Optional[int] = 3072 mlp_scale: Optional[float] = 12.0 weight_tying: Optional[bool] = True bias: Optional[bool] = False rope_theta: Optional[float] = 10000.0 softcap: Optional[float] = 50.0 num_eigh: Optional[int] = 24 use_hankel_L: Optional[bool] = False use_flash_fft: Optional[bool] = True use_tensordot: Optional[bool] = True use_attn: Optional[bool] = True use_alibi: Optional[bool] = False torch_dtype: torch.dtype = torch.bfloat16 device: torch.device = None # Explicit __init__ to handle **kwargs for PretrainedConfig compatibility def __init__( self, bsz: int = 1, dim: int = 1024, r: int = 1024, num_heads: int = 12, num_local_heads: Optional[int] = -1, num_layers: int = 12, seq_len: int = 4096, n: int = 8191, window_size: int = 2048, vocab_size: int = 200064, inter_dim: Optional[int] = 3072, mlp_scale: Optional[float] = 12.0, weight_tying: Optional[bool] = True, bias: Optional[bool] = False, rope_theta: Optional[float] = 10000.0, softcap: Optional[float] = 50.0, num_eigh: Optional[int] = 24, use_hankel_L: Optional[bool] = False, use_flash_fft: Optional[bool] = True, use_tensordot: Optional[bool] = True, use_attn: Optional[bool] = True, use_alibi: Optional[bool] = False, torch_dtype: torch.dtype = torch.bfloat16, device: torch.device = None, **kwargs, # Catch extra arguments like model_type ): super().__init__(**kwargs) # Pass kwargs to parent __init__ # Assign fields from arguments self.bsz = bsz self.dim = dim self.r = r self.num_heads = num_heads self.num_local_heads = num_local_heads self.num_layers = num_layers self.seq_len = seq_len self.n = n self.window_size = window_size self.vocab_size = vocab_size self.inter_dim = inter_dim self.mlp_scale = mlp_scale self.weight_tying = weight_tying self.bias = bias self.rope_theta = rope_theta self.softcap = softcap self.num_eigh = num_eigh self.use_hankel_L = use_hankel_L self.use_flash_fft = use_flash_fft self.use_tensordot = use_tensordot self.use_attn = use_attn self.use_alibi = use_alibi self.torch_dtype = torch_dtype self.device = device # Explicitly call __post_init__ if defined and needed self.__post_init__() def __post_init__(self): # Ensure torch_dtype is a torch.dtype object, not a string if isinstance(self.torch_dtype, str): try: self.torch_dtype = getattr(torch, self.torch_dtype) except AttributeError: raise ValueError(f"Invalid torch_dtype string: {self.torch_dtype}") if self.num_local_heads == -1: self.num_local_heads = self.num_heads if self.inter_dim is None: hidden_dim = self.mlp_scale * self.dim num_hidden = int(2 * hidden_dim / 3) self.inter_dim = find_multiple(num_hidden, 256) self.head_dim = self.dim // self.num_heads @classmethod def from_name(cls, name: str): # presets = { # "tiny": dict(dim=128, num_heads=4, num_layers=2, vocab_size=10000), # "small": dict(dim=256, num_heads=8, num_layers=4, vocab_size=20000), # "gpt2-small": dict(dim=768, num_heads=12, num_layers=12, vocab_size=50257), # # add more as needed # } # if name not in presets: # raise ValueError(f"Unknown model config name: {name}") # return cls(**presets[name]) print("Not yet implemented") pass class MLP(nn.Module): def __init__(self, config: FlashSTUConfig) -> None: super().__init__() self.w1 = nn.Linear(config.dim, config.inter_dim) self.w2 = nn.Linear(config.inter_dim, config.dim) self.w2.SCALE_INIT = 1 def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.gelu(self.w1(x), approximate="tanh")) class SlidingWindowAttention(nn.Module): def __init__(self, config): super().__init__() self.wq = nn.Linear(config.dim, config.dim) self.wk = nn.Linear(config.dim, config.dim) self.wv = nn.Linear(config.dim, config.dim) self.wo = nn.Linear(config.dim, config.dim) self.wo.SCALE_INIT = 1 self.dim = config.dim self.head_dim = config.head_dim self.num_heads = config.num_heads self.num_local_heads = config.num_local_heads self.window_size = config.window_size self.softcap = config.softcap self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None self.rotary_emb = RotaryPositionalEmbeddings( dim=self.head_dim, max_seq_len=config.seq_len, base=config.rope_theta, ) def forward(self, x): bsz, seq_len, dim = x.shape q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.view(bsz, seq_len, self.num_heads, self.head_dim) k = k.view(bsz, seq_len, self.num_local_heads, self.head_dim) v = v.view(bsz, seq_len, self.num_local_heads, self.head_dim) if self.alibi_slopes is None: q, k = self.rotary_emb(q), self.rotary_emb(k) y = flash_attn_func( q=q, k=k, v=v, causal=True, window_size=(self.window_size, 0), # softcap=self.softcap, alibi_slopes=self.alibi_slopes, ) out = y.reshape(bsz, seq_len, -1) out = self.wo(out) return out def _generate_slopes(self, n: int): start = 2 ** (-(2 ** -(math.log2(n) - 3))) return [start * (start**i) for i in range(n)] def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25): # If n_heads is a power of 2, generate slopes directly if math.log2(num_heads).is_integer(): slopes = self._generate_slopes(num_heads) else: # Get slopes for the nearest power of two n = nearest_power_of_two(num_heads, round_up=False) slopes_power_of_two = self._generate_slopes(n) # Generate extra slopes extra_slopes = self._generate_slopes(2 * n) extra_slopes_trunc = extra_slopes[0::2][: num_heads - n] slopes = slopes_power_of_two + extra_slopes_trunc slopes = torch.tensor(slopes, device=torch.device("cuda")) # FA ALiBi must be on CUDA slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017 return slopes class STU(nn.Module): def __init__(self, config): super().__init__() # Set at top-level post- model init self.stu_filters = None self.stu_filters_fft = None # TODO: Optimization: Precompute FFT of filters self.n = config.n self.num_eigh = config.num_eigh self.d_in = config.dim self.d_out = config.dim self.r = config.r self.use_hankel_L = config.use_hankel_L self.use_tensordot = config.use_tensordot self.flash_fft = ( FlashFFTConv(self.n, dtype=torch.bfloat16) if config.use_flash_fft and flash_fft_available else None ) # TODO: Add dimensionality reduction `r` here. if self.use_tensordot: self.M_inputs = nn.Parameter(torch.zeros(self.d_in, self.d_out)) self.M_filters = nn.Parameter(torch.zeros(self.num_eigh, self.d_in)) else: self.M_phi_plus = nn.Parameter(torch.zeros(self.num_eigh, self.d_in, self.d_out)) if not self.use_hankel_L: self.M_phi_minus = nn.Parameter(torch.zeros(self.num_eigh, self.d_in, self.d_out)) def forward(self, x: torch.Tensor) -> torch.Tensor: B, L, D = x.shape if self.use_tensordot: # Contract inputs and filters over (K, D) dims first, then convolve x_proj = x @ self.M_inputs phi_proj = self.stu_filters @ self.M_filters if self.flash_fft: spectral_plus, spectral_minus = self.flash_conv(x_proj, phi_proj, self.flash_fft, self.use_tensordot) else: spectral_plus, spectral_minus = self.conv(x_proj, phi_proj, self.n, self.use_tensordot) else: # Convolve inputs and filters first, then contract over (K, D) dims if self.flash_fft: U_plus, U_minus = self.flash_conv(x, self.stu_filters, self.flash_fft, self.use_tensordot) else: U_plus, U_minus = self.conv(x, self.stu_filters, self.n, self.use_tensordot) B, L, K, D = U_plus.shape spectral_plus = U_plus.reshape(B, L, K * D) @ self.M_phi_plus.reshape(K * D, self.d_out) if not self.use_hankel_L: spectral_minus = U_minus.reshape(B, L, K * D) @ self.M_phi_minus.reshape(K * D, self.d_out) out = spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus return out def conv( self, u: torch.Tensor, v: torch.Tensor, n: int, use_tensordot: bool = True ) -> tuple[torch.Tensor, torch.Tensor]: """ Performs convolution via FFT with causal alignment using a negative featurization. The input tensor u is modulated by an alternating sign tensor (sgn) that multiplies every other time step by -1. This "negative featurization" modulates the phase so that in this implementation the correct causal output is obtained by simply slicing the first L elements (i.e. [:seq_len]). Note: Using a conventional slice [seq_len-1:2*seq_len-1] would yield a flipped alignment, resulting in leakage. Args: u: Input tensor of shape (bsz, seq_len, d_in). v: Kernel tensor; expected shape is (seq_len, d_out) if use_tensordot is True. n: FFT length (typically set to 2*seq_len - 1 for linear convolution with implicit right zero-padding). use_tensordot: Boolean flag to control kernel reshaping. Returns: A tuple (U_plus, U_minus) where: - U_plus is the primary convolution output. - U_minus is the secondary output, corrected by the sign tensor. """ bsz, seq_len, d_in = u.shape sgn = torch.full((1, seq_len, 1), 1, device=u.device) sgn[:, 1::2] *= -1 # Apply negative featurization: multiply every other element by -1. if use_tensordot: _, d_out = v.shape v = v.view(1, -1, d_out, 1).to(torch.float32).contiguous() else: _, K = v.shape sgn = sgn.unsqueeze(-1) v = v.view(1, -1, K, 1, 1).to(torch.float32).contiguous() # (bsz, seq_len, K, d_in, stack) u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in) # Cast kernel to float32 for FFT v_fft = torch.fft.rfft(v.to(torch.float32), n=n, dim=1) U = torch.stack([u, u * sgn], dim=-1).to(torch.float32).contiguous() # Cast input stack to float32 for FFT U_fft = torch.fft.rfft(U.to(torch.float32), n=n, dim=1) # Slicing the first seq_len outputs yields the proper causal convolution given the negative modulation. # Perform convolution in float32 and cast back U_conv = torch.fft.irfft(v_fft * U_fft, n=n, dim=1)[:, :seq_len].to(u.dtype) U_plus, U_minus = torch.unbind(U_conv, dim=-1) U_minus = U_minus * sgn return U_plus.type_as(u), U_minus.type_as(u) def flash_conv( self, u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_tensordot: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Flash FFT convolution. Args: u (torch.Tensor): Input tensor of shape `(B, L, d_in)`, where: - `B` is the batch size, - `L` is the sequence length, - `d_in` is the input dimension. v (torch.Tensor): Filter tensor of shape `(K, d_in)`, where: - `K` is the number of filters, - `d_in` is the input dimension. flash_fft (FlashFFTConv): An instance of the FlashFFTConv module, used to perform the convolution. use_tensordot (bool, optional): If `True`, performs the tensordot approximation (default is `True`). Returns: tuple[torch.Tensor, torch.Tensor]: A tuple `(U_plus, U_minus)`: - `U_plus`: Convolved output tensor with positive eigenvalues. - Shape depends on `use_tensordot`: - If `use_tensordot=True`: `(B, L, d_in)` - If `use_tensordot=False`: `(B, L, K, d_in)` - `U_minus`: Convolved output tensor with negative eigenvalues. - Shape depends on `use_tensordot`: - If `use_tensordot=True`: `(B, L, d_in)` - If `use_tensordot=False`: `(B, L, K, d_in)` Raises: ValueError: If the input tensor shapes do not conform to the expected dimensions. Example: >>> u = torch.randn(4, 16, 32) # (B, L, d_in) >>> v = torch.randn(8, 32) # (K, d_in) >>> flash_fft = FlashFFTConv(n=16, dtype=torch.float32) >>> U_plus, U_minus = flash_convolve(u, v, flash_fft, use_tensordot=True) >>> print(U_plus.shape, U_minus.shape) torch.Size([4, 16, 32]) torch.Size([4, 16, 32]) """ bsz, seq_len, d_in = u.shape _, K = v.shape padded_len = nearest_power_of_two(seq_len, round_up=True) pad_len = padded_len - seq_len sgn = torch.full((1, 1, padded_len), 1, device=u.device) sgn[:, :, 1::2] = -1 if use_tensordot: u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16) v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32) u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len) else: u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).repeat_interleave(K, dim=1) v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).repeat(d_in, 1) u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len) # Ensure inputs to flash_fft are bfloat16 (input) and float32 (kernel) U_conv = flash_fft(u_conv.to(torch.bfloat16), v_padded.to(torch.float32)) # Trim the output back to the original sequence length U_conv = U_conv[..., :seq_len] u_plus, u_minus = torch.chunk(U_conv, 2, dim=0) if use_tensordot: u_minus = u_minus * sgn[:, :, :seq_len] U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2) else: sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2) U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn return U_plus, U_minus class SlidingWindowAttentionLayer(nn.Module): def __init__(self, config): super().__init__() self.swa_norm = nn.LayerNorm(config.dim) self.swa = SlidingWindowAttention(config) self.mlp_norm = nn.LayerNorm(config.dim) self.mlp = MLP(config) def forward(self, x): x = x + self.swa(self.swa_norm(x)) x = x + self.mlp(self.mlp_norm(x)) return x class STULayer(nn.Module): def __init__(self, config): super().__init__() self.stu_norm = nn.LayerNorm(config.dim) self.stu = STU(config) self.mlp_norm = nn.LayerNorm(config.dim) self.mlp = MLP(config) def forward(self, x): x = x + self.stu(self.stu_norm(x)) x = x + self.mlp(self.mlp_norm(x)) return x class FlashSTU(nn.Module): def __init__(self, config): super().__init__() self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.dim) self.layers = nn.ModuleList() for layer_idx in range(config.num_layers): # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887 if layer_idx % 2 == 0: self.layers.append(STULayer(config)) else: self.layers.append(SlidingWindowAttentionLayer(config)) if config.use_attn else self.layers.append( STULayer(config) ) self.norm_f = nn.LayerNorm(config.dim) self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) if self.config.weight_tying: self.tok_emb.weight = self.lm_head.weight self.std = self.config.dim**-0.5 def init_weights(self, module): std = self.std if isinstance(module, nn.Linear): if hasattr(module, "SCALE_INIT"): std *= (2 * self.config.num_layers) ** -0.5 torch.nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=std) def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs) -> CausalLMOutput: x = self.tok_emb(input_ids) for layer in self.layers: x = layer(x) x = self.norm_f(x) logits = self.lm_head(x) loss = None if labels is not None: loss = loss_fn(logits.flatten(0, 1), labels.flatten(0, 1)) return CausalLMOutput( loss=loss, logits=logits, ) def setup_filters( self, spectral_filters: torch.Tensor, spectral_filters_fft: torch.Tensor, ): for layer in self.layers: if isinstance(layer, STULayer): layer.stu.stu_filters = spectral_filters layer.stu.stu_filters_fft = spectral_filters_fft def get_num_params(self): """ Return the number of parameters in the model. For non-embedding count (default), the position embeddings get subtracted. """ n_params = sum(p.numel() for p in self.parameters()) return n_params def create_base_model_components(model_name_or_path=None, **kwargs): """Create config and filters needed for model initialization""" if model_name_or_path is not None: config = FlashSTUConfig.from_pretrained(model_name_or_path, **kwargs) else: config = FlashSTUConfig(**kwargs) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") filters = get_spectral_filters( seq_len=config.seq_len, K=config.num_eigh, use_hankel_L=config.use_hankel_L, device=device, dtype=config.torch_dtype, ) assert filters.dtype == config.torch_dtype, f"filters dtype is {filters.dtype}, expected {config.torch_dtype}" return config, filters class FlashSTUForCausalLM(PreTrainedModel): """Thin wrapper to comply with HuggingFace's expected interface""" config_class = FlashSTUConfig base_model_prefix = "FlashSTU" def __init__(self, config): super().__init__(config) self.flash_stu = FlashSTU(config) self.flash_stu.apply(self.flash_stu.init_weights) device = ( config.device if config.device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) torch_dtype = config.torch_dtype # Assumes __post_init__ already converted it to torch.dtype spectral_filters = get_spectral_filters( seq_len=config.seq_len, K=config.num_eigh, use_hankel_L=config.use_hankel_L, device=device, # Note: get_spectral_filters returns float64, cast later ) spectral_filters_fft = torch.fft.rfft(spectral_filters, n=config.n, dim=1) # Setup filters in the model, casting to the target dtype self.flash_stu.setup_filters( spectral_filters.to(dtype=torch_dtype), spectral_filters_fft.to(dtype=torch_dtype) ) # Note: Moving the entire model to device happens later, after loading weights. def forward( self, input_ids: torch.Tensor, labels: torch.Tensor = None, attention_mask: torch.Tensor = None, **kwargs ) -> CausalLMOutput: outputs = self.flash_stu(input_ids, labels=labels, **kwargs) return outputs def generate( self, input_ids: torch.Tensor, max_length: int = 32, num_return_sequences: int = 4, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.2, seed: int = 42, ) -> torch.Tensor: """Generate text using top-k and nucleus sampling with temperature and repetition penalty. Args: input_ids: Input token ids of shape (batch_size, seq_len) max_length: Maximum length of generated sequence num_return_sequences: Number of sequences to generate per input temperature: Sampling temperature. Higher = more random, lower = more focused top_k: Number of highest probability tokens to keep for top-k sampling top_p: Cumulative probability cutoff for nucleus sampling repetition_penalty: Penalty factor for repeating tokens. 1.0 = no penalty seed: Random seed for reproducibility Returns: Generated token ids of shape (num_return_sequences, max_length) """ self.eval() # Set to eval mode device = input_ids.device # Expand input for multiple sequences input_ids = input_ids.repeat(num_return_sequences, 1) generated = input_ids # Set up generator for reproducible sampling sample_rng = torch.Generator(device=device) sample_rng.manual_seed(seed) # Generate tokens until we reach max_length with torch.no_grad(): while generated.size(1) < max_length: # Get logits for next token outputs = self.flash_stu(generated) next_token_logits = outputs.logits[:, -1, :] # Apply repetition penalty if repetition_penalty != 1.0: for i in range(generated.shape[0]): for token in generated[i]: if token in next_token_logits[i]: next_token_logits[i, token] /= repetition_penalty # Apply temperature if temperature != 1.0: next_token_logits = next_token_logits / temperature # Get probabilities probs = torch.nn.functional.softmax(next_token_logits, dim=-1) # Top-k sampling if top_k > 0: indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None] probs[indices_to_remove] = 0 # Nucleus (top-p) sampling if top_p < 1.0: sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) probs[indices_to_remove] = 0 # Renormalize probabilities probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-8) # Sample next token next_token = torch.multinomial(probs, num_samples=1, generator=sample_rng) # Append to generated sequence generated = torch.cat([generated, next_token], dim=1) return generated def get_num_params(self): return self.flash_stu.get_num_params() @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # Get config and create model config, _ = create_base_model_components(pretrained_model_name_or_path, **kwargs) model = cls(config) # Download safetensors file from hub weights_path = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="model.safetensors", cache_dir=kwargs.get("cache_dir"), force_download=kwargs.get("force_download", False), proxies=kwargs.get("proxies", None), local_files_only=kwargs.get("local_files_only", False), use_auth_token=kwargs.get("use_auth_token", None), revision=kwargs.get("revision", None), subfolder=kwargs.get("subfolder", ""), ) state_dict = load_file(weights_path) # Reconstruct weight tying for tok_emb and lm_head tok_emb_key = "tok_emb.weight" lm_head_key = "lm_head.weight" tok_emb_present = tok_emb_key in state_dict lm_head_present = lm_head_key in state_dict if tok_emb_present and not lm_head_present: print(f"Reconstructing weight tying: Linking missing '{lm_head_key}' to existing '{tok_emb_key}'") state_dict[lm_head_key] = state_dict[tok_emb_key] elif lm_head_present and not tok_emb_present: print(f"Reconstructing weight tying: Linking missing '{tok_emb_key}' to existing '{lm_head_key}'") state_dict[tok_emb_key] = state_dict[lm_head_key] elif not tok_emb_present and not lm_head_present: # This case should ideally not happen if the file is valid print( f"Warning: Neither '{tok_emb_key}' nor '{lm_head_key}' found in state_dict. Weight tying cannot be reconstructed." ) # If both are present, assume they are loaded correctly (or were never tied) # Prepend 'flash_stu.' to all keys to match wrapper's state dict final_state_dict = {f"flash_stu.{k}": v for k, v in state_dict.items()} model.load_state_dict(final_state_dict) # Move to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device=device, dtype=torch.bfloat16) model.eval() # Print parameter count as a sanity check num_params = model.get_num_params() print(f"\nModel loaded: {pretrained_model_name_or_path}") print(f"Parameter count: {num_params / 1e6:.2f}M") return model # Create initial config and filters for registration config, filters = create_base_model_components() # Register models AutoConfig.register("FlashSTU", FlashSTUConfig) AutoModel.register(FlashSTUConfig, FlashSTU) AutoModelForCausalLM.register(FlashSTUConfig, FlashSTUForCausalLM) print("Registered FlashSTU model and configuration.") def run_model_diagnostics(model, tokenizer, device): """Run detailed diagnostics to analyze model behavior.""" print("\nRunning model diagnostics...") # Test cases of varying difficulty and length test_cases = [ # Simple completion "2 + 2 =", # Medium difficulty "The capital of France is Paris. The capital of Germany is", # Complex reasoning "If a train travels 120 kilometers in 2 hours, its average speed is", # Pattern completion "1, 2, 3, 4,", # Long context "The following is a detailed explanation of photosynthesis: Plants use sunlight to", ] with torch.no_grad(): for prompt in test_cases: print(f"\nAnalyzing prompt: {prompt}") # Tokenize tokens = tokenizer(prompt, return_tensors="pt") input_ids = tokens["input_ids"].to(device) outputs = model.flash_stu(input_ids, labels=input_ids) labels = input_ids.clone() shift_logits = outputs.logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss(reduction="none") token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view( shift_labels.size() ) # Print token-by-token analysis input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) print("\nToken-by-token loss:") for i, (token, loss) in enumerate(zip(input_tokens[1:], token_losses[0])): print(f"{token}: {loss.item():.3f}") print(f"Average loss: {token_losses.mean().item():.3f}") # Generate with different temperatures temps = [0.5, 0.7, 1.0] print("\nGeneration temperature comparison:") for temp in temps: gen_ids = model.generate( input_ids, max_length=25, num_return_sequences=1, temperature=temp, top_p=0.9, repetition_penalty=1.5, seed=42, ) gen_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True) print(f"\nTemp {temp}: {gen_text}") def validate_model_generation(): print("\nRunning generation validation test...") try: from transformers import AutoTokenizer # Load model and tokenizer # model_id = "Hazan-Lab/Flash_STU_550M" model_id = "Hazan-Lab/FlashSTU-340M-0428" model = FlashSTUForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) # Move to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device=device, dtype=torch.bfloat16) model.eval() # Print parameter count as a sanity check num_params = model.get_num_params() print(f"\nModel loaded: {model_id}") print(f"Parameter count: {num_params / 1e6:.2f}M") # Run additional diagnostics run_model_diagnostics(model, tokenizer, device) except Exception as e: print(f"\nError during validation: {str(e)}") raise # Run evaluation tasks tasks = [ # "mmlu", "hellaswag", # "piqa", # "siqa", # "boolq", # "winogrande", # "commonsense_qa", # "openbookqa", # "arc", # "arc_easy", # "arc_challenge", # "triviaqa", # "nq_open", # "humaneval", # "mbpp", # "gms8k", # "hendrycks_math", # "mathqa", # "minerva_math", # "score", # "asdiv", # "agieval", # "bigbench", ] tasks_fewshot = { "hellaswag": 0, # "mmlu": 5, # "piqa": 0, # "siqa": 0, # "boolq": 0, # "winogrande": -1, # "commonsense_qa": 7, # "openbookqa": -1, # "arc": -1, # "arc_easy": -1, # "arc_challenge": -1, # "triviaqa": 5, # "nq_open": 5, # "humaneval": -1, # "mbpp": 3, # "gms8k": -1, # "hendrycks_math": 4, # "mathqa": -1, # "minerva_math": -1, # "score": -1, # "asdiv": -1, # "agieval": -1, # "bigbench": -1, } all_results = {} # First validate generation works validate_model_generation() print("\nStarting evaluation tasks...") for task in tasks: print(f"\nEvaluating task: {task}") eval_kwargs = dict( model="hf", model_args=( # "pretrained=Hazan-Lab/Flash_STU_550M," "pretrained=Hazan-Lab/FlashSTU-340M-0428," "trust_remote_code=True," "dtype=bfloat16," "cache_dir=/scratch/gpfs/mn4560/hazan-lab/tensorized_filters/tensorized_filters/eval/cache" ), tasks=[task], batch_size="auto", device="cuda:0", ) few_shot_value = tasks_fewshot.get(task, -1) if few_shot_value != -1: eval_kwargs["num_fewshot"] = few_shot_value results = evaluator.simple_evaluate(**eval_kwargs) task_result = results["results"].get(task, {}) all_results[task] = task_result print(f"Results for {task}:") print(task_result) print("\n" + "=" * 50 + "\n") print("All Evaluation Results:") for task, result in all_results.items(): print(f"{task}: {result}")