""" plot_results.py - Generate visualizations from benchmark results. Creates publication-quality plots comparing RippleGPT vs VanillaGPT2. """ import json import argparse from pathlib import Path from typing import Dict, List, Optional import matplotlib.pyplot as plt import matplotlib.patches as mpatches import numpy as np # Color scheme COLORS = { "ripple": "#4CAF50", # Green "baseline": "#2196F3", # Blue "highlight": "#FF9800", # Orange "background": "#1a1a2e", # Dark background "text": "#ffffff", # White text "grid": "#333355" # Grid lines } # Style configuration plt.style.use('dark_background') plt.rcParams.update({ 'font.family': 'sans-serif', 'font.size': 11, 'axes.titlesize': 14, 'axes.labelsize': 12, 'figure.facecolor': COLORS['background'], 'axes.facecolor': COLORS['background'], 'savefig.facecolor': COLORS['background'], 'axes.edgecolor': COLORS['grid'], 'axes.grid': True, 'grid.color': COLORS['grid'], 'grid.alpha': 0.3 }) def load_results(results_dir: Path) -> List[Dict]: """Load all benchmark result files from directory.""" results = [] for f in results_dir.glob("benchmark_*.json"): with open(f) as fp: results.append(json.load(fp)) return results def plot_parameter_comparison(results: List[Dict], output_path: Path): """Bar chart comparing parameter counts.""" fig, ax = plt.subplots(figsize=(10, 6)) datasets = [] sizes = [] ripple_params = [] baseline_params = [] for r in results: label = f"{r['metadata']['dataset']}_{r['metadata']['size']}" datasets.append(label) ripple_params.append(r['parameters']['ripple'] / 1e6) baseline_params.append(r['parameters']['baseline'] / 1e6) x = np.arange(len(datasets)) width = 0.35 bars1 = ax.bar(x - width/2, ripple_params, width, label='RippleGPT', color=COLORS['ripple'], alpha=0.9) bars2 = ax.bar(x + width/2, baseline_params, width, label='VanillaGPT2', color=COLORS['baseline'], alpha=0.9) ax.set_ylabel('Parameters (Millions)') ax.set_title('šŸ“Š Parameter Comparison: RippleGPT vs VanillaGPT2') ax.set_xticks(x) ax.set_xticklabels(datasets, rotation=15, ha='right') ax.legend() # Add value labels for bar, val in zip(bars1, ripple_params): ax.annotate(f'{val:.1f}M', xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()), xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=9, color=COLORS['text']) for bar, val in zip(bars2, baseline_params): ax.annotate(f'{val:.1f}M', xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()), xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=9, color=COLORS['text']) plt.tight_layout() plt.savefig(output_path / 'parameter_comparison.png', dpi=150) plt.close() print(f"āœ… Saved: {output_path / 'parameter_comparison.png'}") def plot_loss_curves(results: List[Dict], output_path: Path): """Plot training loss curves for all benchmarks.""" n_results = len(results) cols = min(2, n_results) rows = (n_results + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 4*rows)) if n_results == 1: axes = [axes] else: axes = axes.flatten() if n_results > 2 else list(axes) for idx, r in enumerate(results): ax = axes[idx] ripple_curve = r['ripple']['training']['loss_curve'] baseline_curve = r['baseline']['training']['loss_curve'] r_iters = [x[0] for x in ripple_curve] r_losses = [x[1] for x in ripple_curve] b_iters = [x[0] for x in baseline_curve] b_losses = [x[1] for x in baseline_curve] ax.plot(r_iters, r_losses, color=COLORS['ripple'], linewidth=2, label='RippleGPT', marker='o', markersize=4) ax.plot(b_iters, b_losses, color=COLORS['baseline'], linewidth=2, label='VanillaGPT2', marker='s', markersize=4) title = f"{r['metadata']['dataset'].capitalize()} ({r['metadata']['size']})" ax.set_title(f"šŸ“‰ {title}") ax.set_xlabel('Iteration') ax.set_ylabel('Loss') ax.legend(loc='upper right') # Hide unused subplots for idx in range(len(results), len(axes)): axes[idx].set_visible(False) plt.suptitle('Training Loss Curves', fontsize=16, y=1.02) plt.tight_layout() plt.savefig(output_path / 'loss_curves.png', dpi=150) plt.close() print(f"āœ… Saved: {output_path / 'loss_curves.png'}") def plot_extrapolation(results: List[Dict], output_path: Path): """Plot extrapolation capability comparison.""" # Filter results that have extrapolation data extrap_results = [r for r in results if r['ripple'].get('extrapolation')] if not extrap_results: print("āš ļø No extrapolation data found in results") return fig, ax = plt.subplots(figsize=(10, 6)) for idx, r in enumerate(extrap_results): extrap = r['ripple']['extrapolation'] train_block = r['metadata']['model_config']['block_size'] # Collect data points sizes = sorted([int(k) for k in extrap.keys()]) ppls = [extrap[str(s)] for s in sizes] ratios = [s / train_block for s in sizes] # Add training point (estimate from final loss) train_loss = r['ripple']['training']['final_loss'] train_ppl = np.exp(train_loss) all_sizes = [train_block] + sizes all_ppls = [train_ppl] + ppls all_ratios = [1.0] + ratios label = f"{r['metadata']['dataset']} ({r['metadata']['size']})" ax.plot(all_ratios, all_ppls, marker='o', linewidth=2, label=label, markersize=8) ax.axhline(y=train_ppl, color=COLORS['highlight'], linestyle='--', alpha=0.5, label='Training baseline') ax.axvline(x=1.0, color=COLORS['grid'], linestyle=':', alpha=0.5) ax.set_xlabel('Context Ratio (relative to training)') ax.set_ylabel('Perplexity') ax.set_title('šŸ“ RippleGPT Extrapolation Capability\n(Lower is better, <1.0x = shorter, >1.0x = longer than training)') ax.legend() # Add annotation ax.annotate('Training\nContext', xy=(1.0, ax.get_ylim()[0]), xytext=(1.0, ax.get_ylim()[0] + 0.5), ha='center', fontsize=9, color=COLORS['text']) plt.tight_layout() plt.savefig(output_path / 'extrapolation.png', dpi=150) plt.close() print(f"āœ… Saved: {output_path / 'extrapolation.png'}") def plot_summary_table(results: List[Dict], output_path: Path): """Create a summary table as an image.""" fig, ax = plt.subplots(figsize=(12, 4)) ax.axis('off') # Prepare data columns = ['Dataset', 'Size', 'Ripple Params', 'GPT2 Params', 'Ripple Loss', 'GPT2 Loss', 'Winner'] rows = [] for r in results: r_params = f"{r['parameters']['ripple']/1e6:.1f}M" b_params = f"{r['parameters']['baseline']/1e6:.1f}M" r_loss = f"{r['ripple']['training']['final_loss']:.4f}" b_loss = f"{r['baseline']['training']['final_loss']:.4f}" # Determine winner (lower loss wins) winner = "RippleGPT" if r['ripple']['training']['final_loss'] < r['baseline']['training']['final_loss'] else "VanillaGPT2" rows.append([ r['metadata']['dataset'].capitalize(), r['metadata']['size'].capitalize(), r_params, b_params, r_loss, b_loss, winner ]) table = ax.table( cellText=rows, colLabels=columns, loc='center', cellLoc='center', colColours=[COLORS['grid']] * len(columns) ) table.auto_set_font_size(False) table.set_fontsize(10) table.scale(1.2, 1.5) # Style header for (row, col), cell in table.get_celld().items(): if row == 0: cell.set_text_props(weight='bold', color=COLORS['text']) cell.set_facecolor(COLORS['grid']) else: cell.set_facecolor(COLORS['background']) cell.set_text_props(color=COLORS['text']) ax.set_title('šŸ“‹ Benchmark Summary', fontsize=14, pad=20) plt.tight_layout() plt.savefig(output_path / 'summary_table.png', dpi=150, bbox_inches='tight') plt.close() print(f"āœ… Saved: {output_path / 'summary_table.png'}") def generate_all_plots(results_dir: str): """Generate all plots from benchmark results.""" results_path = Path(results_dir) if not results_path.exists(): print(f"āŒ Results directory not found: {results_path}") return results = load_results(results_path) if not results: print(f"āŒ No benchmark results found in {results_path}") return print(f"\nšŸ“Š Found {len(results)} benchmark results") # Create plots directory plots_dir = results_path / 'plots' plots_dir.mkdir(exist_ok=True) # Generate plots print("\nšŸŽØ Generating plots...") plot_parameter_comparison(results, plots_dir) plot_loss_curves(results, plots_dir) plot_extrapolation(results, plots_dir) plot_summary_table(results, plots_dir) print(f"\nāœ… All plots saved to: {plots_dir}") if __name__ == '__main__': parser = argparse.ArgumentParser(description="Generate benchmark plots") parser.add_argument( "--results", type=str, default="validation/benchmarks/results", help="Path to results directory" ) args = parser.parse_args() generate_all_plots(args.results)