""" Model definition for analysis/plotting of grid-run checkpoints. Same architecture as model_tbyt_train.py but with: - Attention weight storage (raw_attn, attn) in CasualSelfAttention - GPTIntervention class for attention intervention experiments Weight-compatible with model_tbyt_train.py checkpoints. """ import torch import torch.nn as nn import torch.nn.functional as F import types class MLP(nn.Module): def __init__(self, config): super().__init__() self.fc_1 = nn.Linear(config.n_embd, 3 * config.n_embd) self.gelu = nn.GELU(approximate='tanh') self.fc_2 = nn.Linear(config.n_embd * 3, config.n_embd) self.NANO_SCALE_GPT = True def forward(self, x): return self.fc_2(self.gelu(self.fc_1(x))) class CasualSelfAttention(nn.Module): def __init__(self, config): super().__init__() self.n_embd = config.n_embd self.n_heads = config.n_heads self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) self.c_proj = nn.Linear(config.n_embd, config.n_embd) seq_len = config.block_size * 2 + 1 self.register_buffer('bias', torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)) self.c_proj.NANOGPT_SCALE_INIT = True self.config = config def forward(self, x, layer_n=-1): B, T, C = x.size() qkv = self.c_attn(x) q, k, v = qkv.split(self.n_embd, dim=2) q = q.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) k = k.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) v = v.view(B, T, self.n_heads, C // self.n_heads).transpose(1, 2) attn = q @ k.transpose(-1, -2) * 0.1 / (k.size(-1)) ** 0.5 attn = attn.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf')) self.raw_attn = attn.clone().detach().view(T, T) attn = F.softmax(attn, dim=-1) self.attn = attn.clone().detach().view(T, T) y = attn @ v y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.c_proj(y) return y class Block(nn.Module): def __init__(self, config): super().__init__() self.c_attn = CasualSelfAttention(config) self.c_fc = MLP(config) self.ln_1 = nn.LayerNorm(config.n_embd) self.ln_2 = nn.LayerNorm(config.n_embd) def forward(self, x, layer_n=-1): x = x + self.c_attn(self.ln_1(x), layer_n=layer_n) return x + self.c_fc(self.ln_2(x)) class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config self.n_layers = config.n_layers self.alpha = 100.0 self.transformer = nn.ModuleDict(dict( wte=nn.Embedding(config.vocab_size + 1, config.n_embd), wpe=nn.Embedding(config.block_size * 4 + 1, config.n_embd), h=nn.ModuleList([Block(config) for _ in range(config.n_layers)]), ln_f=nn.LayerNorm(config.n_embd), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head.weight = self.transformer.wte.weight self.apply(self._init_weights) def _init_weights(self, module): std = 0.02 if isinstance(module, nn.Linear): if hasattr(module, 'NANOGPT_SCALE_INIT'): std *= (2 * self.n_layers) ** -0.5 torch.nn.init.normal_(module.weight, mean=0, std=std) if module.bias is not None: torch.nn.init.zeros_(module.bias) if isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0, std=std) def forward(self, idx, targets=None, flag=False): B, T = idx.size() x = self.transformer.wte(idx) for layer_n, block in enumerate(self.transformer.h): x = block(x, layer_n) if self.config.with_layer_norm: x = self.transformer.ln_f(x) logits = self.lm_head(x) tensor1 = logits[:, self.config.block_size:T - 1, :].contiguous().view(-1, logits.size(-1)) tensor2 = idx[:, self.config.block_size + 1:].contiguous().view(-1) loss = F.cross_entropy(tensor1, tensor2) return logits, loss class GPTConfig: block_size: int = 32 vocab_size: int = 128 n_layers = 2 n_heads = 1 n_embd = 64 with_layer_norm: bool = False def __init__(self, block_size=None, vocab_size=None, with_layer_norm=False): if block_size: self.block_size = block_size if vocab_size: self.vocab_size = vocab_size self.with_layer_norm = with_layer_norm class GPTIntervention: def __init__(self, gpt, idx): super().__init__() self.config = gpt.config self.gpt = gpt self.idx = idx _, _ = self.gpt(self.idx) self.attn = [self.gpt.transformer.h[i].c_attn.attn for i in range(self.config.n_layers)] self.raw_attn = [self.gpt.transformer.h[i].c_attn.raw_attn for i in range(self.config.n_layers)] self.old_attention_forward = [None] * self.config.n_layers def read_attention(self, layer, loc1, loc2): return self.raw_attn[layer][loc1, loc2] def check_if_still_works(self): logits, _ = self.gpt(self.idx) return (torch.argmax(logits, dim=-1)[0, self.location].item(), self.idx[0, self.location + 1].item()) def intervent_attention(self, attention_layer_num, location, unsorted_lb, unsorted_ub, unsorted_lb_num, unsorted_ub_num, unsorted_intensity_inc, sorted_lb, sorted_num, sorted_intensity_inc): self.location = location target_val = self.idx[0, location].item() next_number = self.idx[0, location + 1].item() unsorted_part = self.idx[0, :self.config.block_size] sorted_part = self.idx[0, self.config.block_size + 1:2 * self.config.block_size + 1] unsorted_lb_mask = ((unsorted_part >= target_val - unsorted_lb) & (unsorted_part <= target_val) & (unsorted_part != next_number)) unsorted_lb_indices = torch.where(unsorted_lb_mask)[0] if len(unsorted_lb_indices) < unsorted_lb_num: raise Exception("Not enough numbers for unsorted_lb_num") unsorted_lb_selected = unsorted_lb_indices[torch.randperm(len(unsorted_lb_indices))[:unsorted_lb_num]] unsorted_ub_mask = ((unsorted_part > target_val) & (unsorted_part <= target_val + unsorted_ub) & (unsorted_part != next_number)) unsorted_ub_indices = torch.where(unsorted_ub_mask)[0] if len(unsorted_ub_indices) < unsorted_ub_num: raise Exception("Not enough numbers for unsorted_ub_num") unsorted_ub_selected = (unsorted_ub_indices[torch.randperm(len(unsorted_ub_indices))[:unsorted_ub_num]] if len(unsorted_ub_indices) > 0 else torch.tensor([], dtype=torch.long)) sorted_mask = torch.abs(sorted_part - target_val) <= sorted_lb sorted_indices = torch.where(sorted_mask)[0] if len(sorted_indices) < sorted_num: raise Exception("Not enough numbers for sorted_num") sorted_selected = sorted_indices[torch.randperm(len(sorted_indices))[:sorted_num]] sorted_actual_indices = sorted_selected + self.config.block_size + 1 next_number_location = torch.where(self.idx[0, :self.config.block_size] == next_number)[0][0].item() main_attention_val = self.read_attention(attention_layer_num, location, next_number_location).item() config = self.config def new_forward(self_attn, x, layer_n=-1): B, T, C = x.size() qkv = self_attn.c_attn(x) q, k, v = qkv.split(self_attn.n_embd, dim=2) q = q.view(B, T, self_attn.n_heads, C // self_attn.n_heads).transpose(1, 2) k = k.view(B, T, self_attn.n_heads, C // self_attn.n_heads).transpose(1, 2) v = v.view(B, T, self_attn.n_heads, C // self_attn.n_heads).transpose(1, 2) attn = q @ k.transpose(-1, -2) * 0.1 / (k.size(-1)) ** 0.5 for index in unsorted_lb_selected: attn[:, :, location, index.item()] = main_attention_val + unsorted_intensity_inc for index in unsorted_ub_selected: attn[:, :, location, index.item()] = main_attention_val + unsorted_intensity_inc for index in sorted_actual_indices: attn[:, :, location, index.item()] = main_attention_val + sorted_intensity_inc attn = attn.masked_fill(self_attn.bias[:, :, :T, :T] == 0, float('-inf')) attn = F.softmax(attn, dim=-1) y = attn @ v y = y.transpose(1, 2).contiguous().view(B, T, C) y = self_attn.c_proj(y) return y attention_module = self.gpt.transformer.h[attention_layer_num].c_attn self.old_attention_forward[attention_layer_num] = attention_module.forward attention_module.forward = types.MethodType(new_forward, attention_module) def revert_attention(self, attention_layer_num): if self.old_attention_forward[attention_layer_num] is None: raise Exception("No old attention forward found") self.gpt.transformer.h[attention_layer_num].c_attn.forward = self.old_attention_forward[attention_layer_num]