FlashSTU-340M-0428 / evaluate.py
windsornguyen's picture
add: eval script
41c5b75 verified
import math
import os
from dataclasses import dataclass
from typing import Optional
from huggingface_hub import hf_hub_download
import lm_eval as evaluator
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from torchtune.modules import RotaryPositionalEmbeddings
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
PreTrainedModel,
PretrainedConfig,
)
from transformers.modeling_outputs import CausalLMOutput
try:
from flashfftconv import FlashFFTConv
flash_fft_available = True
except ImportError as e:
print(f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation.")
flash_fft_available = False
try:
from flash_attn import flash_attn_func
except ImportError as e:
print(f"Unable to import Triton-based flash attention: {e}. No alternative currently available.")
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
loss_fn = nn.CrossEntropyLoss()
def nearest_power_of_two(n: int, round_up: bool = False) -> int:
if n <= 1:
return 1
return 1 << ((n - 1).bit_length() if round_up else (n).bit_length() - 1)
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
def get_hankel(seq_len: int, use_hankel_L: bool = False) -> torch.Tensor:
entries = torch.arange(1, seq_len + 1, dtype=torch.float64)
i_plus_j = entries.reshape(-1, 1) + entries.reshape(1, -1)
if use_hankel_L:
sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
Z = sgn * (8.0 / denom)
elif not use_hankel_L:
Z = 2.0 / (i_plus_j**3 - i_plus_j)
else:
raise ValueError("use_hankel_L must be a boolean")
return Z
def get_spectral_filters(
seq_len: int,
K: int,
use_hankel_L: bool = False,
device: torch.device = None,
dtype: torch.dtype = torch.float64,
) -> torch.Tensor:
Z = get_hankel(seq_len, use_hankel_L).to(device)
sigma, phi = torch.linalg.eigh(Z)
sigma_k, phi_k = sigma[-K:], phi[:, -K:]
phi_k *= sigma_k**0.25
return phi_k.to(device=device, dtype=dtype)
class BaseConfigForCausalLM(PretrainedConfig):
"""Base PretrainedConfig class to be decorated with dataclass"""
model_type = "base_model"
def __init__(self, **kwargs):
super().__init__(**kwargs)
@dataclass
class FlashSTUConfig(BaseConfigForCausalLM):
model_type = "FlashSTU"
# Define fields with defaults (as before)
bsz: int = 1
dim: int = 1024
r: int = 1024
num_heads: int = 12
num_local_heads: Optional[int] = -1
num_layers: int = 12
seq_len: int = 4096
n: int = 8191
window_size: int = 2048
vocab_size: int = 200064
inter_dim: Optional[int] = 3072
mlp_scale: Optional[float] = 12.0
weight_tying: Optional[bool] = True
bias: Optional[bool] = False
rope_theta: Optional[float] = 10000.0
softcap: Optional[float] = 50.0
num_eigh: Optional[int] = 24
use_hankel_L: Optional[bool] = False
use_flash_fft: Optional[bool] = True
use_tensordot: Optional[bool] = True
use_attn: Optional[bool] = True
use_alibi: Optional[bool] = False
torch_dtype: torch.dtype = torch.bfloat16
device: torch.device = None
# Explicit __init__ to handle **kwargs for PretrainedConfig compatibility
def __init__(
self,
bsz: int = 1,
dim: int = 1024,
r: int = 1024,
num_heads: int = 12,
num_local_heads: Optional[int] = -1,
num_layers: int = 12,
seq_len: int = 4096,
n: int = 8191,
window_size: int = 2048,
vocab_size: int = 200064,
inter_dim: Optional[int] = 3072,
mlp_scale: Optional[float] = 12.0,
weight_tying: Optional[bool] = True,
bias: Optional[bool] = False,
rope_theta: Optional[float] = 10000.0,
softcap: Optional[float] = 50.0,
num_eigh: Optional[int] = 24,
use_hankel_L: Optional[bool] = False,
use_flash_fft: Optional[bool] = True,
use_tensordot: Optional[bool] = True,
use_attn: Optional[bool] = True,
use_alibi: Optional[bool] = False,
torch_dtype: torch.dtype = torch.bfloat16,
device: torch.device = None,
**kwargs, # Catch extra arguments like model_type
):
super().__init__(**kwargs) # Pass kwargs to parent __init__
# Assign fields from arguments
self.bsz = bsz
self.dim = dim
self.r = r
self.num_heads = num_heads
self.num_local_heads = num_local_heads
self.num_layers = num_layers
self.seq_len = seq_len
self.n = n
self.window_size = window_size
self.vocab_size = vocab_size
self.inter_dim = inter_dim
self.mlp_scale = mlp_scale
self.weight_tying = weight_tying
self.bias = bias
self.rope_theta = rope_theta
self.softcap = softcap
self.num_eigh = num_eigh
self.use_hankel_L = use_hankel_L
self.use_flash_fft = use_flash_fft
self.use_tensordot = use_tensordot
self.use_attn = use_attn
self.use_alibi = use_alibi
self.torch_dtype = torch_dtype
self.device = device
# Explicitly call __post_init__ if defined and needed
self.__post_init__()
def __post_init__(self):
# Ensure torch_dtype is a torch.dtype object, not a string
if isinstance(self.torch_dtype, str):
try:
self.torch_dtype = getattr(torch, self.torch_dtype)
except AttributeError:
raise ValueError(f"Invalid torch_dtype string: {self.torch_dtype}")
if self.num_local_heads == -1:
self.num_local_heads = self.num_heads
if self.inter_dim is None:
hidden_dim = self.mlp_scale * self.dim
num_hidden = int(2 * hidden_dim / 3)
self.inter_dim = find_multiple(num_hidden, 256)
self.head_dim = self.dim // self.num_heads
@classmethod
def from_name(cls, name: str):
# presets = {
# "tiny": dict(dim=128, num_heads=4, num_layers=2, vocab_size=10000),
# "small": dict(dim=256, num_heads=8, num_layers=4, vocab_size=20000),
# "gpt2-small": dict(dim=768, num_heads=12, num_layers=12, vocab_size=50257),
# # add more as needed
# }
# if name not in presets:
# raise ValueError(f"Unknown model config name: {name}")
# return cls(**presets[name])
print("Not yet implemented")
pass
class MLP(nn.Module):
def __init__(self, config: FlashSTUConfig) -> None:
super().__init__()
self.w1 = nn.Linear(config.dim, config.inter_dim)
self.w2 = nn.Linear(config.inter_dim, config.dim)
self.w2.SCALE_INIT = 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.gelu(self.w1(x), approximate="tanh"))
class SlidingWindowAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.wq = nn.Linear(config.dim, config.dim)
self.wk = nn.Linear(config.dim, config.dim)
self.wv = nn.Linear(config.dim, config.dim)
self.wo = nn.Linear(config.dim, config.dim)
self.wo.SCALE_INIT = 1
self.dim = config.dim
self.head_dim = config.head_dim
self.num_heads = config.num_heads
self.num_local_heads = config.num_local_heads
self.window_size = config.window_size
self.softcap = config.softcap
self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None
self.rotary_emb = RotaryPositionalEmbeddings(
dim=self.head_dim,
max_seq_len=config.seq_len,
base=config.rope_theta,
)
def forward(self, x):
bsz, seq_len, dim = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.view(bsz, seq_len, self.num_heads, self.head_dim)
k = k.view(bsz, seq_len, self.num_local_heads, self.head_dim)
v = v.view(bsz, seq_len, self.num_local_heads, self.head_dim)
if self.alibi_slopes is None:
q, k = self.rotary_emb(q), self.rotary_emb(k)
y = flash_attn_func(
q=q,
k=k,
v=v,
causal=True,
window_size=(self.window_size, 0),
# softcap=self.softcap,
alibi_slopes=self.alibi_slopes,
)
out = y.reshape(bsz, seq_len, -1)
out = self.wo(out)
return out
def _generate_slopes(self, n: int):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
return [start * (start**i) for i in range(n)]
def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25):
# If n_heads is a power of 2, generate slopes directly
if math.log2(num_heads).is_integer():
slopes = self._generate_slopes(num_heads)
else:
# Get slopes for the nearest power of two
n = nearest_power_of_two(num_heads, round_up=False)
slopes_power_of_two = self._generate_slopes(n)
# Generate extra slopes
extra_slopes = self._generate_slopes(2 * n)
extra_slopes_trunc = extra_slopes[0::2][: num_heads - n]
slopes = slopes_power_of_two + extra_slopes_trunc
slopes = torch.tensor(slopes, device=torch.device("cuda")) # FA ALiBi must be on CUDA
slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
return slopes
class STU(nn.Module):
def __init__(self, config):
super().__init__()
# Set at top-level post- model init
self.stu_filters = None
self.stu_filters_fft = None # TODO: Optimization: Precompute FFT of filters
self.n = config.n
self.num_eigh = config.num_eigh
self.d_in = config.dim
self.d_out = config.dim
self.r = config.r
self.use_hankel_L = config.use_hankel_L
self.use_tensordot = config.use_tensordot
self.flash_fft = (
FlashFFTConv(self.n, dtype=torch.bfloat16) if config.use_flash_fft and flash_fft_available else None
)
# TODO: Add dimensionality reduction `r` here.
if self.use_tensordot:
self.M_inputs = nn.Parameter(torch.zeros(self.d_in, self.d_out))
self.M_filters = nn.Parameter(torch.zeros(self.num_eigh, self.d_in))
else:
self.M_phi_plus = nn.Parameter(torch.zeros(self.num_eigh, self.d_in, self.d_out))
if not self.use_hankel_L:
self.M_phi_minus = nn.Parameter(torch.zeros(self.num_eigh, self.d_in, self.d_out))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, L, D = x.shape
if self.use_tensordot:
# Contract inputs and filters over (K, D) dims first, then convolve
x_proj = x @ self.M_inputs
phi_proj = self.stu_filters @ self.M_filters
if self.flash_fft:
spectral_plus, spectral_minus = self.flash_conv(x_proj, phi_proj, self.flash_fft, self.use_tensordot)
else:
spectral_plus, spectral_minus = self.conv(x_proj, phi_proj, self.n, self.use_tensordot)
else:
# Convolve inputs and filters first, then contract over (K, D) dims
if self.flash_fft:
U_plus, U_minus = self.flash_conv(x, self.stu_filters, self.flash_fft, self.use_tensordot)
else:
U_plus, U_minus = self.conv(x, self.stu_filters, self.n, self.use_tensordot)
B, L, K, D = U_plus.shape
spectral_plus = U_plus.reshape(B, L, K * D) @ self.M_phi_plus.reshape(K * D, self.d_out)
if not self.use_hankel_L:
spectral_minus = U_minus.reshape(B, L, K * D) @ self.M_phi_minus.reshape(K * D, self.d_out)
out = spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus
return out
def conv(
self, u: torch.Tensor, v: torch.Tensor, n: int, use_tensordot: bool = True
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Performs convolution via FFT with causal alignment using a negative featurization.
The input tensor u is modulated by an alternating sign tensor (sgn) that multiplies every other
time step by -1. This "negative featurization" modulates the phase so that in this implementation
the correct causal output is obtained by simply slicing the first L elements (i.e. [:seq_len]).
Note: Using a conventional slice [seq_len-1:2*seq_len-1] would yield a flipped alignment, resulting in leakage.
Args:
u: Input tensor of shape (bsz, seq_len, d_in).
v: Kernel tensor; expected shape is (seq_len, d_out) if use_tensordot is True.
n: FFT length (typically set to 2*seq_len - 1 for linear convolution with implicit right zero-padding).
use_tensordot: Boolean flag to control kernel reshaping.
Returns:
A tuple (U_plus, U_minus) where:
- U_plus is the primary convolution output.
- U_minus is the secondary output, corrected by the sign tensor.
"""
bsz, seq_len, d_in = u.shape
sgn = torch.full((1, seq_len, 1), 1, device=u.device)
sgn[:, 1::2] *= -1 # Apply negative featurization: multiply every other element by -1.
if use_tensordot:
_, d_out = v.shape
v = v.view(1, -1, d_out, 1).to(torch.float32).contiguous()
else:
_, K = v.shape
sgn = sgn.unsqueeze(-1)
v = v.view(1, -1, K, 1, 1).to(torch.float32).contiguous() # (bsz, seq_len, K, d_in, stack)
u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in)
# Cast kernel to float32 for FFT
v_fft = torch.fft.rfft(v.to(torch.float32), n=n, dim=1)
U = torch.stack([u, u * sgn], dim=-1).to(torch.float32).contiguous()
# Cast input stack to float32 for FFT
U_fft = torch.fft.rfft(U.to(torch.float32), n=n, dim=1)
# Slicing the first seq_len outputs yields the proper causal convolution given the negative modulation.
# Perform convolution in float32 and cast back
U_conv = torch.fft.irfft(v_fft * U_fft, n=n, dim=1)[:, :seq_len].to(u.dtype)
U_plus, U_minus = torch.unbind(U_conv, dim=-1)
U_minus = U_minus * sgn
return U_plus.type_as(u), U_minus.type_as(u)
def flash_conv(
self,
u: torch.Tensor,
v: torch.Tensor,
flash_fft: FlashFFTConv,
use_tensordot: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Flash FFT convolution.
Args:
u (torch.Tensor): Input tensor of shape `(B, L, d_in)`, where:
- `B` is the batch size,
- `L` is the sequence length,
- `d_in` is the input dimension.
v (torch.Tensor): Filter tensor of shape `(K, d_in)`, where:
- `K` is the number of filters,
- `d_in` is the input dimension.
flash_fft (FlashFFTConv): An instance of the FlashFFTConv module, used to perform the convolution.
use_tensordot (bool, optional): If `True`, performs the tensordot approximation (default is `True`).
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple `(U_plus, U_minus)`:
- `U_plus`: Convolved output tensor with positive eigenvalues.
- Shape depends on `use_tensordot`:
- If `use_tensordot=True`: `(B, L, d_in)`
- If `use_tensordot=False`: `(B, L, K, d_in)`
- `U_minus`: Convolved output tensor with negative eigenvalues.
- Shape depends on `use_tensordot`:
- If `use_tensordot=True`: `(B, L, d_in)`
- If `use_tensordot=False`: `(B, L, K, d_in)`
Raises:
ValueError: If the input tensor shapes do not conform to the expected dimensions.
Example:
>>> u = torch.randn(4, 16, 32) # (B, L, d_in)
>>> v = torch.randn(8, 32) # (K, d_in)
>>> flash_fft = FlashFFTConv(n=16, dtype=torch.float32)
>>> U_plus, U_minus = flash_convolve(u, v, flash_fft, use_tensordot=True)
>>> print(U_plus.shape, U_minus.shape)
torch.Size([4, 16, 32]) torch.Size([4, 16, 32])
"""
bsz, seq_len, d_in = u.shape
_, K = v.shape
padded_len = nearest_power_of_two(seq_len, round_up=True)
pad_len = padded_len - seq_len
sgn = torch.full((1, 1, padded_len), 1, device=u.device)
sgn[:, :, 1::2] = -1
if use_tensordot:
u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16)
v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32)
u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len)
else:
u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).repeat_interleave(K, dim=1)
v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).repeat(d_in, 1)
u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len)
# Ensure inputs to flash_fft are bfloat16 (input) and float32 (kernel)
U_conv = flash_fft(u_conv.to(torch.bfloat16), v_padded.to(torch.float32))
# Trim the output back to the original sequence length
U_conv = U_conv[..., :seq_len]
u_plus, u_minus = torch.chunk(U_conv, 2, dim=0)
if use_tensordot:
u_minus = u_minus * sgn[:, :, :seq_len]
U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2)
else:
sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2)
U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous()
U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn
return U_plus, U_minus
class SlidingWindowAttentionLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.swa_norm = nn.LayerNorm(config.dim)
self.swa = SlidingWindowAttention(config)
self.mlp_norm = nn.LayerNorm(config.dim)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.swa(self.swa_norm(x))
x = x + self.mlp(self.mlp_norm(x))
return x
class STULayer(nn.Module):
def __init__(self, config):
super().__init__()
self.stu_norm = nn.LayerNorm(config.dim)
self.stu = STU(config)
self.mlp_norm = nn.LayerNorm(config.dim)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.stu(self.stu_norm(x))
x = x + self.mlp(self.mlp_norm(x))
return x
class FlashSTU(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList()
for layer_idx in range(config.num_layers):
# For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
if layer_idx % 2 == 0:
self.layers.append(STULayer(config))
else:
self.layers.append(SlidingWindowAttentionLayer(config)) if config.use_attn else self.layers.append(
STULayer(config)
)
self.norm_f = nn.LayerNorm(config.dim)
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
if self.config.weight_tying:
self.tok_emb.weight = self.lm_head.weight
self.std = self.config.dim**-0.5
def init_weights(self, module):
std = self.std
if isinstance(module, nn.Linear):
if hasattr(module, "SCALE_INIT"):
std *= (2 * self.config.num_layers) ** -0.5
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs) -> CausalLMOutput:
x = self.tok_emb(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm_f(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss = loss_fn(logits.flatten(0, 1), labels.flatten(0, 1))
return CausalLMOutput(
loss=loss,
logits=logits,
)
def setup_filters(
self,
spectral_filters: torch.Tensor,
spectral_filters_fft: torch.Tensor,
):
for layer in self.layers:
if isinstance(layer, STULayer):
layer.stu.stu_filters = spectral_filters
layer.stu.stu_filters_fft = spectral_filters_fft
def get_num_params(self):
"""
Return the number of parameters in the model.
For non-embedding count (default), the position embeddings get subtracted.
"""
n_params = sum(p.numel() for p in self.parameters())
return n_params
def create_base_model_components(model_name_or_path=None, **kwargs):
"""Create config and filters needed for model initialization"""
if model_name_or_path is not None:
config = FlashSTUConfig.from_pretrained(model_name_or_path, **kwargs)
else:
config = FlashSTUConfig(**kwargs)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
filters = get_spectral_filters(
seq_len=config.seq_len,
K=config.num_eigh,
use_hankel_L=config.use_hankel_L,
device=device,
dtype=config.torch_dtype,
)
assert filters.dtype == config.torch_dtype, f"filters dtype is {filters.dtype}, expected {config.torch_dtype}"
return config, filters
class FlashSTUForCausalLM(PreTrainedModel):
"""Thin wrapper to comply with HuggingFace's expected interface"""
config_class = FlashSTUConfig
base_model_prefix = "FlashSTU"
def __init__(self, config):
super().__init__(config)
self.flash_stu = FlashSTU(config)
self.flash_stu.apply(self.flash_stu.init_weights)
device = (
config.device
if config.device is not None
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
torch_dtype = config.torch_dtype # Assumes __post_init__ already converted it to torch.dtype
spectral_filters = get_spectral_filters(
seq_len=config.seq_len,
K=config.num_eigh,
use_hankel_L=config.use_hankel_L,
device=device,
# Note: get_spectral_filters returns float64, cast later
)
spectral_filters_fft = torch.fft.rfft(spectral_filters, n=config.n, dim=1)
# Setup filters in the model, casting to the target dtype
self.flash_stu.setup_filters(
spectral_filters.to(dtype=torch_dtype), spectral_filters_fft.to(dtype=torch_dtype)
)
# Note: Moving the entire model to device happens later, after loading weights.
def forward(
self, input_ids: torch.Tensor, labels: torch.Tensor = None, attention_mask: torch.Tensor = None, **kwargs
) -> CausalLMOutput:
outputs = self.flash_stu(input_ids, labels=labels, **kwargs)
return outputs
def generate(
self,
input_ids: torch.Tensor,
max_length: int = 32,
num_return_sequences: int = 4,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.95,
repetition_penalty: float = 1.2,
seed: int = 42,
) -> torch.Tensor:
"""Generate text using top-k and nucleus sampling with temperature and repetition penalty.
Args:
input_ids: Input token ids of shape (batch_size, seq_len)
max_length: Maximum length of generated sequence
num_return_sequences: Number of sequences to generate per input
temperature: Sampling temperature. Higher = more random, lower = more focused
top_k: Number of highest probability tokens to keep for top-k sampling
top_p: Cumulative probability cutoff for nucleus sampling
repetition_penalty: Penalty factor for repeating tokens. 1.0 = no penalty
seed: Random seed for reproducibility
Returns:
Generated token ids of shape (num_return_sequences, max_length)
"""
self.eval() # Set to eval mode
device = input_ids.device
# Expand input for multiple sequences
input_ids = input_ids.repeat(num_return_sequences, 1)
generated = input_ids
# Set up generator for reproducible sampling
sample_rng = torch.Generator(device=device)
sample_rng.manual_seed(seed)
# Generate tokens until we reach max_length
with torch.no_grad():
while generated.size(1) < max_length:
# Get logits for next token
outputs = self.flash_stu(generated)
next_token_logits = outputs.logits[:, -1, :]
# Apply repetition penalty
if repetition_penalty != 1.0:
for i in range(generated.shape[0]):
for token in generated[i]:
if token in next_token_logits[i]:
next_token_logits[i, token] /= repetition_penalty
# Apply temperature
if temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Get probabilities
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
# Top-k sampling
if top_k > 0:
indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None]
probs[indices_to_remove] = 0
# Nucleus (top-p) sampling
if top_p < 1.0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
probs[indices_to_remove] = 0
# Renormalize probabilities
probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-8)
# Sample next token
next_token = torch.multinomial(probs, num_samples=1, generator=sample_rng)
# Append to generated sequence
generated = torch.cat([generated, next_token], dim=1)
return generated
def get_num_params(self):
return self.flash_stu.get_num_params()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Get config and create model
config, _ = create_base_model_components(pretrained_model_name_or_path, **kwargs)
model = cls(config)
# Download safetensors file from hub
weights_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="model.safetensors",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", False),
use_auth_token=kwargs.get("use_auth_token", None),
revision=kwargs.get("revision", None),
subfolder=kwargs.get("subfolder", ""),
)
state_dict = load_file(weights_path)
# Reconstruct weight tying for tok_emb and lm_head
tok_emb_key = "tok_emb.weight"
lm_head_key = "lm_head.weight"
tok_emb_present = tok_emb_key in state_dict
lm_head_present = lm_head_key in state_dict
if tok_emb_present and not lm_head_present:
print(f"Reconstructing weight tying: Linking missing '{lm_head_key}' to existing '{tok_emb_key}'")
state_dict[lm_head_key] = state_dict[tok_emb_key]
elif lm_head_present and not tok_emb_present:
print(f"Reconstructing weight tying: Linking missing '{tok_emb_key}' to existing '{lm_head_key}'")
state_dict[tok_emb_key] = state_dict[lm_head_key]
elif not tok_emb_present and not lm_head_present:
# This case should ideally not happen if the file is valid
print(
f"Warning: Neither '{tok_emb_key}' nor '{lm_head_key}' found in state_dict. Weight tying cannot be reconstructed."
)
# If both are present, assume they are loaded correctly (or were never tied)
# Prepend 'flash_stu.' to all keys to match wrapper's state dict
final_state_dict = {f"flash_stu.{k}": v for k, v in state_dict.items()}
model.load_state_dict(final_state_dict)
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device=device, dtype=torch.bfloat16)
model.eval()
# Print parameter count as a sanity check
num_params = model.get_num_params()
print(f"\nModel loaded: {pretrained_model_name_or_path}")
print(f"Parameter count: {num_params / 1e6:.2f}M")
return model
# Create initial config and filters for registration
config, filters = create_base_model_components()
# Register models
AutoConfig.register("FlashSTU", FlashSTUConfig)
AutoModel.register(FlashSTUConfig, FlashSTU)
AutoModelForCausalLM.register(FlashSTUConfig, FlashSTUForCausalLM)
print("Registered FlashSTU model and configuration.")
def run_model_diagnostics(model, tokenizer, device):
"""Run detailed diagnostics to analyze model behavior."""
print("\nRunning model diagnostics...")
# Test cases of varying difficulty and length
test_cases = [
# Simple completion
"2 + 2 =",
# Medium difficulty
"The capital of France is Paris. The capital of Germany is",
# Complex reasoning
"If a train travels 120 kilometers in 2 hours, its average speed is",
# Pattern completion
"1, 2, 3, 4,",
# Long context
"The following is a detailed explanation of photosynthesis: Plants use sunlight to",
]
with torch.no_grad():
for prompt in test_cases:
print(f"\nAnalyzing prompt: {prompt}")
# Tokenize
tokens = tokenizer(prompt, return_tensors="pt")
input_ids = tokens["input_ids"].to(device)
outputs = model.flash_stu(input_ids, labels=input_ids)
labels = input_ids.clone()
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(reduction="none")
token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(
shift_labels.size()
)
# Print token-by-token analysis
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
print("\nToken-by-token loss:")
for i, (token, loss) in enumerate(zip(input_tokens[1:], token_losses[0])):
print(f"{token}: {loss.item():.3f}")
print(f"Average loss: {token_losses.mean().item():.3f}")
# Generate with different temperatures
temps = [0.5, 0.7, 1.0]
print("\nGeneration temperature comparison:")
for temp in temps:
gen_ids = model.generate(
input_ids,
max_length=25,
num_return_sequences=1,
temperature=temp,
top_p=0.9,
repetition_penalty=1.5,
seed=42,
)
gen_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
print(f"\nTemp {temp}: {gen_text}")
def validate_model_generation():
print("\nRunning generation validation test...")
try:
from transformers import AutoTokenizer
# Load model and tokenizer
# model_id = "Hazan-Lab/Flash_STU_550M"
model_id = "Hazan-Lab/FlashSTU-340M-0428"
model = FlashSTUForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device=device, dtype=torch.bfloat16)
model.eval()
# Print parameter count as a sanity check
num_params = model.get_num_params()
print(f"\nModel loaded: {model_id}")
print(f"Parameter count: {num_params / 1e6:.2f}M")
# Run additional diagnostics
run_model_diagnostics(model, tokenizer, device)
except Exception as e:
print(f"\nError during validation: {str(e)}")
raise
# Run evaluation tasks
tasks = [
# "mmlu",
"hellaswag",
# "piqa",
# "siqa",
# "boolq",
# "winogrande",
# "commonsense_qa",
# "openbookqa",
# "arc",
# "arc_easy",
# "arc_challenge",
# "triviaqa",
# "nq_open",
# "humaneval",
# "mbpp",
# "gms8k",
# "hendrycks_math",
# "mathqa",
# "minerva_math",
# "score",
# "asdiv",
# "agieval",
# "bigbench",
]
tasks_fewshot = {
"hellaswag": 0,
# "mmlu": 5,
# "piqa": 0,
# "siqa": 0,
# "boolq": 0,
# "winogrande": -1,
# "commonsense_qa": 7,
# "openbookqa": -1,
# "arc": -1,
# "arc_easy": -1,
# "arc_challenge": -1,
# "triviaqa": 5,
# "nq_open": 5,
# "humaneval": -1,
# "mbpp": 3,
# "gms8k": -1,
# "hendrycks_math": 4,
# "mathqa": -1,
# "minerva_math": -1,
# "score": -1,
# "asdiv": -1,
# "agieval": -1,
# "bigbench": -1,
}
all_results = {}
# First validate generation works
validate_model_generation()
print("\nStarting evaluation tasks...")
for task in tasks:
print(f"\nEvaluating task: {task}")
eval_kwargs = dict(
model="hf",
model_args=(
# "pretrained=Hazan-Lab/Flash_STU_550M,"
"pretrained=Hazan-Lab/FlashSTU-340M-0428,"
"trust_remote_code=True,"
"dtype=bfloat16,"
"cache_dir=/scratch/gpfs/mn4560/hazan-lab/tensorized_filters/tensorized_filters/eval/cache"
),
tasks=[task],
batch_size="auto",
device="cuda:0",
)
few_shot_value = tasks_fewshot.get(task, -1)
if few_shot_value != -1:
eval_kwargs["num_fewshot"] = few_shot_value
results = evaluator.simple_evaluate(**eval_kwargs)
task_result = results["results"].get(task, {})
all_results[task] = task_result
print(f"Results for {task}:")
print(task_result)
print("\n" + "=" * 50 + "\n")
print("All Evaluation Results:")
for task, result in all_results.items():
print(f"{task}: {result}")