|
|
""" |
|
|
Definitions of blocks of VAR transformer model. |
|
|
""" |
|
|
|
|
|
import math |
|
|
import os |
|
|
from functools import partial |
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from timm.models.layers import DropPath, drop_path |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SAGE_ATTN_AVAILABLE = False |
|
|
FLASH_ATTN_AVAILABLE = False |
|
|
sageattn = None |
|
|
sageattn_varlen = None |
|
|
flash_attn_func = None |
|
|
flash_attn_varlen_kvpacked_func = None |
|
|
|
|
|
|
|
|
try: |
|
|
from sageattention import sageattn, sageattn_varlen |
|
|
SAGE_ATTN_AVAILABLE = True |
|
|
print("[INFO] SageAttention detected - will use for 2-5x speedup over FlashAttention") |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
from flash_attn import flash_attn_func |
|
|
from flash_attn import flash_attn_varlen_kvpacked_func |
|
|
FLASH_ATTN_AVAILABLE = True |
|
|
if not SAGE_ATTN_AVAILABLE: |
|
|
print("[INFO] FlashAttention detected - will use for optimized attention") |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
|
|
|
if not SAGE_ATTN_AVAILABLE and not FLASH_ATTN_AVAILABLE: |
|
|
print("[INFO] Using PyTorch scaled_dot_product_attention (no SageAttention or FlashAttention detected)") |
|
|
print(" Install SageAttention for 2-5x speedup: pip install sageattention>=2.2.0 --no-build-isolation") |
|
|
|
|
|
from torch.nn.functional import scaled_dot_product_attention as slow_attn |
|
|
|
|
|
|
|
|
try: |
|
|
import sys |
|
|
import os |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
parent_dirs = [ |
|
|
os.path.join(current_dir, '../../..'), |
|
|
os.path.join(current_dir, '../../../..'), |
|
|
] |
|
|
for parent_dir in parent_dirs: |
|
|
if parent_dir not in sys.path: |
|
|
sys.path.insert(0, parent_dir) |
|
|
from infinity_gguf_utils import dequantize_gguf_tensor, GGUFParameter |
|
|
GGUF_AVAILABLE = True |
|
|
except ImportError: |
|
|
GGUF_AVAILABLE = False |
|
|
GGUFParameter = None |
|
|
|
|
|
def get_weight_for_linear(linear_layer, target_dtype=None): |
|
|
""" |
|
|
Helper function to get weight from a linear layer, dequantizing if it's a GGUF parameter. |
|
|
|
|
|
Args: |
|
|
linear_layer: nn.Linear or GGUFLinear layer |
|
|
target_dtype: Target dtype for dequantization |
|
|
|
|
|
Returns: |
|
|
Weight tensor ready for use in F.linear |
|
|
""" |
|
|
weight = linear_layer.weight |
|
|
if GGUF_AVAILABLE and isinstance(weight, GGUFParameter): |
|
|
|
|
|
return dequantize_gguf_tensor(weight, target_dtype=target_dtype) |
|
|
|
|
|
if target_dtype is not None and weight.dtype != target_dtype: |
|
|
return weight.to(dtype=target_dtype) |
|
|
return weight |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from flash_attn.ops.layer_norm import dropout_add_layer_norm |
|
|
from flash_attn.ops.rms_norm import dropout_add_rms_norm |
|
|
from flash_attn.ops.rms_norm import rms_norm as rms_norm_impl |
|
|
from flash_attn.ops.fused_dense import fused_mlp_func |
|
|
flash_fused_op_installed = True |
|
|
except ImportError: |
|
|
dropout_add_layer_norm = dropout_add_rms_norm = fused_mlp_func = None |
|
|
flash_fused_op_installed = False |
|
|
|
|
|
def rms_norm_impl(x, weight, epsilon): |
|
|
return (x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True).add_(epsilon))) * weight |
|
|
|
|
|
|
|
|
def precompute_rope2d_freqs_grid(dim, dynamic_resolution_h_w, rope2d_normalized_by_hw, pad_to_multiplier=1, max_height=2048 // 16, max_width=2048 // 16, base=10000.0, device=None, scaling_factor=1.0): |
|
|
|
|
|
half_dim = dim // 2 |
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, half_dim, 2, dtype=torch.int64).float().to(device) / half_dim)) |
|
|
t_height = torch.arange(max_height, device=device, dtype=torch.int64).type_as(inv_freq) |
|
|
t_width = torch.arange(max_width, device=device, dtype=torch.int64).type_as(inv_freq) |
|
|
t_height = t_height / scaling_factor |
|
|
freqs_height = torch.outer(t_height, inv_freq) |
|
|
t_width = t_width / scaling_factor |
|
|
freqs_width = torch.outer(t_width, inv_freq) |
|
|
freqs_grid_map = torch.concat([ |
|
|
freqs_height[:, None, :].expand(-1, max_width, -1), |
|
|
freqs_width[None, :, :].expand(max_height, -1, -1), |
|
|
], dim=-1) |
|
|
freqs_grid_map = torch.stack([torch.cos(freqs_grid_map), torch.sin(freqs_grid_map)], dim=0) |
|
|
|
|
|
|
|
|
rope2d_freqs_grid = {} |
|
|
for h_div_w in dynamic_resolution_h_w: |
|
|
scale_schedule = dynamic_resolution_h_w[h_div_w]['1M']['scales'] |
|
|
_, ph, pw = scale_schedule[-1] |
|
|
max_edge_length = freqs_grid_map.shape[1] |
|
|
if ph >= pw: |
|
|
uph, upw = max_edge_length, int(max_edge_length / ph * pw) |
|
|
else: |
|
|
uph, upw = int(max_edge_length / pw * ph), max_edge_length |
|
|
rope_cache_list = [] |
|
|
for (_, ph, pw) in scale_schedule: |
|
|
ph_mul_pw = ph * pw |
|
|
if rope2d_normalized_by_hw == 1: |
|
|
rope_cache = F.interpolate(freqs_grid_map[:, :uph, :upw, :].permute([0,3,1,2]), size=(ph, pw), mode='bilinear', align_corners=True) |
|
|
rope_cache = rope_cache.permute([0,2,3,1]) |
|
|
elif rope2d_normalized_by_hw == 2: |
|
|
_, uph, upw = scale_schedule[-1] |
|
|
indices = torch.stack([ |
|
|
(torch.arange(ph) * (uph / ph)).reshape(ph, 1).expand(ph, pw), |
|
|
(torch.arange(pw) * (upw / pw)).reshape(1, pw).expand(ph, pw), |
|
|
], dim=-1).round().int() |
|
|
indices = indices.reshape(-1, 2) |
|
|
rope_cache = freqs_grid_map[:, indices[:,0], indices[:,1], :] |
|
|
rope_cache = rope_cache.reshape(2, ph, pw, -1) |
|
|
elif rope2d_normalized_by_hw == 0: |
|
|
rope_cache = freqs_grid_map[:, :ph, :pw, :] |
|
|
else: |
|
|
raise ValueError(f'Unknown rope2d_normalized_by_hw: {rope2d_normalized_by_hw}') |
|
|
rope_cache_list.append(rope_cache.reshape(2, ph_mul_pw, -1)) |
|
|
cat_rope_cache = torch.cat(rope_cache_list, 1) |
|
|
if cat_rope_cache.shape[1] % pad_to_multiplier: |
|
|
pad = torch.zeros(2, pad_to_multiplier - cat_rope_cache.shape[1] % pad_to_multiplier, half_dim) |
|
|
cat_rope_cache = torch.cat([cat_rope_cache, pad], dim=1) |
|
|
cat_rope_cache = cat_rope_cache[:,None,None,None] |
|
|
for pn in dynamic_resolution_h_w[h_div_w]: |
|
|
scale_schedule = dynamic_resolution_h_w[h_div_w][pn]['scales'] |
|
|
tmp_scale_schedule = [(1, h, w) for _, h, w in scale_schedule] |
|
|
rope2d_freqs_grid[str(tuple(tmp_scale_schedule))] = cat_rope_cache |
|
|
return rope2d_freqs_grid |
|
|
|
|
|
|
|
|
def apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, pad_to_multiplier, rope2d_normalized_by_hw, scale_ind): |
|
|
qk = torch.stack((q, k), dim=0) |
|
|
device_type = qk.device.type |
|
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
|
seq_len = qk.shape[3] |
|
|
start = 0 |
|
|
if scale_ind >= 1: |
|
|
assert len(scale_schedule[0]) == 3 |
|
|
start = np.sum([item[0] * item[1] * item[2] for item in scale_schedule[:scale_ind]]) |
|
|
rope2d_freqs_grid[str(tuple(scale_schedule))] = rope2d_freqs_grid[str(tuple(scale_schedule))].to(qk.device) |
|
|
assert start+seq_len <= rope2d_freqs_grid[str(tuple(scale_schedule))].shape[4] |
|
|
rope_cache = rope2d_freqs_grid[str(tuple(scale_schedule))][:, :, :, :, start:start+seq_len] |
|
|
qk = qk.reshape(*qk.shape[:-1], -1, 2) |
|
|
qk = torch.stack([ |
|
|
rope_cache[0] * qk[...,0] - rope_cache[1] * qk[...,1], |
|
|
rope_cache[1] * qk[...,0] + rope_cache[0] * qk[...,1], |
|
|
], dim=-1) |
|
|
qk = qk.reshape(*qk.shape[:-2], -1) |
|
|
q, k = qk.unbind(dim=0) |
|
|
return q, k |
|
|
|
|
|
|
|
|
class FastRMSNorm(nn.Module): |
|
|
def __init__(self, C, eps=1e-6, elementwise_affine=True): |
|
|
super().__init__() |
|
|
self.C = C |
|
|
self.eps = eps |
|
|
self.elementwise_affine = elementwise_affine |
|
|
if self.elementwise_affine: |
|
|
self.weight = nn.Parameter(torch.ones(C)) |
|
|
else: |
|
|
self.register_buffer('weight', torch.ones(C)) |
|
|
|
|
|
def forward(self, x): |
|
|
src_type = x.dtype |
|
|
return rms_norm_impl(x.float(), self.weight, epsilon=self.eps).to(src_type) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return f'C={self.C}, eps={self.eps:g}, elementwise_affine={self.elementwise_affine}' |
|
|
|
|
|
|
|
|
def get_dropout_layer(p): |
|
|
return nn.Dropout(p, inplace=True) if p > 0 else nn.Identity() |
|
|
|
|
|
|
|
|
class FFN(nn.Module): |
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_mlp=False): |
|
|
super().__init__() |
|
|
self.fused_mlp_func = fused_mlp_func if fused_mlp else None |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
|
self.act = nn.GELU(approximate='tanh') |
|
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
|
self.drop = get_dropout_layer(drop) |
|
|
self.heuristic = -1 |
|
|
|
|
|
def forward(self, x): |
|
|
if self.fused_mlp_func is not None: |
|
|
return self.drop(self.fused_mlp_func( |
|
|
x=x, |
|
|
weight1=self.fc1.weight, |
|
|
weight2=self.fc2.weight, |
|
|
bias1=self.fc1.bias, |
|
|
bias2=self.fc2.bias, |
|
|
activation='gelu_approx', |
|
|
save_pre_act=self.training, |
|
|
return_residual=False, |
|
|
checkpoint_lvl=0, |
|
|
heuristic=self.heuristic, |
|
|
process_group=None, |
|
|
)) |
|
|
else: |
|
|
return self.drop(self.fc2( self.act(self.fc1(x)) )) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return f'fused_mlp={self.fused_mlp_func is not None}' |
|
|
|
|
|
|
|
|
class FFNSwiGLU(nn.Module): |
|
|
def __init__(self, in_features, hidden_features, out_features=None, drop=0., fused_mlp=False): |
|
|
super().__init__() |
|
|
self.fused_mlp_func = None |
|
|
hidden_features = round(2 * hidden_features / 3 / 256) * 256 |
|
|
|
|
|
out_features = out_features or in_features |
|
|
self.fcg = nn.Linear(in_features, hidden_features, bias=False) |
|
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=False) |
|
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=False) |
|
|
self.drop = get_dropout_layer(drop) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.drop(self.fc2( F.silu(self.fcg(x), inplace=True).mul_(self.fc1(x)) )) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return f'fused_mlp={self.fused_mlp_func is not None}' |
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
def __init__( |
|
|
self, embed_dim=768, num_heads=12, |
|
|
proj_drop=0., tau=1, cos_attn=False, customized_flash_attn=True, use_flex_attn=False, |
|
|
batch_size=2, pad_to_multiplier=1, rope2d_normalized_by_hw=0, |
|
|
): |
|
|
""" |
|
|
:param embed_dim: model's width |
|
|
:param num_heads: num heads of multi-head attention |
|
|
:param proj_drop: always 0 for testing |
|
|
:param tau: always 1 |
|
|
:param cos_attn: always True: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11 |
|
|
:param customized_flash_attn: |
|
|
""" |
|
|
super().__init__() |
|
|
assert embed_dim % num_heads == 0 |
|
|
self.using_flash = customized_flash_attn |
|
|
|
|
|
self.num_heads, self.head_dim = num_heads, embed_dim // num_heads |
|
|
self.tau, self.cos_attn = tau, cos_attn |
|
|
if self.cos_attn: |
|
|
self.scale = 1 |
|
|
size = (1, 1, self.num_heads, 1) if self.using_flash else (1, self.num_heads, 1, 1) |
|
|
|
|
|
self.scale_mul_1H11 = nn.Parameter(torch.full(size=size, fill_value=4.0).log(), requires_grad=True) |
|
|
self.max_scale_mul = torch.log(torch.tensor(100)).item() |
|
|
else: |
|
|
self.scale = 1 / math.sqrt(self.head_dim) / self.tau |
|
|
|
|
|
self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False) |
|
|
self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim)) |
|
|
self.register_buffer('zero_k_bias', torch.zeros(embed_dim)) |
|
|
|
|
|
self.proj = nn.Linear(embed_dim, embed_dim) |
|
|
self.proj_drop = get_dropout_layer(proj_drop) |
|
|
|
|
|
self.caching = False |
|
|
self.cached_k = None |
|
|
self.cached_v = None |
|
|
|
|
|
self.batch_size = batch_size |
|
|
self.use_flex_attn = use_flex_attn |
|
|
self.pad_to_multiplier = pad_to_multiplier |
|
|
|
|
|
self.rope2d_normalized_by_hw = rope2d_normalized_by_hw |
|
|
|
|
|
|
|
|
def kv_caching(self, enable: bool): |
|
|
self.caching = enable |
|
|
self.cached_k = None |
|
|
self.cached_v = None |
|
|
|
|
|
|
|
|
def forward(self, x, attn_bias_or_two_vector: Union[torch.Tensor, Tuple[torch.IntTensor, torch.IntTensor]], attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, scale_ind=0): |
|
|
""" |
|
|
:param (fp32) x: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared |
|
|
:param (fp32) attn_bias_or_two_vector: |
|
|
if not using_flash: |
|
|
a block-wise, lower-triangle matrix, like: |
|
|
[[[[0, -, -, -, -, -, -, -, -, -, -, -, -, -], |
|
|
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -], |
|
|
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -], |
|
|
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -], |
|
|
[0, 0, 0, 0, 0, -, -, -, -, -, -, -, -, -], |
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]] |
|
|
where 0 means visible and - means invisible (-inf) |
|
|
else: |
|
|
a tuple of two 1-dim int vector (VAR_visible_kvlen, VAR_invisible_qlen) |
|
|
:return: shaped (B or batch_size, L or seq_length, C or hidden_dim); if seq-parallel is used, the `L` dim would be shared |
|
|
""" |
|
|
|
|
|
B, L, C = x.shape |
|
|
|
|
|
|
|
|
qkv = F.linear(input=x, weight=get_weight_for_linear(self.mat_qkv, target_dtype=x.dtype), bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim) |
|
|
if self.using_flash: q, k, v = qkv.unbind(dim=2); L_dim = 1 |
|
|
else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); L_dim = 2 |
|
|
|
|
|
if self.cos_attn: |
|
|
scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() |
|
|
q = F.normalize(q, dim=-1, eps=1e-12).mul(scale_mul).contiguous() |
|
|
k = F.normalize(k, dim=-1, eps=1e-12).contiguous() |
|
|
v = v.contiguous() |
|
|
else: |
|
|
q = q.contiguous() |
|
|
k = k.contiguous() |
|
|
v = v.contiguous() |
|
|
if rope2d_freqs_grid is not None: |
|
|
q, k = apply_rotary_emb(q, k, scale_schedule, rope2d_freqs_grid, self.pad_to_multiplier, self.rope2d_normalized_by_hw, scale_ind) |
|
|
if self.caching: |
|
|
if self.cached_k is None: self.cached_k = k; self.cached_v = v |
|
|
else: k = self.cached_k = torch.cat((self.cached_k, k), dim=L_dim); v = self.cached_v = torch.cat((self.cached_v, v), dim=L_dim) |
|
|
|
|
|
if self.using_flash: |
|
|
|
|
|
if SAGE_ATTN_AVAILABLE and attn_bias_or_two_vector is None: |
|
|
try: |
|
|
|
|
|
|
|
|
q_sage = q.transpose(1, 2) |
|
|
k_sage = k.transpose(1, 2) |
|
|
v_sage = v.transpose(1, 2) |
|
|
|
|
|
|
|
|
target_dtype = torch.bfloat16 if v.dtype == torch.float32 else v.dtype |
|
|
q_sage = q_sage.to(target_dtype) |
|
|
k_sage = k_sage.to(target_dtype) |
|
|
v_sage = v_sage.to(target_dtype) |
|
|
|
|
|
|
|
|
oup = sageattn(q_sage, k_sage, v_sage, tensor_layout="HND", is_causal=False) |
|
|
oup = oup.transpose(1, 2).reshape(B, L, C) |
|
|
if target_dtype != v.dtype: |
|
|
oup = oup.to(v.dtype) |
|
|
except Exception as e: |
|
|
print(f"[WARNING] SageAttention failed ({str(e)[:100]}), falling back to FlashAttention/PyTorch") |
|
|
|
|
|
if FLASH_ATTN_AVAILABLE: |
|
|
kw = dict() if attn_bias_or_two_vector is None else dict(VAR_visible_kvlen=attn_bias_or_two_vector[0], VAR_invisible_qlen=attn_bias_or_two_vector[1]) |
|
|
oup = flash_attn_func(q.to(v.dtype), k.to(v.dtype), v, dropout_p=0, softmax_scale=self.scale, **kw).view(B, L, C) |
|
|
else: |
|
|
q_torch = q.transpose(1, 2) |
|
|
k_torch = k.transpose(1, 2) |
|
|
v_torch = v.transpose(1, 2) |
|
|
oup = slow_attn(query=q_torch, key=k_torch, value=v_torch, scale=self.scale, dropout_p=0).transpose(1, 2).reshape(B, L, C) |
|
|
|
|
|
|
|
|
elif FLASH_ATTN_AVAILABLE: |
|
|
if attn_bias_or_two_vector is not None: |
|
|
kw = dict(VAR_visible_kvlen=attn_bias_or_two_vector[0], VAR_invisible_qlen=attn_bias_or_two_vector[1]) |
|
|
else: |
|
|
kw = dict() |
|
|
oup = flash_attn_func(q.to(v.dtype), k.to(v.dtype), v, dropout_p=0, softmax_scale=self.scale, **kw).view(B, L, C) |
|
|
|
|
|
|
|
|
else: |
|
|
q_torch = q.transpose(1, 2) |
|
|
k_torch = k.transpose(1, 2) |
|
|
v_torch = v.transpose(1, 2) |
|
|
oup = slow_attn(query=q_torch, key=k_torch, value=v_torch, scale=self.scale, dropout_p=0).transpose(1, 2).reshape(B, L, C) |
|
|
else: |
|
|
|
|
|
|
|
|
if self.use_flex_attn and attn_fn is not None: |
|
|
oup = attn_fn(q, k, v, scale=self.scale).transpose(1, 2).reshape(B, L, C) |
|
|
else: |
|
|
oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias_or_two_vector, dropout_p=0).transpose(1, 2).reshape(B, L, C) |
|
|
|
|
|
|
|
|
return self.proj_drop(self.proj(oup)) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
tail = '' |
|
|
return f'using_flash={self.using_flash}, tau={self.tau}, cos_attn={self.cos_attn}{tail}' |
|
|
|
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
def __init__( |
|
|
self, for_attn_pool=False, embed_dim=768, kv_dim=4096, num_heads=12, |
|
|
proj_drop=0., cos_attn=False, use_flash_attn=True, |
|
|
): |
|
|
""" |
|
|
:param for_attn_pool: only used in VAR.text_proj_for_sos |
|
|
:param embed_dim: Q's dim |
|
|
:param kv_dim: K's and V's dim |
|
|
:param num_heads: num heads of multi-head attention |
|
|
:param proj_drop: proj drop out |
|
|
:param cos_attn: during attention, q and k will be L2-normalized and scaled by a head-wise learnable parameter self.scale_mul_1H11 |
|
|
""" |
|
|
cos_attn = False |
|
|
super().__init__() |
|
|
self.for_attn_pool = for_attn_pool |
|
|
self.embed_dim = embed_dim |
|
|
self.kv_dim = kv_dim |
|
|
assert embed_dim % num_heads == 0 |
|
|
self.num_heads, self.head_dim = num_heads, embed_dim // num_heads |
|
|
self.cos_attn = cos_attn |
|
|
self.use_flash_attn = use_flash_attn |
|
|
if self.cos_attn: |
|
|
self.scale = 1 |
|
|
self.scale_mul_1H1 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True) |
|
|
self.max_scale_mul = torch.log(torch.tensor(100)).item() |
|
|
else: |
|
|
self.scale = 1 / math.sqrt(self.head_dim) |
|
|
|
|
|
if for_attn_pool: |
|
|
q = torch.empty(1, self.num_heads, self.head_dim) |
|
|
nn.init.trunc_normal_(q, mean=0, std=math.sqrt(1 / embed_dim / 3)) |
|
|
self.mat_q = nn.Parameter(q) |
|
|
else: |
|
|
self.mat_q = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
self.mat_kv = nn.Linear(kv_dim, embed_dim*2, bias=False) |
|
|
self.v_bias = nn.Parameter(torch.zeros(embed_dim)) |
|
|
self.register_buffer('zero_k_bias', torch.zeros(embed_dim)) |
|
|
|
|
|
self.proj = nn.Linear(embed_dim, embed_dim) |
|
|
self.proj_drop = get_dropout_layer(proj_drop) |
|
|
|
|
|
def forward(self, q, ca_kv): |
|
|
""" |
|
|
:param q: shaped as (batch, seq_len, Q_dim) |
|
|
:param ca_kv: contains several vectors, each of which is shaped as (len_i, KV_dim). We have [len_1xKV_dim, len_2xKV_dim, len_3xKV_dim, ...] and lens == [len_1, len_2, len_3, ...] |
|
|
- kv_compact: shaped as (sum(lens), KV_dim) |
|
|
- cu_seqlens_k: cumulated sum of lens |
|
|
- max_seqlen_k: int, max(lens) |
|
|
NOTE: seq_len (num of Qs) can reach 10k; but len_i (num of KVs) must <= 256 |
|
|
|
|
|
:return: shaped as (batch, seq_len, Q_dim) |
|
|
""" |
|
|
kv_compact, cu_seqlens_k, max_seqlen_k = ca_kv |
|
|
N = kv_compact.shape[0] |
|
|
|
|
|
kv_compact = F.linear(kv_compact, weight=get_weight_for_linear(self.mat_kv, target_dtype=kv_compact.dtype), bias=torch.cat((self.zero_k_bias, self.v_bias))).view(N, 2, self.num_heads, self.head_dim) |
|
|
|
|
|
|
|
|
if not self.for_attn_pool: |
|
|
B, Lq = q.shape[:2] |
|
|
q_compact = self.mat_q(q).view(-1, self.num_heads, self.head_dim) |
|
|
else: |
|
|
B = cu_seqlens_k.shape[0] - 1 |
|
|
Lq = 1 |
|
|
|
|
|
mat_q_data = self.mat_q |
|
|
if GGUF_AVAILABLE and isinstance(mat_q_data, GGUFParameter): |
|
|
mat_q_data = dequantize_gguf_tensor(mat_q_data, target_dtype=kv_compact.dtype) |
|
|
q_compact = mat_q_data.repeat(B, 1, 1).to(dtype=kv_compact.dtype) |
|
|
|
|
|
if self.cos_attn: |
|
|
scale_mul = self.scale_mul_1H1.clamp_max(self.max_scale_mul).exp() |
|
|
k, v = kv_compact.unbind(dim=1) |
|
|
q_compact = F.normalize(q_compact, dim=-1).mul(scale_mul) |
|
|
k = F.normalize(k, dim=-1) |
|
|
kv_compact = torch.stack((k, v), dim=1) |
|
|
|
|
|
q_compact = q_compact.contiguous() |
|
|
kv_compact = kv_compact.contiguous() |
|
|
|
|
|
|
|
|
|
|
|
if self.use_flash_attn: |
|
|
cu_seqlens_q = torch.arange(0, Lq * (B+1), Lq, dtype=torch.int32, device=q_compact.device) |
|
|
oup = None |
|
|
|
|
|
|
|
|
if SAGE_ATTN_AVAILABLE: |
|
|
try: |
|
|
|
|
|
|
|
|
k_compact, v_compact = kv_compact.unbind(dim=1) |
|
|
|
|
|
|
|
|
target_dtype = torch.bfloat16 if q_compact.dtype == torch.float32 else q_compact.dtype |
|
|
q_sage = q_compact.to(target_dtype) |
|
|
k_sage = k_compact.to(target_dtype) |
|
|
v_sage = v_compact.to(target_dtype) |
|
|
|
|
|
|
|
|
oup = sageattn_varlen( |
|
|
q=q_sage, |
|
|
k=k_sage, |
|
|
v=v_sage, |
|
|
cu_seqlens_q=cu_seqlens_q, |
|
|
cu_seqlens_k=cu_seqlens_k, |
|
|
max_seqlen_q=Lq, |
|
|
max_seqlen_k=max_seqlen_k, |
|
|
is_causal=False, |
|
|
sm_scale=self.scale, |
|
|
smooth_k=True |
|
|
).reshape(B, Lq, -1) |
|
|
|
|
|
if target_dtype != q_compact.dtype: |
|
|
oup = oup.float() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[WARNING] SageAttention failed ({str(e)[:100]}), falling back to FlashAttention/PyTorch") |
|
|
oup = None |
|
|
|
|
|
|
|
|
if oup is None and FLASH_ATTN_AVAILABLE: |
|
|
try: |
|
|
if q_compact.dtype == torch.float32: |
|
|
oup = flash_attn_varlen_kvpacked_func(q=q_compact.to(dtype=torch.bfloat16), kv=kv_compact.to(dtype=torch.bfloat16), cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1) |
|
|
oup = oup.float() |
|
|
else: |
|
|
oup = flash_attn_varlen_kvpacked_func(q=q_compact, kv=kv_compact, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=Lq, max_seqlen_k=max_seqlen_k, dropout_p=0, softmax_scale=self.scale).reshape(B, Lq, -1) |
|
|
except Exception as e: |
|
|
print(f"[WARNING] FlashAttention failed ({str(e)[:100]}), falling back to PyTorch attention") |
|
|
oup = None |
|
|
|
|
|
|
|
|
if oup is None: |
|
|
self.use_flash_attn = False |
|
|
|
|
|
|
|
|
if not self.use_flash_attn: |
|
|
|
|
|
k, v = kv_compact.unbind(dim=1) |
|
|
|
|
|
|
|
|
k_batched = [] |
|
|
v_batched = [] |
|
|
for i in range(B): |
|
|
start = cu_seqlens_k[i].item() |
|
|
end = cu_seqlens_k[i+1].item() |
|
|
k_batched.append(k[start:end]) |
|
|
v_batched.append(v[start:end]) |
|
|
|
|
|
|
|
|
k_padded = torch.stack([ |
|
|
F.pad(k_i, (0, 0, 0, 0, 0, max_seqlen_k - k_i.shape[0])) if k_i.shape[0] < max_seqlen_k else k_i |
|
|
for k_i in k_batched |
|
|
]) |
|
|
v_padded = torch.stack([ |
|
|
F.pad(v_i, (0, 0, 0, 0, 0, max_seqlen_k - v_i.shape[0])) if v_i.shape[0] < max_seqlen_k else v_i |
|
|
for v_i in v_batched |
|
|
]) |
|
|
|
|
|
|
|
|
q_batched = q_compact.view(B, Lq, self.num_heads, self.head_dim) |
|
|
|
|
|
|
|
|
q_attn = q_batched.transpose(1, 2) |
|
|
k_attn = k_padded.transpose(1, 2) |
|
|
v_attn = v_padded.transpose(1, 2) |
|
|
|
|
|
|
|
|
attn_mask = torch.zeros(B, 1, Lq, max_seqlen_k, dtype=torch.bool, device=q_compact.device) |
|
|
for i in range(B): |
|
|
seq_len = cu_seqlens_k[i+1].item() - cu_seqlens_k[i].item() |
|
|
if seq_len < max_seqlen_k: |
|
|
attn_mask[i, :, :, seq_len:] = True |
|
|
|
|
|
|
|
|
oup = slow_attn( |
|
|
query=q_attn, |
|
|
key=k_attn, |
|
|
value=v_attn, |
|
|
attn_mask=~attn_mask, |
|
|
scale=self.scale, |
|
|
dropout_p=0.0 |
|
|
) |
|
|
|
|
|
|
|
|
oup = oup.transpose(1, 2).reshape(B, Lq, -1) |
|
|
|
|
|
return self.proj_drop(self.proj(oup)) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return f'Cq={self.embed_dim}, Ckv={self.kv_dim}, cos_attn={self.cos_attn}' |
|
|
|
|
|
|
|
|
class SelfAttnBlock(nn.Module): |
|
|
def __init__( |
|
|
self, embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial, |
|
|
num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False, |
|
|
swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False, |
|
|
): |
|
|
super(SelfAttnBlock, self).__init__() |
|
|
self.C, self.D = embed_dim, cond_dim |
|
|
self.drop_path_rate = drop_path |
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
self.attn = SelfAttention( |
|
|
embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn, attn_fn = attn_fn |
|
|
) |
|
|
self.using_swiglu = swiglu |
|
|
self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp) |
|
|
|
|
|
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False) |
|
|
self.fused_norm_func = fused_norm_func |
|
|
self.norm_eps = norm_layer.keywords.get('eps', 1e-6) |
|
|
|
|
|
self.shared_aln = shared_aln |
|
|
if self.shared_aln: |
|
|
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5) |
|
|
else: |
|
|
lin = nn.Linear(cond_dim, 6*embed_dim) |
|
|
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin) |
|
|
|
|
|
|
|
|
def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector): |
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
if self.shared_aln: |
|
|
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) |
|
|
else: |
|
|
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2) |
|
|
|
|
|
if self.fused_ada_norm is None: |
|
|
x = x + self.drop_path(self.attn( self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1), attn_bias_or_two_vector=attn_bias_or_two_vector ).mul_(gamma1)) |
|
|
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) |
|
|
else: |
|
|
x = x + self.drop_path(self.attn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1), attn_bias_or_two_vector=attn_bias_or_two_vector).mul_(gamma1)) |
|
|
x = x + self.drop_path(self.ffn(self.fused_ada_norm(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) |
|
|
return x |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}' |
|
|
|
|
|
|
|
|
class CrossAttnBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_dim, kv_dim, cross_attn_layer_scale, cond_dim, act: bool, shared_aln: bool, norm_layer: partial, |
|
|
num_heads, mlp_ratio=4., drop=0., drop_path=0., tau=1, cos_attn=False, |
|
|
swiglu=False, customized_flash_attn=False, fused_mlp=False, fused_norm_func=None, checkpointing_sa_only=False, |
|
|
use_flex_attn=False, batch_size=2, pad_to_multiplier=1, apply_rope2d=False, rope2d_normalized_by_hw=False, |
|
|
): |
|
|
super(CrossAttnBlock, self).__init__() |
|
|
self.C, self.D = embed_dim, cond_dim |
|
|
self.drop_path_rate = drop_path |
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
self.sa = SelfAttention( |
|
|
embed_dim=embed_dim, num_heads=num_heads, proj_drop=drop, tau=tau, cos_attn=cos_attn, customized_flash_attn=customized_flash_attn, |
|
|
use_flex_attn=use_flex_attn, batch_size=batch_size, pad_to_multiplier=pad_to_multiplier, rope2d_normalized_by_hw=rope2d_normalized_by_hw, |
|
|
) |
|
|
self.ca = CrossAttention(embed_dim=embed_dim, kv_dim=kv_dim, num_heads=num_heads, proj_drop=drop, cos_attn=cos_attn) |
|
|
self.using_swiglu = swiglu |
|
|
self.ffn = (FFNSwiGLU if swiglu else FFN)(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio / 256) * 256, drop=drop, fused_mlp=fused_mlp) |
|
|
|
|
|
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False) |
|
|
self.fused_norm_func = fused_norm_func |
|
|
self.norm_eps = norm_layer.keywords.get('eps', 1e-6) |
|
|
self.ca_norm = norm_layer(embed_dim, elementwise_affine=True) |
|
|
|
|
|
self.shared_aln = shared_aln |
|
|
if self.shared_aln: |
|
|
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5) |
|
|
else: |
|
|
lin = nn.Linear(cond_dim, 6*embed_dim) |
|
|
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin) |
|
|
|
|
|
if cross_attn_layer_scale >= 0: |
|
|
self.ca_gamma = nn.Parameter(cross_attn_layer_scale * torch.ones(embed_dim), requires_grad=True) |
|
|
else: |
|
|
self.ca_gamma = 1 |
|
|
|
|
|
self.checkpointing_sa_only = checkpointing_sa_only |
|
|
|
|
|
|
|
|
def forward(self, x, cond_BD, ca_kv, attn_bias_or_two_vector, attn_fn=None, scale_schedule=None, rope2d_freqs_grid=None, scale_ind=0): |
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
if self.shared_aln: |
|
|
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) |
|
|
else: |
|
|
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2) |
|
|
|
|
|
if self.fused_norm_func is None: |
|
|
x_sa = self.ln_wo_grad(x.float()).mul(scale1.add(1)).add_(shift1) |
|
|
if self.checkpointing_sa_only and self.training: |
|
|
x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False) |
|
|
else: |
|
|
x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid) |
|
|
x = x + self.drop_path(x_sa.mul_(gamma1)) |
|
|
x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma) |
|
|
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x.float()).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) |
|
|
else: |
|
|
x_sa = self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale1, shift=shift1) |
|
|
if self.checkpointing_sa_only and self.training: |
|
|
x_sa = checkpoint(self.sa, x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, use_reentrant=False) |
|
|
else: |
|
|
x_sa = self.sa(x_sa, attn_bias_or_two_vector, attn_fn, scale_schedule, rope2d_freqs_grid, scale_ind=scale_ind) |
|
|
x = x + self.drop_path(x_sa.mul_(gamma1)) |
|
|
x = x + self.ca(self.ca_norm(x), ca_kv).float().mul_(self.ca_gamma) |
|
|
x = x + self.drop_path(self.ffn(self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x, scale=scale2, shift=shift2)).mul(gamma2)) |
|
|
return x |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
return f'shared_aln={self.shared_aln}, fused_norm={self.fused_norm_func is not None}, ca_gamma={"<learnable>" if isinstance(self.ca_gamma, nn.Parameter) else self.ca_gamma}' |
|
|
|
|
|
|
|
|
class AdaLNBeforeHead(nn.Module): |
|
|
def __init__(self, C, D, act: bool, norm_layer: partial, fused_norm_func=None): |
|
|
super().__init__() |
|
|
self.C, self.D = C, D |
|
|
self.ln_wo_grad = norm_layer(C, elementwise_affine=False) |
|
|
self.fused_norm_func = fused_norm_func |
|
|
self.norm_eps = norm_layer.keywords.get('eps', 1e-6) |
|
|
lin = nn.Linear(D, 2*C) |
|
|
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) if act else nn.Sequential(lin) |
|
|
|
|
|
def forward(self, x_BLC: torch.Tensor, cond_BD: Optional[torch.Tensor]): |
|
|
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2) |
|
|
if self.fused_norm_func is None: |
|
|
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift) |
|
|
else: |
|
|
return self.fused_norm_func(C=self.C, eps=self.norm_eps, x=x_BLC, scale=scale, shift=shift) |
|
|
|
|
|
|
|
|
def main(): |
|
|
dev = 'cpu' |
|
|
rng = torch.Generator(device=dev) |
|
|
|
|
|
rng.manual_seed(0) |
|
|
B, H, cq, ckv = 4, 8, 64, 96 |
|
|
Cq = H*cq |
|
|
Ckv = H*ckv |
|
|
|
|
|
Li = [5, 4, 7, 6] |
|
|
Lq = 10 |
|
|
L = max(Li) |
|
|
attn_bias = torch.zeros(B, 1, Lq, L, device=dev) |
|
|
for i, x in enumerate(Li): |
|
|
attn_bias[i, 0, :, x:] = -torch.inf |
|
|
|
|
|
q = torch.randn(B, Lq, H, cq, generator=rng, device=dev) |
|
|
k = torch.randn(B, L, H, ckv, generator=rng, device=dev) |
|
|
v = torch.randn(B, L, H, ckv, generator=rng, device=dev) |
|
|
tq, tk, tv = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
|
|
|
seqlen_k = torch.tensor(Li, dtype=torch.int32, device=dev) |
|
|
cu_seqlens_k = F.pad(torch.cumsum(seqlen_k, dim=0, dtype=torch.torch.int32), (1, 0)) |
|
|
kv = torch.stack([k, v], dim=2) |
|
|
kv_compact = torch.cat([kv[i, :Li[i]] for i in range(B)], dim=0) |
|
|
|
|
|
ca = CrossAttention(for_attn_pool=False, embed_dim=Cq, kv_dim=Ckv, num_heads=H) |
|
|
CrossAttention.forward |
|
|
ca(q, (kv_compact, cu_seqlens_k, max(Li))).mean().backward() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|