|
|
import math |
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
import lm_eval as evaluator |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from safetensors.torch import load_file |
|
|
from torchtune.modules import RotaryPositionalEmbeddings |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
AutoModelForCausalLM, |
|
|
PreTrainedModel, |
|
|
PretrainedConfig, |
|
|
) |
|
|
from transformers.modeling_outputs import CausalLMOutput |
|
|
|
|
|
try: |
|
|
from flashfftconv import FlashFFTConv |
|
|
|
|
|
flash_fft_available = True |
|
|
except ImportError as e: |
|
|
print(f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation.") |
|
|
flash_fft_available = False |
|
|
|
|
|
try: |
|
|
from flash_attn import flash_attn_func |
|
|
except ImportError as e: |
|
|
print(f"Unable to import Triton-based flash attention: {e}. No alternative currently available.") |
|
|
|
|
|
os.environ["HF_ALLOW_CODE_EVAL"] = "1" |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
loss_fn = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
def nearest_power_of_two(n: int, round_up: bool = False) -> int: |
|
|
if n <= 1: |
|
|
return 1 |
|
|
return 1 << ((n - 1).bit_length() if round_up else (n).bit_length() - 1) |
|
|
|
|
|
|
|
|
def find_multiple(n: int, k: int) -> int: |
|
|
if n % k == 0: |
|
|
return n |
|
|
return n + k - (n % k) |
|
|
|
|
|
|
|
|
def get_hankel(seq_len: int, use_hankel_L: bool = False) -> torch.Tensor: |
|
|
entries = torch.arange(1, seq_len + 1, dtype=torch.float64) |
|
|
i_plus_j = entries.reshape(-1, 1) + entries.reshape(1, -1) |
|
|
|
|
|
if use_hankel_L: |
|
|
sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0 |
|
|
denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0) |
|
|
Z = sgn * (8.0 / denom) |
|
|
elif not use_hankel_L: |
|
|
Z = 2.0 / (i_plus_j**3 - i_plus_j) |
|
|
else: |
|
|
raise ValueError("use_hankel_L must be a boolean") |
|
|
|
|
|
return Z |
|
|
|
|
|
|
|
|
def get_spectral_filters( |
|
|
seq_len: int, |
|
|
K: int, |
|
|
use_hankel_L: bool = False, |
|
|
device: torch.device = None, |
|
|
dtype: torch.dtype = torch.float64, |
|
|
) -> torch.Tensor: |
|
|
Z = get_hankel(seq_len, use_hankel_L).to(device) |
|
|
sigma, phi = torch.linalg.eigh(Z) |
|
|
sigma_k, phi_k = sigma[-K:], phi[:, -K:] |
|
|
phi_k *= sigma_k**0.25 |
|
|
return phi_k.to(device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
class BaseConfigForCausalLM(PretrainedConfig): |
|
|
"""Base PretrainedConfig class to be decorated with dataclass""" |
|
|
|
|
|
model_type = "base_model" |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FlashSTUConfig(BaseConfigForCausalLM): |
|
|
model_type = "FlashSTU" |
|
|
|
|
|
|
|
|
bsz: int = 1 |
|
|
dim: int = 1024 |
|
|
r: int = 1024 |
|
|
num_heads: int = 12 |
|
|
num_local_heads: Optional[int] = -1 |
|
|
num_layers: int = 12 |
|
|
seq_len: int = 4096 |
|
|
n: int = 8191 |
|
|
window_size: int = 2048 |
|
|
vocab_size: int = 200064 |
|
|
inter_dim: Optional[int] = 3072 |
|
|
mlp_scale: Optional[float] = 12.0 |
|
|
weight_tying: Optional[bool] = True |
|
|
bias: Optional[bool] = False |
|
|
rope_theta: Optional[float] = 10000.0 |
|
|
softcap: Optional[float] = 50.0 |
|
|
num_eigh: Optional[int] = 24 |
|
|
use_hankel_L: Optional[bool] = False |
|
|
use_flash_fft: Optional[bool] = True |
|
|
use_tensordot: Optional[bool] = True |
|
|
use_attn: Optional[bool] = True |
|
|
use_alibi: Optional[bool] = False |
|
|
torch_dtype: torch.dtype = torch.bfloat16 |
|
|
device: torch.device = None |
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
bsz: int = 1, |
|
|
dim: int = 1024, |
|
|
r: int = 1024, |
|
|
num_heads: int = 12, |
|
|
num_local_heads: Optional[int] = -1, |
|
|
num_layers: int = 12, |
|
|
seq_len: int = 4096, |
|
|
n: int = 8191, |
|
|
window_size: int = 2048, |
|
|
vocab_size: int = 200064, |
|
|
inter_dim: Optional[int] = 3072, |
|
|
mlp_scale: Optional[float] = 12.0, |
|
|
weight_tying: Optional[bool] = True, |
|
|
bias: Optional[bool] = False, |
|
|
rope_theta: Optional[float] = 10000.0, |
|
|
softcap: Optional[float] = 50.0, |
|
|
num_eigh: Optional[int] = 24, |
|
|
use_hankel_L: Optional[bool] = False, |
|
|
use_flash_fft: Optional[bool] = True, |
|
|
use_tensordot: Optional[bool] = True, |
|
|
use_attn: Optional[bool] = True, |
|
|
use_alibi: Optional[bool] = False, |
|
|
torch_dtype: torch.dtype = torch.bfloat16, |
|
|
device: torch.device = None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
self.bsz = bsz |
|
|
self.dim = dim |
|
|
self.r = r |
|
|
self.num_heads = num_heads |
|
|
self.num_local_heads = num_local_heads |
|
|
self.num_layers = num_layers |
|
|
self.seq_len = seq_len |
|
|
self.n = n |
|
|
self.window_size = window_size |
|
|
self.vocab_size = vocab_size |
|
|
self.inter_dim = inter_dim |
|
|
self.mlp_scale = mlp_scale |
|
|
self.weight_tying = weight_tying |
|
|
self.bias = bias |
|
|
self.rope_theta = rope_theta |
|
|
self.softcap = softcap |
|
|
self.num_eigh = num_eigh |
|
|
self.use_hankel_L = use_hankel_L |
|
|
self.use_flash_fft = use_flash_fft |
|
|
self.use_tensordot = use_tensordot |
|
|
self.use_attn = use_attn |
|
|
self.use_alibi = use_alibi |
|
|
self.torch_dtype = torch_dtype |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self.__post_init__() |
|
|
|
|
|
def __post_init__(self): |
|
|
|
|
|
if isinstance(self.torch_dtype, str): |
|
|
try: |
|
|
self.torch_dtype = getattr(torch, self.torch_dtype) |
|
|
except AttributeError: |
|
|
raise ValueError(f"Invalid torch_dtype string: {self.torch_dtype}") |
|
|
|
|
|
if self.num_local_heads == -1: |
|
|
self.num_local_heads = self.num_heads |
|
|
if self.inter_dim is None: |
|
|
hidden_dim = self.mlp_scale * self.dim |
|
|
num_hidden = int(2 * hidden_dim / 3) |
|
|
self.inter_dim = find_multiple(num_hidden, 256) |
|
|
self.head_dim = self.dim // self.num_heads |
|
|
|
|
|
@classmethod |
|
|
def from_name(cls, name: str): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Not yet implemented") |
|
|
pass |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, config: FlashSTUConfig) -> None: |
|
|
super().__init__() |
|
|
self.w1 = nn.Linear(config.dim, config.inter_dim) |
|
|
self.w2 = nn.Linear(config.inter_dim, config.dim) |
|
|
self.w2.SCALE_INIT = 1 |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.w2(F.gelu(self.w1(x), approximate="tanh")) |
|
|
|
|
|
|
|
|
class SlidingWindowAttention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.wq = nn.Linear(config.dim, config.dim) |
|
|
self.wk = nn.Linear(config.dim, config.dim) |
|
|
self.wv = nn.Linear(config.dim, config.dim) |
|
|
self.wo = nn.Linear(config.dim, config.dim) |
|
|
self.wo.SCALE_INIT = 1 |
|
|
|
|
|
self.dim = config.dim |
|
|
self.head_dim = config.head_dim |
|
|
self.num_heads = config.num_heads |
|
|
self.num_local_heads = config.num_local_heads |
|
|
self.window_size = config.window_size |
|
|
self.softcap = config.softcap |
|
|
|
|
|
self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None |
|
|
self.rotary_emb = RotaryPositionalEmbeddings( |
|
|
dim=self.head_dim, |
|
|
max_seq_len=config.seq_len, |
|
|
base=config.rope_theta, |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
bsz, seq_len, dim = x.shape |
|
|
|
|
|
q, k, v = self.wq(x), self.wk(x), self.wv(x) |
|
|
q = q.view(bsz, seq_len, self.num_heads, self.head_dim) |
|
|
k = k.view(bsz, seq_len, self.num_local_heads, self.head_dim) |
|
|
v = v.view(bsz, seq_len, self.num_local_heads, self.head_dim) |
|
|
|
|
|
if self.alibi_slopes is None: |
|
|
q, k = self.rotary_emb(q), self.rotary_emb(k) |
|
|
|
|
|
y = flash_attn_func( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
causal=True, |
|
|
window_size=(self.window_size, 0), |
|
|
|
|
|
alibi_slopes=self.alibi_slopes, |
|
|
) |
|
|
|
|
|
out = y.reshape(bsz, seq_len, -1) |
|
|
out = self.wo(out) |
|
|
|
|
|
return out |
|
|
|
|
|
def _generate_slopes(self, n: int): |
|
|
start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
|
|
return [start * (start**i) for i in range(n)] |
|
|
|
|
|
def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25): |
|
|
|
|
|
if math.log2(num_heads).is_integer(): |
|
|
slopes = self._generate_slopes(num_heads) |
|
|
else: |
|
|
|
|
|
n = nearest_power_of_two(num_heads, round_up=False) |
|
|
slopes_power_of_two = self._generate_slopes(n) |
|
|
|
|
|
|
|
|
extra_slopes = self._generate_slopes(2 * n) |
|
|
extra_slopes_trunc = extra_slopes[0::2][: num_heads - n] |
|
|
slopes = slopes_power_of_two + extra_slopes_trunc |
|
|
slopes = torch.tensor(slopes, device=torch.device("cuda")) |
|
|
slopes = slopes * interpolation_factor |
|
|
return slopes |
|
|
|
|
|
|
|
|
class STU(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.stu_filters = None |
|
|
self.stu_filters_fft = None |
|
|
|
|
|
self.n = config.n |
|
|
self.num_eigh = config.num_eigh |
|
|
self.d_in = config.dim |
|
|
self.d_out = config.dim |
|
|
self.r = config.r |
|
|
self.use_hankel_L = config.use_hankel_L |
|
|
self.use_tensordot = config.use_tensordot |
|
|
self.flash_fft = ( |
|
|
FlashFFTConv(self.n, dtype=torch.bfloat16) if config.use_flash_fft and flash_fft_available else None |
|
|
) |
|
|
|
|
|
|
|
|
if self.use_tensordot: |
|
|
self.M_inputs = nn.Parameter(torch.zeros(self.d_in, self.d_out)) |
|
|
self.M_filters = nn.Parameter(torch.zeros(self.num_eigh, self.d_in)) |
|
|
else: |
|
|
self.M_phi_plus = nn.Parameter(torch.zeros(self.num_eigh, self.d_in, self.d_out)) |
|
|
if not self.use_hankel_L: |
|
|
self.M_phi_minus = nn.Parameter(torch.zeros(self.num_eigh, self.d_in, self.d_out)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B, L, D = x.shape |
|
|
|
|
|
if self.use_tensordot: |
|
|
|
|
|
x_proj = x @ self.M_inputs |
|
|
phi_proj = self.stu_filters @ self.M_filters |
|
|
if self.flash_fft: |
|
|
spectral_plus, spectral_minus = self.flash_conv(x_proj, phi_proj, self.flash_fft, self.use_tensordot) |
|
|
else: |
|
|
spectral_plus, spectral_minus = self.conv(x_proj, phi_proj, self.n, self.use_tensordot) |
|
|
|
|
|
else: |
|
|
|
|
|
if self.flash_fft: |
|
|
U_plus, U_minus = self.flash_conv(x, self.stu_filters, self.flash_fft, self.use_tensordot) |
|
|
else: |
|
|
U_plus, U_minus = self.conv(x, self.stu_filters, self.n, self.use_tensordot) |
|
|
|
|
|
B, L, K, D = U_plus.shape |
|
|
spectral_plus = U_plus.reshape(B, L, K * D) @ self.M_phi_plus.reshape(K * D, self.d_out) |
|
|
if not self.use_hankel_L: |
|
|
spectral_minus = U_minus.reshape(B, L, K * D) @ self.M_phi_minus.reshape(K * D, self.d_out) |
|
|
|
|
|
out = spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus |
|
|
return out |
|
|
|
|
|
def conv( |
|
|
self, u: torch.Tensor, v: torch.Tensor, n: int, use_tensordot: bool = True |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Performs convolution via FFT with causal alignment using a negative featurization. |
|
|
|
|
|
The input tensor u is modulated by an alternating sign tensor (sgn) that multiplies every other |
|
|
time step by -1. This "negative featurization" modulates the phase so that in this implementation |
|
|
the correct causal output is obtained by simply slicing the first L elements (i.e. [:seq_len]). |
|
|
Note: Using a conventional slice [seq_len-1:2*seq_len-1] would yield a flipped alignment, resulting in leakage. |
|
|
|
|
|
Args: |
|
|
u: Input tensor of shape (bsz, seq_len, d_in). |
|
|
v: Kernel tensor; expected shape is (seq_len, d_out) if use_tensordot is True. |
|
|
n: FFT length (typically set to 2*seq_len - 1 for linear convolution with implicit right zero-padding). |
|
|
use_tensordot: Boolean flag to control kernel reshaping. |
|
|
|
|
|
Returns: |
|
|
A tuple (U_plus, U_minus) where: |
|
|
- U_plus is the primary convolution output. |
|
|
- U_minus is the secondary output, corrected by the sign tensor. |
|
|
""" |
|
|
bsz, seq_len, d_in = u.shape |
|
|
|
|
|
sgn = torch.full((1, seq_len, 1), 1, device=u.device) |
|
|
sgn[:, 1::2] *= -1 |
|
|
|
|
|
if use_tensordot: |
|
|
_, d_out = v.shape |
|
|
v = v.view(1, -1, d_out, 1).to(torch.float32).contiguous() |
|
|
else: |
|
|
_, K = v.shape |
|
|
sgn = sgn.unsqueeze(-1) |
|
|
v = v.view(1, -1, K, 1, 1).to(torch.float32).contiguous() |
|
|
u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in) |
|
|
|
|
|
|
|
|
v_fft = torch.fft.rfft(v.to(torch.float32), n=n, dim=1) |
|
|
|
|
|
U = torch.stack([u, u * sgn], dim=-1).to(torch.float32).contiguous() |
|
|
|
|
|
U_fft = torch.fft.rfft(U.to(torch.float32), n=n, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
U_conv = torch.fft.irfft(v_fft * U_fft, n=n, dim=1)[:, :seq_len].to(u.dtype) |
|
|
U_plus, U_minus = torch.unbind(U_conv, dim=-1) |
|
|
U_minus = U_minus * sgn |
|
|
|
|
|
return U_plus.type_as(u), U_minus.type_as(u) |
|
|
|
|
|
def flash_conv( |
|
|
self, |
|
|
u: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
flash_fft: FlashFFTConv, |
|
|
use_tensordot: bool = True, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Flash FFT convolution. |
|
|
|
|
|
Args: |
|
|
u (torch.Tensor): Input tensor of shape `(B, L, d_in)`, where: |
|
|
- `B` is the batch size, |
|
|
- `L` is the sequence length, |
|
|
- `d_in` is the input dimension. |
|
|
v (torch.Tensor): Filter tensor of shape `(K, d_in)`, where: |
|
|
- `K` is the number of filters, |
|
|
- `d_in` is the input dimension. |
|
|
flash_fft (FlashFFTConv): An instance of the FlashFFTConv module, used to perform the convolution. |
|
|
use_tensordot (bool, optional): If `True`, performs the tensordot approximation (default is `True`). |
|
|
|
|
|
Returns: |
|
|
tuple[torch.Tensor, torch.Tensor]: A tuple `(U_plus, U_minus)`: |
|
|
- `U_plus`: Convolved output tensor with positive eigenvalues. |
|
|
- Shape depends on `use_tensordot`: |
|
|
- If `use_tensordot=True`: `(B, L, d_in)` |
|
|
- If `use_tensordot=False`: `(B, L, K, d_in)` |
|
|
- `U_minus`: Convolved output tensor with negative eigenvalues. |
|
|
- Shape depends on `use_tensordot`: |
|
|
- If `use_tensordot=True`: `(B, L, d_in)` |
|
|
- If `use_tensordot=False`: `(B, L, K, d_in)` |
|
|
|
|
|
Raises: |
|
|
ValueError: If the input tensor shapes do not conform to the expected dimensions. |
|
|
|
|
|
Example: |
|
|
>>> u = torch.randn(4, 16, 32) # (B, L, d_in) |
|
|
>>> v = torch.randn(8, 32) # (K, d_in) |
|
|
>>> flash_fft = FlashFFTConv(n=16, dtype=torch.float32) |
|
|
>>> U_plus, U_minus = flash_convolve(u, v, flash_fft, use_tensordot=True) |
|
|
>>> print(U_plus.shape, U_minus.shape) |
|
|
torch.Size([4, 16, 32]) torch.Size([4, 16, 32]) |
|
|
|
|
|
""" |
|
|
bsz, seq_len, d_in = u.shape |
|
|
_, K = v.shape |
|
|
|
|
|
padded_len = nearest_power_of_two(seq_len, round_up=True) |
|
|
pad_len = padded_len - seq_len |
|
|
|
|
|
sgn = torch.full((1, 1, padded_len), 1, device=u.device) |
|
|
sgn[:, :, 1::2] = -1 |
|
|
|
|
|
if use_tensordot: |
|
|
u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16) |
|
|
v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32) |
|
|
u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len) |
|
|
else: |
|
|
u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).repeat_interleave(K, dim=1) |
|
|
v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).repeat(d_in, 1) |
|
|
u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len) |
|
|
|
|
|
|
|
|
U_conv = flash_fft(u_conv.to(torch.bfloat16), v_padded.to(torch.float32)) |
|
|
|
|
|
|
|
|
U_conv = U_conv[..., :seq_len] |
|
|
u_plus, u_minus = torch.chunk(U_conv, 2, dim=0) |
|
|
|
|
|
if use_tensordot: |
|
|
u_minus = u_minus * sgn[:, :, :seq_len] |
|
|
U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2) |
|
|
else: |
|
|
sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2) |
|
|
U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() |
|
|
U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn |
|
|
|
|
|
return U_plus, U_minus |
|
|
|
|
|
|
|
|
class SlidingWindowAttentionLayer(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.swa_norm = nn.LayerNorm(config.dim) |
|
|
self.swa = SlidingWindowAttention(config) |
|
|
self.mlp_norm = nn.LayerNorm(config.dim) |
|
|
self.mlp = MLP(config) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.swa(self.swa_norm(x)) |
|
|
x = x + self.mlp(self.mlp_norm(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class STULayer(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.stu_norm = nn.LayerNorm(config.dim) |
|
|
self.stu = STU(config) |
|
|
self.mlp_norm = nn.LayerNorm(config.dim) |
|
|
self.mlp = MLP(config) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.stu(self.stu_norm(x)) |
|
|
x = x + self.mlp(self.mlp_norm(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class FlashSTU(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.tok_emb = nn.Embedding(config.vocab_size, config.dim) |
|
|
self.layers = nn.ModuleList() |
|
|
|
|
|
for layer_idx in range(config.num_layers): |
|
|
|
|
|
if layer_idx % 2 == 0: |
|
|
self.layers.append(STULayer(config)) |
|
|
else: |
|
|
self.layers.append(SlidingWindowAttentionLayer(config)) if config.use_attn else self.layers.append( |
|
|
STULayer(config) |
|
|
) |
|
|
|
|
|
self.norm_f = nn.LayerNorm(config.dim) |
|
|
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) |
|
|
|
|
|
if self.config.weight_tying: |
|
|
self.tok_emb.weight = self.lm_head.weight |
|
|
|
|
|
self.std = self.config.dim**-0.5 |
|
|
|
|
|
def init_weights(self, module): |
|
|
std = self.std |
|
|
if isinstance(module, nn.Linear): |
|
|
if hasattr(module, "SCALE_INIT"): |
|
|
std *= (2 * self.config.num_layers) ** -0.5 |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
|
|
|
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None, **kwargs) -> CausalLMOutput: |
|
|
x = self.tok_emb(input_ids) |
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
|
|
|
x = self.norm_f(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = loss_fn(logits.flatten(0, 1), labels.flatten(0, 1)) |
|
|
|
|
|
return CausalLMOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
) |
|
|
|
|
|
def setup_filters( |
|
|
self, |
|
|
spectral_filters: torch.Tensor, |
|
|
spectral_filters_fft: torch.Tensor, |
|
|
): |
|
|
for layer in self.layers: |
|
|
if isinstance(layer, STULayer): |
|
|
layer.stu.stu_filters = spectral_filters |
|
|
layer.stu.stu_filters_fft = spectral_filters_fft |
|
|
|
|
|
def get_num_params(self): |
|
|
""" |
|
|
Return the number of parameters in the model. |
|
|
For non-embedding count (default), the position embeddings get subtracted. |
|
|
""" |
|
|
n_params = sum(p.numel() for p in self.parameters()) |
|
|
return n_params |
|
|
|
|
|
|
|
|
def create_base_model_components(model_name_or_path=None, **kwargs): |
|
|
"""Create config and filters needed for model initialization""" |
|
|
if model_name_or_path is not None: |
|
|
config = FlashSTUConfig.from_pretrained(model_name_or_path, **kwargs) |
|
|
else: |
|
|
config = FlashSTUConfig(**kwargs) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
filters = get_spectral_filters( |
|
|
seq_len=config.seq_len, |
|
|
K=config.num_eigh, |
|
|
use_hankel_L=config.use_hankel_L, |
|
|
device=device, |
|
|
dtype=config.torch_dtype, |
|
|
) |
|
|
assert filters.dtype == config.torch_dtype, f"filters dtype is {filters.dtype}, expected {config.torch_dtype}" |
|
|
return config, filters |
|
|
|
|
|
|
|
|
class FlashSTUForCausalLM(PreTrainedModel): |
|
|
"""Thin wrapper to comply with HuggingFace's expected interface""" |
|
|
|
|
|
config_class = FlashSTUConfig |
|
|
base_model_prefix = "FlashSTU" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.flash_stu = FlashSTU(config) |
|
|
self.flash_stu.apply(self.flash_stu.init_weights) |
|
|
|
|
|
device = ( |
|
|
config.device |
|
|
if config.device is not None |
|
|
else torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
) |
|
|
torch_dtype = config.torch_dtype |
|
|
|
|
|
spectral_filters = get_spectral_filters( |
|
|
seq_len=config.seq_len, |
|
|
K=config.num_eigh, |
|
|
use_hankel_L=config.use_hankel_L, |
|
|
device=device, |
|
|
|
|
|
) |
|
|
spectral_filters_fft = torch.fft.rfft(spectral_filters, n=config.n, dim=1) |
|
|
|
|
|
|
|
|
self.flash_stu.setup_filters( |
|
|
spectral_filters.to(dtype=torch_dtype), spectral_filters_fft.to(dtype=torch_dtype) |
|
|
) |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, input_ids: torch.Tensor, labels: torch.Tensor = None, attention_mask: torch.Tensor = None, **kwargs |
|
|
) -> CausalLMOutput: |
|
|
outputs = self.flash_stu(input_ids, labels=labels, **kwargs) |
|
|
return outputs |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
max_length: int = 32, |
|
|
num_return_sequences: int = 4, |
|
|
temperature: float = 0.8, |
|
|
top_k: int = 50, |
|
|
top_p: float = 0.95, |
|
|
repetition_penalty: float = 1.2, |
|
|
seed: int = 42, |
|
|
) -> torch.Tensor: |
|
|
"""Generate text using top-k and nucleus sampling with temperature and repetition penalty. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token ids of shape (batch_size, seq_len) |
|
|
max_length: Maximum length of generated sequence |
|
|
num_return_sequences: Number of sequences to generate per input |
|
|
temperature: Sampling temperature. Higher = more random, lower = more focused |
|
|
top_k: Number of highest probability tokens to keep for top-k sampling |
|
|
top_p: Cumulative probability cutoff for nucleus sampling |
|
|
repetition_penalty: Penalty factor for repeating tokens. 1.0 = no penalty |
|
|
seed: Random seed for reproducibility |
|
|
|
|
|
Returns: |
|
|
Generated token ids of shape (num_return_sequences, max_length) |
|
|
""" |
|
|
self.eval() |
|
|
device = input_ids.device |
|
|
|
|
|
|
|
|
input_ids = input_ids.repeat(num_return_sequences, 1) |
|
|
generated = input_ids |
|
|
|
|
|
|
|
|
sample_rng = torch.Generator(device=device) |
|
|
sample_rng.manual_seed(seed) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
while generated.size(1) < max_length: |
|
|
|
|
|
outputs = self.flash_stu(generated) |
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0: |
|
|
for i in range(generated.shape[0]): |
|
|
for token in generated[i]: |
|
|
if token in next_token_logits[i]: |
|
|
next_token_logits[i, token] /= repetition_penalty |
|
|
|
|
|
|
|
|
if temperature != 1.0: |
|
|
next_token_logits = next_token_logits / temperature |
|
|
|
|
|
|
|
|
probs = torch.nn.functional.softmax(next_token_logits, dim=-1) |
|
|
|
|
|
|
|
|
if top_k > 0: |
|
|
indices_to_remove = probs < torch.topk(probs, top_k)[0][..., -1, None] |
|
|
probs[indices_to_remove] = 0 |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
probs[indices_to_remove] = 0 |
|
|
|
|
|
|
|
|
probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-8) |
|
|
|
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1, generator=sample_rng) |
|
|
|
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=1) |
|
|
|
|
|
return generated |
|
|
|
|
|
def get_num_params(self): |
|
|
return self.flash_stu.get_num_params() |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
|
|
|
|
config, _ = create_base_model_components(pretrained_model_name_or_path, **kwargs) |
|
|
model = cls(config) |
|
|
|
|
|
|
|
|
weights_path = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="model.safetensors", |
|
|
cache_dir=kwargs.get("cache_dir"), |
|
|
force_download=kwargs.get("force_download", False), |
|
|
proxies=kwargs.get("proxies", None), |
|
|
local_files_only=kwargs.get("local_files_only", False), |
|
|
use_auth_token=kwargs.get("use_auth_token", None), |
|
|
revision=kwargs.get("revision", None), |
|
|
subfolder=kwargs.get("subfolder", ""), |
|
|
) |
|
|
|
|
|
state_dict = load_file(weights_path) |
|
|
|
|
|
|
|
|
tok_emb_key = "tok_emb.weight" |
|
|
lm_head_key = "lm_head.weight" |
|
|
|
|
|
tok_emb_present = tok_emb_key in state_dict |
|
|
lm_head_present = lm_head_key in state_dict |
|
|
|
|
|
if tok_emb_present and not lm_head_present: |
|
|
print(f"Reconstructing weight tying: Linking missing '{lm_head_key}' to existing '{tok_emb_key}'") |
|
|
state_dict[lm_head_key] = state_dict[tok_emb_key] |
|
|
elif lm_head_present and not tok_emb_present: |
|
|
print(f"Reconstructing weight tying: Linking missing '{tok_emb_key}' to existing '{lm_head_key}'") |
|
|
state_dict[tok_emb_key] = state_dict[lm_head_key] |
|
|
elif not tok_emb_present and not lm_head_present: |
|
|
|
|
|
print( |
|
|
f"Warning: Neither '{tok_emb_key}' nor '{lm_head_key}' found in state_dict. Weight tying cannot be reconstructed." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
final_state_dict = {f"flash_stu.{k}": v for k, v in state_dict.items()} |
|
|
model.load_state_dict(final_state_dict) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = model.to(device=device, dtype=torch.bfloat16) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
num_params = model.get_num_params() |
|
|
print(f"\nModel loaded: {pretrained_model_name_or_path}") |
|
|
print(f"Parameter count: {num_params / 1e6:.2f}M") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
config, filters = create_base_model_components() |
|
|
|
|
|
|
|
|
AutoConfig.register("FlashSTU", FlashSTUConfig) |
|
|
AutoModel.register(FlashSTUConfig, FlashSTU) |
|
|
AutoModelForCausalLM.register(FlashSTUConfig, FlashSTUForCausalLM) |
|
|
|
|
|
print("Registered FlashSTU model and configuration.") |
|
|
|
|
|
|
|
|
def run_model_diagnostics(model, tokenizer, device): |
|
|
"""Run detailed diagnostics to analyze model behavior.""" |
|
|
print("\nRunning model diagnostics...") |
|
|
|
|
|
|
|
|
test_cases = [ |
|
|
|
|
|
"2 + 2 =", |
|
|
|
|
|
"The capital of France is Paris. The capital of Germany is", |
|
|
|
|
|
"If a train travels 120 kilometers in 2 hours, its average speed is", |
|
|
|
|
|
"1, 2, 3, 4,", |
|
|
|
|
|
"The following is a detailed explanation of photosynthesis: Plants use sunlight to", |
|
|
] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for prompt in test_cases: |
|
|
print(f"\nAnalyzing prompt: {prompt}") |
|
|
|
|
|
|
|
|
tokens = tokenizer(prompt, return_tensors="pt") |
|
|
input_ids = tokens["input_ids"].to(device) |
|
|
|
|
|
outputs = model.flash_stu(input_ids, labels=input_ids) |
|
|
|
|
|
labels = input_ids.clone() |
|
|
shift_logits = outputs.logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
loss_fct = nn.CrossEntropyLoss(reduction="none") |
|
|
token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view( |
|
|
shift_labels.size() |
|
|
) |
|
|
|
|
|
|
|
|
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) |
|
|
print("\nToken-by-token loss:") |
|
|
for i, (token, loss) in enumerate(zip(input_tokens[1:], token_losses[0])): |
|
|
print(f"{token}: {loss.item():.3f}") |
|
|
|
|
|
print(f"Average loss: {token_losses.mean().item():.3f}") |
|
|
|
|
|
|
|
|
temps = [0.5, 0.7, 1.0] |
|
|
print("\nGeneration temperature comparison:") |
|
|
for temp in temps: |
|
|
gen_ids = model.generate( |
|
|
input_ids, |
|
|
max_length=25, |
|
|
num_return_sequences=1, |
|
|
temperature=temp, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.5, |
|
|
seed=42, |
|
|
) |
|
|
gen_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True) |
|
|
print(f"\nTemp {temp}: {gen_text}") |
|
|
|
|
|
|
|
|
def validate_model_generation(): |
|
|
print("\nRunning generation validation test...") |
|
|
|
|
|
try: |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
|
|
|
model_id = "Hazan-Lab/FlashSTU-340M-0428" |
|
|
model = FlashSTUForCausalLM.from_pretrained(model_id) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = model.to(device=device, dtype=torch.bfloat16) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
num_params = model.get_num_params() |
|
|
print(f"\nModel loaded: {model_id}") |
|
|
print(f"Parameter count: {num_params / 1e6:.2f}M") |
|
|
|
|
|
|
|
|
run_model_diagnostics(model, tokenizer, device) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nError during validation: {str(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
tasks = [ |
|
|
|
|
|
"hellaswag", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
tasks_fewshot = { |
|
|
"hellaswag": 0, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
all_results = {} |
|
|
|
|
|
|
|
|
validate_model_generation() |
|
|
|
|
|
print("\nStarting evaluation tasks...") |
|
|
for task in tasks: |
|
|
print(f"\nEvaluating task: {task}") |
|
|
eval_kwargs = dict( |
|
|
model="hf", |
|
|
model_args=( |
|
|
|
|
|
"pretrained=Hazan-Lab/FlashSTU-340M-0428," |
|
|
"trust_remote_code=True," |
|
|
"dtype=bfloat16," |
|
|
"cache_dir=/scratch/gpfs/mn4560/hazan-lab/tensorized_filters/tensorized_filters/eval/cache" |
|
|
), |
|
|
tasks=[task], |
|
|
batch_size="auto", |
|
|
device="cuda:0", |
|
|
) |
|
|
few_shot_value = tasks_fewshot.get(task, -1) |
|
|
if few_shot_value != -1: |
|
|
eval_kwargs["num_fewshot"] = few_shot_value |
|
|
results = evaluator.simple_evaluate(**eval_kwargs) |
|
|
task_result = results["results"].get(task, {}) |
|
|
all_results[task] = task_result |
|
|
print(f"Results for {task}:") |
|
|
print(task_result) |
|
|
print("\n" + "=" * 50 + "\n") |
|
|
|
|
|
print("All Evaluation Results:") |
|
|
for task, result in all_results.items(): |
|
|
print(f"{task}: {result}") |
|
|
|