""" Generate all visualizations for speculative decoding paper. Creates publication-quality figures matching PAPER_OUTLINE.md specifications. Author: Claude Code Date: 2025-11-30 """ import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from pathlib import Path from typing import Dict, List # Set publication style plt.style.use('seaborn-v0_8-paper') sns.set_palette("colorblind") plt.rcParams['figure.dpi'] = 300 plt.rcParams['savefig.dpi'] = 300 plt.rcParams['font.size'] = 10 plt.rcParams['axes.labelsize'] = 11 plt.rcParams['axes.titlesize'] = 12 plt.rcParams['xtick.labelsize'] = 9 plt.rcParams['ytick.labelsize'] = 9 # Directories DATA_DIR = Path(__file__).parent.parent / "data" FIGURES_DIR = Path(__file__).parent.parent / "paper" / "figures" FIGURES_DIR.mkdir(parents=True, exist_ok=True) def figure3_rejection_by_domain(df: pd.DataFrame): """Bar chart: Rejection rates by domain.""" print("\nšŸ“Š Generating Figure 3: Rejection by Domain...") # Calculate rejection rates rejection_rates = df.groupby('domain')['is_rejected'].mean().sort_values() fig, ax = plt.subplots(figsize=(8, 5)) colors = ['#2ecc71', '#3498db', '#e74c3c', '#e67e22'] bars = ax.bar(range(len(rejection_rates)), rejection_rates.values * 100, color=colors) # Labels ax.set_xlabel('Domain') ax.set_ylabel('Rejection Rate (%)') ax.set_title('Draft Rejection Rates by Domain') ax.set_xticks(range(len(rejection_rates))) ax.set_xticklabels([d.replace('_', '-').title() for d in rejection_rates.index], rotation=15, ha='right') ax.set_ylim(0, 40) ax.grid(axis='y', alpha=0.3) # Add value labels on bars for i, (bar, val) in enumerate(zip(bars, rejection_rates.values)): ax.text(bar.get_x() + bar.get_width()/2, val*100 + 1, f'{val*100:.1f}%', ha='center', va='bottom', fontsize=9, fontweight='bold') plt.tight_layout() output_path = FIGURES_DIR / "figure3_rejection_by_domain.png" plt.savefig(output_path, bbox_inches='tight') plt.close() print(f" āœ… Saved: {output_path}") def figure4_rejection_vs_position(df: pd.DataFrame): """Line plot: Rejection rate vs token position.""" print("\nšŸ“Š Generating Figure 4: Rejection vs Position...") # Bin positions for smoother plot df['position_bin'] = pd.cut(df['token_position'], bins=20) position_rates = df.groupby('position_bin')['is_rejected'].mean() # Get bin centers bin_centers = [(interval.left + interval.right) / 2 for interval in position_rates.index] fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(bin_centers, position_rates.values * 100, marker='o', linewidth=2, markersize=6, color='#3498db', label='Rejection Rate') # Highlight regions ax.axvspan(0, 20, alpha=0.1, color='red', label='Early (<20)') ax.axvspan(100, max(bin_centers), alpha=0.1, color='green', label='Late (>100)') ax.set_xlabel('Token Position in Sequence') ax.set_ylabel('Rejection Rate (%)') ax.set_title('Draft Rejection Rate by Token Position') ax.set_ylim(20, 35) ax.grid(alpha=0.3) ax.legend() plt.tight_layout() output_path = FIGURES_DIR / "figure4_rejection_vs_position.png" plt.savefig(output_path, bbox_inches='tight') plt.close() print(f" āœ… Saved: {output_path}") def figure5_mask_performance_heatmap(df: pd.DataFrame): """Heatmap: Mask performance by domain.""" print("\nšŸ“Š Generating Figure 5: Mask Performance Heatmap...") # Pivot table: domain x mask → acceptance rate pivot = df.groupby(['domain', 'mask_type'])['is_accepted'].mean().unstack() * 100 # Reorder for better display mask_order = ['causal', 'tidar', 'bidirectional', 'windowed', 'strided'] domain_order = ['code', 'math', 'translation'] pivot = pivot.loc[domain_order, mask_order] fig, ax = plt.subplots(figsize=(10, 5)) sns.heatmap(pivot, annot=True, fmt='.1f', cmap='RdYlGn', vmin=5, vmax=35, cbar_kws={'label': 'Acceptance Rate (%)'}, ax=ax, linewidths=0.5) ax.set_xlabel('Attention Mask Type') ax.set_ylabel('Domain') ax.set_title('Acceptance Rate by Domain and Attention Mask') ax.set_yticklabels([d.replace('_', '-').title() for d in domain_order], rotation=0) ax.set_xticklabels([m.title() for m in mask_order], rotation=15, ha='right') plt.tight_layout() output_path = FIGURES_DIR / "figure5_mask_performance_heatmap.png" plt.savefig(output_path, bbox_inches='tight') plt.close() print(f" āœ… Saved: {output_path}") def figure6_throughput_quality_tradeoff(ablation_df: pd.DataFrame): """Scatter plot: Throughput vs quality trade-off.""" print("\nšŸ“Š Generating Figure 6: Throughput-Quality Trade-off...") # Aggregate by mask mask_stats = ablation_df.groupby('mask_type').agg({ 'throughput_tokens_per_sec': 'mean', 'is_accepted': 'mean' }).reset_index() fig, ax = plt.subplots(figsize=(8, 6)) colors = {'causal': '#3498db', 'tidar': '#9b59b6', 'bidirectional': '#2ecc71', 'windowed': '#e74c3c', 'strided': '#e67e22'} for _, row in mask_stats.iterrows(): ax.scatter(row['throughput_tokens_per_sec'], row['is_accepted'] * 100, s=200, color=colors.get(row['mask_type'], 'gray'), alpha=0.7, edgecolors='black', linewidth=1.5) ax.text(row['throughput_tokens_per_sec'] + 5, row['is_accepted'] * 100 + 1, row['mask_type'].title(), fontsize=9, fontweight='bold') ax.set_xlabel('Throughput (tokens/second)') ax.set_ylabel('Acceptance Rate (%)') ax.set_title('Throughput-Quality Trade-off Across Attention Masks') ax.grid(alpha=0.3) ax.set_xlim(40, 150) plt.tight_layout() output_path = FIGURES_DIR / "figure6_throughput_quality_tradeoff.png" plt.savefig(output_path, bbox_inches='tight') plt.close() print(f" āœ… Saved: {output_path}") def figure_domain_comparison_table(df: pd.DataFrame, quality_df: pd.DataFrame): """Generate formatted table image for domain comparison.""" print("\nšŸ“Š Generating Table 1: Domain Comparison...") # Aggregate stats domain_stats = df.groupby('domain').agg({ 'is_rejected': 'mean', 'sequence_length': 'mean' }).reset_index() # Merge with quality metrics domain_stats = domain_stats.merge(quality_df, on='domain', how='left') # Format table fig, ax = plt.subplots(figsize=(12, 4)) ax.axis('tight') ax.axis('off') table_data = [] for _, row in domain_stats.iterrows(): table_data.append([ row['domain'].replace('_', '-').title(), f"{row['is_rejected']*100:.1f}%", f"{row['metric']}", f"{row['value']:.2f}" if row['value'] < 1 else f"{row['value']:.1f}", f"{row['samples']}" ]) headers = ['Domain', 'Rejection Rate', 'Quality Metric', 'Score', 'Samples'] table = ax.table(cellText=table_data, colLabels=headers, loc='center', cellLoc='center', colWidths=[0.2, 0.2, 0.2, 0.15, 0.15]) table.auto_set_font_size(False) table.set_fontsize(10) table.scale(1, 2) # Style header for i in range(len(headers)): table[(0, i)].set_facecolor('#3498db') table[(0, i)].set_text_props(weight='bold', color='white') # Alternate row colors for i in range(1, len(table_data) + 1): for j in range(len(headers)): if i % 2 == 0: table[(i, j)].set_facecolor('#ecf0f1') plt.title('Table 1: Domain-Specific Rejection Rates and Quality Metrics', fontsize=12, fontweight='bold', pad=20) output_path = FIGURES_DIR / "table1_domain_comparison.png" plt.savefig(output_path, bbox_inches='tight', dpi=300) plt.close() print(f" āœ… Saved: {output_path}") def main(): """Generate all visualizations.""" print("=" * 60) print("Generating Publication-Quality Visualizations") print("=" * 60) # Load data print("\nLoading data...") cross_domain_df = pd.read_csv(DATA_DIR / "phase1_cross_domain.csv") ablation_df = pd.read_csv(DATA_DIR / "phase3_ablation.csv") quality_df = pd.read_csv(DATA_DIR / "quality_metrics.csv") print(f"āœ… Data loaded") # Generate figures figure3_rejection_by_domain(cross_domain_df) figure4_rejection_vs_position(cross_domain_df) figure5_mask_performance_heatmap(ablation_df) figure6_throughput_quality_tradeoff(ablation_df) figure_domain_comparison_table(cross_domain_df, quality_df) print("\n" + "=" * 60) print(f"āœ… All figures generated!") print(f" Saved to: {FIGURES_DIR}") print("=" * 60) print("\n=== Generated Figures ===") for fig_path in sorted(FIGURES_DIR.glob("*.png")): print(f" - {fig_path.name}") if __name__ == "__main__": main()