"""Minimal PyTorch Needle model for Cactus conversion.""" from __future__ import annotations import math from typing import Any import torch from torch import nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from .configuration_needle import NeedleConfig class NeedleRMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float) -> None: super().__init__() self.weight = nn.Parameter(torch.zeros(hidden_size)) self.eps = float(eps) def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = x.dtype variance = x.float().pow(2).mean(dim=-1, keepdim=True) x = x.float() * torch.rsqrt(variance + self.eps) return x.to(dtype=dtype) * (1.0 + self.weight.to(dtype=dtype)) def _padding_mask(input_ids: torch.Tensor, pad_token_id: int) -> torch.Tensor: return (input_ids != int(pad_token_id))[:, None, None, :] def _causal_mask(seq_len: int, device: torch.device) -> torch.Tensor: return torch.ones((seq_len, seq_len), dtype=torch.bool, device=device).tril()[None, None, :, :] def _build_inv_freq(head_dim: int, theta: float) -> torch.Tensor: return 1.0 / (float(theta) ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / float(head_dim))) def _rotate_half(x: torch.Tensor) -> torch.Tensor: half = x.shape[-1] // 2 return torch.cat((-x[..., half:], x[..., :half]), dim=-1) def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: cos = cos.unsqueeze(2) sin = sin.unsqueeze(2) return (x * cos) + (_rotate_half(x) * sin) def _rotary_tables( inv_freq: torch.Tensor, batch_size: int, seq_len: int, device: torch.device, dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1) inv_freq = inv_freq[None, :, None].float().expand(batch_size, -1, 1).to(device) freqs = (inv_freq @ position_ids[:, None, :].float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos().to(dtype=dtype), emb.sin().to(dtype=dtype) def _rotary_tables_for_position_ids( inv_freq: torch.Tensor, position_ids: torch.Tensor, dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: inv_freq = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(position_ids.device) freqs = (inv_freq @ position_ids[:, None, :].float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos().to(dtype=dtype), emb.sin().to(dtype=dtype) def _add_clipped(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return torch.clamp(a + b, min=-65500.0, max=65500.0) class NeedleAttention(nn.Module): def __init__(self, config: NeedleConfig) -> None: super().__init__() self.hidden_size = int(config.hidden_size) self.num_heads = int(config.num_attention_heads) self.num_key_value_heads = int(config.num_key_value_heads) self.head_dim = self.hidden_size // self.num_heads kv_size = self.num_key_value_heads * self.head_dim self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.k_proj = nn.Linear(self.hidden_size, kv_size, bias=False) self.v_proj = nn.Linear(self.hidden_size, kv_size, bias=False) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.q_norm = NeedleRMSNorm(self.head_dim, config.rms_norm_eps) self.k_norm = NeedleRMSNorm(self.head_dim, config.rms_norm_eps) self.scale = 1.0 / math.sqrt(float(self.head_dim)) def forward( self, hidden_states: torch.Tensor, key_value_states: torch.Tensor, attention_mask: torch.Tensor | None, rope: tuple[torch.Tensor, torch.Tensor] | None, ) -> torch.Tensor: batch, q_len, _ = hidden_states.shape kv_len = key_value_states.shape[1] q = self.q_proj(hidden_states).view(batch, q_len, self.num_heads, self.head_dim) k = self.k_proj(key_value_states).view(batch, kv_len, self.num_key_value_heads, self.head_dim) v = self.v_proj(key_value_states).view(batch, kv_len, self.num_key_value_heads, self.head_dim) q = self.q_norm(q) k = self.k_norm(k) if rope is not None: cos, sin = rope q = _apply_rope(q, cos, sin) k = _apply_rope(k, cos, sin) out = F.scaled_dot_product_attention( q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=self.scale, enable_gqa=self.num_heads != self.num_key_value_heads, ) return self.out_proj(out.transpose(1, 2).contiguous().view(batch, q_len, self.hidden_size)) def project_kv(self, key_value_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch, kv_len, _ = key_value_states.shape k = self.k_proj(key_value_states).view(batch, kv_len, self.num_key_value_heads, self.head_dim) v = self.v_proj(key_value_states).view(batch, kv_len, self.num_key_value_heads, self.head_dim) k = self.k_norm(k) return k.contiguous(), v.contiguous() def forward_with_kv( self, hidden_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: torch.Tensor | None, rope: tuple[torch.Tensor, torch.Tensor] | None, ) -> torch.Tensor: batch, q_len, _ = hidden_states.shape q = self.q_proj(hidden_states).view(batch, q_len, self.num_heads, self.head_dim) q = self.q_norm(q) if rope is not None: cos, sin = rope q = _apply_rope(q, cos, sin) out = F.scaled_dot_product_attention( q.transpose(1, 2), key_states.transpose(1, 2), value_states.transpose(1, 2), attn_mask=attention_mask, dropout_p=0.0, is_causal=False, scale=self.scale, enable_gqa=self.num_heads != self.num_key_value_heads, ) return self.out_proj(out.transpose(1, 2).contiguous().view(batch, q_len, self.hidden_size)) class NeedleEncoderLayer(nn.Module): def __init__(self, config: NeedleConfig) -> None: super().__init__() self.input_layernorm = NeedleRMSNorm(config.hidden_size, config.rms_norm_eps) self.self_attn = NeedleAttention(config) self.attn_gate = nn.Parameter(torch.zeros(1)) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rope: tuple[torch.Tensor, torch.Tensor], ) -> torch.Tensor: normed = self.input_layernorm(hidden_states) attn = self.self_attn(normed, normed, attention_mask, rope) return _add_clipped(hidden_states, torch.sigmoid(self.attn_gate).to(dtype=attn.dtype) * attn) class NeedleDecoderLayer(nn.Module): def __init__(self, config: NeedleConfig) -> None: super().__init__() self.input_layernorm = NeedleRMSNorm(config.hidden_size, config.rms_norm_eps) self.self_attn = NeedleAttention(config) self.self_attn_gate = nn.Parameter(torch.zeros(1)) self.encoder_attn_layer_norm = NeedleRMSNorm(config.hidden_size, config.rms_norm_eps) self.encoder_attn = NeedleAttention(config) self.cross_attn_gate = nn.Parameter(torch.zeros(1)) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, self_mask: torch.Tensor, encoder_mask: torch.Tensor, rope: tuple[torch.Tensor, torch.Tensor], ) -> torch.Tensor: normed = self.input_layernorm(hidden_states) attn = self.self_attn(normed, normed, self_mask, rope) hidden_states = _add_clipped(hidden_states, torch.sigmoid(self.self_attn_gate).to(dtype=attn.dtype) * attn) attn = self.encoder_attn(self.encoder_attn_layer_norm(hidden_states), encoder_hidden_states, encoder_mask, None) return _add_clipped(hidden_states, torch.sigmoid(self.cross_attn_gate).to(dtype=attn.dtype) * attn) class NeedleEncoder(nn.Module): def __init__(self, config: NeedleConfig) -> None: super().__init__() self.layers = nn.ModuleList([NeedleEncoderLayer(config) for _ in range(config.num_encoder_layers)]) self.final_norm = NeedleRMSNorm(config.hidden_size, config.rms_norm_eps) self.head_dim = config.hidden_size // config.num_attention_heads self.rope_theta = float(config.rope_theta) self.register_buffer("inv_freq", _build_inv_freq(self.head_dim, self.rope_theta), persistent=False) def reset_rope(self) -> None: self.inv_freq = _build_inv_freq(self.head_dim, self.rope_theta).to(device=self.inv_freq.device) def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: inv_freq = _build_inv_freq(self.head_dim, self.rope_theta).to(device=hidden_states.device) rope = _rotary_tables( inv_freq, hidden_states.shape[0], hidden_states.shape[1], hidden_states.device, hidden_states.dtype, ) for layer in self.layers: hidden_states = layer(hidden_states, attention_mask, rope) return self.final_norm(hidden_states) class NeedleDecoder(nn.Module): def __init__(self, config: NeedleConfig) -> None: super().__init__() self.layers = nn.ModuleList([NeedleDecoderLayer(config) for _ in range(config.num_decoder_layers)]) self.norm = NeedleRMSNorm(config.hidden_size, config.rms_norm_eps) self.head_dim = config.hidden_size // config.num_attention_heads self.rope_theta = float(config.rope_theta) self.register_buffer("inv_freq", _build_inv_freq(self.head_dim, self.rope_theta), persistent=False) def reset_rope(self) -> None: self.inv_freq = _build_inv_freq(self.head_dim, self.rope_theta).to(device=self.inv_freq.device) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, self_mask: torch.Tensor, encoder_mask: torch.Tensor, ) -> torch.Tensor: inv_freq = _build_inv_freq(self.head_dim, self.rope_theta).to(device=hidden_states.device) rope = _rotary_tables( inv_freq, hidden_states.shape[0], hidden_states.shape[1], hidden_states.device, hidden_states.dtype, ) for layer in self.layers: hidden_states = layer(hidden_states, encoder_hidden_states, self_mask, encoder_mask, rope) return self.norm(hidden_states) class NeedleModel(PreTrainedModel): config_class = NeedleConfig base_model_prefix = "model" main_input_name = "input_ids" def __init__(self, config: NeedleConfig) -> None: super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_scale = math.sqrt(float(config.hidden_size)) self.encoder = NeedleEncoder(config) self.decoder = NeedleDecoder(config) self.post_init() self.reset_rope() def reset_rope(self) -> None: self.encoder.reset_rope() self.decoder.reset_rope() def get_input_embeddings(self) -> nn.Embedding: return self.embed_tokens def set_input_embeddings(self, value: nn.Embedding) -> None: self.embed_tokens = value def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, decoder_input_ids: torch.Tensor | None = None, **_: Any, ) -> BaseModelOutput: decoder_input_ids = input_ids if decoder_input_ids is None else decoder_input_ids encoder_mask = _padding_mask(input_ids, self.config.pad_token_id) if attention_mask is not None: encoder_mask = encoder_mask & attention_mask[:, None, None, :].to(dtype=torch.bool) encoder_hidden = self.embed_tokens(input_ids) * self.embed_scale encoder_hidden = self.encoder(encoder_hidden, encoder_mask) self_mask = _causal_mask(decoder_input_ids.shape[1], decoder_input_ids.device) decoder_hidden = self.embed_tokens(decoder_input_ids) * self.embed_scale decoder_hidden = self.decoder(decoder_hidden, encoder_hidden, self_mask, encoder_mask) return BaseModelOutput(last_hidden_state=decoder_hidden) class NeedleForCausalLM(PreTrainedModel): config_class = NeedleConfig base_model_prefix = "model" main_input_name = "input_ids" _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} def __init__(self, config: NeedleConfig) -> None: super().__init__(config) self.model = NeedleModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() self.model.reset_rope() self.tie_weights() def get_encoder(self) -> NeedleEncoder: return self.model.encoder def get_input_embeddings(self) -> nn.Embedding: return self.model.embed_tokens def set_input_embeddings(self, value: nn.Embedding) -> None: self.model.embed_tokens = value def get_output_embeddings(self) -> nn.Linear: return self.lm_head def set_output_embeddings(self, value: nn.Linear) -> None: self.lm_head = value def tie_weights(self, *args: Any, **kwargs: Any) -> None: del args, kwargs if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, decoder_input_ids: torch.Tensor | None = None, **kwargs: Any, ) -> Seq2SeqLMOutput: hidden_states = self.model( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, **kwargs, ).last_hidden_state return Seq2SeqLMOutput(logits=self.lm_head(hidden_states)) def cactus_source_encode( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: base_mask = _padding_mask(input_ids, self.config.pad_token_id) base_mask = base_mask & attention_mask[:, None, None, :].to(dtype=torch.bool) encoder_mask = base_mask.expand( -1, int(self.config.num_attention_heads), input_ids.shape[1], -1, ).contiguous() encoder_hidden = self.model.embed_tokens(input_ids) * self.model.embed_scale encoder_hidden = self.model.encoder(encoder_hidden, encoder_mask) decoder_mask = base_mask.expand( -1, int(self.config.num_attention_heads), 1, -1, ).contiguous() return encoder_hidden, decoder_mask.to(dtype=encoder_hidden.dtype) def cactus_decoder_cross_kv( self, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, ) -> tuple[torch.Tensor, ...]: del encoder_attention_mask outputs: list[torch.Tensor] = [] for layer in self.model.decoder.layers: k, v = layer.encoder_attn.project_kv(encoder_hidden_states) outputs.extend((k, v)) return tuple(outputs) def cactus_decoder_step( self, decoder_input_ids: torch.Tensor, position_ids: torch.Tensor, encoder_attention_mask: torch.Tensor, *cross_kv: torch.Tensor, ) -> torch.Tensor: encoder_attention_mask = encoder_attention_mask != 0 hidden_states = self.model.embed_tokens(decoder_input_ids) * self.model.embed_scale inv_freq = _build_inv_freq( self.model.decoder.head_dim, self.model.decoder.rope_theta, ).to(device=hidden_states.device) rope = _rotary_tables_for_position_ids( inv_freq, position_ids.to(dtype=torch.long), hidden_states.dtype, ) for layer_index, layer in enumerate(self.model.decoder.layers): normed = layer.input_layernorm(hidden_states) attn = layer.self_attn(normed, normed, None, rope) hidden_states = _add_clipped(hidden_states, torch.sigmoid(layer.self_attn_gate).to(dtype=attn.dtype) * attn) cross_attn = layer.encoder_attn.forward_with_kv( layer.encoder_attn_layer_norm(hidden_states), cross_kv[layer_index * 2], cross_kv[layer_index * 2 + 1], encoder_attention_mask, None, ) hidden_states = _add_clipped(hidden_states, torch.sigmoid(layer.cross_attn_gate).to(dtype=cross_attn.dtype) * cross_attn) hidden_states = self.model.decoder.norm(hidden_states) return self.lm_head(hidden_states) def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, NeedleRMSNorm): nn.init.zeros_(module.weight)