serliezer's picture
Add scripts/make_all_figures.py
42c6078 verified
#!/usr/bin/env python3
"""Generate all figures from processed results."""
import os
import sys
import json
import glob
import argparse
import pandas as pd
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.plotting import (
plot_influence_vs_distance, plot_decay_ablation, plot_error_vs_radius,
plot_chi_vs_error, plot_interference_vs_chi, plot_runtime_vs_error,
plot_model_family_influence, plot_model_family_decay_mu,
plot_model_family_error_vs_radius, plot_model_family_proxy_vs_error,
plot_model_family_prior_noise_ablation
)
from src.utils import ensure_dir
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config/default.yaml')
args = parser.parse_args()
proc_csv = 'results/processed/all_results.csv'
if not os.path.exists(proc_csv):
print(f"No processed results found at {proc_csv}. Run analyze_results.py first.")
return
df = pd.read_csv(proc_csv)
print(f"Loaded {len(df)} records")
fig_dir = ensure_dir('results/figures')
input_files = [proc_csv]
# Split
syn_df = df[df['dataset_type'] == 'synthetic'] if 'dataset_type' in df.columns else df
real_df = df[df['dataset_type'] == 'real'] if 'dataset_type' in df.columns else pd.DataFrame()
generated_figures = []
# ============================================================
# Synthetic figures
# ============================================================
if len(syn_df) > 0:
pg_df = syn_df[syn_df['model_family'] == 'poisson_gamma'] if 'model_family' in syn_df.columns else syn_df
if len(pg_df) > 0:
# S1: Influence vs distance
print("Generating Figure S1: influence vs distance...")
try:
paths = plot_influence_vs_distance(
pg_df, os.path.join(fig_dir, 'synthetic_influence_vs_distance.pdf'),
group_col='graph_type', title='Synthetic: Deletion Influence vs Distance',
input_files=input_files)
generated_figures.append(('S1', paths))
except Exception as e:
print(f" Error: {e}")
# S2: Decay ablation
print("Generating Figure S2: decay ablation...")
try:
paths = plot_decay_ablation(
pg_df, os.path.join(fig_dir, 'synthetic_decay_ablation.pdf'),
x_col='regime', y_col='empirical_decay_mu',
title='Synthetic: Empirical Decay Rate by Regime',
input_files=input_files)
generated_figures.append(('S2', paths))
except Exception as e:
print(f" Error: {e}")
# S3: Error vs radius
print("Generating Figure S3: error vs radius...")
try:
paths = plot_error_vs_radius(
pg_df, os.path.join(fig_dir, 'synthetic_error_vs_radius.pdf'),
group_col='graph_type', dataset_type='synthetic',
input_files=input_files)
generated_figures.append(('S3', paths))
except Exception as e:
print(f" Error: {e}")
# S4: chi vs error
print("Generating Figure S4: chi vs error...")
try:
paths = plot_chi_vs_error(
pg_df, os.path.join(fig_dir, 'synthetic_chi_vs_error.pdf'),
chi_col='chi_seed_max', error_col='rel_error_R2',
title='Synthetic: χ_max(z) vs Local Error (R=2)',
input_files=input_files)
generated_figures.append(('S4', paths))
except Exception as e:
print(f" Error: {e}")
# S5: Interference vs chi
print("Generating Figure S5: interference vs chi...")
try:
paths = plot_interference_vs_chi(
pg_df, os.path.join(fig_dir, 'synthetic_interference_vs_chi.pdf'),
chi_col='chi_seed_max', interf_col='interference_cosine_R2',
title='Synthetic: Interference vs χ_max(z)',
input_files=input_files)
generated_figures.append(('S5', paths))
except Exception as e:
print(f" Error: {e}")
# S6: Runtime vs error
print("Generating Figure S6: runtime vs error...")
try:
paths = plot_runtime_vs_error(
pg_df, os.path.join(fig_dir, 'synthetic_runtime_vs_error.pdf'),
title='Synthetic: Runtime vs Approximation Error',
input_files=input_files)
generated_figures.append(('S6', paths))
except Exception as e:
print(f" Error: {e}")
# ============================================================
# Real-data figures
# ============================================================
if len(real_df) > 0:
# R1: Influence vs distance
print("Generating Figure R1: real influence vs distance...")
try:
paths = plot_influence_vs_distance(
real_df, os.path.join(fig_dir, 'real_influence_vs_distance.pdf'),
group_col='dataset_name', title='Real Data: Deletion Influence vs Distance',
input_files=input_files)
generated_figures.append(('R1', paths))
except Exception as e:
print(f" Error: {e}")
# R2: Error vs radius
print("Generating Figure R2: real error vs radius...")
try:
paths = plot_error_vs_radius(
real_df, os.path.join(fig_dir, 'real_error_vs_radius.pdf'),
group_col='dataset_name', dataset_type='real',
input_files=input_files)
generated_figures.append(('R2', paths))
except Exception as e:
print(f" Error: {e}")
# R3: chi vs error
print("Generating Figure R3: real chi vs error...")
try:
paths = plot_chi_vs_error(
real_df, os.path.join(fig_dir, 'real_chi_vs_error.pdf'),
chi_col='chi_seed_max', error_col='rel_error_R2',
title='Real Data: χ_max(z) vs Local Error (R=2)',
input_files=input_files)
generated_figures.append(('R3', paths))
except Exception as e:
print(f" Error: {e}")
# R4: Runtime vs error
print("Generating Figure R4: real runtime vs error...")
try:
paths = plot_runtime_vs_error(
real_df, os.path.join(fig_dir, 'real_runtime_vs_error.pdf'),
title='Real Data: Runtime vs Approximation Error',
input_files=input_files)
generated_figures.append(('R4', paths))
except Exception as e:
print(f" Error: {e}")
# ============================================================
# Model-family figures
# ============================================================
if 'model_family' in df.columns and df['model_family'].nunique() > 1:
mf_df = df[df['dataset_type'] == 'synthetic'] if 'dataset_type' in df.columns else df
# M1: Influence by model family
print("Generating Figure M1: model family influence...")
try:
# Filter to bounded_degree for main figure
bd_df = mf_df[mf_df['graph_type'] == 'bounded_degree'] if 'graph_type' in mf_df.columns else mf_df
paths = plot_model_family_influence(
bd_df, os.path.join(fig_dir, 'model_family_influence_vs_distance.pdf'),
title='Influence Decay by Model Family (Bounded Degree)',
input_files=input_files)
generated_figures.append(('M1', paths))
except Exception as e:
print(f" Error: {e}")
# M2: Decay mu by model family
print("Generating Figure M2: model family decay mu...")
try:
# Need per-config aggregated mu_emp
mu_df = mf_df[mf_df['empirical_decay_mu'].notna()] if 'empirical_decay_mu' in mf_df.columns else pd.DataFrame()
if len(mu_df) > 0:
paths = plot_model_family_decay_mu(
mu_df, os.path.join(fig_dir, 'model_family_decay_mu.pdf'),
title='Empirical Decay Rate by Model Family',
input_files=input_files)
generated_figures.append(('M2', paths))
except Exception as e:
print(f" Error: {e}")
# M3: Error vs radius by model family
print("Generating Figure M3: model family error vs radius...")
try:
paths = plot_model_family_error_vs_radius(
mf_df, os.path.join(fig_dir, 'model_family_error_vs_radius.pdf'),
title='Error vs Radius by Model Family',
input_files=input_files)
generated_figures.append(('M3', paths))
except Exception as e:
print(f" Error: {e}")
# M4: Proxy vs error across families
print("Generating Figure M4: model family proxy vs error...")
try:
paths = plot_model_family_proxy_vs_error(
mf_df, os.path.join(fig_dir, 'model_family_proxy_vs_error.pdf'),
title='Interaction Proxy vs Error Across Models',
input_files=input_files)
generated_figures.append(('M4', paths))
except Exception as e:
print(f" Error: {e}")
# M5: Prior/noise ablation
print("Generating Figure M5: prior/noise ablation...")
try:
paths = plot_model_family_prior_noise_ablation(
mf_df, os.path.join(fig_dir, 'model_family_prior_noise_ablation.pdf'),
title='Prior/Noise Ablation by Model Family',
input_files=input_files)
generated_figures.append(('M5', paths))
except Exception as e:
print(f" Error: {e}")
print(f"\nGenerated {len(generated_figures)} figures:")
for name, paths in generated_figures:
print(f" {name}: {paths}")
if __name__ == '__main__':
main()