| """Plotting functions for experiment results.""" |
| import os |
| import json |
| import numpy as np |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| from datetime import datetime |
| from scipy import stats |
|
|
| from src.utils import save_sidecar_json, ensure_dir |
|
|
|
|
| def _save_fig(fig, outpath, metadata=None): |
| """Save figure as both PDF and PNG with sidecar metadata.""" |
| ensure_dir(os.path.dirname(outpath)) |
| base = outpath.rsplit('.', 1)[0] |
| |
| fig.savefig(base + '.pdf', bbox_inches='tight', dpi=150) |
| fig.savefig(base + '.png', bbox_inches='tight', dpi=150) |
| plt.close(fig) |
| |
| if metadata is None: |
| metadata = {} |
| metadata['generation_timestamp'] = datetime.now().isoformat() |
| metadata['output_pdf'] = base + '.pdf' |
| metadata['output_png'] = base + '.png' |
| save_sidecar_json(outpath, metadata) |
| |
| return base + '.pdf', base + '.png' |
|
|
|
|
| def plot_influence_vs_distance(df, outpath, group_col='graph_type', |
| title='Deletion Influence vs Distance', |
| input_files=None): |
| """Plot mean deletion influence vs graph distance (semilog-y).""" |
| fig, ax = plt.subplots(figsize=(8, 5)) |
| |
| groups = df[group_col].unique() if group_col in df.columns else ['all'] |
| colors = plt.cm.tab10(np.linspace(0, 1, max(len(groups), 1))) |
| |
| for idx, grp in enumerate(sorted(groups)): |
| if group_col in df.columns: |
| sub = df[df[group_col] == grp] |
| else: |
| sub = df |
| |
| |
| dist_cols = [c for c in sub.columns if c.startswith('influence_d')] |
| if dist_cols: |
| distances = [] |
| means = [] |
| stds = [] |
| for col in sorted(dist_cols): |
| d = int(col.split('_d')[1]) |
| vals = sub[col].dropna() |
| if len(vals) > 0: |
| distances.append(d) |
| means.append(vals.mean()) |
| stds.append(vals.std() / np.sqrt(len(vals))) |
| |
| means = np.array(means) |
| stds = np.array(stds) |
| ax.semilogy(distances, means, 'o-', color=colors[idx], label=str(grp), linewidth=2) |
| ax.fill_between(distances, |
| np.maximum(means - stds, 1e-15), |
| means + stds, |
| color=colors[idx], alpha=0.15) |
| |
| ax.set_xlabel('Graph Distance from Seed Set', fontsize=12) |
| ax.set_ylabel('Mean Deletion Influence', fontsize=12) |
| ax.set_title(title, fontsize=13) |
| ax.legend(fontsize=10) |
| ax.grid(True, alpha=0.3) |
| |
| meta = {'script': 'plotting.py', 'function': 'plot_influence_vs_distance', |
| 'group_col': group_col, 'input_files': input_files or []} |
| return _save_fig(fig, outpath, meta) |
|
|
|
|
| def plot_decay_ablation(df, outpath, x_col='regime', y_col='mu_emp', |
| title='Empirical Decay Rate Ablation', input_files=None): |
| """Bar chart of empirical decay rate across regimes.""" |
| fig, ax = plt.subplots(figsize=(10, 5)) |
| |
| if x_col in df.columns and y_col in df.columns: |
| grouped = df.groupby(x_col)[y_col].agg(['mean', 'std']).reset_index() |
| grouped = grouped.sort_values('mean', ascending=False) |
| |
| x = np.arange(len(grouped)) |
| ax.bar(x, grouped['mean'], yerr=grouped['std'], capsize=3, |
| color=plt.cm.viridis(np.linspace(0.3, 0.9, len(x)))) |
| ax.set_xticks(x) |
| ax.set_xticklabels(grouped[x_col], rotation=45, ha='right', fontsize=9) |
| |
| ax.set_ylabel('Empirical Decay Rate μ_emp', fontsize=12) |
| ax.set_title(title, fontsize=13) |
| ax.grid(True, alpha=0.3, axis='y') |
| |
| meta = {'script': 'plotting.py', 'function': 'plot_decay_ablation', |
| 'input_files': input_files or []} |
| return _save_fig(fig, outpath, meta) |
|
|
|
|
| def plot_error_vs_radius(df, outpath, group_col='dataset_name', dataset_type='synthetic', |
| title=None, input_files=None): |
| """Plot mean relative error vs radius R.""" |
| if title is None: |
| title = f'Local Approximation Error vs Radius ({dataset_type})' |
| |
| fig, ax = plt.subplots(figsize=(8, 5)) |
| groups = df[group_col].unique() if group_col in df.columns else ['all'] |
| colors = plt.cm.tab10(np.linspace(0, 1, max(len(groups), 1))) |
| |
| for idx, grp in enumerate(sorted(groups)): |
| if group_col in df.columns: |
| sub = df[df[group_col] == grp] |
| else: |
| sub = df |
| |
| err_cols = [c for c in sub.columns if c.startswith('rel_error_R')] |
| if err_cols: |
| radii = [] |
| means = [] |
| stds = [] |
| for col in sorted(err_cols): |
| R = int(col.split('_R')[1]) |
| vals = sub[col].dropna() |
| if len(vals) > 0: |
| radii.append(R) |
| means.append(vals.mean()) |
| stds.append(vals.std() / np.sqrt(len(vals))) |
| |
| means = np.array(means) |
| stds = np.array(stds) |
| ax.semilogy(radii, means, 'o-', color=colors[idx], label=str(grp), linewidth=2) |
| ax.fill_between(radii, |
| np.maximum(means - stds, 1e-15), |
| means + stds, |
| color=colors[idx], alpha=0.15) |
| |
| ax.set_xlabel('Radius R', fontsize=12) |
| ax.set_ylabel('Mean Relative Error', fontsize=12) |
| ax.set_title(title, fontsize=13) |
| ax.legend(fontsize=10) |
| ax.grid(True, alpha=0.3) |
| |
| meta = {'script': 'plotting.py', 'function': 'plot_error_vs_radius', |
| 'dataset_type': dataset_type, 'input_files': input_files or []} |
| return _save_fig(fig, outpath, meta) |
|
|
|
|
| def plot_chi_vs_error(df, outpath, chi_col='chi_seed_max', error_col='rel_error_R2', |
| title=None, input_files=None, use_log_chi=True): |
| """Scatter plot of chi(z) vs local error with regression line. |
| |
| Uses log(1+chi) by default to handle extreme outliers. |
| Colors by graph_type if available for within-regime analysis. |
| """ |
| if title is None: |
| title = f'Interaction Proxy vs Local Error' |
| |
| fig, axes = plt.subplots(1, 2, figsize=(14, 6)) |
| |
| if chi_col in df.columns and error_col in df.columns: |
| x_raw = df[chi_col].dropna() |
| y_raw = df[error_col].dropna() |
| common_idx = x_raw.index.intersection(y_raw.index) |
| x_raw = x_raw.loc[common_idx].values |
| y_raw = y_raw.loc[common_idx].values |
| |
| mask = np.isfinite(x_raw) & np.isfinite(y_raw) & (y_raw > 0) & (x_raw > 0) |
| x_raw, y_raw = x_raw[mask], y_raw[mask] |
| x_log = np.log1p(x_raw) |
| |
| if len(x_raw) > 3: |
| |
| ax = axes[0] |
| ax.scatter(x_raw, y_raw, alpha=0.3, s=15, c='steelblue') |
| sr_raw, sp_raw = stats.spearmanr(x_raw, y_raw) |
| ax.text(0.05, 0.95, f'Spearman ρ={sr_raw:.3f}\n(p={sp_raw:.2e})', |
| transform=ax.transAxes, fontsize=9, verticalalignment='top', |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) |
| ax.set_xlabel('χ_max(z)', fontsize=12) |
| ax.set_ylabel('Relative Error (R=2)', fontsize=12) |
| ax.set_title('Raw χ', fontsize=11) |
| ax.grid(True, alpha=0.3) |
| |
| |
| ax = axes[1] |
| ax.scatter(x_log, y_raw, alpha=0.3, s=15, c='steelblue') |
| slope, intercept, r_value, p_value, _ = stats.linregress(x_log, y_raw) |
| x_line = np.linspace(np.min(x_log), np.max(x_log), 100) |
| ax.plot(x_line, slope * x_line + intercept, 'r--', linewidth=2, label='OLS fit') |
| pr_log, pp_log = stats.pearsonr(x_log, y_raw) |
| sr_log, sp_log = stats.spearmanr(x_log, y_raw) |
| ax.text(0.05, 0.95, f'Pearson r={pr_log:.3f} (p={pp_log:.2e})\nSpearman ρ={sr_log:.3f} (p={sp_log:.2e})', |
| transform=ax.transAxes, fontsize=9, verticalalignment='top', |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) |
| ax.set_xlabel('log(1 + χ_max(z))', fontsize=12) |
| ax.set_ylabel('Relative Error (R=2)', fontsize=12) |
| ax.set_title('Log-transformed χ', fontsize=11) |
| ax.legend(fontsize=9) |
| ax.grid(True, alpha=0.3) |
| |
| fig.suptitle(title, fontsize=13, y=1.02) |
| plt.tight_layout() |
| |
| meta = {'script': 'plotting.py', 'function': 'plot_chi_vs_error', |
| 'chi_col': chi_col, 'error_col': error_col, 'input_files': input_files or []} |
| return _save_fig(fig, outpath, meta) |
|
|
|
|
| def plot_interference_vs_chi(df, outpath, chi_col='chi_seed_max', |
| interf_col='interference_cosine_R2', |
| title='Interference vs Interaction Proxy', input_files=None): |
| """Scatter plot of interference proxy vs chi.""" |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| |
| if chi_col in df.columns and interf_col in df.columns: |
| x = df[chi_col].dropna() |
| y = df[interf_col].dropna() |
| common_idx = x.index.intersection(y.index) |
| x = x.loc[common_idx].values |
| y = y.loc[common_idx].values |
| |
| mask = np.isfinite(x) & np.isfinite(y) |
| x, y = x[mask], y[mask] |
| |
| if len(x) > 3: |
| ax.scatter(x, y, alpha=0.4, s=20, c='darkorange') |
| |
| pearson_r, pearson_p = stats.pearsonr(x, y) |
| spearman_r, spearman_p = stats.spearmanr(x, y) |
| corr_text = f'Pearson r={pearson_r:.3f}\nSpearman ρ={spearman_r:.3f}' |
| ax.text(0.05, 0.95, corr_text, transform=ax.transAxes, fontsize=9, |
| verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.5)) |
| |
| ax.set_xlabel(chi_col.replace('_', ' ').title(), fontsize=12) |
| ax.set_ylabel(interf_col.replace('_', ' ').title(), fontsize=12) |
| ax.set_title(title, fontsize=13) |
| ax.grid(True, alpha=0.3) |
| |
| meta = {'script': 'plotting.py', 'function': 'plot_interference_vs_chi', |
| 'input_files': input_files or []} |
| return _save_fig(fig, outpath, meta) |
|
|
|
|
| def plot_runtime_vs_error(df, outpath, title='Runtime vs Approximation Error', |
| input_files=None): |
| """Scatter plot of runtime vs error for different methods.""" |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| |
| method_cols = { |
| 'exact': ('runtime_exact', None), |
| 'warm_start': ('runtime_warm_start', 'rel_error_warm_start'), |
| 'one_step': ('runtime_one_step', 'rel_error_one_step'), |
| } |
| |
| |
| for R in [1, 2, 3, 4]: |
| method_cols[f'local_R{R}'] = (f'runtime_local_R{R}', f'rel_error_R{R}') |
| |
| markers = ['s', 'D', '^', 'o', 'o', 'o', 'o'] |
| colors_map = {'exact': 'red', 'warm_start': 'blue', 'one_step': 'green', |
| 'local_R1': '#1f77b4', 'local_R2': '#ff7f0e', |
| 'local_R3': '#2ca02c', 'local_R4': '#d62728'} |
| |
| for idx, (method, (rt_col, err_col)) in enumerate(method_cols.items()): |
| if rt_col in df.columns: |
| if err_col is None: |
| rt = df[rt_col].dropna().mean() |
| ax.axvline(rt, color=colors_map.get(method, 'gray'), linestyle='--', |
| alpha=0.5, label=f'{method} (err=0)') |
| elif err_col in df.columns: |
| rt_vals = df[rt_col].dropna() |
| err_vals = df[err_col].dropna() |
| common = rt_vals.index.intersection(err_vals.index) |
| if len(common) > 0: |
| ax.scatter(rt_vals.loc[common].mean(), err_vals.loc[common].mean(), |
| s=120, marker=markers[min(idx, len(markers)-1)], |
| color=colors_map.get(method, 'gray'), |
| label=method, zorder=5, edgecolors='black') |
| |
| ax.set_xlabel('Mean Runtime (seconds)', fontsize=12) |
| ax.set_ylabel('Mean Relative Error', fontsize=12) |
| ax.set_title(title, fontsize=13) |
| ax.legend(fontsize=9) |
| ax.grid(True, alpha=0.3) |
| if ax.get_ylim()[0] > 0: |
| ax.set_yscale('log') |
| |
| meta = {'script': 'plotting.py', 'function': 'plot_runtime_vs_error', |
| 'input_files': input_files or []} |
| return _save_fig(fig, outpath, meta) |
|
|
|
|
| |
| |
| |
|
|
| def plot_model_family_influence(df, outpath, title='Influence Decay by Model Family', |
| input_files=None): |
| """Influence vs distance, grouped by model family.""" |
| fig, ax = plt.subplots(figsize=(8, 5)) |
| |
| families = df['model_family'].unique() if 'model_family' in df.columns else ['poisson_gamma'] |
| family_styles = { |
| 'poisson_gamma': ('o-', '#1f77b4'), |
| 'gaussian_gaussian': ('s--', '#ff7f0e'), |
| 'gaussian_gamma_map': ('^:', '#2ca02c'), |
| } |
| |
| for fam in sorted(families): |
| sub = df[df['model_family'] == fam] if 'model_family' in df.columns else df |
| |
| dist_cols = sorted([c for c in sub.columns if c.startswith('influence_d')]) |
| if dist_cols: |
| distances = [int(c.split('_d')[1]) for c in dist_cols] |
| means = [sub[c].dropna().mean() for c in dist_cols] |
| stds = [sub[c].dropna().std() / np.sqrt(max(1, sub[c].dropna().count())) for c in dist_cols] |
| |
| style, color = family_styles.get(fam, ('o-', 'gray')) |
| means = np.array(means) |
| stds = np.array(stds) |
| ax.semilogy(distances, means, style, color=color, label=fam.replace('_', ' ').title(), linewidth=2) |
| ax.fill_between(distances, np.maximum(means - stds, 1e-15), means + stds, |
| color=color, alpha=0.1) |
| |
| ax.set_xlabel('Graph Distance', fontsize=12) |
| ax.set_ylabel('Mean Deletion Influence', fontsize=12) |
| ax.set_title(title, fontsize=13) |
| ax.legend(fontsize=10) |
| ax.grid(True, alpha=0.3) |
| |
| meta = {'script': 'plotting.py', 'function': 'plot_model_family_influence', |
| 'input_files': input_files or []} |
| return _save_fig(fig, outpath, meta) |
|
|
|
|
| def plot_model_family_decay_mu(df, outpath, title='Empirical Decay Rate by Model Family', |
| input_files=None): |
| """Bar chart of mu_emp grouped by model family and graph type.""" |
| fig, ax = plt.subplots(figsize=(10, 5)) |
| |
| if 'model_family' in df.columns and 'mu_emp' in df.columns: |
| if 'graph_type' in df.columns: |
| grouped = df.groupby(['model_family', 'graph_type'])['mu_emp'].mean().unstack(fill_value=0) |
| else: |
| grouped = df.groupby('model_family')['mu_emp'].mean() |
| grouped = grouped.to_frame() |
| |
| grouped.plot(kind='bar', ax=ax, width=0.7, capsize=3) |
| |
| ax.set_ylabel('Mean Empirical Decay Rate μ_emp', fontsize=12) |
| ax.set_title(title, fontsize=13) |
| ax.legend(fontsize=9, title='Graph Type') |
| ax.grid(True, alpha=0.3, axis='y') |
| plt.xticks(rotation=30, ha='right') |
| |
| meta = {'script': 'plotting.py', 'function': 'plot_model_family_decay_mu', |
| 'input_files': input_files or []} |
| return _save_fig(fig, outpath, meta) |
|
|
|
|
| def plot_model_family_error_vs_radius(df, outpath, |
| title='Error vs Radius by Model Family', |
| input_files=None): |
| """Error vs radius, one line per model family.""" |
| fig, ax = plt.subplots(figsize=(8, 5)) |
| |
| families = df['model_family'].unique() if 'model_family' in df.columns else ['poisson_gamma'] |
| family_styles = { |
| 'poisson_gamma': ('o-', '#1f77b4'), |
| 'gaussian_gaussian': ('s--', '#ff7f0e'), |
| 'gaussian_gamma_map': ('^:', '#2ca02c'), |
| } |
| |
| for fam in sorted(families): |
| sub = df[df['model_family'] == fam] if 'model_family' in df.columns else df |
| |
| err_cols = sorted([c for c in sub.columns if c.startswith('rel_error_R')]) |
| if err_cols: |
| radii = [int(c.split('_R')[1]) for c in err_cols] |
| means = [sub[c].dropna().mean() for c in err_cols] |
| |
| style, color = family_styles.get(fam, ('o-', 'gray')) |
| ax.semilogy(radii, means, style, color=color, label=fam.replace('_', ' ').title(), linewidth=2) |
| |
| ax.set_xlabel('Radius R', fontsize=12) |
| ax.set_ylabel('Mean Relative Error', fontsize=12) |
| ax.set_title(title, fontsize=13) |
| ax.legend(fontsize=10) |
| ax.grid(True, alpha=0.3) |
| |
| meta = {'script': 'plotting.py', 'function': 'plot_model_family_error_vs_radius', |
| 'input_files': input_files or []} |
| return _save_fig(fig, outpath, meta) |
|
|
|
|
| def plot_model_family_proxy_vs_error(df, outpath, |
| title='Interaction Proxy vs Error Across Models', |
| input_files=None): |
| """Scatter of proxy vs error, colored by model family.""" |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| |
| families = df['model_family'].unique() if 'model_family' in df.columns else ['poisson_gamma'] |
| colors_map = { |
| 'poisson_gamma': '#1f77b4', |
| 'gaussian_gaussian': '#ff7f0e', |
| 'gaussian_gamma_map': '#2ca02c', |
| } |
| |
| for fam in sorted(families): |
| sub = df[df['model_family'] == fam] if 'model_family' in df.columns else df |
| |
| chi_col = 'chi_seed_max' |
| err_col = 'rel_error_R2' |
| |
| if chi_col in sub.columns and err_col in sub.columns: |
| x = sub[chi_col].dropna() |
| y = sub[err_col].dropna() |
| common = x.index.intersection(y.index) |
| x, y = x.loc[common].values, y.loc[common].values |
| mask = np.isfinite(x) & np.isfinite(y) & (y > 0) |
| x, y = x[mask], y[mask] |
| |
| if len(x) > 2: |
| ax.scatter(x, y, alpha=0.4, s=20, color=colors_map.get(fam, 'gray'), |
| label=fam.replace('_', ' ').title()) |
| |
| ax.set_xlabel('Interaction Proxy χ_max(z)', fontsize=12) |
| ax.set_ylabel('Relative Error (R=2)', fontsize=12) |
| ax.set_title(title, fontsize=13) |
| ax.legend(fontsize=10) |
| ax.grid(True, alpha=0.3) |
| |
| meta = {'script': 'plotting.py', 'function': 'plot_model_family_proxy_vs_error', |
| 'input_files': input_files or []} |
| return _save_fig(fig, outpath, meta) |
|
|
|
|
| def plot_model_family_prior_noise_ablation(df, outpath, |
| title='Prior/Noise Ablation by Model Family', |
| input_files=None): |
| """Panel plot showing prior/noise effects on locality.""" |
| families = df['model_family'].unique() if 'model_family' in df.columns else ['poisson_gamma'] |
| n_fam = len(families) |
| |
| fig, axes = plt.subplots(1, n_fam, figsize=(5 * n_fam, 5), sharey=True) |
| if n_fam == 1: |
| axes = [axes] |
| |
| for ax, fam in zip(axes, sorted(families)): |
| sub = df[df['model_family'] == fam] if 'model_family' in df.columns else df |
| |
| regime_col = 'prior_strength' if 'prior_strength' in sub.columns else 'regime' |
| y_col = 'mu_emp' |
| |
| if regime_col in sub.columns and y_col in sub.columns: |
| grouped = sub.groupby(regime_col)[y_col].agg(['mean', 'std']).reset_index() |
| x = np.arange(len(grouped)) |
| ax.bar(x, grouped['mean'], yerr=grouped['std'], capsize=3, alpha=0.8) |
| ax.set_xticks(x) |
| ax.set_xticklabels(grouped[regime_col], rotation=30, ha='right') |
| |
| ax.set_title(fam.replace('_', ' ').title(), fontsize=11) |
| ax.set_ylabel('μ_emp' if ax == axes[0] else '', fontsize=12) |
| ax.grid(True, alpha=0.3, axis='y') |
| |
| fig.suptitle(title, fontsize=13, y=1.02) |
| plt.tight_layout() |
| |
| meta = {'script': 'plotting.py', 'function': 'plot_model_family_prior_noise_ablation', |
| 'input_files': input_files or []} |
| return _save_fig(fig, outpath, meta) |
|
|