#!/usr/bin/env python3 """ Large-scale experiment runner for NeurIPS-quality results. Runs synthetic (full grid), model-family ablation, and real-data experiments. Includes sanity checks and bootstrap CIs. """ import os, sys, json, time, yaml, argparse import numpy as np from datetime import datetime from collections import defaultdict sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.data import (generate_gamma_poisson_data, generate_gaussian_gaussian_data, generate_gaussian_gamma_data, sample_deletions) from src.model import PoissonGammaVI, GaussianGaussianVI, GaussianGammaMAP from src.graph_utils import build_adjacency, compute_graph_stats from src.metrics import (compute_all_metrics, compute_deletion_influence_by_distance, fit_exponential_decay, compute_local_error) from src.unlearning import one_step_downdate_poisson_gamma from src.utils import FitResult, generate_run_id, generate_config_id, save_jsonl, ensure_dir # =========================================================== # Sanity checks # =========================================================== def run_sanity_checks(model, edges, full_params, edge_to_del, exact_params, local_params_by_R, ws_params, model_family, N, M, K): """Run all sanity checks, return dict of results.""" checks = {} # 1. ELBO / objective improvement check if hasattr(model, 'compute_elbo'): try: elbo_full = model.compute_elbo(edges, full_params) checks['full_objective'] = elbo_full checks['objective_finite'] = bool(np.isfinite(elbo_full)) except: checks['objective_finite'] = False # 2. Parameters positive (for Gamma models) if model_family == 'poisson_gamma': checks['params_positive'] = bool( np.all(full_params['a'] > 0) and np.all(full_params['b'] > 0) and np.all(full_params['c'] > 0) and np.all(full_params['d'] > 0)) checks['params_no_nan'] = bool( not np.any(np.isnan(full_params['a'])) and not np.any(np.isnan(full_params['b']))) # 3. Responsibilities sum to 1 (for PG) if model_family == 'poisson_gamma': from scipy.special import digamma a, b, c, d = full_params['a'], full_params['b'], full_params['c'], full_params['d'] # Check a few random edges resp_ok = True for edge in edges[:min(20, len(edges))]: i, j, x = edge if x > 0: log_r = digamma(a[i]) - np.log(b[i]) + digamma(c[j]) - np.log(d[j]) log_r -= log_r.max() r = np.exp(log_r) r_sum = r.sum() r /= r_sum if abs(r.sum() - 1.0) > 1e-6: resp_ok = False break checks['responsibilities_sum_to_one'] = resp_ok # 4. Exact deletion differs from full from src.metrics import compute_all_param_vector v_full = compute_all_param_vector(full_params, model_family) v_exact = compute_all_param_vector(exact_params, model_family) diff = np.linalg.norm(v_full - v_exact) checks['exact_differs_from_full'] = bool(diff > 1e-10) checks['exact_full_diff_norm'] = float(diff) # 5. Local error decreases with R errors_by_R = {} for R, lp in sorted(local_params_by_R.items()): err = compute_local_error(lp, exact_params, model_family) errors_by_R[R] = err['relative_error'] checks['errors_by_R'] = errors_by_R if len(errors_by_R) >= 2: R_list = sorted(errors_by_R.keys()) checks['error_decreases_with_R'] = bool(errors_by_R[R_list[-1]] <= errors_by_R[R_list[0]]) # 6. Warm-start matches large-R local (approximate) if ws_params is not None and max(local_params_by_R.keys()) >= 4: ws_err = compute_local_error(ws_params, exact_params, model_family) r4_err = errors_by_R.get(4, None) if r4_err is not None: checks['ws_error'] = ws_err['relative_error'] checks['ws_close_to_R4'] = bool(ws_err['relative_error'] <= r4_err * 5 + 0.01) return checks # =========================================================== # Config builders # =========================================================== def build_full_synthetic_configs(): """Full synthetic grid: 3 graph × 3 degree × 3 count × 2 prior × 3 K = 162 configs.""" configs = [] N, M = 300, 300 radii = [1, 2, 3, 4] num_del = 50 for K in [5, 10, 20]: for gt in ['bounded_degree', 'erdos_renyi', 'power_law']: for deg_name, deg in [('low', 5), ('medium', 10), ('high', 20)]: for cs_name, cs in [('low', 0.5), ('medium', 1.0), ('high', 3.0)]: for ps_name, ps in [('strong', {'a0':1.0,'b0':1.0,'c0':1.0,'d0':1.0}), ('weak', {'a0':0.1,'b0':0.1,'c0':0.1,'d0':0.1})]: configs.append({ 'N': N, 'M': M, 'K': K, 'graph_type': gt, 'avg_degree': deg, 'count_scale': cs, 'count_scale_label': cs_name, 'prior_strength': ps_name, 'prior_config': ps, 'num_deletions': num_del, 'radii': radii, 'seed': 42, 'model_family': 'poisson_gamma', 'max_iter': 300, 'tol': 1e-5, }) return configs def build_model_family_configs(): """Model-family ablation: balanced across 3 families.""" configs = [] N, M = 200, 200 radii = [1, 2, 3, 4] num_del = 30 for K in [5, 10]: for gt in ['bounded_degree', 'erdos_renyi', 'power_law']: for deg in [5, 15]: # Poisson-Gamma for cs_name, cs in [('low', 0.5), ('medium', 1.0), ('high', 3.0)]: for ps_name, ps in [('strong', {'a0':1.0,'b0':1.0,'c0':1.0,'d0':1.0}), ('weak', {'a0':0.3,'b0':0.3,'c0':0.3,'d0':0.3})]: configs.append({ 'N': N, 'M': M, 'K': K, 'graph_type': gt, 'avg_degree': deg, 'count_scale': cs, 'count_scale_label': cs_name, 'prior_strength': ps_name, 'prior_config': ps, 'num_deletions': num_del, 'radii': radii, 'seed': 42, 'model_family': 'poisson_gamma', 'max_iter': 300, 'tol': 1e-5, }) # Gaussian-Gaussian for sx_name, sx in [('high_noise', 2.0), ('medium_noise', 1.0), ('low_noise', 0.3)]: for sp_name, sp in [('strong_prior', 0.5), ('weak_prior', 3.0)]: configs.append({ 'N': N, 'M': M, 'K': K, 'graph_type': gt, 'avg_degree': deg, 'sigma_x': sx, 'sigma_x_label': sx_name, 'sigma_U': sp, 'sigma_V': sp, 'prior_strength': sp_name, 'num_deletions': num_del, 'radii': radii, 'seed': 42, 'model_family': 'gaussian_gaussian', 'max_iter': 300, 'tol': 1e-5, }) # Gaussian-Gamma MAP (with fixed optimizer settings) for sx_name, sx in [('high_noise', 2.0), ('medium_noise', 1.0), ('low_noise', 0.3)]: for gp_name, gp in [('strong', {'a0':2.0,'b0':2.0,'c0':2.0,'d0':2.0}), ('weak', {'a0':0.3,'b0':0.3,'c0':0.3,'d0':0.3})]: configs.append({ 'N': N, 'M': M, 'K': K, 'graph_type': gt, 'avg_degree': deg, 'sigma_x': sx, 'sigma_x_label': sx_name, 'prior_strength': gp_name, 'prior_config': gp, 'num_deletions': num_del, 'radii': radii, 'seed': 42, 'model_family': 'gaussian_gamma_map', 'lr': 0.05, 'max_iter': 2000, 'tol': 1e-6, 'grad_clip': 10.0, }) return configs # =========================================================== # Run single config # =========================================================== def run_config(config): """Run one configuration end-to-end with sanity checks.""" model_family = config['model_family'] gt = config['graph_type'] N, M, K = config['N'], config['M'], config['K'] avg_degree = config['avg_degree'] radii = config.get('radii', [1, 2, 3, 4]) num_del = config.get('num_deletions', 50) seed = config.get('seed', 42) max_iter = config.get('max_iter', 300) tol = config.get('tol', 1e-5) config_id = generate_config_id(config) run_id = generate_run_id() prior_cfg = config.get('prior_config', {}) a0 = prior_cfg.get('a0', 0.3) b0 = prior_cfg.get('b0', 1.0) c0 = prior_cfg.get('c0', 0.3) d0 = prior_cfg.get('d0', 1.0) count_scale = config.get('count_scale', 1.0) prior_strength = config.get('prior_strength', 'strong') # Generate data if model_family == 'poisson_gamma': edges, U_true, V_true, ge = generate_gamma_poisson_data( N, M, K, gt, avg_degree, count_scale, a0, b0, c0, d0, seed=seed) elif model_family == 'gaussian_gaussian': sigma_U = config.get('sigma_U', 1.0) sigma_V = config.get('sigma_V', 1.0) sigma_x = config.get('sigma_x', 1.0) edges, U_true, V_true, ge = generate_gaussian_gaussian_data( N, M, K, gt, avg_degree, sigma_U, sigma_V, sigma_x, seed=seed) elif model_family == 'gaussian_gamma_map': sigma_x = config.get('sigma_x', 1.0) edges, U_true, V_true, ge = generate_gaussian_gamma_data( N, M, K, gt, avg_degree, a0, b0, c0, d0, sigma_x, seed=seed) if len(edges) < 10: print(f" SKIP: only {len(edges)} edges") return [] # Create model if model_family == 'poisson_gamma': model = PoissonGammaVI(N, M, K, a0, b0, c0, d0, max_iter=max_iter, tol=tol, seed=seed) elif model_family == 'gaussian_gaussian': model = GaussianGaussianVI(N, M, K, sigma_U=config.get('sigma_U', 1.0), sigma_V=config.get('sigma_V', 1.0), sigma_x=config.get('sigma_x', 1.0), max_iter=max_iter, tol=tol, seed=seed) elif model_family == 'gaussian_gamma_map': model = GaussianGammaMAP(N, M, K, a0, b0, c0, d0, sigma_x=config.get('sigma_x', 1.0), lr=config.get('lr', 0.05), max_iter=max_iter, tol=tol, seed=seed, grad_clip=config.get('grad_clip', 10.0)) # Fit t0 = time.time() full_result = model.fit_full(edges) t_full = time.time() - t0 full_params = full_result.params # Deletions u2i, i2u, ed = build_adjacency(edges, N, M) dels = sample_deletions(edges, u2i, i2u, num_del, seed=seed) records = [] sanity_results = [] for del_idx, (edge_to_del, del_type) in enumerate(dels): i_del, j_del, x_del = edge_to_del # Exact exact_result = model.fit_without_edge(edges, edge_to_del, init_params=full_params) # Local local_results = {} local_params = {} for R in radii: lr = model.fit_local(edges, edge_to_del, R, init_params=full_params) local_results[R] = lr local_params[R] = lr.params # Warm-start ws_result = model.fit_warm_start_global(edges, edge_to_del, init_params=full_params) # One-step (PG only) one_step_params = None one_step_runtime = None if model_family == 'poisson_gamma': os_res = one_step_downdate_poisson_gamma( edges, edge_to_del, full_params, N, M, K, a0, b0, c0, d0) one_step_params = os_res.params one_step_runtime = os_res.runtime_sec # Metrics model_kwargs = {} if model_family == 'poisson_gamma': model_kwargs = {'a0': a0, 'b0': b0, 'c0': c0, 'd0': d0} else: model_kwargs = {'sigma_x': config.get('sigma_x', 1.0)} metrics = compute_all_metrics( full_params, exact_result.params, local_params, ws_result.params, one_step_params, edge_to_del, edges, N, M, K, model_family, radii=radii, model_kwargs=model_kwargs) # Sanity checks (first 3 deletions only) if del_idx < 3: sanity = run_sanity_checks( model, edges, full_params, edge_to_del, exact_result.params, local_params, ws_result.params, model_family, N, M, K) sanity_results.append(sanity) # Build record record = { 'run_id': run_id, 'config_id': config_id, 'dataset_type': 'synthetic', 'dataset_name': f'synthetic_{model_family}', 'model_family': model_family, 'inference_type': 'vi' if model_family != 'gaussian_gamma_map' else 'map', 'likelihood': 'poisson' if model_family == 'poisson_gamma' else 'gaussian', 'prior': 'gamma' if 'gamma' in model_family else 'gaussian', 'graph_type': gt, 'seed': seed, 'N': N, 'M': M, 'K': K, 'avg_degree': avg_degree, 'count_scale': count_scale if model_family == 'poisson_gamma' else None, 'prior_strength': prior_strength, 'deletion_edge': [int(i_del), int(j_del), float(x_del)], 'deletion_type': del_type, 'deletion_index': del_idx, 'runtime_full': t_full, 'runtime_exact': exact_result.runtime_sec, 'runtime_warm_start': ws_result.runtime_sec, 'runtime_one_step': one_step_runtime, 'exact_converged': exact_result.converged, 'exact_iterations': exact_result.n_iterations, 'full_converged': full_result.converged, } for R in radii: record[f'runtime_local_R{R}'] = local_results[R].runtime_sec record[f'local_R{R}_converged'] = local_results[R].converged record[f'local_R{R}_iterations'] = local_results[R].n_iterations record.update(metrics) if 'influence_by_distance' in record: for d_str, val in record['influence_by_distance'].items(): record[f'influence_d{d_str}'] = val record['regime'] = f"{gt}_{prior_strength}_deg{avg_degree}" if model_family == 'poisson_gamma': record['regime'] += f"_cs{count_scale}" record['a0'] = a0; record['b0'] = b0; record['c0'] = c0; record['d0'] = d0 if model_family in ('gaussian_gaussian', 'gaussian_gamma_map'): record['sigma_x'] = config.get('sigma_x', 1.0) if model_family == 'gaussian_gaussian': record['sigma_U'] = config.get('sigma_U', 1.0) record['sigma_V'] = config.get('sigma_V', 1.0) records.append(record) return records, sanity_results def main(): parser = argparse.ArgumentParser() parser.add_argument('--mode', type=str, default='full_synthetic', choices=['full_synthetic', 'model_family', 'both']) parser.add_argument('--max_configs', type=int, default=None) args = parser.parse_args() if args.mode in ('full_synthetic', 'both'): configs = build_full_synthetic_configs() label = 'full_synthetic' elif args.mode == 'model_family': configs = build_model_family_configs() label = 'model_family_v2' if args.mode == 'both': configs += build_model_family_configs() label = 'all' if args.max_configs: configs = configs[:args.max_configs] print(f"Running {len(configs)} configs ({args.mode})") output_dir = ensure_dir('results/raw') ts = datetime.now().strftime("%Y%m%d_%H%M%S") output_file = os.path.join(output_dir, f'{label}_{ts}.jsonl') sanity_file = os.path.join(output_dir, f'sanity_{label}_{ts}.jsonl') total_records = 0 all_sanity = [] for idx, config in enumerate(configs): mf = config['model_family'] gt = config['graph_type'] K = config['K'] print(f"\n[{idx+1}/{len(configs)}] {mf} {gt} K={K} deg={config['avg_degree']} ps={config.get('prior_strength','')}") try: records, sanity = run_config(config) total_records += len(records) all_sanity.extend(sanity) save_jsonl(records, output_file) print(f" -> {len(records)} records (total: {total_records})") except Exception as e: print(f" ERROR: {e}") import traceback; traceback.print_exc() # Save sanity checks save_jsonl(all_sanity, sanity_file) print(f"\n{'='*60}") print(f"Done. {total_records} records in {output_file}") print(f"Sanity checks: {len(all_sanity)} in {sanity_file}") if __name__ == '__main__': main()