import torch import torch.nn as nn import torch.nn.functional as F import math from functools import partial from einops import rearrange, repeat from typing import Optional, Tuple, Union Linear = partial(nn.Linear, bias=False) LayerNorm = partial(nn.LayerNorm, bias=False) def rotate_half(x, interleaved=False): if not interleaved: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) else: x1, x2 = x[..., ::2], x[..., 1::2] return rearrange( torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 ) def apply_rotary_emb_torch(x, cos, sin, interleaved=False, _inplace=False): """ x: (batch_size, seqlen, nheads, headdim) cos, sin: (seqlen, rotary_dim / 2) """ ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] seqlen = x.size(1) cos = cos[:seqlen] sin = sin[:seqlen] cos = repeat(cos, "s d -> s 1 (2 d)") sin = repeat(sin, "s d -> s 1 (2 d)") return torch.cat( [ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:], ], dim=-1, ) class RotaryEmbedding(torch.nn.Module): def __init__( self, dim: int, base=10000.0, interleaved=False, scale_base=None, scaling_factor=1.0, pos_idx_in_fp32=True, device=None, ): super().__init__() self.dim = dim self.base = float(base) self.pos_idx_in_fp32 = pos_idx_in_fp32 # Generate and save the inverse frequency buffer (non trainable) self.interleaved = interleaved self.scale_base = scale_base self.scaling_factor = scaling_factor self.device = device self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None self._cos_k_cached = None self._sin_k_cached = None self.reset_parameters() def reset_parameters(self): inv_freq = self._compute_inv_freq(self.device) self.register_buffer("inv_freq", inv_freq, persistent=False) arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32) scale = ( (arange + 0.4 * self.dim) / (1.4 * self.dim) if self.scale_base is not None else None ) self.register_buffer("scale", scale) def _compute_inv_freq(self, device=None): return 1 / ( self.base ** ( torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim ) ) def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): if ( seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype or (self.training and self._cos_cached.is_inference()) ): self._seq_len_cached = seqlen if self.pos_idx_in_fp32: t = torch.arange(seqlen, device=device, dtype=torch.float32) t /= self.scaling_factor if self.inv_freq.dtype != torch.float32: inv_freq = self.inv_freq.to(torch.float32) else: inv_freq = self.inv_freq else: t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) t /= self.scaling_factor inv_freq = self.inv_freq freqs = torch.outer(t, inv_freq) if self.scale is None: self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) else: power = ( torch.arange( seqlen, dtype=self.scale.dtype, device=self.scale.device ) - seqlen // 2 ) / self.scale_base scale = self.scale.to(device=power.device) ** power.unsqueeze(-1) self._cos_cached = (torch.cos(freqs) * scale).to(dtype) self._sin_cached = (torch.sin(freqs) * scale).to(dtype) self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ q: (batch, seqlen, nheads, headdim) k: (batch, seqlen, nheads, headdim) """ self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype) assert self._cos_cached is not None assert self._sin_cached is not None if self.scale is None: return ( apply_rotary_emb_torch( q, self._cos_cached, self._sin_cached, self.interleaved, True, # inplace=True ), apply_rotary_emb_torch( k, self._cos_cached, self._sin_cached, self.interleaved, True, # inplace=True ), ) # type: ignore else: assert False class MultiHeadAttention(nn.Module): def __init__(self, hidden_size: int, n_heads: int, rotary: bool = True): super().__init__() self.hidden_size = hidden_size self.n_heads = n_heads self.d_head = self.hidden_size // self.n_heads self.layernorm_qkv = nn.Sequential( LayerNorm(hidden_size), Linear(hidden_size, hidden_size * 3) ) self.out_proj = Linear(hidden_size, hidden_size) self.q_ln = LayerNorm(hidden_size, bias=False) self.k_ln = LayerNorm(hidden_size, bias=False) self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads) self.rotary = RotaryEmbedding(hidden_size // n_heads) if rotary else None def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: q = q.unflatten(-1, (self.n_heads, self.d_head)) k = k.unflatten(-1, (self.n_heads, self.d_head)) q, k = self.rotary(q, k) q = q.flatten(-2, -1) k = k.flatten(-2, -1) return q, k def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # attention mask already prepped for sdpa shape (bs, 1, seq_len, seq_len) qkv = self.layernorm_qkv(x) # (bs, seq_len, d_model * 3) q, k, v = torch.chunk(qkv, 3, dim=-1) # (bs, seq_len, hidden_size) q, k = self.q_ln(q).to(q.dtype), self.k_ln(k).to(q.dtype) if self.rotary: q, k = self._apply_rotary(q, k) q, k, v = map(self.reshaper, (q, k, v)) # (bs, n_heads, seq_len, d_head) a = F.scaled_dot_product_attention(q, k, v, attention_mask) # (bs, n_heads, seq_len, d_head) a = rearrange(a, "b h s d -> b s (h d)") # (bs, seq_len, n_heads * d_head) return self.out_proj(a) # (bs, seq_len, hidden_size) class PAttention(nn.Module): """ Cross-attention mechanism for token-parameter-attention (b, L, d) -> (b, L, n_tokens) -> (b, L, d) """ def __init__( self, hidden_size: int, n_tokens: int, dropout: float = 0.2, ): super(PAttention, self).__init__() self.n_tokens = n_tokens self.Wq = Linear(hidden_size, hidden_size) self.Pk = nn.Parameter(torch.randn(1, n_tokens, hidden_size)) self.Pv = nn.Parameter(torch.randn(1, n_tokens, hidden_size)) self.dropout = nn.Dropout(dropout) def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: b, L, _ = x.size() if attention_mask is not None: attention_mask = attention_mask[:, None, :].expand(b, self.n_token, self.L).bool() q = self.Wq(x) # (b, L, d) out = F.scaled_dot_product_attention(q, self.Pk, self.Pv, attn_mask=attention_mask, is_causal=False) # (b, L, d) return self.dropout(out) class AttentionLogitsSequence(nn.Module): """ Cross-attention mechanism for token-parameter-attention (b, L, d) -> (b, L, num_labels) -> (b, num_labels) """ def __init__(self, hidden_size: int, num_labels: int = 1, sim_type: str = 'dot'): super(AttentionLogitsSequence, self).__init__() self.num_labels = num_labels self.Wp = nn.Parameter(torch.randn(1, hidden_size, num_labels)) self.Wx = Linear(hidden_size, hidden_size) self.sim_type = sim_type def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d) if attention_mask is None: return emb.mean(dim=1) else: return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) # (b, d) def dot_product(self, x: torch.Tensor, p: torch.Tensor): # (b, L, d) * (b, d, num_labels) -> (b, L, num_labels) return torch.matmul(x, p) def euclidean_distance(self, x: torch.Tensor, p: torch.Tensor): # (b, L, d) * (b, d, num_labels) -> (b, L, num_labels) # x: (b, L, d), p: (b, d, num_labels) x_exp = x.unsqueeze(-1) # (b, L, d, 1) p_exp = p.unsqueeze(1) # (b, 1, d, num_labels) dist = torch.abs(torch.norm(x_exp - p_exp, p=2, dim=2)) # (b, L, num_labels) return -dist def cosine_similarity( self, x: torch.Tensor, p: torch.Tensor, attention_mask: torch.Tensor = None, ) -> torch.Tensor: # (b, L, d) * (b, d, num_labels) -> (b, L, num_labels) x = x * attention_mask x = F.normalize(x, p=2, dim=-1) p = F.normalize(p, p=2, dim=1) cos_sims = torch.matmul(x, p) assert cos_sims.max().item() <= 1.0 and cos_sims.min().item() >= -1.0, "Cosine similarity values should be between -1 and 1" return cos_sims def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: b, l, _ = x.size() p = self.Wp.expand(b, -1, -1) # (b, d, num_labels) x = self.Wx(x) # (b, L, d) if attention_mask is None: attention_mask = torch.ones(b, l, device=x.device, dtype=x.dtype) attention_mask = attention_mask.unsqueeze(-1) if self.sim_type == 'dot': y = self.dot_product(x, p) elif self.sim_type == 'euclidean': y = self.euclidean_distance(x, p) elif self.sim_type == 'cosine': y = self.cosine_similarity(x, p, attention_mask) else: raise ValueError(f"Invalid similarity type: {self.sim_type}") # y (b, L, num_labels) logits = self.mean_pooling(y, attention_mask) # (b, num_labels) return logits, y, x class AttentionLogitsToken(nn.Module): """ Cross-attention mechanism for token-parameter-attention (b, L, d) -> (b, L, num_labels) """ def __init__(self, hidden_size: int, num_labels: int = 1, sim_type: str = 'dot'): super(AttentionLogitsToken, self).__init__() self.num_labels = num_labels self.Wp = nn.Parameter(torch.randn(1, hidden_size, num_labels)) self.Wx = Linear(hidden_size, hidden_size) self.sim_type = sim_type def dot_product(self, x: torch.Tensor, p: torch.Tensor): return torch.matmul(x, p) def euclidean_distance(self, x: torch.Tensor, p: torch.Tensor): return torch.norm(x - p, p=2, dim=-1) def cosine_similarity( self, x: torch.Tensor, p: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: if attention_mask is not None: x = x * attention_mask.unsqueeze(-1) x = F.normalize(x, p=2, dim=-1) p = F.normalize(p, p=2, dim=1) cos_sims = torch.matmul(x, p) assert cos_sims.max().item() <= 1.0 and cos_sims.min().item() >= -1.0, "Cosine similarity values should be between -1 and 1" return cos_sims def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: b, L, d = x.size() p = self.Wp.expand(b, -1, -1) # (b, d, num_labels) x = self.Wx(x) # (b, L, d) if self.sim_type == 'dot': logits = self.dot_product(x, p) elif self.sim_type == 'euclidean': logits = self.euclidean_distance(x, p) elif self.sim_type == 'cosine': logits = self.cosine_similarity(x, p, attention_mask) else: raise ValueError(f"Invalid similarity type: {self.sim_type}") return logits # (b, L, num_labels) class MultiHeadPAttention(nn.Module): def __init__( self, hidden_size: int, n_heads: int, n_tokens: int, dropout: float = 0.2, rotary: bool = True, causal: bool = False, ): super().__init__() self.hidden_size = hidden_size self.n_heads = n_heads self.d_head = self.hidden_size // self.n_heads self.Wq = PAttention(hidden_size, n_tokens=n_tokens, dropout=dropout) self.Wk = PAttention(hidden_size, n_tokens=n_tokens, dropout=dropout) self.Wv = PAttention(hidden_size, n_tokens=n_tokens, dropout=dropout) self.out_proj = Linear((hidden_size // n_heads) * n_heads, hidden_size) self.q_ln = LayerNorm(hidden_size) self.k_ln = LayerNorm(hidden_size) self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads) self.rotary = RotaryEmbedding(hidden_size // n_heads) if rotary else None self.causal = causal def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: q = q.unflatten(-1, (self.n_heads, self.d_head)) k = k.unflatten(-1, (self.n_heads, self.d_head)) q, k = self.rotary(q, k) q = q.flatten(-2, -1) k = k.flatten(-2, -1) return q, k def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # attention mask already prepped for sdpa shape (bs, 1, seq_len, seq_len) b, L, _ = x.shape if attention_mask is not None and attention_mask.dim() == 2: attention_mask = attention_mask[:, None, None, :].expand(b, 1, L, L).bool() q = self.Wq(x) k = self.Wk(x) v = self.Wv(x) q, k = self.q_ln(q).to(q.dtype), self.k_ln(k).to(q.dtype) if self.rotary: q, k = self._apply_rotary(q, k) q, k, v = map(self.reshaper, (q, k, v)) # (bs, n_heads, seq_len, d_head) a = F.scaled_dot_product_attention(q, k, v, attention_mask if not self.causal else None, is_causal=self.causal) # (bs, n_heads, seq_len, d_head) a = rearrange(a, "b h s d -> b s (h d)") # (bs, seq_len, n_heads * d_head) return self.out_proj(a) # (bs, seq_len, hidden_size)