serliezer commited on
Commit
71d7a3b
·
verified ·
1 Parent(s): 714d8d5

v2: analyze_results.py

Browse files
Files changed (1) hide show
  1. scripts/analyze_results.py +59 -11
scripts/analyze_results.py CHANGED
@@ -61,22 +61,36 @@ def process_synthetic(df):
61
 
62
 
63
  def compute_correlations(df):
64
- """Compute correlation table between proxies and error metrics."""
 
 
 
 
65
  rows = []
66
 
67
  proxy_cols = ['chi_seed_max', 'chi_seed_sum', 'seed_degree']
 
68
  target_cols = ['rel_error_R2', 'rel_error_R4', 'interference_cosine_R2']
69
 
70
- # By dataset/regime grouping
71
- if 'dataset_name' in df.columns:
72
- groups = df.groupby('dataset_name')
73
- elif 'regime' in df.columns:
74
- groups = df.groupby('regime')
 
 
 
 
 
 
 
 
 
75
  else:
76
- groups = [('all', df)]
77
 
78
- for grp_name, grp_df in groups:
79
- for proxy in proxy_cols:
80
  for target in target_cols:
81
  if proxy in grp_df.columns and target in grp_df.columns:
82
  x = grp_df[proxy].dropna()
@@ -88,8 +102,13 @@ def compute_correlations(df):
88
  x, y = x[mask], y[mask]
89
 
90
  if len(x) >= 5:
91
- pr, pp = stats.pearsonr(x, y)
92
- sr, sp = stats.spearmanr(x, y)
 
 
 
 
 
93
  rows.append({
94
  'dataset_regime': grp_name,
95
  'model_family': grp_df['model_family'].iloc[0] if 'model_family' in grp_df.columns else 'unknown',
@@ -245,6 +264,35 @@ def main():
245
  save_table(real_summary, os.path.join(tables_dir, 'table_real_datasets'),
246
  'Real Dataset Summary')
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  # Table 3: Correlations
249
  corr_df = compute_correlations(df)
250
  if len(corr_df) > 0:
 
61
 
62
 
63
  def compute_correlations(df):
64
+ """Compute correlation table between proxies and error metrics.
65
+
66
+ Computes within-regime correlations (controlling for graph structure)
67
+ and also log-transformed chi correlations.
68
+ """
69
  rows = []
70
 
71
  proxy_cols = ['chi_seed_max', 'chi_seed_sum', 'seed_degree']
72
+ log_proxy_cols = ['log_chi_max', 'log_chi_sum']
73
  target_cols = ['rel_error_R2', 'rel_error_R4', 'interference_cosine_R2']
74
 
75
+ # Add log-chi columns
76
+ df_copy = df.copy()
77
+ if 'chi_seed_max' in df_copy.columns:
78
+ df_copy['log_chi_max'] = np.log1p(df_copy['chi_seed_max'].clip(lower=0))
79
+ if 'chi_seed_sum' in df_copy.columns:
80
+ df_copy['log_chi_sum'] = np.log1p(df_copy['chi_seed_sum'].clip(lower=0))
81
+
82
+ all_proxies = proxy_cols + log_proxy_cols
83
+
84
+ # Within-regime correlations (most informative)
85
+ if 'regime' in df_copy.columns:
86
+ regime_groups = df_copy.groupby('regime')
87
+ elif 'dataset_name' in df_copy.columns:
88
+ regime_groups = df_copy.groupby('dataset_name')
89
  else:
90
+ regime_groups = [('all', df_copy)]
91
 
92
+ for grp_name, grp_df in regime_groups:
93
+ for proxy in all_proxies:
94
  for target in target_cols:
95
  if proxy in grp_df.columns and target in grp_df.columns:
96
  x = grp_df[proxy].dropna()
 
102
  x, y = x[mask], y[mask]
103
 
104
  if len(x) >= 5:
105
+ try:
106
+ pr, pp = stats.pearsonr(x, y)
107
+ sr, sp = stats.spearmanr(x, y)
108
+ except:
109
+ continue
110
+ if np.isnan(pr) or np.isnan(sr):
111
+ continue
112
  rows.append({
113
  'dataset_regime': grp_name,
114
  'model_family': grp_df['model_family'].iloc[0] if 'model_family' in grp_df.columns else 'unknown',
 
264
  save_table(real_summary, os.path.join(tables_dir, 'table_real_datasets'),
265
  'Real Dataset Summary')
266
 
267
+ # Bootstrap CIs for key metrics
268
+ from src.metrics import compute_bootstrap_summary
269
+
270
+ metric_cols = ['empirical_decay_mu', 'rel_error_R1', 'rel_error_R2', 'rel_error_R3',
271
+ 'rel_error_R4', 'chi_seed_max', 'interference_cosine_R2',
272
+ 'rel_error_warm_start', 'rel_error_one_step']
273
+
274
+ if len(syn_df) > 0:
275
+ boot_syn = compute_bootstrap_summary(
276
+ syn_df, ['graph_type', 'prior_strength', 'K'], metric_cols)
277
+ if len(boot_syn) > 0:
278
+ save_table(boot_syn, os.path.join(tables_dir, 'table_synthetic_bootstrap'),
279
+ 'Synthetic Bootstrap CIs')
280
+
281
+ if len(real_df) > 0:
282
+ boot_real = compute_bootstrap_summary(
283
+ real_df, ['dataset_name', 'K'], metric_cols)
284
+ if len(boot_real) > 0:
285
+ save_table(boot_real, os.path.join(tables_dir, 'table_real_bootstrap'),
286
+ 'Real Data Bootstrap CIs')
287
+
288
+ if 'model_family' in df.columns and df['model_family'].nunique() > 1:
289
+ boot_mf = compute_bootstrap_summary(
290
+ df[df['dataset_type'] == 'synthetic'],
291
+ ['model_family', 'graph_type'], metric_cols)
292
+ if len(boot_mf) > 0:
293
+ save_table(boot_mf, os.path.join(tables_dir, 'table_model_family_bootstrap'),
294
+ 'Model Family Bootstrap CIs')
295
+
296
  # Table 3: Correlations
297
  corr_df = compute_correlations(df)
298
  if len(corr_df) > 0: