interpgpt-standard-23M / modeling_interpgpt.py
connaaa's picture
Phase 1 release: InterpGPT matched-pair checkpoint
378744c verified
"""
HuggingFace PreTrainedModel wrapper for InterpGPT / TaskGPT.
Weights map 1:1 to the original gpt_model.TaskGPT state dict, so the same
.pt checkpoints produced during Phase 1 load here without remapping.
Usage (after upload):
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("connaaa/interpgpt-standard-23M",
trust_remote_code=True)
# Or for the analysis pipeline:
from transformer_lens import HookedTransformer
hooked = HookedTransformer.from_pretrained("connaaa/interpgpt-standard-23M",
hf_model=model,
...)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from .configuration_interpgpt import InterpGPTConfig
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * norm * self.weight
class RotaryPositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_seq_len: int = 512, base: float = 10000.0):
super().__init__()
assert d_model % 2 == 0
inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(max_seq_len, dtype=torch.float)
freqs = torch.einsum("i,j->ij", t, inv_freq)
self.register_buffer("cos_cached", freqs.cos())
self.register_buffer("sin_cached", freqs.sin())
def forward(self, seq_len: int):
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def apply_rotary_emb(x, cos, sin):
d_half = x.shape[-1] // 2
x1, x2 = x[..., :d_half], x[..., d_half:]
cos = cos[: x.shape[2]].unsqueeze(0).unsqueeze(0)
sin = sin[: x.shape[2]].unsqueeze(0).unsqueeze(0)
out1 = x1 * cos - x2 * sin
out2 = x2 * cos + x1 * sin
return torch.cat([out1, out2], dim=-1)
class CausalSelfAttention(nn.Module):
def __init__(self, config: InterpGPTConfig):
super().__init__()
assert config.d_model % config.n_heads == 0
self.n_heads = config.n_heads
self.head_dim = config.d_model // config.n_heads
self.qkv = nn.Linear(config.d_model, 3 * config.d_model, bias=config.bias)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=config.bias)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.rope = RotaryPositionalEncoding(self.head_dim, config.max_seq_len)
mask = torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
self.register_buffer("causal_mask", mask.view(1, 1, config.max_seq_len, config.max_seq_len))
def forward(self, x, kv_cache=None):
B, T, D = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rope(T)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
if kv_cache is not None:
if "k" in kv_cache:
k = torch.cat([kv_cache["k"], k], dim=2)
v = torch.cat([kv_cache["v"], v], dim=2)
kv_cache["k"] = k
kv_cache["v"] = v
if hasattr(F, "scaled_dot_product_attention") and kv_cache is None:
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=True,
)
else:
scale = 1.0 / math.sqrt(self.head_dim)
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
T_k = k.size(2)
causal = self.causal_mask[:, :, T_k - T : T_k, :T_k]
attn = attn.masked_fill(causal == 0, float("-inf"))
attn = F.softmax(attn, dim=-1)
attn = self.attn_dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, D)
return self.resid_dropout(self.out_proj(out))
class FeedForward(nn.Module):
def __init__(self, config: InterpGPTConfig):
super().__init__()
hidden = int(2 * config.d_ff / 3)
hidden = 64 * ((hidden + 63) // 64)
self.gate_proj = nn.Linear(config.d_model, hidden, bias=config.bias)
self.up_proj = nn.Linear(config.d_model, hidden, bias=config.bias)
self.down_proj = nn.Linear(hidden, config.d_model, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
class TransformerBlock(nn.Module):
def __init__(self, config: InterpGPTConfig):
super().__init__()
self.ln1 = RMSNorm(config.d_model)
self.attn = CausalSelfAttention(config)
self.ln2 = RMSNorm(config.d_model)
self.ffn = FeedForward(config)
def forward(self, x, kv_cache=None):
x = x + self.attn(self.ln1(x), kv_cache)
x = x + self.ffn(self.ln2(x))
return x
class InterpGPTModel(PreTrainedModel):
"""
HF-wrapped InterpGPT / TaskGPT. State dict parameter names match the
original gpt_model.TaskGPT exactly so Phase 1 .pt checkpoints load
via state_dict without remapping.
"""
config_class = InterpGPTConfig
base_model_prefix = "interpgpt"
supports_gradient_checkpointing = False
def __init__(self, config: InterpGPTConfig):
super().__init__(config)
self.config = config
self.token_embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_id)
self.drop = nn.Dropout(config.dropout)
self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.ln_final = RMSNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = self.token_embedding.weight
self.post_init()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.padding_idx is not None:
nn.init.zeros_(module.weight[module.padding_idx])
def forward(self, input_ids, attention_mask=None, labels=None, loss_mask=None, **kwargs):
B, T = input_ids.shape
x = self.drop(self.token_embedding(input_ids))
for block in self.blocks:
x = block(x)
x = self.ln_final(x)
logits = self.lm_head(x)
output = {"logits": logits}
if labels is not None:
shift_logits = logits[:, :-1].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=self.config.pad_id,
reduction="none",
).view(B, T - 1)
if loss_mask is not None:
shift_mask = loss_mask[:, 1:].contiguous().float()
loss = (loss * shift_mask).sum() / shift_mask.sum().clamp(min=1.0)
else:
loss = loss.mean()
output["loss"] = loss
return output