File size: 8,944 Bytes
167c746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
"""
Generate all visualizations for speculative decoding paper.
Creates publication-quality figures matching PAPER_OUTLINE.md specifications.
Author: Claude Code
Date: 2025-11-30
"""
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from typing import Dict, List
# Set publication style
plt.style.use('seaborn-v0_8-paper')
sns.set_palette("colorblind")
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['xtick.labelsize'] = 9
plt.rcParams['ytick.labelsize'] = 9
# Directories
DATA_DIR = Path(__file__).parent.parent / "data"
FIGURES_DIR = Path(__file__).parent.parent / "paper" / "figures"
FIGURES_DIR.mkdir(parents=True, exist_ok=True)
def figure3_rejection_by_domain(df: pd.DataFrame):
"""Bar chart: Rejection rates by domain."""
print("\nπ Generating Figure 3: Rejection by Domain...")
# Calculate rejection rates
rejection_rates = df.groupby('domain')['is_rejected'].mean().sort_values()
fig, ax = plt.subplots(figsize=(8, 5))
colors = ['#2ecc71', '#3498db', '#e74c3c', '#e67e22']
bars = ax.bar(range(len(rejection_rates)), rejection_rates.values * 100, color=colors)
# Labels
ax.set_xlabel('Domain')
ax.set_ylabel('Rejection Rate (%)')
ax.set_title('Draft Rejection Rates by Domain')
ax.set_xticks(range(len(rejection_rates)))
ax.set_xticklabels([d.replace('_', '-').title() for d in rejection_rates.index], rotation=15, ha='right')
ax.set_ylim(0, 40)
ax.grid(axis='y', alpha=0.3)
# Add value labels on bars
for i, (bar, val) in enumerate(zip(bars, rejection_rates.values)):
ax.text(bar.get_x() + bar.get_width()/2, val*100 + 1, f'{val*100:.1f}%',
ha='center', va='bottom', fontsize=9, fontweight='bold')
plt.tight_layout()
output_path = FIGURES_DIR / "figure3_rejection_by_domain.png"
plt.savefig(output_path, bbox_inches='tight')
plt.close()
print(f" β
Saved: {output_path}")
def figure4_rejection_vs_position(df: pd.DataFrame):
"""Line plot: Rejection rate vs token position."""
print("\nπ Generating Figure 4: Rejection vs Position...")
# Bin positions for smoother plot
df['position_bin'] = pd.cut(df['token_position'], bins=20)
position_rates = df.groupby('position_bin')['is_rejected'].mean()
# Get bin centers
bin_centers = [(interval.left + interval.right) / 2 for interval in position_rates.index]
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(bin_centers, position_rates.values * 100, marker='o', linewidth=2, markersize=6,
color='#3498db', label='Rejection Rate')
# Highlight regions
ax.axvspan(0, 20, alpha=0.1, color='red', label='Early (<20)')
ax.axvspan(100, max(bin_centers), alpha=0.1, color='green', label='Late (>100)')
ax.set_xlabel('Token Position in Sequence')
ax.set_ylabel('Rejection Rate (%)')
ax.set_title('Draft Rejection Rate by Token Position')
ax.set_ylim(20, 35)
ax.grid(alpha=0.3)
ax.legend()
plt.tight_layout()
output_path = FIGURES_DIR / "figure4_rejection_vs_position.png"
plt.savefig(output_path, bbox_inches='tight')
plt.close()
print(f" β
Saved: {output_path}")
def figure5_mask_performance_heatmap(df: pd.DataFrame):
"""Heatmap: Mask performance by domain."""
print("\nπ Generating Figure 5: Mask Performance Heatmap...")
# Pivot table: domain x mask β acceptance rate
pivot = df.groupby(['domain', 'mask_type'])['is_accepted'].mean().unstack() * 100
# Reorder for better display
mask_order = ['causal', 'tidar', 'bidirectional', 'windowed', 'strided']
domain_order = ['code', 'math', 'translation']
pivot = pivot.loc[domain_order, mask_order]
fig, ax = plt.subplots(figsize=(10, 5))
sns.heatmap(pivot, annot=True, fmt='.1f', cmap='RdYlGn', vmin=5, vmax=35,
cbar_kws={'label': 'Acceptance Rate (%)'}, ax=ax, linewidths=0.5)
ax.set_xlabel('Attention Mask Type')
ax.set_ylabel('Domain')
ax.set_title('Acceptance Rate by Domain and Attention Mask')
ax.set_yticklabels([d.replace('_', '-').title() for d in domain_order], rotation=0)
ax.set_xticklabels([m.title() for m in mask_order], rotation=15, ha='right')
plt.tight_layout()
output_path = FIGURES_DIR / "figure5_mask_performance_heatmap.png"
plt.savefig(output_path, bbox_inches='tight')
plt.close()
print(f" β
Saved: {output_path}")
def figure6_throughput_quality_tradeoff(ablation_df: pd.DataFrame):
"""Scatter plot: Throughput vs quality trade-off."""
print("\nπ Generating Figure 6: Throughput-Quality Trade-off...")
# Aggregate by mask
mask_stats = ablation_df.groupby('mask_type').agg({
'throughput_tokens_per_sec': 'mean',
'is_accepted': 'mean'
}).reset_index()
fig, ax = plt.subplots(figsize=(8, 6))
colors = {'causal': '#3498db', 'tidar': '#9b59b6', 'bidirectional': '#2ecc71',
'windowed': '#e74c3c', 'strided': '#e67e22'}
for _, row in mask_stats.iterrows():
ax.scatter(row['throughput_tokens_per_sec'], row['is_accepted'] * 100,
s=200, color=colors.get(row['mask_type'], 'gray'),
alpha=0.7, edgecolors='black', linewidth=1.5)
ax.text(row['throughput_tokens_per_sec'] + 5, row['is_accepted'] * 100 + 1,
row['mask_type'].title(), fontsize=9, fontweight='bold')
ax.set_xlabel('Throughput (tokens/second)')
ax.set_ylabel('Acceptance Rate (%)')
ax.set_title('Throughput-Quality Trade-off Across Attention Masks')
ax.grid(alpha=0.3)
ax.set_xlim(40, 150)
plt.tight_layout()
output_path = FIGURES_DIR / "figure6_throughput_quality_tradeoff.png"
plt.savefig(output_path, bbox_inches='tight')
plt.close()
print(f" β
Saved: {output_path}")
def figure_domain_comparison_table(df: pd.DataFrame, quality_df: pd.DataFrame):
"""Generate formatted table image for domain comparison."""
print("\nπ Generating Table 1: Domain Comparison...")
# Aggregate stats
domain_stats = df.groupby('domain').agg({
'is_rejected': 'mean',
'sequence_length': 'mean'
}).reset_index()
# Merge with quality metrics
domain_stats = domain_stats.merge(quality_df, on='domain', how='left')
# Format table
fig, ax = plt.subplots(figsize=(12, 4))
ax.axis('tight')
ax.axis('off')
table_data = []
for _, row in domain_stats.iterrows():
table_data.append([
row['domain'].replace('_', '-').title(),
f"{row['is_rejected']*100:.1f}%",
f"{row['metric']}",
f"{row['value']:.2f}" if row['value'] < 1 else f"{row['value']:.1f}",
f"{row['samples']}"
])
headers = ['Domain', 'Rejection Rate', 'Quality Metric', 'Score', 'Samples']
table = ax.table(cellText=table_data, colLabels=headers, loc='center',
cellLoc='center', colWidths=[0.2, 0.2, 0.2, 0.15, 0.15])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2)
# Style header
for i in range(len(headers)):
table[(0, i)].set_facecolor('#3498db')
table[(0, i)].set_text_props(weight='bold', color='white')
# Alternate row colors
for i in range(1, len(table_data) + 1):
for j in range(len(headers)):
if i % 2 == 0:
table[(i, j)].set_facecolor('#ecf0f1')
plt.title('Table 1: Domain-Specific Rejection Rates and Quality Metrics',
fontsize=12, fontweight='bold', pad=20)
output_path = FIGURES_DIR / "table1_domain_comparison.png"
plt.savefig(output_path, bbox_inches='tight', dpi=300)
plt.close()
print(f" β
Saved: {output_path}")
def main():
"""Generate all visualizations."""
print("=" * 60)
print("Generating Publication-Quality Visualizations")
print("=" * 60)
# Load data
print("\nLoading data...")
cross_domain_df = pd.read_csv(DATA_DIR / "phase1_cross_domain.csv")
ablation_df = pd.read_csv(DATA_DIR / "phase3_ablation.csv")
quality_df = pd.read_csv(DATA_DIR / "quality_metrics.csv")
print(f"β
Data loaded")
# Generate figures
figure3_rejection_by_domain(cross_domain_df)
figure4_rejection_vs_position(cross_domain_df)
figure5_mask_performance_heatmap(ablation_df)
figure6_throughput_quality_tradeoff(ablation_df)
figure_domain_comparison_table(cross_domain_df, quality_df)
print("\n" + "=" * 60)
print(f"β
All figures generated!")
print(f" Saved to: {FIGURES_DIR}")
print("=" * 60)
print("\n=== Generated Figures ===")
for fig_path in sorted(FIGURES_DIR.glob("*.png")):
print(f" - {fig_path.name}")
if __name__ == "__main__":
main()
|