Atomic-VSA / scripts /generate_paper_charts.py
marshad180's picture
Update Atomic VSA deployment
fa6bd30 verified
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
# Ensure output directory exists
# Ensure output directory exists
# Output to ../papers relative to this script
output_dir = os.path.join(os.path.dirname(__file__), "..", "papers")
os.makedirs(output_dir, exist_ok=True)
# Set style for scientific publication
plt.style.use('default')
# Use a simple, clean style since 'seaborn-whitegrid' might not be available or 'seaborn' not installed
# We will manually set grid and aesthetic
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
plt.rcParams['figure.dpi'] = 300
# Color palette (Scientific Blue/Orange)
c_blue = '#1f77b4'
c_orange = '#ff7f0e'
c_green = '#2ca02c'
c_red = '#d62728'
# --- Metric 1: F1 Improvement (Bar Chart) ---
def plot_f1_improvement():
data = {
'Strategy': ['Baseline (Fixed τ)', 'Optimized (Adaptive τ)', 'Breakthrough (Argmax)'],
'F1 Score': [81.3, 98.4, 99.6],
'Recall': [75.8, 100.0, 99.5]
}
df = pd.DataFrame(data)
fig, ax = plt.subplots(figsize=(8, 5))
x = np.arange(len(df['Strategy']))
width = 0.35
rects1 = ax.bar(x - width/2, df['F1 Score'], width, label='F1 Score', color=c_blue, alpha=0.8, edgecolor='black')
rects2 = ax.bar(x + width/2, df['Recall'], width, label='Recall', color=c_green, alpha=0.8, edgecolor='black')
ax.set_ylabel('Performance (%)', fontsize=12, fontweight='bold')
ax.set_title('Figure 1: Atomic VSA Optimization Trajectory', fontsize=14, fontweight='bold', pad=15)
ax.set_xticks(x)
ax.set_xticklabels(df['Strategy'], fontsize=10, rotation=0)
ax.set_ylim(60, 105)
ax.legend(loc='lower right')
ax.grid(axis='y', linestyle='--', alpha=0.5)
# Add value labels
def autolabel(rects):
for rect in rects:
height = rect.get_height()
ax.annotate(f'{height}%',
xy=(rect.get_x() + rect.get_width() / 2, height),
xytext=(0, 3), # 3 points vertical offset
textcoords="offset points",
ha='center', va='bottom', fontweight='bold')
autolabel(rects1)
autolabel(rects2)
plt.tight_layout()
save_path = os.path.join(output_dir, "fig1_optimization_trajectory.png")
plt.savefig(save_path)
print(f"Generated: {save_path}")
plt.close()
# --- Metric 2: SNR Gap Analysis (Horizontal Bar Chart) ---
def plot_snr_gap():
# Data from Table 1 in the paper
conditions = [
'Acanthamoebiasis', 'Cholera', 'Resp. Tuberculosis',
'Malaria (P. falc)', 'Plasmodium w/ Complications',
'Acute Hep B', 'Acute Hep A', 'Typhoid Fever'
]
snr_gaps = [0.603, 0.482, 0.324, 0.229, 0.188, 0.017, 0.000, 0.000]
colors = [c_green if x > 0.1 else (c_orange if x > 0 else c_red) for x in snr_gaps]
fig, ax = plt.subplots(figsize=(10, 6))
y_pos = np.arange(len(conditions))
bars = ax.barh(y_pos, snr_gaps, color=colors, edgecolor='black', alpha=0.8)
ax.set_yticks(y_pos)
ax.set_yticklabels(conditions)
ax.invert_yaxis() # labels read top-to-bottom
ax.set_xlabel('Resonance Gap (SNR)', fontsize=12, fontweight='bold')
ax.set_title('Figure 2: The Holographic Limit (Resonance Gap Analysis)', fontsize=14, fontweight='bold', pad=15)
ax.axvline(x=0.05, color='red', linestyle='--', label='Noise Floor (0.05)')
ax.legend()
ax.grid(axis='x', linestyle='--', alpha=0.5)
# Add value labels
for i, v in enumerate(snr_gaps):
ax.text(v + 0.01, i + 0.1, f'{v:.3f}', color='black', fontweight='bold')
plt.tight_layout()
save_path = os.path.join(output_dir, "fig2_snr_analysis.png")
plt.savefig(save_path)
print(f"Generated: {save_path}")
plt.close()
# --- Metric 3: Speedup (Log Scale) ---
def plot_speedup():
labels = ['Atomic VSA', 'Neural Net (Inference)']
times_us = [42, 50000] # 42us vs 50ms (50,000us)
fig, ax = plt.subplots(figsize=(8, 4))
y_pos = np.arange(len(labels))
rects = ax.barh(y_pos, times_us, color=[c_green, c_red], edgecolor='black')
ax.set_yticks(y_pos)
ax.set_yticklabels(labels)
ax.invert_yaxis()
ax.set_xlabel('Inference Time (microseconds) - Log Scale', fontsize=12, fontweight='bold')
ax.set_title('Figure 3: Computational Efficiency (Log Scale)', fontsize=14, fontweight='bold')
ax.set_xscale('log')
ax.grid(axis='x', linestyle='--', alpha=0.5)
for i, v in enumerate(times_us):
label = f"{v} µs" if v < 1000 else f"{v/1000} ms"
ax.text(v * 1.1, i, label, va='center', fontweight='bold')
plt.tight_layout()
save_path = os.path.join(output_dir, "fig3_speedup.png")
plt.savefig(save_path)
print(f"Generated: {save_path}")
plt.close()
if __name__ == "__main__":
plot_f1_improvement()
plot_snr_gap()
plot_speedup()