Tavernari's picture
Upload folder using huggingface_hub
148b631 verified
"""
plot_results.py - Generate visualizations from benchmark results.
Creates publication-quality plots comparing RippleGPT vs VanillaGPT2.
"""
import json
import argparse
from pathlib import Path
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
# Color scheme
COLORS = {
"ripple": "#4CAF50", # Green
"baseline": "#2196F3", # Blue
"highlight": "#FF9800", # Orange
"background": "#1a1a2e", # Dark background
"text": "#ffffff", # White text
"grid": "#333355" # Grid lines
}
# Style configuration
plt.style.use('dark_background')
plt.rcParams.update({
'font.family': 'sans-serif',
'font.size': 11,
'axes.titlesize': 14,
'axes.labelsize': 12,
'figure.facecolor': COLORS['background'],
'axes.facecolor': COLORS['background'],
'savefig.facecolor': COLORS['background'],
'axes.edgecolor': COLORS['grid'],
'axes.grid': True,
'grid.color': COLORS['grid'],
'grid.alpha': 0.3
})
def load_results(results_dir: Path) -> List[Dict]:
"""Load all benchmark result files from directory."""
results = []
for f in results_dir.glob("benchmark_*.json"):
with open(f) as fp:
results.append(json.load(fp))
return results
def plot_parameter_comparison(results: List[Dict], output_path: Path):
"""Bar chart comparing parameter counts."""
fig, ax = plt.subplots(figsize=(10, 6))
datasets = []
sizes = []
ripple_params = []
baseline_params = []
for r in results:
label = f"{r['metadata']['dataset']}_{r['metadata']['size']}"
datasets.append(label)
ripple_params.append(r['parameters']['ripple'] / 1e6)
baseline_params.append(r['parameters']['baseline'] / 1e6)
x = np.arange(len(datasets))
width = 0.35
bars1 = ax.bar(x - width/2, ripple_params, width,
label='RippleGPT', color=COLORS['ripple'], alpha=0.9)
bars2 = ax.bar(x + width/2, baseline_params, width,
label='VanillaGPT2', color=COLORS['baseline'], alpha=0.9)
ax.set_ylabel('Parameters (Millions)')
ax.set_title('πŸ“Š Parameter Comparison: RippleGPT vs VanillaGPT2')
ax.set_xticks(x)
ax.set_xticklabels(datasets, rotation=15, ha='right')
ax.legend()
# Add value labels
for bar, val in zip(bars1, ripple_params):
ax.annotate(f'{val:.1f}M',
xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
xytext=(0, 3), textcoords="offset points",
ha='center', va='bottom', fontsize=9, color=COLORS['text'])
for bar, val in zip(bars2, baseline_params):
ax.annotate(f'{val:.1f}M',
xy=(bar.get_x() + bar.get_width() / 2, bar.get_height()),
xytext=(0, 3), textcoords="offset points",
ha='center', va='bottom', fontsize=9, color=COLORS['text'])
plt.tight_layout()
plt.savefig(output_path / 'parameter_comparison.png', dpi=150)
plt.close()
print(f"βœ… Saved: {output_path / 'parameter_comparison.png'}")
def plot_loss_curves(results: List[Dict], output_path: Path):
"""Plot training loss curves for all benchmarks."""
n_results = len(results)
cols = min(2, n_results)
rows = (n_results + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 4*rows))
if n_results == 1:
axes = [axes]
else:
axes = axes.flatten() if n_results > 2 else list(axes)
for idx, r in enumerate(results):
ax = axes[idx]
ripple_curve = r['ripple']['training']['loss_curve']
baseline_curve = r['baseline']['training']['loss_curve']
r_iters = [x[0] for x in ripple_curve]
r_losses = [x[1] for x in ripple_curve]
b_iters = [x[0] for x in baseline_curve]
b_losses = [x[1] for x in baseline_curve]
ax.plot(r_iters, r_losses, color=COLORS['ripple'],
linewidth=2, label='RippleGPT', marker='o', markersize=4)
ax.plot(b_iters, b_losses, color=COLORS['baseline'],
linewidth=2, label='VanillaGPT2', marker='s', markersize=4)
title = f"{r['metadata']['dataset'].capitalize()} ({r['metadata']['size']})"
ax.set_title(f"πŸ“‰ {title}")
ax.set_xlabel('Iteration')
ax.set_ylabel('Loss')
ax.legend(loc='upper right')
# Hide unused subplots
for idx in range(len(results), len(axes)):
axes[idx].set_visible(False)
plt.suptitle('Training Loss Curves', fontsize=16, y=1.02)
plt.tight_layout()
plt.savefig(output_path / 'loss_curves.png', dpi=150)
plt.close()
print(f"βœ… Saved: {output_path / 'loss_curves.png'}")
def plot_extrapolation(results: List[Dict], output_path: Path):
"""Plot extrapolation capability comparison."""
# Filter results that have extrapolation data
extrap_results = [r for r in results if r['ripple'].get('extrapolation')]
if not extrap_results:
print("⚠️ No extrapolation data found in results")
return
fig, ax = plt.subplots(figsize=(10, 6))
for idx, r in enumerate(extrap_results):
extrap = r['ripple']['extrapolation']
train_block = r['metadata']['model_config']['block_size']
# Collect data points
sizes = sorted([int(k) for k in extrap.keys()])
ppls = [extrap[str(s)] for s in sizes]
ratios = [s / train_block for s in sizes]
# Add training point (estimate from final loss)
train_loss = r['ripple']['training']['final_loss']
train_ppl = np.exp(train_loss)
all_sizes = [train_block] + sizes
all_ppls = [train_ppl] + ppls
all_ratios = [1.0] + ratios
label = f"{r['metadata']['dataset']} ({r['metadata']['size']})"
ax.plot(all_ratios, all_ppls, marker='o', linewidth=2,
label=label, markersize=8)
ax.axhline(y=train_ppl, color=COLORS['highlight'], linestyle='--',
alpha=0.5, label='Training baseline')
ax.axvline(x=1.0, color=COLORS['grid'], linestyle=':', alpha=0.5)
ax.set_xlabel('Context Ratio (relative to training)')
ax.set_ylabel('Perplexity')
ax.set_title('πŸ“ RippleGPT Extrapolation Capability\n(Lower is better, <1.0x = shorter, >1.0x = longer than training)')
ax.legend()
# Add annotation
ax.annotate('Training\nContext', xy=(1.0, ax.get_ylim()[0]),
xytext=(1.0, ax.get_ylim()[0] + 0.5),
ha='center', fontsize=9, color=COLORS['text'])
plt.tight_layout()
plt.savefig(output_path / 'extrapolation.png', dpi=150)
plt.close()
print(f"βœ… Saved: {output_path / 'extrapolation.png'}")
def plot_summary_table(results: List[Dict], output_path: Path):
"""Create a summary table as an image."""
fig, ax = plt.subplots(figsize=(12, 4))
ax.axis('off')
# Prepare data
columns = ['Dataset', 'Size', 'Ripple Params', 'GPT2 Params',
'Ripple Loss', 'GPT2 Loss', 'Winner']
rows = []
for r in results:
r_params = f"{r['parameters']['ripple']/1e6:.1f}M"
b_params = f"{r['parameters']['baseline']/1e6:.1f}M"
r_loss = f"{r['ripple']['training']['final_loss']:.4f}"
b_loss = f"{r['baseline']['training']['final_loss']:.4f}"
# Determine winner (lower loss wins)
winner = "RippleGPT" if r['ripple']['training']['final_loss'] < r['baseline']['training']['final_loss'] else "VanillaGPT2"
rows.append([
r['metadata']['dataset'].capitalize(),
r['metadata']['size'].capitalize(),
r_params,
b_params,
r_loss,
b_loss,
winner
])
table = ax.table(
cellText=rows,
colLabels=columns,
loc='center',
cellLoc='center',
colColours=[COLORS['grid']] * len(columns)
)
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 1.5)
# Style header
for (row, col), cell in table.get_celld().items():
if row == 0:
cell.set_text_props(weight='bold', color=COLORS['text'])
cell.set_facecolor(COLORS['grid'])
else:
cell.set_facecolor(COLORS['background'])
cell.set_text_props(color=COLORS['text'])
ax.set_title('πŸ“‹ Benchmark Summary', fontsize=14, pad=20)
plt.tight_layout()
plt.savefig(output_path / 'summary_table.png', dpi=150, bbox_inches='tight')
plt.close()
print(f"βœ… Saved: {output_path / 'summary_table.png'}")
def generate_all_plots(results_dir: str):
"""Generate all plots from benchmark results."""
results_path = Path(results_dir)
if not results_path.exists():
print(f"❌ Results directory not found: {results_path}")
return
results = load_results(results_path)
if not results:
print(f"❌ No benchmark results found in {results_path}")
return
print(f"\nπŸ“Š Found {len(results)} benchmark results")
# Create plots directory
plots_dir = results_path / 'plots'
plots_dir.mkdir(exist_ok=True)
# Generate plots
print("\n🎨 Generating plots...")
plot_parameter_comparison(results, plots_dir)
plot_loss_curves(results, plots_dir)
plot_extrapolation(results, plots_dir)
plot_summary_table(results, plots_dir)
print(f"\nβœ… All plots saved to: {plots_dir}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Generate benchmark plots")
parser.add_argument(
"--results",
type=str,
default="validation/benchmarks/results",
help="Path to results directory"
)
args = parser.parse_args()
generate_all_plots(args.results)