|
|
""" |
|
|
Generate all visualizations for speculative decoding paper. |
|
|
|
|
|
Creates publication-quality figures matching PAPER_OUTLINE.md specifications. |
|
|
|
|
|
Author: Claude Code |
|
|
Date: 2025-11-30 |
|
|
""" |
|
|
|
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from pathlib import Path |
|
|
from typing import Dict, List |
|
|
|
|
|
|
|
|
plt.style.use('seaborn-v0_8-paper') |
|
|
sns.set_palette("colorblind") |
|
|
plt.rcParams['figure.dpi'] = 300 |
|
|
plt.rcParams['savefig.dpi'] = 300 |
|
|
plt.rcParams['font.size'] = 10 |
|
|
plt.rcParams['axes.labelsize'] = 11 |
|
|
plt.rcParams['axes.titlesize'] = 12 |
|
|
plt.rcParams['xtick.labelsize'] = 9 |
|
|
plt.rcParams['ytick.labelsize'] = 9 |
|
|
|
|
|
|
|
|
DATA_DIR = Path(__file__).parent.parent / "data" |
|
|
FIGURES_DIR = Path(__file__).parent.parent / "paper" / "figures" |
|
|
FIGURES_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
def figure3_rejection_by_domain(df: pd.DataFrame): |
|
|
"""Bar chart: Rejection rates by domain.""" |
|
|
|
|
|
print("\nπ Generating Figure 3: Rejection by Domain...") |
|
|
|
|
|
|
|
|
rejection_rates = df.groupby('domain')['is_rejected'].mean().sort_values() |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 5)) |
|
|
|
|
|
colors = ['#2ecc71', '#3498db', '#e74c3c', '#e67e22'] |
|
|
bars = ax.bar(range(len(rejection_rates)), rejection_rates.values * 100, color=colors) |
|
|
|
|
|
|
|
|
ax.set_xlabel('Domain') |
|
|
ax.set_ylabel('Rejection Rate (%)') |
|
|
ax.set_title('Draft Rejection Rates by Domain') |
|
|
ax.set_xticks(range(len(rejection_rates))) |
|
|
ax.set_xticklabels([d.replace('_', '-').title() for d in rejection_rates.index], rotation=15, ha='right') |
|
|
ax.set_ylim(0, 40) |
|
|
ax.grid(axis='y', alpha=0.3) |
|
|
|
|
|
|
|
|
for i, (bar, val) in enumerate(zip(bars, rejection_rates.values)): |
|
|
ax.text(bar.get_x() + bar.get_width()/2, val*100 + 1, f'{val*100:.1f}%', |
|
|
ha='center', va='bottom', fontsize=9, fontweight='bold') |
|
|
|
|
|
plt.tight_layout() |
|
|
output_path = FIGURES_DIR / "figure3_rejection_by_domain.png" |
|
|
plt.savefig(output_path, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f" β
Saved: {output_path}") |
|
|
|
|
|
|
|
|
def figure4_rejection_vs_position(df: pd.DataFrame): |
|
|
"""Line plot: Rejection rate vs token position.""" |
|
|
|
|
|
print("\nπ Generating Figure 4: Rejection vs Position...") |
|
|
|
|
|
|
|
|
df['position_bin'] = pd.cut(df['token_position'], bins=20) |
|
|
position_rates = df.groupby('position_bin')['is_rejected'].mean() |
|
|
|
|
|
|
|
|
bin_centers = [(interval.left + interval.right) / 2 for interval in position_rates.index] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 5)) |
|
|
|
|
|
ax.plot(bin_centers, position_rates.values * 100, marker='o', linewidth=2, markersize=6, |
|
|
color='#3498db', label='Rejection Rate') |
|
|
|
|
|
|
|
|
ax.axvspan(0, 20, alpha=0.1, color='red', label='Early (<20)') |
|
|
ax.axvspan(100, max(bin_centers), alpha=0.1, color='green', label='Late (>100)') |
|
|
|
|
|
ax.set_xlabel('Token Position in Sequence') |
|
|
ax.set_ylabel('Rejection Rate (%)') |
|
|
ax.set_title('Draft Rejection Rate by Token Position') |
|
|
ax.set_ylim(20, 35) |
|
|
ax.grid(alpha=0.3) |
|
|
ax.legend() |
|
|
|
|
|
plt.tight_layout() |
|
|
output_path = FIGURES_DIR / "figure4_rejection_vs_position.png" |
|
|
plt.savefig(output_path, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f" β
Saved: {output_path}") |
|
|
|
|
|
|
|
|
def figure5_mask_performance_heatmap(df: pd.DataFrame): |
|
|
"""Heatmap: Mask performance by domain.""" |
|
|
|
|
|
print("\nπ Generating Figure 5: Mask Performance Heatmap...") |
|
|
|
|
|
|
|
|
pivot = df.groupby(['domain', 'mask_type'])['is_accepted'].mean().unstack() * 100 |
|
|
|
|
|
|
|
|
mask_order = ['causal', 'tidar', 'bidirectional', 'windowed', 'strided'] |
|
|
domain_order = ['code', 'math', 'translation'] |
|
|
pivot = pivot.loc[domain_order, mask_order] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 5)) |
|
|
|
|
|
sns.heatmap(pivot, annot=True, fmt='.1f', cmap='RdYlGn', vmin=5, vmax=35, |
|
|
cbar_kws={'label': 'Acceptance Rate (%)'}, ax=ax, linewidths=0.5) |
|
|
|
|
|
ax.set_xlabel('Attention Mask Type') |
|
|
ax.set_ylabel('Domain') |
|
|
ax.set_title('Acceptance Rate by Domain and Attention Mask') |
|
|
ax.set_yticklabels([d.replace('_', '-').title() for d in domain_order], rotation=0) |
|
|
ax.set_xticklabels([m.title() for m in mask_order], rotation=15, ha='right') |
|
|
|
|
|
plt.tight_layout() |
|
|
output_path = FIGURES_DIR / "figure5_mask_performance_heatmap.png" |
|
|
plt.savefig(output_path, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f" β
Saved: {output_path}") |
|
|
|
|
|
|
|
|
def figure6_throughput_quality_tradeoff(ablation_df: pd.DataFrame): |
|
|
"""Scatter plot: Throughput vs quality trade-off.""" |
|
|
|
|
|
print("\nπ Generating Figure 6: Throughput-Quality Trade-off...") |
|
|
|
|
|
|
|
|
mask_stats = ablation_df.groupby('mask_type').agg({ |
|
|
'throughput_tokens_per_sec': 'mean', |
|
|
'is_accepted': 'mean' |
|
|
}).reset_index() |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 6)) |
|
|
|
|
|
colors = {'causal': '#3498db', 'tidar': '#9b59b6', 'bidirectional': '#2ecc71', |
|
|
'windowed': '#e74c3c', 'strided': '#e67e22'} |
|
|
|
|
|
for _, row in mask_stats.iterrows(): |
|
|
ax.scatter(row['throughput_tokens_per_sec'], row['is_accepted'] * 100, |
|
|
s=200, color=colors.get(row['mask_type'], 'gray'), |
|
|
alpha=0.7, edgecolors='black', linewidth=1.5) |
|
|
ax.text(row['throughput_tokens_per_sec'] + 5, row['is_accepted'] * 100 + 1, |
|
|
row['mask_type'].title(), fontsize=9, fontweight='bold') |
|
|
|
|
|
ax.set_xlabel('Throughput (tokens/second)') |
|
|
ax.set_ylabel('Acceptance Rate (%)') |
|
|
ax.set_title('Throughput-Quality Trade-off Across Attention Masks') |
|
|
ax.grid(alpha=0.3) |
|
|
ax.set_xlim(40, 150) |
|
|
|
|
|
plt.tight_layout() |
|
|
output_path = FIGURES_DIR / "figure6_throughput_quality_tradeoff.png" |
|
|
plt.savefig(output_path, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f" β
Saved: {output_path}") |
|
|
|
|
|
|
|
|
def figure_domain_comparison_table(df: pd.DataFrame, quality_df: pd.DataFrame): |
|
|
"""Generate formatted table image for domain comparison.""" |
|
|
|
|
|
print("\nπ Generating Table 1: Domain Comparison...") |
|
|
|
|
|
|
|
|
domain_stats = df.groupby('domain').agg({ |
|
|
'is_rejected': 'mean', |
|
|
'sequence_length': 'mean' |
|
|
}).reset_index() |
|
|
|
|
|
|
|
|
domain_stats = domain_stats.merge(quality_df, on='domain', how='left') |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 4)) |
|
|
ax.axis('tight') |
|
|
ax.axis('off') |
|
|
|
|
|
table_data = [] |
|
|
for _, row in domain_stats.iterrows(): |
|
|
table_data.append([ |
|
|
row['domain'].replace('_', '-').title(), |
|
|
f"{row['is_rejected']*100:.1f}%", |
|
|
f"{row['metric']}", |
|
|
f"{row['value']:.2f}" if row['value'] < 1 else f"{row['value']:.1f}", |
|
|
f"{row['samples']}" |
|
|
]) |
|
|
|
|
|
headers = ['Domain', 'Rejection Rate', 'Quality Metric', 'Score', 'Samples'] |
|
|
|
|
|
table = ax.table(cellText=table_data, colLabels=headers, loc='center', |
|
|
cellLoc='center', colWidths=[0.2, 0.2, 0.2, 0.15, 0.15]) |
|
|
|
|
|
table.auto_set_font_size(False) |
|
|
table.set_fontsize(10) |
|
|
table.scale(1, 2) |
|
|
|
|
|
|
|
|
for i in range(len(headers)): |
|
|
table[(0, i)].set_facecolor('#3498db') |
|
|
table[(0, i)].set_text_props(weight='bold', color='white') |
|
|
|
|
|
|
|
|
for i in range(1, len(table_data) + 1): |
|
|
for j in range(len(headers)): |
|
|
if i % 2 == 0: |
|
|
table[(i, j)].set_facecolor('#ecf0f1') |
|
|
|
|
|
plt.title('Table 1: Domain-Specific Rejection Rates and Quality Metrics', |
|
|
fontsize=12, fontweight='bold', pad=20) |
|
|
|
|
|
output_path = FIGURES_DIR / "table1_domain_comparison.png" |
|
|
plt.savefig(output_path, bbox_inches='tight', dpi=300) |
|
|
plt.close() |
|
|
|
|
|
print(f" β
Saved: {output_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Generate all visualizations.""" |
|
|
|
|
|
print("=" * 60) |
|
|
print("Generating Publication-Quality Visualizations") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\nLoading data...") |
|
|
cross_domain_df = pd.read_csv(DATA_DIR / "phase1_cross_domain.csv") |
|
|
ablation_df = pd.read_csv(DATA_DIR / "phase3_ablation.csv") |
|
|
quality_df = pd.read_csv(DATA_DIR / "quality_metrics.csv") |
|
|
print(f"β
Data loaded") |
|
|
|
|
|
|
|
|
figure3_rejection_by_domain(cross_domain_df) |
|
|
figure4_rejection_vs_position(cross_domain_df) |
|
|
figure5_mask_performance_heatmap(ablation_df) |
|
|
figure6_throughput_quality_tradeoff(ablation_df) |
|
|
figure_domain_comparison_table(cross_domain_df, quality_df) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print(f"β
All figures generated!") |
|
|
print(f" Saved to: {FIGURES_DIR}") |
|
|
print("=" * 60) |
|
|
|
|
|
print("\n=== Generated Figures ===") |
|
|
for fig_path in sorted(FIGURES_DIR.glob("*.png")): |
|
|
print(f" - {fig_path.name}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|