| """ |
| Visualization for Representation Learning Dynamics experiment. |
| ================================================================ |
| Generates publication-quality figures from experiment results. |
| """ |
|
|
| import json |
| import numpy as np |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import matplotlib.gridspec as gridspec |
| from pathlib import Path |
| from typing import Dict, List, Optional |
| import argparse |
|
|
|
|
| def load_results(results_path: str) -> Dict: |
| with open(results_path) as f: |
| return json.load(f) |
|
|
|
|
| def extract_metric_series(history: List[Dict], metric_name: str) -> tuple: |
| """Extract (steps, values) for a metric from history.""" |
| steps = [h['step'] for h in history if metric_name in h] |
| values = [h[metric_name] for h in history if metric_name in h] |
| return np.array(steps), np.array(values) |
|
|
|
|
| def plot_training_curves(results: Dict, output_dir: str): |
| """Plot training loss and task accuracies across all phases.""" |
| fig, axes = plt.subplots(2, 2, figsize=(14, 10)) |
|
|
| |
| p1 = results['phase1_history'] |
| steps_p1 = [h['step'] for h in p1] |
| loss_p1 = [h['train_loss'] for h in p1] |
| acc_add_p1 = [h.get('eval/add_test_acc', 0) for h in p1] |
| acc_sub_p1 = [h.get('eval/subtract_test_acc', 0) for h in p1] |
|
|
| |
| p2aa = results['phase2_aa_history'] |
| steps_aa = [h['step'] + steps_p1[-1] for h in p2aa] if p2aa else [] |
| loss_aa = [h['train_loss'] for h in p2aa] |
| acc_add_aa = [h.get('eval/add_test_acc', 0) for h in p2aa] |
| acc_sub_aa = [h.get('eval/subtract_test_acc', 0) for h in p2aa] |
|
|
| |
| p2ab = results['phase2_ab_history'] |
| steps_ab = [h['step'] + steps_p1[-1] for h in p2ab] if p2ab else [] |
| loss_ab = [h['train_loss'] for h in p2ab] |
| acc_add_ab = [h.get('eval/add_test_acc', 0) for h in p2ab] |
| acc_sub_ab = [h.get('eval/subtract_test_acc', 0) for h in p2ab] |
|
|
| |
| ax = axes[0, 0] |
| ax.plot(steps_p1, loss_p1, 'k-', label='Phase 1 (Add)', linewidth=2) |
| if steps_aa: |
| ax.plot(steps_aa, loss_aa, 'b-', label='A→A (Continue Add)', linewidth=2) |
| if steps_ab: |
| ax.plot(steps_ab, loss_ab, 'r-', label='A→B (Switch to Sub)', linewidth=2) |
| ax.axvline(x=steps_p1[-1] if steps_p1 else 0, color='gray', linestyle='--', |
| alpha=0.5, label='Phase transition') |
| ax.set_xlabel('Training Step') |
| ax.set_ylabel('Loss') |
| ax.set_title('Training Loss') |
| ax.legend() |
| ax.set_yscale('log') |
|
|
| |
| ax = axes[0, 1] |
| ax.plot(steps_p1, acc_add_p1, 'k-', label='Phase 1', linewidth=2) |
| if steps_aa: |
| ax.plot(steps_aa, acc_add_aa, 'b-', label='A→A', linewidth=2) |
| if steps_ab: |
| ax.plot(steps_ab, acc_add_ab, 'r-', label='A→B', linewidth=2) |
| ax.axvline(x=steps_p1[-1] if steps_p1 else 0, color='gray', |
| linestyle='--', alpha=0.5) |
| ax.set_xlabel('Training Step') |
| ax.set_ylabel('Accuracy') |
| ax.set_title('Task A (Addition) Accuracy') |
| ax.legend() |
| ax.set_ylim(-0.05, 1.05) |
|
|
| |
| ax = axes[1, 0] |
| ax.plot(steps_p1, acc_sub_p1, 'k-', label='Phase 1', linewidth=2) |
| if steps_aa: |
| ax.plot(steps_aa, acc_sub_aa, 'b-', label='A→A', linewidth=2) |
| if steps_ab: |
| ax.plot(steps_ab, acc_sub_ab, 'r-', label='A→B', linewidth=2) |
| ax.axvline(x=steps_p1[-1] if steps_p1 else 0, color='gray', |
| linestyle='--', alpha=0.5) |
| ax.set_xlabel('Training Step') |
| ax.set_ylabel('Accuracy') |
| ax.set_title('Task B (Subtraction) Accuracy') |
| ax.legend() |
| ax.set_ylim(-0.05, 1.05) |
|
|
| |
| ax = axes[1, 1] |
| ga_p1 = [h.get('gradient_alignment_a_vs_b', 0) for h in p1] |
| ga_aa = [h.get('gradient_alignment_a_vs_b', 0) for h in p2aa] |
| ga_ab = [h.get('gradient_alignment_a_vs_b', 0) for h in p2ab] |
| ax.plot(steps_p1, ga_p1, 'k-', label='Phase 1', linewidth=2) |
| if steps_aa: |
| ax.plot(steps_aa, ga_aa, 'b-', label='A→A', linewidth=2) |
| if steps_ab: |
| ax.plot(steps_ab, ga_ab, 'r-', label='A→B', linewidth=2) |
| ax.axvline(x=steps_p1[-1] if steps_p1 else 0, color='gray', |
| linestyle='--', alpha=0.5) |
| ax.axhline(y=0, color='gray', linestyle=':', alpha=0.3) |
| ax.set_xlabel('Training Step') |
| ax.set_ylabel('Cosine Similarity') |
| ax.set_title('Gradient Alignment (Task A vs Task B)') |
| ax.legend() |
|
|
| plt.tight_layout() |
| plt.savefig(f'{output_dir}/training_curves.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Saved: {output_dir}/training_curves.png") |
|
|
|
|
| def plot_cka_dynamics(results: Dict, output_dir: str): |
| """Plot CKA drift from Phase 1 end across all layers.""" |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| n_layers = results['config']['n_layers'] + 1 |
|
|
| for layer_idx in range(n_layers): |
| metric = f'layer_{layer_idx}/cka_vs_phase1' |
|
|
| |
| p2aa = results['phase2_aa_history'] |
| steps_aa = [h['step'] for h in p2aa if metric in h] |
| vals_aa = [h[metric] for h in p2aa if metric in h] |
|
|
| |
| p2ab = results['phase2_ab_history'] |
| steps_ab = [h['step'] for h in p2ab if metric in h] |
| vals_ab = [h[metric] for h in p2ab if metric in h] |
|
|
| label = f'Layer {layer_idx}' if layer_idx > 0 else 'Embedding' |
| axes[0].plot(steps_aa, vals_aa, '-', label=label, linewidth=1.5) |
| axes[1].plot(steps_ab, vals_ab, '-', label=label, linewidth=1.5) |
|
|
| axes[0].set_title('Branch A→A: CKA vs Phase 1 End') |
| axes[0].set_xlabel('Training Step') |
| axes[0].set_ylabel('CKA Similarity') |
| axes[0].legend() |
| axes[0].set_ylim(0, 1.05) |
|
|
| axes[1].set_title('Branch A→B: CKA vs Phase 1 End') |
| axes[1].set_xlabel('Training Step') |
| axes[1].set_ylabel('CKA Similarity') |
| axes[1].legend() |
| axes[1].set_ylim(0, 1.05) |
|
|
| plt.tight_layout() |
| plt.savefig(f'{output_dir}/cka_dynamics.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Saved: {output_dir}/cka_dynamics.png") |
|
|
|
|
| def plot_attention_entropy(results: Dict, output_dir: str): |
| """Plot attention entropy per head over training.""" |
| n_layers = results['config']['n_layers'] |
| n_heads = results['config']['n_heads'] |
|
|
| fig, axes = plt.subplots(n_layers, 2, figsize=(14, 4 * n_layers)) |
| if n_layers == 1: |
| axes = axes.reshape(1, 2) |
|
|
| for layer_idx in range(n_layers): |
| for head_idx in range(n_heads): |
| metric = f'layer_{layer_idx+1}/head_{head_idx}_entropy' |
|
|
| |
| p2aa = results['phase2_aa_history'] |
| steps_aa = [h['step'] for h in p2aa if metric in h] |
| vals_aa = [h[metric] for h in p2aa if metric in h] |
| axes[layer_idx, 0].plot(steps_aa, vals_aa, label=f'Head {head_idx}') |
|
|
| |
| p2ab = results['phase2_ab_history'] |
| steps_ab = [h['step'] for h in p2ab if metric in h] |
| vals_ab = [h[metric] for h in p2ab if metric in h] |
| axes[layer_idx, 1].plot(steps_ab, vals_ab, label=f'Head {head_idx}') |
|
|
| axes[layer_idx, 0].set_title(f'Layer {layer_idx+1} — A→A') |
| axes[layer_idx, 0].set_ylabel('Entropy (bits)') |
| axes[layer_idx, 0].legend() |
| axes[layer_idx, 1].set_title(f'Layer {layer_idx+1} — A→B') |
| axes[layer_idx, 1].legend() |
|
|
| axes[-1, 0].set_xlabel('Training Step') |
| axes[-1, 1].set_xlabel('Training Step') |
|
|
| plt.tight_layout() |
| plt.savefig(f'{output_dir}/attention_entropy.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Saved: {output_dir}/attention_entropy.png") |
|
|
|
|
| def plot_cka_heatmaps(results: Dict, output_dir: str): |
| """Plot CKA cross-layer heatmaps for final model comparisons.""" |
| heatmaps = results['cka_heatmaps'] |
|
|
| fig, axes = plt.subplots(1, 3, figsize=(18, 5)) |
|
|
| titles = ['A→A vs A→B', 'A→A vs Phase 1 End', 'A→B vs Phase 1 End'] |
| keys = ['aa_vs_ab', 'aa_vs_p1', 'ab_vs_p1'] |
|
|
| for ax, title, key in zip(axes, titles, keys): |
| hm = np.array(heatmaps[key]) |
| im = ax.imshow(hm, cmap='viridis', vmin=0, vmax=1, aspect='auto') |
| ax.set_title(title) |
| ax.set_xlabel('Layer (model 2)') |
| ax.set_ylabel('Layer (model 1)') |
| |
| for i in range(hm.shape[0]): |
| for j in range(hm.shape[1]): |
| color = 'white' if hm[i, j] < 0.5 else 'black' |
| ax.text(j, i, f'{hm[i,j]:.2f}', ha='center', va='center', |
| fontsize=8, color=color) |
| plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) |
|
|
| plt.tight_layout() |
| plt.savefig(f'{output_dir}/cka_heatmaps.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Saved: {output_dir}/cka_heatmaps.png") |
|
|
|
|
| def plot_subspace_angles(results: Dict, output_dir: str): |
| """Plot subspace angle divergence between branches.""" |
| n_layers = results['config']['n_layers'] + 1 |
|
|
| fig, ax = plt.subplots(figsize=(10, 5)) |
|
|
| for layer_idx in range(n_layers): |
| metric = f'layer_{layer_idx}/subspace_angle_vs_phase1' |
|
|
| p2aa = results['phase2_aa_history'] |
| steps_aa = [h['step'] for h in p2aa if metric in h] |
| vals_aa = [h[metric] for h in p2aa if metric in h] |
|
|
| p2ab = results['phase2_ab_history'] |
| steps_ab = [h['step'] for h in p2ab if metric in h] |
| vals_ab = [h[metric] for h in p2ab if metric in h] |
|
|
| label = f'Layer {layer_idx}' if layer_idx > 0 else 'Embedding' |
| if steps_aa: |
| ax.plot(steps_aa, vals_aa, '--', label=f'{label} (A→A)', |
| alpha=0.7, linewidth=1.5) |
| if steps_ab: |
| ax.plot(steps_ab, vals_ab, '-', label=f'{label} (A→B)', |
| linewidth=2) |
|
|
| ax.set_xlabel('Training Step') |
| ax.set_ylabel('Mean Subspace Angle (degrees)') |
| ax.set_title('Subspace Angle Drift from Phase 1 End') |
| ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') |
| plt.tight_layout() |
| plt.savefig(f'{output_dir}/subspace_angles.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Saved: {output_dir}/subspace_angles.png") |
|
|
|
|
| def plot_weight_changes(results: Dict, output_dir: str): |
| """Plot weight change magnitude per block.""" |
| n_blocks = results['config']['n_layers'] |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| for block_idx in range(n_blocks): |
| metric_init = f'block_{block_idx}/weight_change_from_init' |
| metric_p1 = f'block_{block_idx}/weight_change_from_phase1' |
|
|
| |
| p2aa = results['phase2_aa_history'] |
| steps = [h['step'] for h in p2aa if metric_p1 in h] |
| vals = [h[metric_p1] for h in p2aa if metric_p1 in h] |
| axes[0].plot(steps, vals, label=f'Block {block_idx}', linewidth=2) |
|
|
| |
| p2ab = results['phase2_ab_history'] |
| steps = [h['step'] for h in p2ab if metric_p1 in h] |
| vals = [h[metric_p1] for h in p2ab if metric_p1 in h] |
| axes[1].plot(steps, vals, label=f'Block {block_idx}', linewidth=2) |
|
|
| axes[0].set_title('A→A: Weight Change from Phase 1') |
| axes[0].set_xlabel('Training Step') |
| axes[0].set_ylabel('L2 Norm of Weight Delta') |
| axes[0].legend() |
|
|
| axes[1].set_title('A→B: Weight Change from Phase 1') |
| axes[1].set_xlabel('Training Step') |
| axes[1].set_ylabel('L2 Norm of Weight Delta') |
| axes[1].legend() |
|
|
| plt.tight_layout() |
| plt.savefig(f'{output_dir}/weight_changes.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f"Saved: {output_dir}/weight_changes.png") |
|
|
|
|
| def generate_all_plots(results_path: str, output_dir: str = None): |
| """Generate all visualization plots from experiment results.""" |
| results = load_results(results_path) |
| if output_dir is None: |
| output_dir = str(Path(results_path).parent) |
|
|
| Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
|
| plot_training_curves(results, output_dir) |
| plot_cka_dynamics(results, output_dir) |
| plot_attention_entropy(results, output_dir) |
| plot_cka_heatmaps(results, output_dir) |
| plot_subspace_angles(results, output_dir) |
| plot_weight_changes(results, output_dir) |
|
|
| print(f"\nAll plots saved to {output_dir}/") |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--results', type=str, default='results/experiment_results.json') |
| parser.add_argument('--output-dir', type=str, default=None) |
| args = parser.parse_args() |
| generate_all_plots(args.results, args.output_dir) |
|
|