| | from flash_attn import flash_attn_func |
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange, repeat |
| |
|
| | from .extact import xATGLU |
| | from .liger_rope import LigerRopeFunction |
| | from .config import LlamaConfig |
| |
|
| | |
| | |
| |
|
| | class DifferentialAttention(nn.Module): |
| | def __init__(self, config: LlamaConfig, layer_num): |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| | self.num_heads = config.num_attention_heads |
| | self.num_kv_heads = config.num_key_value_heads |
| | self.n_rep = self.num_heads // self.num_kv_heads |
| | self.head_dim = self.hidden_size // (2 * self.num_heads) |
| | self.max_position_embeddings = config.max_position_embeddings |
| | self.rope_theta = config.rope_theta |
| | self.scaling = self.head_dim ** -0.5 |
| |
|
| | self.q_proj = nn.Linear(self.hidden_size, 2 * self.num_heads * self.head_dim, bias=False) |
| | self.k_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False) |
| | self.v_proj = nn.Linear(self.hidden_size, 2 * self.num_kv_heads * self.head_dim, bias=False) |
| | self.o_proj = nn.Linear(2 * self.num_heads * self.head_dim, self.hidden_size, bias=False) |
| |
|
| | self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * layer_num) |
| | self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) |
| | self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) |
| | self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) |
| | self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim).normal_(0, 0.1)) |
| |
|
| | self.subln = nn.LayerNorm(2 * self.head_dim, elementwise_affine=False) |
| |
|
| | self.register_buffer( |
| | "cos_cached", |
| | self._compute_rope_embeddings( |
| | self.max_position_embeddings, |
| | self.head_dim, |
| | self.rope_theta, |
| | dtype=torch.float32, |
| | device=self.q_proj.weight.device, |
| | )[0], |
| | persistent=False, |
| | ) |
| | self.register_buffer( |
| | "sin_cached", |
| | self._compute_rope_embeddings( |
| | self.max_position_embeddings, |
| | self.head_dim, |
| | self.rope_theta, |
| | dtype=torch.float32, |
| | device=self.q_proj.weight.device, |
| | )[1], |
| | persistent=False, |
| | ) |
| |
|
| | def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None): |
| | inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) |
| | t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32) |
| | freqs = torch.einsum("i,j->ij", t, inv_freq) |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| | cos = emb.cos().to(dtype) |
| | sin = emb.sin().to(dtype) |
| | return cos.unsqueeze(0), sin.unsqueeze(0) |
| |
|
| | def forward( |
| | self, |
| | hidden_states, |
| | attention_mask, |
| | position_ids, |
| | ) -> torch.Tensor: |
| | bsz, seq_len, embed_dim = hidden_states.size() |
| | |
| | if position_ids is None: |
| | position_ids = torch.arange(seq_len, device=hidden_states.device) |
| | position_ids = repeat(position_ids, 'l -> b l', b=bsz) |
| |
|
| | q = self.q_proj(hidden_states) |
| | k = self.k_proj(hidden_states) |
| | v = self.v_proj(hidden_states) |
| |
|
| | q = rearrange(q, 'b s (h d) -> b s h d', h=2*self.num_heads, d=self.head_dim) |
| | k = rearrange(k, 'b s (h d) -> b s h d', h=2*self.num_kv_heads, d=self.head_dim) |
| | |
| | |
| | v = rearrange(v, 'b s (h g d) -> b s h g d', h=self.num_kv_heads, g=2, d=self.head_dim) |
| |
|
| | |
| | cos = self.cos_cached[:, position_ids] |
| | sin = self.sin_cached[:, position_ids] |
| | q, k = LigerRopeFunction.apply(q, k, cos, sin, position_ids) |
| | |
| | |
| | q = rearrange(q, 'b s (h g) d -> b s h g d', h=self.num_heads, g=2) |
| | k = rearrange(k, 'b s (h g) d -> b s h g d', h=self.num_kv_heads, g=2) |
| | |
| | q1, q2 = q[:, :, :, 0], q[:, :, :, 1] |
| | k1, k2 = k[:, :, :, 0], k[:, :, :, 1] |
| | v1, v2 = v[:, :, :, 0], v[:, :, :, 1] |
| |
|
| | |
| | attn11 = flash_attn_func( |
| | q1, |
| | k1, |
| | v1, |
| | dropout_p=0.0, |
| | causal=attention_mask is None |
| | ) |
| | attn12 = flash_attn_func( |
| | q1, |
| | k1, |
| | v2, |
| | dropout_p=0.0, |
| | causal=attention_mask is None |
| | ) |
| | attn1 = torch.cat([attn11, attn12], dim=-1) |
| |
|
| | |
| | attn21 = flash_attn_func( |
| | q2, |
| | k2, |
| | v1, |
| | dropout_p=0.0, |
| | causal=attention_mask is None |
| | ) |
| | attn22 = flash_attn_func( |
| | q2, |
| | k2, |
| | v2, |
| | dropout_p=0.0, |
| | causal=attention_mask is None |
| | ) |
| | attn2 = torch.cat([attn21, attn22], dim=-1) |
| |
|
| | lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) |
| | lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) |
| | lambda_full = lambda_1 - lambda_2 + self.lambda_init |
| | attn = attn1 - lambda_full * attn2 |
| |
|
| | attn = self.subln(attn) |
| | attn = attn * (1 - self.lambda_init) |
| |
|
| | attn_output = rearrange(attn, "b s h d -> b s (h d)") |
| | return self.o_proj(attn_output) |