""" HuggingFace PreTrainedModel wrapper for InterpGPT / TaskGPT. Weights map 1:1 to the original gpt_model.TaskGPT state dict, so the same .pt checkpoints produced during Phase 1 load here without remapping. Usage (after upload): from transformers import AutoModel, AutoTokenizer model = AutoModel.from_pretrained("connaaa/interpgpt-standard-23M", trust_remote_code=True) # Or for the analysis pipeline: from transformer_lens import HookedTransformer hooked = HookedTransformer.from_pretrained("connaaa/interpgpt-standard-23M", hf_model=model, ...) """ import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from .configuration_interpgpt import InterpGPTConfig class RMSNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(d_model)) self.eps = eps def forward(self, x): norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * norm * self.weight class RotaryPositionalEncoding(nn.Module): def __init__(self, d_model: int, max_seq_len: int = 512, base: float = 10000.0): super().__init__() assert d_model % 2 == 0 inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(max_seq_len, dtype=torch.float) freqs = torch.einsum("i,j->ij", t, inv_freq) self.register_buffer("cos_cached", freqs.cos()) self.register_buffer("sin_cached", freqs.sin()) def forward(self, seq_len: int): return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def apply_rotary_emb(x, cos, sin): d_half = x.shape[-1] // 2 x1, x2 = x[..., :d_half], x[..., d_half:] cos = cos[: x.shape[2]].unsqueeze(0).unsqueeze(0) sin = sin[: x.shape[2]].unsqueeze(0).unsqueeze(0) out1 = x1 * cos - x2 * sin out2 = x2 * cos + x1 * sin return torch.cat([out1, out2], dim=-1) class CausalSelfAttention(nn.Module): def __init__(self, config: InterpGPTConfig): super().__init__() assert config.d_model % config.n_heads == 0 self.n_heads = config.n_heads self.head_dim = config.d_model // config.n_heads self.qkv = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias) self.out_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias) self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) self.rope = RotaryPositionalEncoding(self.head_dim, config.max_seq_len) mask = torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)) self.register_buffer("causal_mask", mask.view(1, 1, config.max_seq_len, config.max_seq_len)) def forward(self, x, kv_cache=None): B, T, D = x.shape qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) cos, sin = self.rope(T) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) if kv_cache is not None: if "k" in kv_cache: k = torch.cat([kv_cache["k"], k], dim=2) v = torch.cat([kv_cache["v"], v], dim=2) kv_cache["k"] = k kv_cache["v"] = v if hasattr(F, "scaled_dot_product_attention") and kv_cache is None: out = F.scaled_dot_product_attention( q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True, ) else: scale = 1.0 / math.sqrt(self.head_dim) attn = torch.matmul(q, k.transpose(-2, -1)) * scale T_k = k.size(2) causal = self.causal_mask[:, :, T_k - T : T_k, :T_k] attn = attn.masked_fill(causal == 0, float("-inf")) attn = F.softmax(attn, dim=-1) attn = self.attn_dropout(attn) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(B, T, D) return self.resid_dropout(self.out_proj(out)) class FeedForward(nn.Module): def __init__(self, config: InterpGPTConfig): super().__init__() hidden = int(2 * config.d_ff / 3) hidden = 64 * ((hidden + 63) // 64) self.gate_proj = nn.Linear(config.d_model, hidden, bias=config.bias) self.up_proj = nn.Linear(config.d_model, hidden, bias=config.bias) self.down_proj = nn.Linear(hidden, config.d_model, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))) class TransformerBlock(nn.Module): def __init__(self, config: InterpGPTConfig): super().__init__() self.ln1 = RMSNorm(config.d_model) self.attn = CausalSelfAttention(config) self.ln2 = RMSNorm(config.d_model) self.ffn = FeedForward(config) def forward(self, x, kv_cache=None): x = x + self.attn(self.ln1(x), kv_cache) x = x + self.ffn(self.ln2(x)) return x class InterpGPTModel(PreTrainedModel): """ HF-wrapped InterpGPT / TaskGPT. State dict parameter names match the original gpt_model.TaskGPT exactly so Phase 1 .pt checkpoints load via state_dict without remapping. """ config_class = InterpGPTConfig base_model_prefix = "interpgpt" supports_gradient_checkpointing = False def __init__(self, config: InterpGPTConfig): super().__init__(config) self.config = config self.token_embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_id) self.drop = nn.Dropout(config.dropout) self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) self.ln_final = RMSNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head.weight = self.token_embedding.weight self.post_init() def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.padding_idx is not None: nn.init.zeros_(module.weight[module.padding_idx]) def forward(self, input_ids, attention_mask=None, labels=None, loss_mask=None, **kwargs): B, T = input_ids.shape x = self.drop(self.token_embedding(input_ids)) for block in self.blocks: x = block(x) x = self.ln_final(x) logits = self.lm_head(x) output = {"logits": logits} if labels is not None: shift_logits = logits[:, :-1].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=self.config.pad_id, reduction="none", ).view(B, T - 1) if loss_mask is not None: shift_mask = loss_mask[:, 1:].contiguous().float() loss = (loss * shift_mask).sum() / shift_mask.sum().clamp(min=1.0) else: loss = loss.mean() output["loss"] = loss return output