| """ |
| Representation Learning Dynamics Experiment |
| ============================================ |
| How does a model's internal representation respond to continued training |
| on the same task vs. learning a new one? |
| |
| Experiment design: |
| Phase 1: Train model on Task A (modular addition) until convergence |
| Phase 2: Fork into two branches: |
| Branch A→A: Continue training on Task A (same task, more data) |
| Branch A→B: Switch to Task B (modular subtraction) |
| Track: CKA, subspace angles, gradient alignment, attention entropy, |
| representation variance explained, probing accuracy — all per layer, |
| at every checkpoint. |
| |
| The key contrast reveals what "learning" looks like at the representation |
| level vs. what "forgetting" looks like — and the precise moment they diverge. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import numpy as np |
| import json |
| import os |
| import copy |
| import time |
| from pathlib import Path |
| from collections import defaultdict |
| from typing import Dict, List, Optional |
|
|
| from model import SmallTransformer, TransformerConfig |
| from tasks import ( |
| ModularArithmeticDataset, get_probe_data, get_dataloaders, |
| DEFAULT_P, NUM_SPECIAL |
| ) |
| from representation_tracker import ( |
| linear_CKA, svcca, subspace_angles, mean_subspace_angle_degrees, |
| gradient_alignment, attention_entropy, task_variance_explained, |
| parameter_delta_cosine, weight_change_magnitude_per_layer, |
| cka_heatmap |
| ) |
|
|
|
|
| def evaluate(model, dataloader, device) -> Dict[str, float]: |
| """Evaluate accuracy and loss on a dataset.""" |
| model.eval() |
| total_loss = 0 |
| total_correct = 0 |
| total_count = 0 |
|
|
| with torch.no_grad(): |
| for batch in dataloader: |
| input_ids = batch['input_ids'].to(device) |
| labels = batch['labels'].to(device) |
| out = model(input_ids, labels=labels) |
|
|
| total_loss += out['loss'].item() * input_ids.shape[0] |
| |
| preds = out['logits'][:, -1, :].argmax(dim=-1) |
| targets = labels[:, -1] |
| total_correct += (preds == targets).sum().item() |
| total_count += input_ids.shape[0] |
|
|
| return { |
| 'loss': total_loss / total_count, |
| 'accuracy': total_correct / total_count, |
| } |
|
|
|
|
| def collect_representations(model, probe_input_ids, device, |
| token_position: int = -1) -> Dict: |
| """ |
| Collect all representation data from a single forward pass on probe data. |
| Returns hidden states, attention patterns, MLP activations. |
| """ |
| model.eval() |
| with torch.no_grad(): |
| out = model(probe_input_ids.to(device), return_internals=True) |
|
|
| |
| hidden_states = [hs[:, token_position, :].cpu() |
| for hs in out['hidden_states']] |
| attn_weights = [aw.cpu() for aw in out['attn_weights']] |
| mlp_hidden = [mh[:, token_position, :].cpu() |
| for mh in out['mlp_hidden']] |
|
|
| return { |
| 'hidden_states': hidden_states, |
| 'attn_weights': attn_weights, |
| 'mlp_hidden': mlp_hidden, |
| } |
|
|
|
|
| def compute_all_metrics( |
| model, model_init_state, model_phase1_state, |
| reps_current, reps_at_init, reps_at_phase1_end, |
| probe_input_ids_a, probe_labels_a, |
| probe_input_ids_b, probe_labels_b, |
| device, config |
| ) -> Dict: |
| """ |
| Compute the full suite of representation metrics at a single checkpoint. |
| """ |
| metrics = {} |
| n_layers = config.n_layers + 1 |
|
|
| |
| for layer_idx in range(n_layers): |
| prefix = f'layer_{layer_idx}' |
| curr = reps_current['hidden_states'][layer_idx] |
| init = reps_at_init['hidden_states'][layer_idx] |
| p1 = reps_at_phase1_end['hidden_states'][layer_idx] |
|
|
| |
| metrics[f'{prefix}/cka_vs_init'] = linear_CKA(curr, init) |
|
|
| |
| metrics[f'{prefix}/cka_vs_phase1'] = linear_CKA(curr, p1) |
|
|
| |
| metrics[f'{prefix}/svcca_vs_phase1'] = svcca(curr, p1) |
|
|
| |
| k = min(10, curr.shape[0] // 2, curr.shape[1]) |
| if k > 0: |
| metrics[f'{prefix}/subspace_angle_vs_phase1'] = \ |
| mean_subspace_angle_degrees(curr, p1, k=k) |
| else: |
| metrics[f'{prefix}/subspace_angle_vs_phase1'] = 0.0 |
|
|
| |
| for layer_idx, aw in enumerate(reps_current['attn_weights']): |
| ent = attention_entropy(aw) |
| metrics[f'layer_{layer_idx+1}/attn_entropy_mean'] = ent['mean_entropy'] |
| for h, he in enumerate(ent['per_head_entropy']): |
| metrics[f'layer_{layer_idx+1}/head_{h}_entropy'] = he |
|
|
| |
| for layer_idx in range(n_layers): |
| curr = reps_current['hidden_states'][layer_idx] |
| if len(set(probe_labels_a.tolist())) > 1: |
| tve = task_variance_explained( |
| curr, torch.tensor(probe_labels_a, dtype=torch.float), n_components=10 |
| ) |
| metrics[f'layer_{layer_idx}/task_a_var_explained'] = tve['weighted_r2'] |
|
|
| |
| current_state = {k: v.cpu() for k, v in model.state_dict().items()} |
| wc_from_init = weight_change_magnitude_per_layer(model_init_state, current_state) |
| wc_from_p1 = weight_change_magnitude_per_layer(model_phase1_state, current_state) |
|
|
| |
| for block_idx in range(config.n_layers): |
| init_total = sum(v for k, v in wc_from_init.items() |
| if f'blocks.{block_idx}' in k) |
| p1_total = sum(v for k, v in wc_from_p1.items() |
| if f'blocks.{block_idx}' in k) |
| metrics[f'block_{block_idx}/weight_change_from_init'] = init_total |
| metrics[f'block_{block_idx}/weight_change_from_phase1'] = p1_total |
|
|
| return metrics |
|
|
|
|
| def train_phase( |
| model, optimizer, dataloader, n_epochs: int, |
| device, phase_name: str, |
| |
| model_init_state, model_phase1_state, |
| reps_at_init, reps_at_phase1_end, |
| probe_input_ids_a, probe_labels_a, |
| probe_input_ids_b, probe_labels_b, |
| eval_loaders: Dict, |
| config: TransformerConfig, |
| checkpoint_every: int = 50, |
| output_dir: str = 'results', |
| ) -> List[Dict]: |
| """ |
| Train for n_epochs, collecting representation metrics periodically. |
| """ |
| history = [] |
| global_step = 0 |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| for epoch in range(n_epochs): |
| model.train() |
| epoch_loss = 0 |
| n_batches = 0 |
|
|
| for batch in dataloader: |
| input_ids = batch['input_ids'].to(device) |
| labels = batch['labels'].to(device) |
|
|
| out = model(input_ids, labels=labels) |
| loss = out['loss'] |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| epoch_loss += loss.item() |
| n_batches += 1 |
| global_step += 1 |
|
|
| |
| if global_step % checkpoint_every == 0: |
| model.eval() |
|
|
| |
| reps_current = collect_representations( |
| model, probe_input_ids_a, device |
| ) |
|
|
| |
| step_metrics = compute_all_metrics( |
| model, model_init_state, model_phase1_state, |
| reps_current, reps_at_init, reps_at_phase1_end, |
| probe_input_ids_a, probe_labels_a, |
| probe_input_ids_b, probe_labels_b, |
| device, config |
| ) |
|
|
| |
| for name, loader in eval_loaders.items(): |
| eval_res = evaluate(model, loader, device) |
| step_metrics[f'eval/{name}_loss'] = eval_res['loss'] |
| step_metrics[f'eval/{name}_acc'] = eval_res['accuracy'] |
|
|
| |
| |
| batch_a = next(iter(eval_loaders['add_test'])) |
| batch_b = next(iter(eval_loaders['subtract_test'])) |
|
|
| def loss_fn(m, b): |
| return m(b['input_ids'].to(device), |
| labels=b['labels'].to(device))['loss'] |
|
|
| try: |
| ga = gradient_alignment(model, batch_a, batch_b, loss_fn) |
| step_metrics['gradient_alignment_a_vs_b'] = ga |
| except Exception: |
| step_metrics['gradient_alignment_a_vs_b'] = 0.0 |
|
|
| step_metrics['phase'] = phase_name |
| step_metrics['epoch'] = epoch |
| step_metrics['step'] = global_step |
| step_metrics['train_loss'] = epoch_loss / n_batches |
|
|
| history.append(step_metrics) |
|
|
| print(f"[{phase_name}] Step {global_step} | " |
| f"Loss: {epoch_loss/n_batches:.4f} | " |
| f"Add acc: {step_metrics.get('eval/add_test_acc', 0):.3f} | " |
| f"Sub acc: {step_metrics.get('eval/subtract_test_acc', 0):.3f} | " |
| f"CKA(L1 vs P1): {step_metrics.get('layer_1/cka_vs_phase1', 0):.3f} | " |
| f"Grad align: {step_metrics.get('gradient_alignment_a_vs_b', 0):.3f}") |
|
|
| model.train() |
|
|
| |
| print(f"[{phase_name}] Epoch {epoch+1}/{n_epochs} complete, " |
| f"avg loss: {epoch_loss/n_batches:.4f}") |
|
|
| return history |
|
|
|
|
| def run_experiment( |
| p: int = DEFAULT_P, |
| n_layers: int = 2, |
| d_model: int = 128, |
| n_heads: int = 4, |
| d_mlp: int = 512, |
| phase1_epochs: int = 100, |
| phase2_epochs: int = 100, |
| lr: float = 1e-3, |
| weight_decay: float = 1.0, |
| batch_size: int = 512, |
| train_frac: float = 0.5, |
| checkpoint_every: int = 50, |
| output_dir: str = 'results', |
| seed: int = 42, |
| ): |
| """ |
| Run the full two-phase experiment. |
| """ |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
|
|
| |
| config = TransformerConfig( |
| vocab_size=p + NUM_SPECIAL, |
| n_layers=n_layers, |
| d_model=d_model, |
| n_heads=n_heads, |
| d_mlp=d_mlp, |
| max_seq_len=5, |
| ) |
|
|
| model = SmallTransformer(config).to(device) |
| print(f"Model parameters: {model.count_parameters():,}") |
|
|
| |
| model_init_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
|
|
| |
| loaders = get_dataloaders(p=p, batch_size=batch_size, |
| train_frac=train_frac, seed=seed) |
|
|
| |
| ds_a = ModularArithmeticDataset('add', p=p, split='test', train_frac=train_frac, seed=seed) |
| ds_b = ModularArithmeticDataset('subtract', p=p, split='test', train_frac=train_frac, seed=seed) |
| probe_ids_a, probe_labels_a = get_probe_data(ds_a, n_samples=min(500, len(ds_a))) |
| probe_ids_b, probe_labels_b = get_probe_data(ds_b, n_samples=min(500, len(ds_b))) |
|
|
| |
| reps_at_init = collect_representations(model, probe_ids_a, device) |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("PHASE 1: Training on Task A (Modular Addition)") |
| print("=" * 60) |
|
|
| optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
| |
| phase1_history = train_phase( |
| model=model, |
| optimizer=optimizer, |
| dataloader=loaders['add_train'], |
| n_epochs=phase1_epochs, |
| device=device, |
| phase_name='phase1_add', |
| model_init_state=model_init_state, |
| model_phase1_state=model_init_state, |
| reps_at_init=reps_at_init, |
| reps_at_phase1_end=reps_at_init, |
| probe_input_ids_a=probe_ids_a, |
| probe_labels_a=probe_labels_a, |
| probe_input_ids_b=probe_ids_b, |
| probe_labels_b=probe_labels_b, |
| eval_loaders=loaders, |
| config=config, |
| checkpoint_every=checkpoint_every, |
| output_dir=output_dir, |
| ) |
|
|
| |
| model_phase1_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
| reps_at_phase1_end = collect_representations(model, probe_ids_a, device) |
| phase1_final_eval = evaluate(model, loaders['add_test'], device) |
| print(f"\nPhase 1 final — Add accuracy: {phase1_final_eval['accuracy']:.3f}") |
|
|
| |
| torch.save(model.state_dict(), os.path.join(output_dir, 'phase1_checkpoint.pt')) |
|
|
| |
| |
| |
|
|
| |
| print("\n" + "=" * 60) |
| print("PHASE 2a: Branch A→A (Continue training on Addition)") |
| print("=" * 60) |
|
|
| model_aa = SmallTransformer(config).to(device) |
| model_aa.load_state_dict(torch.load(os.path.join(output_dir, 'phase1_checkpoint.pt'), |
| weights_only=True)) |
| optimizer_aa = optim.AdamW(model_aa.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
| history_aa = train_phase( |
| model=model_aa, |
| optimizer=optimizer_aa, |
| dataloader=loaders['add_train'], |
| n_epochs=phase2_epochs, |
| device=device, |
| phase_name='phase2_aa', |
| model_init_state=model_init_state, |
| model_phase1_state=model_phase1_state, |
| reps_at_init=reps_at_init, |
| reps_at_phase1_end=reps_at_phase1_end, |
| probe_input_ids_a=probe_ids_a, |
| probe_labels_a=probe_labels_a, |
| probe_input_ids_b=probe_ids_b, |
| probe_labels_b=probe_labels_b, |
| eval_loaders=loaders, |
| config=config, |
| checkpoint_every=checkpoint_every, |
| output_dir=output_dir, |
| ) |
|
|
| |
| print("\n" + "=" * 60) |
| print("PHASE 2b: Branch A→B (Switch to Subtraction)") |
| print("=" * 60) |
|
|
| model_ab = SmallTransformer(config).to(device) |
| model_ab.load_state_dict(torch.load(os.path.join(output_dir, 'phase1_checkpoint.pt'), |
| weights_only=True)) |
| optimizer_ab = optim.AdamW(model_ab.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
| history_ab = train_phase( |
| model=model_ab, |
| optimizer=optimizer_ab, |
| dataloader=loaders['subtract_train'], |
| n_epochs=phase2_epochs, |
| device=device, |
| phase_name='phase2_ab', |
| model_init_state=model_init_state, |
| model_phase1_state=model_phase1_state, |
| reps_at_init=reps_at_init, |
| reps_at_phase1_end=reps_at_phase1_end, |
| probe_input_ids_a=probe_ids_a, |
| probe_labels_a=probe_labels_a, |
| probe_input_ids_b=probe_ids_b, |
| probe_labels_b=probe_labels_b, |
| eval_loaders=loaders, |
| config=config, |
| checkpoint_every=checkpoint_every, |
| output_dir=output_dir, |
| ) |
|
|
| |
| |
| |
| print("\n" + "=" * 60) |
| print("PHASE 3: Cross-model representation comparison") |
| print("=" * 60) |
|
|
| reps_aa = collect_representations(model_aa, probe_ids_a, device) |
| reps_ab = collect_representations(model_ab, probe_ids_a, device) |
|
|
| cross_metrics = {} |
| for layer_idx in range(config.n_layers + 1): |
| ha = reps_aa['hidden_states'][layer_idx] |
| hb = reps_ab['hidden_states'][layer_idx] |
| hp1 = reps_at_phase1_end['hidden_states'][layer_idx] |
|
|
| cross_metrics[f'layer_{layer_idx}/cka_aa_vs_ab'] = linear_CKA(ha, hb) |
| cross_metrics[f'layer_{layer_idx}/cka_aa_vs_p1'] = linear_CKA(ha, hp1) |
| cross_metrics[f'layer_{layer_idx}/cka_ab_vs_p1'] = linear_CKA(hb, hp1) |
| cross_metrics[f'layer_{layer_idx}/subspace_angle_aa_vs_ab'] = \ |
| mean_subspace_angle_degrees(ha, hb, k=min(10, ha.shape[0] // 2, ha.shape[1])) |
|
|
| |
| heatmap_aa_vs_ab = cka_heatmap(reps_aa['hidden_states'], reps_ab['hidden_states']) |
| heatmap_aa_vs_p1 = cka_heatmap(reps_aa['hidden_states'], |
| reps_at_phase1_end['hidden_states']) |
| heatmap_ab_vs_p1 = cka_heatmap(reps_ab['hidden_states'], |
| reps_at_phase1_end['hidden_states']) |
|
|
| |
| params_init = [v for v in model_init_state.values()] |
| params_aa = [v.cpu() for v in model_aa.state_dict().values()] |
| params_ab = [v.cpu() for v in model_ab.state_dict().values()] |
| params_p1 = [v for v in model_phase1_state.values()] |
|
|
| cross_metrics['param_delta_cosine_aa_vs_ab'] = \ |
| parameter_delta_cosine(params_p1, params_aa, params_ab) |
| cross_metrics['param_delta_cosine_aa_vs_p1_from_init'] = \ |
| parameter_delta_cosine(params_init, params_p1, params_aa) |
|
|
| print("\n=== Cross-model metrics ===") |
| for k, v in sorted(cross_metrics.items()): |
| print(f" {k}: {v:.4f}") |
|
|
| |
| |
| |
| results = { |
| 'config': { |
| 'p': p, 'n_layers': n_layers, 'd_model': d_model, |
| 'n_heads': n_heads, 'd_mlp': d_mlp, |
| 'phase1_epochs': phase1_epochs, 'phase2_epochs': phase2_epochs, |
| 'lr': lr, 'weight_decay': weight_decay, 'batch_size': batch_size, |
| 'train_frac': train_frac, 'seed': seed, |
| 'n_parameters': model.count_parameters(), |
| }, |
| 'phase1_history': phase1_history, |
| 'phase2_aa_history': history_aa, |
| 'phase2_ab_history': history_ab, |
| 'cross_metrics': cross_metrics, |
| 'cka_heatmaps': { |
| 'aa_vs_ab': heatmap_aa_vs_ab.tolist(), |
| 'aa_vs_p1': heatmap_aa_vs_p1.tolist(), |
| 'ab_vs_p1': heatmap_ab_vs_p1.tolist(), |
| }, |
| } |
|
|
| results_path = os.path.join(output_dir, 'experiment_results.json') |
| with open(results_path, 'w') as f: |
| json.dump(results, f, indent=2, default=str) |
| print(f"\nResults saved to {results_path}") |
|
|
| |
| torch.save(model_aa.state_dict(), os.path.join(output_dir, 'model_aa_final.pt')) |
| torch.save(model_ab.state_dict(), os.path.join(output_dir, 'model_ab_final.pt')) |
|
|
| return results |
|
|
|
|
| if __name__ == '__main__': |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--p', type=int, default=DEFAULT_P) |
| parser.add_argument('--n-layers', type=int, default=2) |
| parser.add_argument('--d-model', type=int, default=128) |
| parser.add_argument('--n-heads', type=int, default=4) |
| parser.add_argument('--d-mlp', type=int, default=512) |
| parser.add_argument('--phase1-epochs', type=int, default=100) |
| parser.add_argument('--phase2-epochs', type=int, default=100) |
| parser.add_argument('--lr', type=float, default=1e-3) |
| parser.add_argument('--weight-decay', type=float, default=1.0) |
| parser.add_argument('--batch-size', type=int, default=512) |
| parser.add_argument('--train-frac', type=float, default=0.5) |
| parser.add_argument('--checkpoint-every', type=int, default=50) |
| parser.add_argument('--output-dir', type=str, default='results') |
| parser.add_argument('--seed', type=int, default=42) |
| args = parser.parse_args() |
|
|
| run_experiment(**vars(args)) |
|
|