noblebarkrr's picture
Updated to Dzeta
4f175c5
from functools import partial
import torch
from torch import nn
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from einops import rearrange, pack, unpack
from typing import Tuple
from functools import wraps
from packaging import version
from collections import namedtuple
import os
# PyTorch version check
TORCH_VERSION = tuple(map(int, torch.__version__.split('.')[:2]))
IS_TORCH_LT_2_5 = TORCH_VERSION < (2, 5)
IS_TORCH_LT_2_0 = TORCH_VERSION < (2, 0)
# Conditional import for flex attention
if not IS_TORCH_LT_2_5:
from .flex_attention_utils import (
FlexAttention,
generate_sliding_window_with_sinks,
)
else:
FlexAttention = None
generate_sliding_window_with_sinks = None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
class RMSNorm(Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
return F.normalize(x, dim=-1) * self.scale * self.gamma
class FeedForward(Module):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
dim_inner = int(dim * mult)
self.net = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim_inner),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_inner, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
# Manual SDPA for PyTorch < 2.0
def manual_scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=False, scale=None, dropout_p=0.0):
"""SDPA implementation for PyTorch < 2.0"""
if scale is None:
scale = q.shape[-1] ** -0.5
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
if is_causal:
seq_len = attn_weights.shape[-1]
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=attn_weights.device), diagonal=1).bool()
attn_weights = attn_weights.masked_fill(causal_mask, float('-inf'))
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_weights = attn_weights.masked_fill(~attn_mask, float('-inf'))
else:
attn_weights = attn_weights + attn_mask
attn_weights = torch.softmax(attn_weights, dim=-1)
if dropout_p > 0.0 and q.requires_grad:
attn_weights = torch.dropout(attn_weights, dropout_p, train=True)
return torch.matmul(attn_weights, v)
class Attention(Module):
def __init__(
self,
dim,
heads=8,
dim_head=64,
dropout=0.0,
rotary_embed=None,
flash=True,
wsa_window_len=None,
n_wsa_sinks=None,
):
super().__init__()
self.heads = heads
self.scale = dim_head ** -0.5
dim_inner = heads * dim_head
self.rotary_embed = rotary_embed
self.attend = Attend(
flash=flash and not IS_TORCH_LT_2_0, # Disable flash for old PyTorch
dropout=dropout,
wsa_window_len=wsa_window_len,
n_wsa_sinks=n_wsa_sinks,
)
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
self.to_gates = nn.Linear(dim, heads)
self.to_out = nn.Sequential(
nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
)
def forward(self, x, return_attn=False):
x = self.norm(x)
q, k, v = rearrange(
self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
)
if self.rotary_embed is not None:
q = self.rotary_embed.rotate_queries_or_keys(q)
k = self.rotary_embed.rotate_queries_or_keys(k)
if return_attn:
out, attn = self.attend(q, k, v, return_attn=True)
else:
out = self.attend(q, k, v)
gates = self.to_gates(x)
out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
out = rearrange(out, "b h n d -> b n (h d)")
result = self.to_out(out)
if return_attn:
return result, attn
return result
class Transformer(Module):
def __init__(
self,
*,
dim,
depth: int = 1,
dim_head=64,
heads=8,
attn_dropout=0.0,
ff_dropout=0.0,
ff_mult=4,
norm_output=True,
rotary_embed=None,
use_flash=True,
wsa_window_len=None,
n_wsa_sinks=None,
):
super().__init__()
self.layers = ModuleList([])
for _ in range(depth):
attn = Attention(
dim=dim,
dim_head=dim_head,
heads=heads,
dropout=attn_dropout,
rotary_embed=rotary_embed,
flash=use_flash,
wsa_window_len=wsa_window_len,
n_wsa_sinks=n_wsa_sinks,
)
self.layers.append(
ModuleList(
[attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
)
)
self.norm = RMSNorm(dim) if norm_output else nn.Identity()
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)
class BandSplit(Module):
def __init__(self, dim: int, dim_inputs: Tuple[int, ...]):
super().__init__()
self.dim_inputs = dim_inputs
self.to_features = ModuleList([])
for dim_in in dim_inputs:
net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
self.to_features.append(net)
def forward(self, x):
x = x.split(self.dim_inputs, dim=-1)
outs = []
for split_input, to_feature in zip(x, self.to_features):
split_output = to_feature(split_input)
outs.append(split_output)
return torch.stack(outs, dim=-2)
def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
dim_hidden = dim_hidden if dim_hidden is not None else dim_in
net = []
dims = (dim_in, *((dim_hidden,) * depth), dim_out)
for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
is_last = ind == (len(dims) - 2)
net.append(nn.Linear(layer_dim_in, layer_dim_out))
if is_last:
continue
net.append(activation())
return nn.Sequential(*net)
class MaskEstimator(Module):
def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
super().__init__()
self.dim_inputs = dim_inputs
self.to_freqs = ModuleList([])
dim_hidden = dim * mlp_expansion_factor
for dim_in in dim_inputs:
mlp = nn.Sequential(
MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
)
self.to_freqs.append(mlp)
def forward(self, x):
x = x.unbind(dim=-2)
outs = []
for band_features, mlp in zip(x, self.to_freqs):
freq_out = mlp(band_features)
outs.append(freq_out)
return torch.cat(outs, dim=-1)
FlashAttentionConfig = namedtuple(
"FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
)
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
print_once = once(print)
class Attend(nn.Module):
def __init__(
self,
dropout=0.0,
flash=False,
scale=None,
wsa_window_len=None,
n_wsa_sinks=None,
):
super().__init__()
self.scale = scale
self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)
self.wsa_window_len = wsa_window_len
self.n_wsa_sinks = n_wsa_sinks
self.use_flash = flash
# Flex attention only for PyTorch >= 2.5
if wsa_window_len is not None and n_wsa_sinks is not None and n_wsa_sinks > 0:
if IS_TORCH_LT_2_5:
print_once(
f"Warning: WSA (windowed sliding attention) requires PyTorch >= 2.5.0, got {torch.__version__}. "
"Disabling WSA and falling back to standard attention."
)
self.flex_attn = None
self.wsa_window_len = None
self.n_wsa_sinks = None
else:
assert not (
version.parse(torch.__version__) < version.parse("2.5.0")
), "in order to use flex attention, you must be using pytorch 2.5 or above"
mask_mod = generate_sliding_window_with_sinks(wsa_window_len, n_wsa_sinks)
self.flex_attn = FlexAttention(
mask_mod=mask_mod,
dropout=dropout,
scale=scale,
compile=True,
)
else:
self.flex_attn = None
# Flash attention warning for old PyTorch
if self.use_flash and IS_TORCH_LT_2_0:
print_once(
f"Warning: Flash attention requires PyTorch >= 2.0.0, got {torch.__version__}. "
"Falling back to standard attention."
)
self.use_flash = False
self.cpu_config = FlashAttentionConfig(True, True, True)
self.cuda_config = None
if not torch.cuda.is_available() or not self.use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
device_version = version.parse(
f"{device_properties.major}.{device_properties.minor}"
)
if device_version >= version.parse("8.0"):
if os.name == "nt":
print_once(
"Windows OS detected, using math or mem efficient attention if input tensor is on cuda"
)
self.cuda_config = FlashAttentionConfig(False, True, True)
else:
print_once(
"GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda"
)
self.cuda_config = FlashAttentionConfig(True, False, False)
else:
print_once(
"GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda"
)
self.cuda_config = FlashAttentionConfig(False, True, True)
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = (
*q.shape,
k.shape[-2],
q.is_cuda,
q.device,
)
if self.scale is not None:
default_scale = q.shape[-1] ** -0.5
q = q * (self.scale / default_scale)
config = self.cuda_config if is_cuda else self.cpu_config
# For PyTorch < 2.0, use manual attention
if IS_TORCH_LT_2_0:
return manual_scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0
)
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout if self.training else 0.0
)
return out
def forward(self, q, k, v, return_attn=False):
# Flex attention path
if self.flex_attn is not None:
return self.flex_attn(q, k, v)
# Flash attention path (PyTorch >= 2.0)
if self.use_flash and not IS_TORCH_LT_2_0:
return self.flash_attn(q, k, v)
# Manual attention path (PyTorch < 2.0 or fallback)
q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
scale = self.scale if self.scale is not None else q.shape[-1] ** -0.5
sim = torch.einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
attn = sim.softmax(dim=-1)
attn = self.attn_dropout(attn)
out = torch.einsum(f"b h i j, b h j d -> b h i d", attn, v)
if return_attn:
return out, attn
return out