|
|
import inspect |
|
|
import math |
|
|
import warnings |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.checkpoint |
|
|
from torch import nn |
|
|
|
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.utils.import_utils import is_torch_fx_available |
|
|
|
|
|
from torch.utils.checkpoint import checkpoint |
|
|
from functools import partial |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from flash_attn import flash_attn_func as _flash_attn_func |
|
|
from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func |
|
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
|
|
|
|
|
HAVE_FLASH_ATTN = True |
|
|
except Exception: |
|
|
_flash_attn_func = None |
|
|
_flash_attn_varlen_func = None |
|
|
index_first_axis = None |
|
|
pad_input = None |
|
|
unpad_input = None |
|
|
HAVE_FLASH_ATTN = False |
|
|
|
|
|
|
|
|
|
|
|
def _repeat_kv_for_gqa(x: torch.Tensor, repeat: int) -> torch.Tensor: |
|
|
|
|
|
if repeat == 1: |
|
|
return x |
|
|
B, S, Hk, D = x.shape |
|
|
x = x.unsqueeze(2).expand(B, S, repeat, Hk, D) |
|
|
return x.reshape(B, S, repeat * Hk, D) |
|
|
|
|
|
@torch.no_grad() |
|
|
def _build_window_mask( |
|
|
Sq: int, Sk: int, left: int, right: int, causal: bool, device: torch.device |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
FA2 window semantics: |
|
|
valid j for query i: j ∈ [ i + Sk - Sq - left, i + Sk - Sq + right ] |
|
|
FA2.1 causal alignment (bottom-right): additionally disallow j > i + Sk - Sq |
|
|
Return: float mask [1,1,Sq,Sk] with 0 for keep, -inf for mask. |
|
|
""" |
|
|
i = torch.arange(Sq, device=device).view(-1, 1) |
|
|
j = torch.arange(Sk, device=device).view(1, -1) |
|
|
shift = Sk - Sq |
|
|
j_min = i + shift - left |
|
|
j_max = i + shift + right |
|
|
allowed = (j >= j_min) & (j <= j_max) |
|
|
if causal: |
|
|
|
|
|
allowed &= (j <= (i + shift)) |
|
|
masked = ~allowed |
|
|
m = torch.full((Sq, Sk), 0.0, device=device) |
|
|
m[masked] = -torch.finfo(m.dtype).max |
|
|
return m.view(1, 1, Sq, Sk).contiguous() |
|
|
|
|
|
@torch.no_grad() |
|
|
def _build_causal_mask_fa21( |
|
|
Sq: int, Sk: int, device: torch.device |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
FA2.1 causal only (no window): mask positions with j > i + (Sk - Sq). |
|
|
Returns float mask [1,1,Sq,Sk] with 0 keep, -inf mask. |
|
|
""" |
|
|
i = torch.arange(Sq, device=device).view(-1, 1) |
|
|
j = torch.arange(Sk, device=device).view(1, -1) |
|
|
shift = Sk - Sq |
|
|
allowed = (j <= (i + shift)) |
|
|
masked = ~allowed |
|
|
m = torch.full((Sq, Sk), 0.0, device=device) |
|
|
m[masked] = -torch.finfo(m.dtype).max |
|
|
return m.view(1, 1, Sq, Sk).contiguous() |
|
|
|
|
|
def _sdpa_flash_attn_compat( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
*, |
|
|
dropout_p: float = 0.0, |
|
|
softmax_scale: Optional[float] = None, |
|
|
causal: bool = False, |
|
|
window_size: Tuple[int, int] = (-1, -1), |
|
|
alibi_slopes: Optional[torch.Tensor] = None, |
|
|
training: Optional[bool] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
SDPA path emulating flash_attn_func semantics (v2): |
|
|
- supports GQA (Hq divisible by Hk) |
|
|
- FA2.1 causal alignment when Sq != Sk |
|
|
- sliding window: j in [i + Sk - Sq - left, i + Sk - Sq + right] |
|
|
- ALiBi additive bias |
|
|
Returns: [B,Sq,Hq,D] with original dtype. |
|
|
""" |
|
|
assert q.dim() == k.dim() == v.dim() == 4, "Expect [B,S,H,D] tensors" |
|
|
B, Sq, Hq, D = q.shape |
|
|
Bk, Sk, Hk, Dk = k.shape |
|
|
assert (Bk, Sk, Dk) == (B, k.shape[1], D), "Batch/Dim mismatch" |
|
|
assert v.shape[:3] == k.shape[:3] and v.shape[3] == D, "K/V mismatch" |
|
|
assert Hq % Hk == 0, "Hq must be divisible by Hk for GQA/MQA" |
|
|
repeat = Hq // Hk |
|
|
|
|
|
|
|
|
k_exp = _repeat_kv_for_gqa(k, repeat) |
|
|
v_exp = _repeat_kv_for_gqa(v, repeat) |
|
|
|
|
|
|
|
|
qh = q.permute(0, 2, 1, 3).to(torch.float32) |
|
|
kh = k_exp.permute(0, 2, 1, 3).to(torch.float32) |
|
|
vh = v_exp.permute(0, 2, 1, 3).to(torch.float32) |
|
|
in_dtype = q.dtype |
|
|
device = q.device |
|
|
|
|
|
|
|
|
if softmax_scale is None: |
|
|
softmax_scale = 1.0 / math.sqrt(D) |
|
|
qh = qh * (softmax_scale * math.sqrt(D)) |
|
|
|
|
|
|
|
|
left, right = window_size |
|
|
use_window = (left, right) != (-1, -1) |
|
|
attn_bias = None |
|
|
|
|
|
if use_window: |
|
|
|
|
|
if causal and right > 0: |
|
|
right = 0 |
|
|
base = _build_window_mask(Sq, Sk, left, right, causal, device) |
|
|
attn_bias = base.expand(B, Hq, Sq, Sk) |
|
|
is_causal = False |
|
|
elif causal: |
|
|
base = _build_causal_mask_fa21(Sq, Sk, device) |
|
|
attn_bias = base.expand(B, Hq, Sq, Sk) |
|
|
is_causal = False |
|
|
else: |
|
|
is_causal = False |
|
|
attn_bias = None |
|
|
|
|
|
|
|
|
if alibi_slopes is not None: |
|
|
|
|
|
if alibi_slopes.dim() == 1: |
|
|
|
|
|
alibi = alibi_slopes.view(1, Hq, 1, 1).to(dtype=torch.float32, device=device) |
|
|
alibi = alibi.expand(B, Hq, 1, 1) |
|
|
elif alibi_slopes.dim() == 2: |
|
|
|
|
|
alibi = alibi_slopes.view(B, Hq, 1, 1).to(dtype=torch.float32, device=device) |
|
|
else: |
|
|
raise ValueError("alibi_slopes must be (Hq,) or (B,Hq)") |
|
|
i = torch.arange(Sq, device=device).view(1, 1, -1, 1) |
|
|
j = torch.arange(Sk, device=device).view(1, 1, 1, -1) |
|
|
shift = Sk - Sq |
|
|
dist = (i + shift - j).abs().to(torch.float32) |
|
|
alibi_term = -(alibi * dist) |
|
|
if attn_bias is None: |
|
|
attn_bias = alibi_term |
|
|
else: |
|
|
attn_bias = attn_bias + alibi_term |
|
|
|
|
|
|
|
|
if training is None: |
|
|
training = (dropout_p > 0.0) and any(t.requires_grad for t in (q, k, v)) |
|
|
dp = dropout_p if training else 0.0 |
|
|
|
|
|
out = F.scaled_dot_product_attention( |
|
|
qh, kh, vh, |
|
|
attn_mask=attn_bias, |
|
|
dropout_p=dp, |
|
|
is_causal=is_causal, |
|
|
) |
|
|
|
|
|
return out.permute(0, 2, 1, 3).to(in_dtype).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
def _attn_dispatch( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
*, |
|
|
causal: bool = True, |
|
|
window_size: Optional[Tuple[int, int]] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Dispatches to either flash attention or the SDPA fallback. This function |
|
|
accepts and returns tensors shaped ``[batch, seq_len, num_heads, head_dim]``. |
|
|
""" |
|
|
if HAVE_FLASH_ATTN: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return _flash_attn_func( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
causal=causal, |
|
|
window_size=window_size, |
|
|
) |
|
|
|
|
|
return _sdpa_flash_attn_compat( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
causal=causal, |
|
|
window_size=window_size, |
|
|
) |
|
|
|
|
|
|
|
|
def rotate_half(x: torch.Tensor) -> torch.Tensor: |
|
|
"""Rotate half the hidden dimensions of the input.""" |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb( |
|
|
q: Optional[torch.Tensor], |
|
|
k: Optional[torch.Tensor], |
|
|
cos: torch.Tensor, |
|
|
sin: torch.Tensor, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
unsqueeze_dim: int = 1, |
|
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
|
""" |
|
|
Applies rotary position embeddings to the query and key tensors. |
|
|
""" |
|
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
|
if q is not None: |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
else: |
|
|
q_embed = None |
|
|
if k is not None: |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
else: |
|
|
k_embed = None |
|
|
return q_embed, k_embed |
|
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
|
""" |
|
|
Equivalent to ``torch.repeat_interleave(x, dim=1, repeats=n_rep)``. Converts |
|
|
hidden states from shape (batch, num_key_value_heads, seq_len, head_dim) to |
|
|
(batch, num_attention_heads, seq_len, head_dim). |
|
|
""" |
|
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
|
if n_rep == 1: |
|
|
return hidden_states |
|
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
|
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
""" |
|
|
Computes rotary position embeddings. See |
|
|
https://arxiv.org/abs/2104.09864 for details. |
|
|
""" |
|
|
|
|
|
def __init__(self, dim: int, base: int = 10000): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.base = base |
|
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
if position_ids is None: |
|
|
|
|
|
position_ids = torch.arange(x.shape[2], device=x.device, dtype=torch.int64).unsqueeze(0).expand(x.shape[0], -1) |
|
|
|
|
|
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
|
|
position_ids_expanded = position_ids[:, None, :].float() |
|
|
device_type = x.device.type |
|
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
|
|
|
|
|
with torch.autocast(device_type=device_type, enabled=False): |
|
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos = emb.cos() |
|
|
sin = emb.sin() |
|
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
""" |
|
|
Root Mean Square layer normalization. Equivalent to T5LayerNorm. |
|
|
""" |
|
|
|
|
|
def __init__(self, hidden_size: int, eps: float = 1e-6): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.variance_epsilon = eps |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
input_dtype = hidden_states.dtype |
|
|
hidden_states = hidden_states.to(torch.float32) |
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True) |
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
""" |
|
|
Multi‑head attention module with optional rotary positional embeddings and |
|
|
windowed attention. Uses flash attention when available, otherwise falls |
|
|
back to PyTorch's scaled dot product attention. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_attention_heads: int, |
|
|
num_key_value_heads: int, |
|
|
attention_head_size: int, |
|
|
attention_window_size: Optional[int] = None, |
|
|
seq_length: Optional[int] = None, |
|
|
use_positional_embedding: bool = False, |
|
|
rope_base: Optional[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.num_key_value_heads = num_key_value_heads |
|
|
self.attention_head_size = attention_head_size |
|
|
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads |
|
|
self.attention_window_size = attention_window_size |
|
|
self.seq_length = seq_length |
|
|
self.use_positional_embedding = use_positional_embedding |
|
|
self.rope_base = rope_base |
|
|
if self.use_positional_embedding: |
|
|
self.rotary_emb = RotaryEmbedding(dim=self.attention_head_size, base=self.rope_base) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
query_states: torch.Tensor, |
|
|
key_states: torch.Tensor, |
|
|
value_states: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
bsz, q_len, _ = query_states.size() |
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_size).transpose(1, 2).contiguous() |
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.attention_head_size).transpose(1, 2).contiguous() |
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.attention_head_size).transpose(1, 2).contiguous() |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
|
|
if self.use_positional_embedding: |
|
|
cos, sin = self.rotary_emb(query_states) |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
|
|
query_states = query_states.transpose(1, 2) |
|
|
key_states = key_states.transpose(1, 2) |
|
|
value_states = value_states.transpose(1, 2) |
|
|
|
|
|
if self.attention_window_size is not None: |
|
|
ws = (self.attention_window_size, self.attention_window_size) |
|
|
else: |
|
|
ws = None |
|
|
attn_outputs = _attn_dispatch( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
causal=True, |
|
|
window_size=ws, |
|
|
) |
|
|
|
|
|
attn_outputs = attn_outputs.reshape(bsz, q_len, int(self.num_attention_heads * self.attention_head_size)).contiguous() |
|
|
return attn_outputs |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
""" |
|
|
Basic transformer block consisting of an input projection into query/key/value |
|
|
and residual channels, a single attention layer, layer normalization and an |
|
|
output projection. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int = 768, |
|
|
num_attention_heads: int = 12, |
|
|
num_key_value_heads: int = 4, |
|
|
attention_window_size: Optional[int] = None, |
|
|
seq_length: Optional[int] = None, |
|
|
use_positional_embedding: bool = False, |
|
|
rope_base: Optional[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
self.intermediate_size = self.hidden_size |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.num_key_value_heads = num_key_value_heads |
|
|
self.attention_head_size = int(self.intermediate_size / self.num_attention_heads) |
|
|
|
|
|
|
|
|
|
|
|
self.latent_dim = self.intermediate_size + self.attention_head_size * self.num_key_value_heads * 2 |
|
|
self.pre_avg_layernorm = RMSNorm(self.intermediate_size) |
|
|
self.in_proj = nn.Linear(self.hidden_size, self.latent_dim, bias=True) |
|
|
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=True) |
|
|
self.self_attn = Attention( |
|
|
self.num_attention_heads, |
|
|
self.num_key_value_heads, |
|
|
self.attention_head_size, |
|
|
attention_window_size, |
|
|
seq_length, |
|
|
use_positional_embedding, |
|
|
rope_base, |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
batch_size, seq_len, hidden_size = hidden_states.shape |
|
|
|
|
|
hidden_states = self.in_proj(hidden_states).transpose(1, 2) |
|
|
|
|
|
q, k, v, residual = hidden_states.tensor_split( |
|
|
( |
|
|
self.intermediate_size, |
|
|
self.intermediate_size + self.attention_head_size * self.num_key_value_heads, |
|
|
self.intermediate_size + self.attention_head_size * self.num_key_value_heads * 2, |
|
|
), |
|
|
dim=1, |
|
|
) |
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
|
|
|
attn_outputs = self.self_attn( |
|
|
query_states=q, |
|
|
key_states=k, |
|
|
value_states=v, |
|
|
) |
|
|
|
|
|
hidden_states = self.pre_avg_layernorm(attn_outputs) |
|
|
contextualized_states = self.out_proj(hidden_states) |
|
|
return contextualized_states |
|
|
|