RyeCatcher's picture
Upload folder using huggingface_hub
167c746 verified
raw
history blame
8.94 kB
"""
Generate all visualizations for speculative decoding paper.
Creates publication-quality figures matching PAPER_OUTLINE.md specifications.
Author: Claude Code
Date: 2025-11-30
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from typing import Dict, List
# Set publication style
plt.style.use('seaborn-v0_8-paper')
sns.set_palette("colorblind")
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['xtick.labelsize'] = 9
plt.rcParams['ytick.labelsize'] = 9
# Directories
DATA_DIR = Path(__file__).parent.parent / "data"
FIGURES_DIR = Path(__file__).parent.parent / "paper" / "figures"
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
def figure3_rejection_by_domain(df: pd.DataFrame):
"""Bar chart: Rejection rates by domain."""
print("\nπŸ“Š Generating Figure 3: Rejection by Domain...")
# Calculate rejection rates
rejection_rates = df.groupby('domain')['is_rejected'].mean().sort_values()
fig, ax = plt.subplots(figsize=(8, 5))
colors = ['#2ecc71', '#3498db', '#e74c3c', '#e67e22']
bars = ax.bar(range(len(rejection_rates)), rejection_rates.values * 100, color=colors)
# Labels
ax.set_xlabel('Domain')
ax.set_ylabel('Rejection Rate (%)')
ax.set_title('Draft Rejection Rates by Domain')
ax.set_xticks(range(len(rejection_rates)))
ax.set_xticklabels([d.replace('_', '-').title() for d in rejection_rates.index], rotation=15, ha='right')
ax.set_ylim(0, 40)
ax.grid(axis='y', alpha=0.3)
# Add value labels on bars
for i, (bar, val) in enumerate(zip(bars, rejection_rates.values)):
ax.text(bar.get_x() + bar.get_width()/2, val*100 + 1, f'{val*100:.1f}%',
ha='center', va='bottom', fontsize=9, fontweight='bold')
plt.tight_layout()
output_path = FIGURES_DIR / "figure3_rejection_by_domain.png"
plt.savefig(output_path, bbox_inches='tight')
plt.close()
print(f" βœ… Saved: {output_path}")
def figure4_rejection_vs_position(df: pd.DataFrame):
"""Line plot: Rejection rate vs token position."""
print("\nπŸ“Š Generating Figure 4: Rejection vs Position...")
# Bin positions for smoother plot
df['position_bin'] = pd.cut(df['token_position'], bins=20)
position_rates = df.groupby('position_bin')['is_rejected'].mean()
# Get bin centers
bin_centers = [(interval.left + interval.right) / 2 for interval in position_rates.index]
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(bin_centers, position_rates.values * 100, marker='o', linewidth=2, markersize=6,
color='#3498db', label='Rejection Rate')
# Highlight regions
ax.axvspan(0, 20, alpha=0.1, color='red', label='Early (<20)')
ax.axvspan(100, max(bin_centers), alpha=0.1, color='green', label='Late (>100)')
ax.set_xlabel('Token Position in Sequence')
ax.set_ylabel('Rejection Rate (%)')
ax.set_title('Draft Rejection Rate by Token Position')
ax.set_ylim(20, 35)
ax.grid(alpha=0.3)
ax.legend()
plt.tight_layout()
output_path = FIGURES_DIR / "figure4_rejection_vs_position.png"
plt.savefig(output_path, bbox_inches='tight')
plt.close()
print(f" βœ… Saved: {output_path}")
def figure5_mask_performance_heatmap(df: pd.DataFrame):
"""Heatmap: Mask performance by domain."""
print("\nπŸ“Š Generating Figure 5: Mask Performance Heatmap...")
# Pivot table: domain x mask β†’ acceptance rate
pivot = df.groupby(['domain', 'mask_type'])['is_accepted'].mean().unstack() * 100
# Reorder for better display
mask_order = ['causal', 'tidar', 'bidirectional', 'windowed', 'strided']
domain_order = ['code', 'math', 'translation']
pivot = pivot.loc[domain_order, mask_order]
fig, ax = plt.subplots(figsize=(10, 5))
sns.heatmap(pivot, annot=True, fmt='.1f', cmap='RdYlGn', vmin=5, vmax=35,
cbar_kws={'label': 'Acceptance Rate (%)'}, ax=ax, linewidths=0.5)
ax.set_xlabel('Attention Mask Type')
ax.set_ylabel('Domain')
ax.set_title('Acceptance Rate by Domain and Attention Mask')
ax.set_yticklabels([d.replace('_', '-').title() for d in domain_order], rotation=0)
ax.set_xticklabels([m.title() for m in mask_order], rotation=15, ha='right')
plt.tight_layout()
output_path = FIGURES_DIR / "figure5_mask_performance_heatmap.png"
plt.savefig(output_path, bbox_inches='tight')
plt.close()
print(f" βœ… Saved: {output_path}")
def figure6_throughput_quality_tradeoff(ablation_df: pd.DataFrame):
"""Scatter plot: Throughput vs quality trade-off."""
print("\nπŸ“Š Generating Figure 6: Throughput-Quality Trade-off...")
# Aggregate by mask
mask_stats = ablation_df.groupby('mask_type').agg({
'throughput_tokens_per_sec': 'mean',
'is_accepted': 'mean'
}).reset_index()
fig, ax = plt.subplots(figsize=(8, 6))
colors = {'causal': '#3498db', 'tidar': '#9b59b6', 'bidirectional': '#2ecc71',
'windowed': '#e74c3c', 'strided': '#e67e22'}
for _, row in mask_stats.iterrows():
ax.scatter(row['throughput_tokens_per_sec'], row['is_accepted'] * 100,
s=200, color=colors.get(row['mask_type'], 'gray'),
alpha=0.7, edgecolors='black', linewidth=1.5)
ax.text(row['throughput_tokens_per_sec'] + 5, row['is_accepted'] * 100 + 1,
row['mask_type'].title(), fontsize=9, fontweight='bold')
ax.set_xlabel('Throughput (tokens/second)')
ax.set_ylabel('Acceptance Rate (%)')
ax.set_title('Throughput-Quality Trade-off Across Attention Masks')
ax.grid(alpha=0.3)
ax.set_xlim(40, 150)
plt.tight_layout()
output_path = FIGURES_DIR / "figure6_throughput_quality_tradeoff.png"
plt.savefig(output_path, bbox_inches='tight')
plt.close()
print(f" βœ… Saved: {output_path}")
def figure_domain_comparison_table(df: pd.DataFrame, quality_df: pd.DataFrame):
"""Generate formatted table image for domain comparison."""
print("\nπŸ“Š Generating Table 1: Domain Comparison...")
# Aggregate stats
domain_stats = df.groupby('domain').agg({
'is_rejected': 'mean',
'sequence_length': 'mean'
}).reset_index()
# Merge with quality metrics
domain_stats = domain_stats.merge(quality_df, on='domain', how='left')
# Format table
fig, ax = plt.subplots(figsize=(12, 4))
ax.axis('tight')
ax.axis('off')
table_data = []
for _, row in domain_stats.iterrows():
table_data.append([
row['domain'].replace('_', '-').title(),
f"{row['is_rejected']*100:.1f}%",
f"{row['metric']}",
f"{row['value']:.2f}" if row['value'] < 1 else f"{row['value']:.1f}",
f"{row['samples']}"
])
headers = ['Domain', 'Rejection Rate', 'Quality Metric', 'Score', 'Samples']
table = ax.table(cellText=table_data, colLabels=headers, loc='center',
cellLoc='center', colWidths=[0.2, 0.2, 0.2, 0.15, 0.15])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2)
# Style header
for i in range(len(headers)):
table[(0, i)].set_facecolor('#3498db')
table[(0, i)].set_text_props(weight='bold', color='white')
# Alternate row colors
for i in range(1, len(table_data) + 1):
for j in range(len(headers)):
if i % 2 == 0:
table[(i, j)].set_facecolor('#ecf0f1')
plt.title('Table 1: Domain-Specific Rejection Rates and Quality Metrics',
fontsize=12, fontweight='bold', pad=20)
output_path = FIGURES_DIR / "table1_domain_comparison.png"
plt.savefig(output_path, bbox_inches='tight', dpi=300)
plt.close()
print(f" βœ… Saved: {output_path}")
def main():
"""Generate all visualizations."""
print("=" * 60)
print("Generating Publication-Quality Visualizations")
print("=" * 60)
# Load data
print("\nLoading data...")
cross_domain_df = pd.read_csv(DATA_DIR / "phase1_cross_domain.csv")
ablation_df = pd.read_csv(DATA_DIR / "phase3_ablation.csv")
quality_df = pd.read_csv(DATA_DIR / "quality_metrics.csv")
print(f"βœ… Data loaded")
# Generate figures
figure3_rejection_by_domain(cross_domain_df)
figure4_rejection_vs_position(cross_domain_df)
figure5_mask_performance_heatmap(ablation_df)
figure6_throughput_quality_tradeoff(ablation_df)
figure_domain_comparison_table(cross_domain_df, quality_df)
print("\n" + "=" * 60)
print(f"βœ… All figures generated!")
print(f" Saved to: {FIGURES_DIR}")
print("=" * 60)
print("\n=== Generated Figures ===")
for fig_path in sorted(FIGURES_DIR.glob("*.png")):
print(f" - {fig_path.name}")
if __name__ == "__main__":
main()