| """ |
| modeling_saber.py — Full PyTorch implementation of Eve-3-SABER-1B. |
| |
| Architecture highlights |
| ----------------------- |
| * Dense decoder-only transformer with pre-RMSNorm. |
| * RoPE (rotary position embeddings) applied to Q and K after head reshape. |
| * **Slip-Anchors**: learnable codebook biases K/V *after* RoPE, fully |
| compatible with FlashAttention / F.scaled_dot_product_attention. |
| * **Experience Stream**: a per-token, layer-traversing state with a curiosity |
| auxiliary loss (prediction-error on a stop-gradient summary). |
| * **Resonant FFN**: even-indexed layers augment SwiGLU with a learned |
| sinusoidal modulation blended by a trainable scalar alpha. |
| * Weight-tied LM head. |
| * Gradient-checkpointing support. |
| |
| Intended usage (HuggingFace Trainer / SFTTrainer compatible): |
| from configuration_saber import SABERConfig |
| from modeling_saber import SABERForCausalLM |
| |
| config = SABERConfig() |
| model = SABERForCausalLM(config) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
|
|
| from transformers import PreTrainedModel |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast |
| from transformers.utils import logging |
|
|
| from configuration_saber import SABERConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
| |
| |
| |
|
|
| class SABERRMSNorm(nn.Module): |
| """Root-mean-square layer normalization (no bias, learnable scale).""" |
|
|
| def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.eps = eps |
|
|
| def _norm(self, x: torch.Tensor) -> torch.Tensor: |
| |
| return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| return (self._norm(x.float()) * self.weight.float()).to(x.dtype) |
|
|
|
|
| |
| |
| |
|
|
| class SABERRotaryEmbedding(nn.Module): |
| """ |
| Standard RoPE using complex-number rotation (Llama / GPT-NeoX style). |
| |
| Frequencies are cached up to ``max_seq_len`` and extended on the fly if |
| a longer sequence is encountered. |
| """ |
|
|
| def __init__( |
| self, |
| head_dim: int, |
| max_seq_len: int = 2048, |
| theta: float = 10_000.0, |
| device: Optional[torch.device] = None, |
| ) -> None: |
| super().__init__() |
| self.head_dim = head_dim |
| self.max_seq_len = max_seq_len |
| self.theta = theta |
|
|
| |
| inv_freq = 1.0 / ( |
| theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) |
| / head_dim) |
| ) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self._build_cache(max_seq_len, device=device) |
|
|
| def _build_cache(self, seq_len: int, device: Optional[torch.device] = None) -> None: |
| """Build (or extend) the cos/sin cache.""" |
| t = torch.arange(seq_len, dtype=torch.float32, |
| device=self.inv_freq.device if device is None else device) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) |
| self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) |
| self.max_seq_len = seq_len |
|
|
| @staticmethod |
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """Rotate the second half of the last dimension by -90°.""" |
| half = x.shape[-1] // 2 |
| x1, x2 = x[..., :half], x[..., half:] |
| return torch.cat([-x2, x1], dim=-1) |
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| seq_len: int, |
| position_ids: Optional[torch.LongTensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Apply RoPE to q and k. |
| |
| q, k: (batch, n_heads, seq_len, head_dim) |
| position_ids: (batch, seq_len) or None |
| """ |
| if seq_len > self.max_seq_len: |
| self._build_cache(seq_len, device=q.device) |
|
|
| if position_ids is not None: |
| |
| |
| |
| |
| cos_2d = self.cos_cached.squeeze(0).squeeze(0).to(q.dtype) |
| sin_2d = self.sin_cached.squeeze(0).squeeze(0).to(q.dtype) |
| cos = cos_2d[position_ids].unsqueeze(1) |
| sin = sin_2d[position_ids].unsqueeze(1) |
| else: |
| cos = self.cos_cached[:, :, :seq_len, :].to(q.dtype) |
| sin = self.sin_cached[:, :, :seq_len, :].to(q.dtype) |
|
|
| q_rot = q * cos + self._rotate_half(q) * sin |
| k_rot = k * cos + self._rotate_half(k) * sin |
| return q_rot, k_rot |
|
|
|
|
| |
| |
| |
|
|
| class SlipAnchors(nn.Module): |
| """ |
| Slip-anchor module — biases K and V using a learnable codebook. |
| |
| Applied *after* RoPE, so FlashAttention compatibility is preserved. |
| |
| Parameters |
| ---------- |
| d_model : residual hidden dimension (2048) |
| n_anchors : codebook size (64) |
| d_anchor : anchor bottleneck dim (128) |
| head_dim : per-head dimension (128) |
| n_heads : number of attention heads (16) |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_anchors: int, |
| d_anchor: int, |
| head_dim: int, |
| n_heads: int, |
| ) -> None: |
| super().__init__() |
| self.n_anchors = n_anchors |
| self.d_anchor = d_anchor |
| self.n_heads = n_heads |
| self.head_dim = head_dim |
|
|
| |
| self.anchors = nn.Parameter(torch.empty(n_anchors, d_anchor)) |
| |
| self.W_anchor_down = nn.Linear(d_model, d_anchor, bias=False) |
| |
| self.U_k = nn.Linear(d_anchor, head_dim, bias=False) |
| |
| self.U_v = nn.Linear(d_anchor, head_dim, bias=False) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self) -> None: |
| nn.init.normal_(self.anchors, std=0.02) |
| nn.init.normal_(self.W_anchor_down.weight, std=0.02) |
| nn.init.normal_(self.U_k.weight, std=0.02) |
| nn.init.normal_(self.U_v.weight, std=0.02) |
|
|
| def forward( |
| self, |
| h: torch.Tensor, |
| K: torch.Tensor, |
| V: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Return K_modified, V_modified.""" |
| B, L, _ = h.shape |
|
|
| |
| h_anchor = self.W_anchor_down(h) |
|
|
| |
| scores = torch.softmax(h_anchor @ self.anchors.T, dim=-1) |
|
|
| |
| anchor_context = scores @ self.anchors |
|
|
| |
| k_bias = self.U_k(anchor_context) |
| v_bias = self.U_v(anchor_context) |
|
|
| |
| K_modified = K + k_bias.unsqueeze(1) |
| V_modified = V + v_bias.unsqueeze(1) |
|
|
| return K_modified, V_modified |
|
|
|
|
| |
| |
| |
|
|
| class SABERAttention(nn.Module): |
| """ |
| Multi-head attention with: |
| * No projection biases. |
| * RoPE applied to Q and K after head reshape. |
| * Slip-anchor modulation of K and V after RoPE. |
| * F.scaled_dot_product_attention (FlashAttention 2 compatible). |
| """ |
|
|
| def __init__(self, config: SABERConfig, layer_idx: int) -> None: |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.d_model = config.d_model |
| self.n_heads = config.n_heads |
| self.head_dim = config.head_dim |
|
|
| |
| self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False) |
| self.k_proj = nn.Linear(self.d_model, self.d_model, bias=False) |
| self.v_proj = nn.Linear(self.d_model, self.d_model, bias=False) |
| self.o_proj = nn.Linear(self.d_model, self.d_model, bias=False) |
|
|
| |
| |
| self.rotary_emb = SABERRotaryEmbedding( |
| head_dim=self.head_dim, |
| max_seq_len=config.max_position_embeddings, |
| theta=config.rope_theta, |
| ) |
|
|
| |
| self.slip_anchors = SlipAnchors( |
| d_model=self.d_model, |
| n_anchors=config.n_anchors, |
| d_anchor=config.d_anchor, |
| head_dim=self.head_dim, |
| n_heads=self.n_heads, |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| output_attentions: bool = False, |
| ) -> Tuple[torch.Tensor, ...]: |
|
|
| B, L, _ = hidden_states.shape |
|
|
| |
| Q = self.q_proj(hidden_states) |
| K = self.k_proj(hidden_states) |
| V = self.v_proj(hidden_states) |
|
|
| |
| def _reshape(t: torch.Tensor) -> torch.Tensor: |
| return t.view(B, L, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
| Q, K, V = _reshape(Q), _reshape(K), _reshape(V) |
|
|
| |
| kv_seq_len = L |
| if past_key_value is not None: |
| kv_seq_len += past_key_value[0].shape[-2] |
|
|
| Q, K = self.rotary_emb(Q, K, seq_len=kv_seq_len, position_ids=position_ids) |
|
|
| |
| if past_key_value is not None: |
| K = torch.cat([past_key_value[0], K], dim=2) |
| V = torch.cat([past_key_value[1], V], dim=2) |
|
|
| present_kv = (K, V) if use_cache else None |
|
|
| |
| |
| if getattr(self.config, 'enable_anchors', True): |
| K, V = self.slip_anchors(hidden_states, K, V) |
|
|
| |
| |
| is_causal = attention_mask is None and L > 1 |
| attn_out = F.scaled_dot_product_attention( |
| Q, K, V, |
| attn_mask=attention_mask, |
| dropout_p=0.0, |
| is_causal=is_causal, |
| ) |
|
|
| |
| attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, self.d_model) |
| attn_out = self.o_proj(attn_out) |
|
|
| outputs: Tuple = (attn_out,) |
| if use_cache: |
| outputs += (present_kv,) |
| if output_attentions: |
| |
| outputs += (None,) |
|
|
| return outputs |
|
|
|
|
| |
| |
| |
|
|
| class ExperienceStream(nn.Module): |
| """ |
| Per-layer experience update with a curiosity (prediction-error) auxiliary loss. |
| |
| State flows layer-to-layer within a single forward pass; it is reset to |
| zeros at the start of each new sequence. |
| |
| Parameters |
| ---------- |
| d_model : residual hidden dimension |
| d_exp : experience state dimension (256) |
| """ |
|
|
| def __init__(self, d_model: int, d_exp: int) -> None: |
| super().__init__() |
| |
| self.W_s = nn.Linear(d_model, d_exp, bias=False) |
| |
| self.W_pred = nn.Linear(d_exp, d_exp, bias=False) |
| |
| self.W_e = nn.Linear(d_exp, d_exp, bias=False) |
| |
| self.decay_raw = nn.Parameter(torch.full((d_exp,), 3.0)) |
| |
| self.exp_norm = nn.LayerNorm(d_exp) |
|
|
| def forward( |
| self, |
| h: torch.Tensor, |
| experience_state: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Returns |
| ------- |
| new_experience_state : (B, L, d_exp) |
| curiosity_loss : scalar tensor |
| """ |
| |
| s = self.W_s(h) |
|
|
| |
| s_sg = s.detach() |
|
|
| |
| s_pred = self.W_pred(experience_state) |
|
|
| |
| curiosity_loss = (s_sg - s_pred).pow(2).mean() |
|
|
| |
| decay = torch.sigmoid(self.decay_raw) |
| delta = F.silu(self.W_e(s)) |
| new_state = decay * experience_state + delta |
| new_state = self.exp_norm(new_state) |
|
|
| return new_state, curiosity_loss |
|
|
|
|
| |
| |
| |
|
|
| class StandardFFN(nn.Module): |
| """Standard SwiGLU FFN (used on odd-indexed layers).""" |
|
|
| def __init__(self, d_model: int, d_ff: int) -> None: |
| super().__init__() |
| self.W1 = nn.Linear(d_model, d_ff, bias=False) |
| self.W3 = nn.Linear(d_model, d_ff, bias=False) |
| self.W2 = nn.Linear(d_ff, d_model, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| return self.W2(F.silu(self.W1(x)) * self.W3(x)) |
|
|
|
|
| class ResonantFFN(nn.Module): |
| """ |
| Resonant FFN (used on even-indexed layers). |
| |
| Augments standard SwiGLU with a learned sinusoidal modulation. |
| The blend is controlled by a per-layer scalar alpha (init ≈ 0.95). |
| |
| ffn_out = W2(silu(W1(x)) * W3(x)) # standard SwiGLU |
| mod = sin(W_freq @ x) # sinusoidal modulation |
| alpha = sigmoid(alpha_raw) # ≈ 0.95 at init |
| output = alpha * ffn_out + (1-alpha) * ffn_out * (1 + mod) |
| = ffn_out * (alpha + (1-alpha) * (1 + mod)) |
| """ |
|
|
| def __init__(self, d_model: int, d_ff: int, alpha_init: float = 3.0) -> None: |
| super().__init__() |
| |
| self.W1 = nn.Linear(d_model, d_ff, bias=False) |
| self.W3 = nn.Linear(d_model, d_ff, bias=False) |
| self.W2 = nn.Linear(d_ff, d_model, bias=False) |
| |
| self.W_freq = nn.Linear(d_model, d_model, bias=False) |
| |
| self.alpha_raw = nn.Parameter(torch.tensor(alpha_init)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| ffn_out = self.W2(F.silu(self.W1(x)) * self.W3(x)) |
|
|
| |
| mod = torch.sin(self.W_freq(x)) |
|
|
| |
| alpha = torch.sigmoid(self.alpha_raw) |
| output = alpha * ffn_out + (1.0 - alpha) * (ffn_out * (1.0 + mod)) |
| return output |
|
|
|
|
| |
| |
| |
|
|
| class SABERBlock(nn.Module): |
| """ |
| Single SABER transformer block. |
| |
| Structure (pre-norm): |
| h = h + Attention(RMSNorm(h)) |
| h = h + FFN(RMSNorm(h)) |
| experience_state, curiosity = ExperienceStream(h, experience_state) |
| """ |
|
|
| def __init__(self, config: SABERConfig, layer_idx: int) -> None: |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
|
|
| self.input_layernorm = SABERRMSNorm(config.d_model, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = SABERRMSNorm(config.d_model, eps=config.rms_norm_eps) |
|
|
| self.self_attn = SABERAttention(config, layer_idx=layer_idx) |
|
|
| |
| if layer_idx in config.resonant_layers: |
| self.ffn: nn.Module = ResonantFFN( |
| d_model=config.d_model, |
| d_ff=config.d_ff, |
| alpha_init=config.resonant_alpha_init, |
| ) |
| else: |
| self.ffn = StandardFFN(d_model=config.d_model, d_ff=config.d_ff) |
|
|
| self.experience_stream = ExperienceStream( |
| d_model=config.d_model, |
| d_exp=config.d_exp, |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| experience_state: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| output_attentions: bool = False, |
| ) -> Tuple: |
| residual = hidden_states |
|
|
| |
| normed = self.input_layernorm(hidden_states) |
| attn_outputs = self.self_attn( |
| normed, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| attn_out = attn_outputs[0] |
| hidden_states = residual + attn_out |
|
|
| |
| residual = hidden_states |
| hidden_states = residual + self.ffn(self.post_attention_layernorm(hidden_states)) |
|
|
| |
| if getattr(self.config, 'enable_experience', True): |
| experience_state, curiosity_loss = self.experience_stream( |
| hidden_states, experience_state |
| ) |
| else: |
| curiosity_loss = torch.tensor(0.0, device=hidden_states.device) |
|
|
| |
| extra = attn_outputs[1:] |
| return (hidden_states, experience_state, curiosity_loss) + extra |
|
|
|
|
| |
| |
| |
|
|
| class SABERModel(PreTrainedModel): |
| """ |
| SABER base model: token embeddings → blocks → final RMSNorm. |
| |
| Does not include the LM head — use ``SABERForCausalLM`` for training. |
| """ |
|
|
| config_class = SABERConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["SABERBlock"] |
| _supports_flash_attn_2 = True |
|
|
| def __init__(self, config: SABERConfig) -> None: |
| super().__init__(config) |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) |
| self.layers = nn.ModuleList( |
| [SABERBlock(config, layer_idx=i) for i in range(config.n_layers)] |
| ) |
| self.norm = SABERRMSNorm(config.d_model, eps=config.rms_norm_eps) |
|
|
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| |
| |
| |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
| elif isinstance(module, SABERRMSNorm): |
| nn.init.ones_(module.weight) |
| elif isinstance(module, SlipAnchors): |
| |
| pass |
| |
|
|
| |
| |
| |
|
|
| def get_input_embeddings(self) -> nn.Embedding: |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value: nn.Embedding) -> None: |
| self.embed_tokens = value |
|
|
| |
| |
| |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[BaseModelOutputWithPast, Tuple]: |
|
|
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| output_attentions = output_attentions or False |
| output_hidden_states = output_hidden_states or False |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if inputs_embeds is None: |
| if input_ids is None: |
| raise ValueError("Provide either input_ids or inputs_embeds.") |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| B, L, _ = inputs_embeds.shape |
|
|
| |
| if position_ids is None: |
| past_len = past_key_values[0][0].shape[-2] if past_key_values else 0 |
| position_ids = torch.arange( |
| past_len, past_len + L, |
| dtype=torch.long, |
| device=inputs_embeds.device, |
| ).unsqueeze(0).expand(B, -1) |
|
|
| |
| |
| |
| |
| causal_mask: Optional[torch.Tensor] = None |
| if attention_mask is not None and attention_mask.dim() == 2: |
| |
| |
| causal_mask = ( |
| (1.0 - attention_mask[:, None, None, :].float()) |
| * torch.finfo(inputs_embeds.dtype).min |
| ) |
|
|
| |
| |
| |
| |
| |
| experience_state = torch.zeros( |
| B, L, self.config.d_exp, |
| dtype=inputs_embeds.dtype, |
| device=inputs_embeds.device, |
| ) |
|
|
| |
| hidden_states = inputs_embeds |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_cache = [] |
| total_curiosity = torch.tensor(0.0, device=inputs_embeds.device, |
| dtype=inputs_embeds.dtype) |
|
|
| for i, layer in enumerate(self.layers): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| past_kv = past_key_values[i] if past_key_values is not None else None |
|
|
| if self.gradient_checkpointing and self.training: |
| |
| |
| |
| def _make_ckpt_fn(layer, experience_state): |
| def _fn(hidden_states, causal_mask, position_ids): |
| return layer( |
| hidden_states, |
| experience_state=experience_state, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_value=None, |
| use_cache=False, |
| output_attentions=output_attentions, |
| ) |
| return _fn |
|
|
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| _make_ckpt_fn(layer, experience_state), |
| hidden_states, |
| causal_mask, |
| position_ids, |
| use_reentrant=False, |
| ) |
| else: |
| layer_outputs = layer( |
| hidden_states, |
| experience_state=experience_state, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_value=past_kv, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
| experience_state = layer_outputs[1] |
| total_curiosity = total_curiosity + layer_outputs[2] |
|
|
| |
| if use_cache: |
| |
| next_cache.append(layer_outputs[3] if len(layer_outputs) > 3 else None) |
|
|
| if output_attentions: |
| |
| all_self_attns += (layer_outputs[-1],) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| |
| mean_curiosity = total_curiosity / self.config.n_layers |
|
|
| next_cache_out = next_cache if use_cache else None |
|
|
| if not return_dict: |
| |
| |
| |
| |
| |
| |
| |
| return ( |
| hidden_states, |
| mean_curiosity, |
| next_cache_out, |
| all_hidden_states, |
| all_self_attns, |
| ) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache_out, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ), mean_curiosity |
|
|
|
|
| |
| |
| |
|
|
| class SABERForCausalLM(PreTrainedModel, GenerationMixin): |
| """ |
| Eve-3-SABER-1B for causal language modelling. |
| |
| Compatible with HuggingFace ``Trainer``, ``SFTTrainer``, PEFT, and |
| standard ``generate()`` pipelines. |
| |
| Loss = L_CE + curiosity_coeff * L_curiosity |
| """ |
|
|
| config_class = SABERConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["SABERBlock"] |
| _supports_flash_attn_2 = True |
| |
| |
| _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} |
|
|
| def __init__(self, config: SABERConfig) -> None: |
| super().__init__(config) |
| self.model = SABERModel(config) |
| |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| |
| |
| |
|
|
| def get_input_embeddings(self) -> nn.Embedding: |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value: nn.Embedding) -> None: |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self) -> nn.Linear: |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: |
| self.lm_head = new_embeddings |
|
|
| def tie_weights(self, **kwargs) -> None: |
| """Tie lm_head.weight ← embed_tokens.weight.""" |
| self.lm_head.weight = self.model.embed_tokens.weight |
|
|
| |
| |
| |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[CausalLMOutputWithPast, Tuple]: |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| |
| |
| base_out = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=False, |
| ) |
| |
| hidden_states = base_out[0] |
| curiosity_loss = base_out[1] |
| pkv = base_out[2] if len(base_out) > 2 else None |
| all_hs = base_out[3] if len(base_out) > 3 else None |
| all_attn = base_out[4] if len(base_out) > 4 else None |
|
|
| |
| logits = self.lm_head(hidden_states) |
|
|
| |
| loss: Optional[torch.Tensor] = None |
| if labels is not None: |
| |
| |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
|
|
| loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| ce_loss = loss_fct( |
| shift_logits.view(-1, self.config.vocab_size), |
| shift_labels.view(-1), |
| ) |
| loss = ce_loss + self.config.curiosity_coeff * curiosity_loss |
|
|
| if not return_dict: |
| out = (logits,) |
| if loss is not None: |
| out = (loss,) + out |
| if pkv is not None: |
| out += (pkv,) |
| return out |
|
|
| output = CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=pkv, |
| hidden_states=all_hs, |
| attentions=all_attn, |
| ) |
| |
| if labels is not None: |
| output["ce_loss"] = ce_loss |
| output["curiosity_loss"] = curiosity_loss |
| return output |
|
|
| |
| |
| |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past_key_values: Optional[List] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| **kwargs, |
| ) -> dict: |
| if past_key_values is not None: |
| |
| input_ids = input_ids[:, -1:] |
|
|
| |
| position_ids = kwargs.get("position_ids", None) |
| if attention_mask is not None and position_ids is None: |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values is not None: |
| position_ids = position_ids[:, -1:] |
|
|
| model_inputs: dict = {} |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs["inputs_embeds"] = inputs_embeds |
| else: |
| model_inputs["input_ids"] = input_ids |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache", True), |
| "attention_mask": attention_mask, |
| } |
| ) |
| return model_inputs |
|
|
| @staticmethod |
| def _reorder_cache( |
| past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], |
| beam_idx: torch.LongTensor, |
| ) -> List[Tuple[torch.Tensor, torch.Tensor]]: |
| """Re-order KV cache for beam search.""" |
| return [ |
| ( |
| past_kv[0].index_select(0, beam_idx.to(past_kv[0].device)), |
| past_kv[1].index_select(0, beam_idx.to(past_kv[1].device)), |
| ) |
| for past_kv in past_key_values |
| ] |
|
|
|
|
| |
| |
| |
|
|
| SABERConfig.register_for_auto_class("AutoConfig") |
| SABERForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
|
|