| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from functools import partial |
| from einops import rearrange, repeat |
| from typing import Optional, Tuple, Union |
|
|
|
|
| Linear = partial(nn.Linear, bias=False) |
| LayerNorm = partial(nn.LayerNorm, bias=False) |
|
|
|
|
| def rotate_half(x, interleaved=False): |
| if not interleaved: |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
| else: |
| x1, x2 = x[..., ::2], x[..., 1::2] |
| return rearrange( |
| torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 |
| ) |
|
|
|
|
| def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False): |
| """ |
| x: (batch_size, seqlen, nheads, headdim) |
| cos, sin: (seqlen, rotary_dim / 2) |
| """ |
| ro_dim = cos.shape[-1] * 2 |
| assert ro_dim <= x.shape[-1] |
| seqlen = x.size(1) |
| cos = cos[:seqlen] |
| sin = sin[:seqlen] |
| cos = repeat(cos, "s d -> s 1 (2 d)") |
| sin = repeat(sin, "s d -> s 1 (2 d)") |
| return torch.cat( |
| [ |
| x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, |
| x[..., ro_dim:], |
| ], |
| dim=-1, |
| ) |
|
|
|
|
| class RotaryEmbedding(torch.nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| base=10000.0, |
| interleaved=False, |
| scale_base=None, |
| scaling_factor=1.0, |
| pos_idx_in_fp32=True, |
| device=None, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.base = float(base) |
| self.pos_idx_in_fp32 = pos_idx_in_fp32 |
| |
| self.interleaved = interleaved |
| self.scale_base = scale_base |
| self.scaling_factor = scaling_factor |
| self.device = device |
|
|
| self._seq_len_cached = 0 |
| self._cos_cached = None |
| self._sin_cached = None |
| self._cos_k_cached = None |
| self._sin_k_cached = None |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| inv_freq = self._compute_inv_freq(self.device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) |
| scale = ( |
| (arange + 0.4 * self.dim) / (1.4 * self.dim) |
| if self.scale_base is not None |
| else None |
| ) |
| self.register_buffer("scale", scale) |
|
|
| def _compute_inv_freq(self, device=None): |
| return 1 / ( |
| self.base |
| ** ( |
| torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) |
| / self.dim |
| ) |
| ) |
|
|
| def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): |
| if ( |
| seqlen > self._seq_len_cached |
| or self._cos_cached is None |
| or self._cos_cached.device != device |
| or self._cos_cached.dtype != dtype |
| or (self.training and self._cos_cached.is_inference()) |
| ): |
| self._seq_len_cached = seqlen |
| if self.pos_idx_in_fp32: |
| t = torch.arange(seqlen, device=device, dtype=torch.float32) |
| t /= self.scaling_factor |
| if self.inv_freq.dtype != torch.float32: |
| inv_freq = self.inv_freq.to(torch.float32) |
| else: |
| inv_freq = self.inv_freq |
| else: |
| t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) |
| t /= self.scaling_factor |
| inv_freq = self.inv_freq |
| freqs = torch.outer(t, inv_freq) |
|
|
| if self.scale is None: |
| self._cos_cached = torch.cos(freqs).to(dtype) |
| self._sin_cached = torch.sin(freqs).to(dtype) |
| else: |
| power = ( |
| torch.arange( |
| seqlen, dtype=self.scale.dtype, device=self.scale.device |
| ) |
| - seqlen // 2 |
| ) / self.scale_base |
| scale = self.scale.to(device=power.device) ** power.unsqueeze(-1) |
| self._cos_cached = (torch.cos(freqs) * scale).to(dtype) |
| self._sin_cached = (torch.sin(freqs) * scale).to(dtype) |
| self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) |
| self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| q: (batch, seqlen, nheads, headdim) |
| k: (batch, seqlen, nheads, headdim) |
| """ |
| self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype) |
| assert self._cos_cached is not None |
| assert self._sin_cached is not None |
| if self.scale is None: |
| return ( |
| apply_rotary_emb_torch( |
| q, |
| self._cos_cached, |
| self._sin_cached, |
| self.interleaved, |
| True, |
| ), |
| apply_rotary_emb_torch( |
| k, |
| self._cos_cached, |
| self._sin_cached, |
| self.interleaved, |
| True, |
| ), |
| ) |
| else: |
| assert False |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| def __init__(self, hidden_size: int, n_heads: int, rotary: bool = True): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.n_heads = n_heads |
| self.d_head = self.hidden_size // self.n_heads |
| self.layernorm_qkv = nn.Sequential( |
| LayerNorm(hidden_size), Linear(hidden_size, hidden_size * 3) |
| ) |
| self.out_proj = Linear(hidden_size, hidden_size) |
| self.q_ln = LayerNorm(hidden_size, bias=False) |
| self.k_ln = LayerNorm(hidden_size, bias=False) |
| self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads) |
| self.rotary = RotaryEmbedding(hidden_size // n_heads) if rotary else None |
|
|
| def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| q = q.unflatten(-1, (self.n_heads, self.d_head)) |
| k = k.unflatten(-1, (self.n_heads, self.d_head)) |
| q, k = self.rotary(q, k) |
| q = q.flatten(-2, -1) |
| k = k.flatten(-2, -1) |
| return q, k |
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| |
| qkv = self.layernorm_qkv(x) |
| q, k, v = torch.chunk(qkv, 3, dim=-1) |
| q, k = self.q_ln(q).to(q.dtype), self.k_ln(k).to(q.dtype) |
| if self.rotary: |
| q, k = self._apply_rotary(q, k) |
| q, k, v = map(self.reshaper, (q, k, v)) |
| a = F.scaled_dot_product_attention(q, k, v, attention_mask) |
| a = rearrange(a, "b h s d -> b s (h d)") |
| return self.out_proj(a) |
|
|
|
|
| class PAttention(nn.Module): |
| """ |
| Cross-attention mechanism for token-parameter-attention (b, L, d) -> (b, L, n_tokens) -> (b, L, d) |
| """ |
| def __init__( |
| self, |
| hidden_size: int, |
| n_tokens: int, |
| dropout: float = 0.2, |
| ): |
| super(PAttention, self).__init__() |
| self.n_tokens = n_tokens |
| self.Wq = Linear(hidden_size, hidden_size) |
| self.Pk = nn.Parameter(torch.randn(1, n_tokens, hidden_size)) |
| self.Pv = nn.Parameter(torch.randn(1, n_tokens, hidden_size)) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| b, L, _ = x.size() |
| if attention_mask is not None: |
| attention_mask = attention_mask[:, None, :].expand(b, self.n_token, self.L).bool() |
| |
| q = self.Wq(x) |
| out = F.scaled_dot_product_attention(q, self.Pk, self.Pv, attn_mask=attention_mask, is_causal=False) |
| return self.dropout(out) |
|
|
|
|
| class AttentionLogitsSequence(nn.Module): |
| """ |
| Cross-attention mechanism for token-parameter-attention (b, L, d) -> (b, L, num_labels) -> (b, num_labels) |
| """ |
| def __init__(self, hidden_size: int, num_labels: int = 1, sim_type: str = 'dot'): |
| super(AttentionLogitsSequence, self).__init__() |
| self.num_labels = num_labels |
| self.Wp = nn.Parameter(torch.randn(1, hidden_size, num_labels)) |
| self.Wx = Linear(hidden_size, hidden_size) |
| self.sim_type = sim_type |
|
|
| def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): |
| if attention_mask is None: |
| return emb.mean(dim=1) |
| else: |
| return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) |
|
|
| def dot_product(self, x: torch.Tensor, p: torch.Tensor): |
| return torch.matmul(x, p) |
|
|
| def euclidean_distance(self, x: torch.Tensor, p: torch.Tensor): |
| |
| x_exp = x.unsqueeze(-1) |
| p_exp = p.unsqueeze(1) |
| dist = torch.abs(torch.norm(x_exp - p_exp, p=2, dim=2)) |
| return -dist |
|
|
| def cosine_similarity( |
| self, |
| x: torch.Tensor, |
| p: torch.Tensor, |
| attention_mask: torch.Tensor = None, |
| ) -> torch.Tensor: |
| x = x * attention_mask |
| x = F.normalize(x, p=2, dim=-1) |
| p = F.normalize(p, p=2, dim=1) |
| cos_sims = torch.matmul(x, p) |
| assert cos_sims.max().item() <= 1.0 and cos_sims.min().item() >= -1.0, "Cosine similarity values should be between -1 and 1" |
| return cos_sims |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| b, l, _ = x.size() |
| p = self.Wp.expand(b, -1, -1) |
| x = self.Wx(x) |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones(b, l, device=x.device, dtype=x.dtype) |
| |
| attention_mask = attention_mask.unsqueeze(-1) |
|
|
| if self.sim_type == 'dot': |
| y = self.dot_product(x, p) |
| elif self.sim_type == 'euclidean': |
| y = self.euclidean_distance(x, p) |
| elif self.sim_type == 'cosine': |
| y = self.cosine_similarity(x, p, attention_mask) |
| else: |
| raise ValueError(f"Invalid similarity type: {self.sim_type}") |
|
|
| |
| logits = self.mean_pooling(y, attention_mask) |
| return logits, y, x |
|
|
|
|
| class AttentionLogitsToken(nn.Module): |
| """ |
| Cross-attention mechanism for token-parameter-attention (b, L, d) -> (b, L, num_labels) |
| """ |
| def __init__(self, hidden_size: int, num_labels: int = 1, sim_type: str = 'dot'): |
| super(AttentionLogitsToken, self).__init__() |
| self.num_labels = num_labels |
| self.Wp = nn.Parameter(torch.randn(1, hidden_size, num_labels)) |
| self.Wx = Linear(hidden_size, hidden_size) |
| self.sim_type = sim_type |
|
|
| def dot_product(self, x: torch.Tensor, p: torch.Tensor): |
| return torch.matmul(x, p) |
| |
| def euclidean_distance(self, x: torch.Tensor, p: torch.Tensor): |
| return torch.norm(x - p, p=2, dim=-1) |
| |
| def cosine_similarity( |
| self, |
| x: torch.Tensor, |
| p: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| if attention_mask is not None: |
| x = x * attention_mask.unsqueeze(-1) |
| |
| x = F.normalize(x, p=2, dim=-1) |
| p = F.normalize(p, p=2, dim=1) |
| cos_sims = torch.matmul(x, p) |
| assert cos_sims.max().item() <= 1.0 and cos_sims.min().item() >= -1.0, "Cosine similarity values should be between -1 and 1" |
| return cos_sims |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| b, L, d = x.size() |
| p = self.Wp.expand(b, -1, -1) |
| x = self.Wx(x) |
| if self.sim_type == 'dot': |
| logits = self.dot_product(x, p) |
| elif self.sim_type == 'euclidean': |
| logits = self.euclidean_distance(x, p) |
| elif self.sim_type == 'cosine': |
| logits = self.cosine_similarity(x, p, attention_mask) |
| else: |
| raise ValueError(f"Invalid similarity type: {self.sim_type}") |
| return logits |
|
|
|
|
| class MultiHeadPAttention(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| n_heads: int, |
| n_tokens: int, |
| dropout: float = 0.2, |
| rotary: bool = True, |
| causal: bool = False, |
| ): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.n_heads = n_heads |
| self.d_head = self.hidden_size // self.n_heads |
| self.Wq = PAttention(hidden_size, n_tokens=n_tokens, dropout=dropout) |
| self.Wk = PAttention(hidden_size, n_tokens=n_tokens, dropout=dropout) |
| self.Wv = PAttention(hidden_size, n_tokens=n_tokens, dropout=dropout) |
| self.out_proj = Linear((hidden_size // n_heads) * n_heads, hidden_size) |
| self.q_ln = LayerNorm(hidden_size) |
| self.k_ln = LayerNorm(hidden_size) |
| self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads) |
| self.rotary = RotaryEmbedding(hidden_size // n_heads) if rotary else None |
| self.causal = causal |
|
|
| def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| q = q.unflatten(-1, (self.n_heads, self.d_head)) |
| k = k.unflatten(-1, (self.n_heads, self.d_head)) |
| q, k = self.rotary(q, k) |
| q = q.flatten(-2, -1) |
| k = k.flatten(-2, -1) |
| return q, k |
|
|
| def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| |
| b, L, _ = x.shape |
| if attention_mask is not None and attention_mask.dim() == 2: |
| attention_mask = attention_mask[:, None, None, :].expand(b, 1, L, L).bool() |
| q = self.Wq(x) |
| k = self.Wk(x) |
| v = self.Wv(x) |
| q, k = self.q_ln(q).to(q.dtype), self.k_ln(k).to(q.dtype) |
| if self.rotary: |
| q, k = self._apply_rotary(q, k) |
| q, k, v = map(self.reshaper, (q, k, v)) |
| a = F.scaled_dot_product_attention(q, k, v, attention_mask if not self.causal else None, is_causal=self.causal) |
| a = rearrange(a, "b h s d -> b s (h d)") |
| return self.out_proj(a) |
|
|