| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple |
| from dataclasses import dataclass |
| from einops import rearrange, repeat |
|
|
| from flash_attn import flash_attn_func |
| from .liger_rope import LigerRopeFunction |
| from .rms_norm import LlamaRMSNorm |
| from .config import LlamaConfig |
|
|
| class CPLinear(nn.Module): |
| def __init__(self, in_features, n_head, head_dim, kv_rank=2, q_rank=6): |
| super().__init__() |
| self.W_A_q = nn.Linear(in_features, n_head * q_rank, bias=False) |
| self.W_B_q = nn.Linear(in_features, q_rank * head_dim, bias=False) |
| self.W_A_k = nn.Linear(in_features, n_head * kv_rank, bias=False) |
| self.W_B_k = nn.Linear(in_features, kv_rank * head_dim, bias=False) |
| self.W_A_v = nn.Linear(in_features, n_head * kv_rank, bias=False) |
| self.W_B_v = nn.Linear(in_features, kv_rank * head_dim, bias=False) |
| |
| nn.init.xavier_uniform_(self.W_A_q.weight) |
| nn.init.xavier_uniform_(self.W_B_q.weight) |
| nn.init.xavier_uniform_(self.W_A_k.weight) |
| nn.init.xavier_uniform_(self.W_B_k.weight) |
| nn.init.xavier_uniform_(self.W_A_v.weight) |
| nn.init.xavier_uniform_(self.W_B_v.weight) |
| |
| self.n_head = n_head |
| self.q_rank = q_rank |
| self.head_dim = head_dim |
| self.kv_rank = kv_rank |
| |
| def forward(self, x): |
| batch_size, seq_len, _ = x.size() |
|
|
| A_q = self.W_A_q(x).view(batch_size, seq_len, self.n_head, self.q_rank) |
| A_k = self.W_A_k(x).view(batch_size, seq_len, self.n_head, self.kv_rank) |
| A_v = self.W_A_v(x).view(batch_size, seq_len, self.n_head, self.kv_rank) |
|
|
| B_q = self.W_B_q(x).view(batch_size, seq_len, self.q_rank, self.head_dim) |
| B_k = self.W_B_k(x).view(batch_size, seq_len, self.kv_rank, self.head_dim) |
| B_v = self.W_B_v(x).view(batch_size, seq_len, self.kv_rank, self.head_dim) |
|
|
| A_q = A_q.view(batch_size * seq_len, self.n_head, self.q_rank) |
| A_k = A_k.view(batch_size * seq_len, self.n_head, self.kv_rank) |
| A_v = A_v.view(batch_size * seq_len, self.n_head, self.kv_rank) |
|
|
| B_q = B_q.view(batch_size * seq_len, self.q_rank, self.head_dim) |
| B_k = B_k.view(batch_size * seq_len, self.kv_rank, self.head_dim) |
| B_v = B_v.view(batch_size * seq_len, self.kv_rank, self.head_dim) |
| |
| q = torch.bmm(A_q, B_q).div_(self.q_rank).view(batch_size, seq_len, self.n_head, self.head_dim) |
| k = torch.bmm(A_k, B_k).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim) |
| v = torch.bmm(A_v, B_v).div_(self.kv_rank).view(batch_size, seq_len, self.n_head, self.head_dim) |
|
|
| return q, k, v |
|
|
| class CausalTensorProductSelfAttn(nn.Module): |
| def __init__(self, config, kv_rank=2, q_rank=6): |
| super().__init__() |
| self.n_head = config.num_attention_heads |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| self.n_embd = config.hidden_size |
| self.rank = kv_rank |
| self.q_rank = q_rank |
| self.max_position_embeddings = config.max_position_embeddings |
| self.rope_theta = config.rope_theta |
|
|
| self.c_qkv = CPLinear(self.n_embd, self.n_head, self.head_dim, self.rank, self.q_rank) |
| self.o_proj = nn.Linear(self.n_head * self.head_dim, self.n_embd, bias=False) |
| |
| self.register_buffer( |
| "cos_cached", |
| self._compute_rope_embeddings( |
| self.max_position_embeddings, |
| self.head_dim, |
| self.rope_theta, |
| dtype=torch.float32, |
| device=self.o_proj.weight.device, |
| )[0], |
| persistent=False, |
| ) |
| self.register_buffer( |
| "sin_cached", |
| self._compute_rope_embeddings( |
| self.max_position_embeddings, |
| self.head_dim, |
| self.rope_theta, |
| dtype=torch.float32, |
| device=self.o_proj.weight.device, |
| )[1], |
| persistent=False, |
| ) |
|
|
| self.using_groupnorm = getattr(config, 'using_groupnorm', False) |
| self.subln = LlamaRMSNorm(self.head_dim, eps=1e-5) |
| |
| def _compute_rope_embeddings(self, max_position_embeddings, head_dim, base=10000, dtype=None, device=None): |
| inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) |
| t = torch.arange(max_position_embeddings, device=device, dtype=torch.float32) |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos().to(dtype) |
| sin = emb.sin().to(dtype) |
| return cos.unsqueeze(0), sin.unsqueeze(0) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| ) -> torch.Tensor: |
| |
| bsz, seq_len, _ = hidden_states.size() |
| |
| if position_ids is None: |
| position_ids = torch.arange(seq_len, device=hidden_states.device) |
| position_ids = repeat(position_ids, 'l -> b l', b=bsz) |
|
|
| q, k, v = self.c_qkv(hidden_states) |
|
|
| cos = self.cos_cached[:, position_ids] |
| sin = self.sin_cached[:, position_ids] |
| |
| q, k = LigerRopeFunction.apply( |
| q, |
| k, |
| cos.squeeze(0), |
| sin.squeeze(0), |
| position_ids |
| ) |
|
|
| attn_out = flash_attn_func( |
| q, |
| k, |
| v, |
| dropout_p=0.0, |
| causal=attention_mask is None |
| ) |
| |
| attn_out = self.subln(attn_out) |
| |
| attn_out = rearrange(attn_out, "b s h d -> b s (h d)") |
| attn_out = self.o_proj(attn_out) |
| return attn_out |