serliezer's picture
v2: run_scaled.py
a1ceb4f verified
#!/usr/bin/env python3
"""
Large-scale experiment runner for NeurIPS-quality results.
Runs synthetic (full grid), model-family ablation, and real-data experiments.
Includes sanity checks and bootstrap CIs.
"""
import os, sys, json, time, yaml, argparse
import numpy as np
from datetime import datetime
from collections import defaultdict
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.data import (generate_gamma_poisson_data, generate_gaussian_gaussian_data,
generate_gaussian_gamma_data, sample_deletions)
from src.model import PoissonGammaVI, GaussianGaussianVI, GaussianGammaMAP
from src.graph_utils import build_adjacency, compute_graph_stats
from src.metrics import (compute_all_metrics, compute_deletion_influence_by_distance,
fit_exponential_decay, compute_local_error)
from src.unlearning import one_step_downdate_poisson_gamma
from src.utils import FitResult, generate_run_id, generate_config_id, save_jsonl, ensure_dir
# ===========================================================
# Sanity checks
# ===========================================================
def run_sanity_checks(model, edges, full_params, edge_to_del, exact_params,
local_params_by_R, ws_params, model_family, N, M, K):
"""Run all sanity checks, return dict of results."""
checks = {}
# 1. ELBO / objective improvement check
if hasattr(model, 'compute_elbo'):
try:
elbo_full = model.compute_elbo(edges, full_params)
checks['full_objective'] = elbo_full
checks['objective_finite'] = bool(np.isfinite(elbo_full))
except:
checks['objective_finite'] = False
# 2. Parameters positive (for Gamma models)
if model_family == 'poisson_gamma':
checks['params_positive'] = bool(
np.all(full_params['a'] > 0) and np.all(full_params['b'] > 0) and
np.all(full_params['c'] > 0) and np.all(full_params['d'] > 0))
checks['params_no_nan'] = bool(
not np.any(np.isnan(full_params['a'])) and not np.any(np.isnan(full_params['b'])))
# 3. Responsibilities sum to 1 (for PG)
if model_family == 'poisson_gamma':
from scipy.special import digamma
a, b, c, d = full_params['a'], full_params['b'], full_params['c'], full_params['d']
# Check a few random edges
resp_ok = True
for edge in edges[:min(20, len(edges))]:
i, j, x = edge
if x > 0:
log_r = digamma(a[i]) - np.log(b[i]) + digamma(c[j]) - np.log(d[j])
log_r -= log_r.max()
r = np.exp(log_r)
r_sum = r.sum()
r /= r_sum
if abs(r.sum() - 1.0) > 1e-6:
resp_ok = False
break
checks['responsibilities_sum_to_one'] = resp_ok
# 4. Exact deletion differs from full
from src.metrics import compute_all_param_vector
v_full = compute_all_param_vector(full_params, model_family)
v_exact = compute_all_param_vector(exact_params, model_family)
diff = np.linalg.norm(v_full - v_exact)
checks['exact_differs_from_full'] = bool(diff > 1e-10)
checks['exact_full_diff_norm'] = float(diff)
# 5. Local error decreases with R
errors_by_R = {}
for R, lp in sorted(local_params_by_R.items()):
err = compute_local_error(lp, exact_params, model_family)
errors_by_R[R] = err['relative_error']
checks['errors_by_R'] = errors_by_R
if len(errors_by_R) >= 2:
R_list = sorted(errors_by_R.keys())
checks['error_decreases_with_R'] = bool(errors_by_R[R_list[-1]] <= errors_by_R[R_list[0]])
# 6. Warm-start matches large-R local (approximate)
if ws_params is not None and max(local_params_by_R.keys()) >= 4:
ws_err = compute_local_error(ws_params, exact_params, model_family)
r4_err = errors_by_R.get(4, None)
if r4_err is not None:
checks['ws_error'] = ws_err['relative_error']
checks['ws_close_to_R4'] = bool(ws_err['relative_error'] <= r4_err * 5 + 0.01)
return checks
# ===========================================================
# Config builders
# ===========================================================
def build_full_synthetic_configs():
"""Full synthetic grid: 3 graph × 3 degree × 3 count × 2 prior × 3 K = 162 configs."""
configs = []
N, M = 300, 300
radii = [1, 2, 3, 4]
num_del = 50
for K in [5, 10, 20]:
for gt in ['bounded_degree', 'erdos_renyi', 'power_law']:
for deg_name, deg in [('low', 5), ('medium', 10), ('high', 20)]:
for cs_name, cs in [('low', 0.5), ('medium', 1.0), ('high', 3.0)]:
for ps_name, ps in [('strong', {'a0':1.0,'b0':1.0,'c0':1.0,'d0':1.0}),
('weak', {'a0':0.1,'b0':0.1,'c0':0.1,'d0':0.1})]:
configs.append({
'N': N, 'M': M, 'K': K,
'graph_type': gt, 'avg_degree': deg,
'count_scale': cs, 'count_scale_label': cs_name,
'prior_strength': ps_name, 'prior_config': ps,
'num_deletions': num_del, 'radii': radii, 'seed': 42,
'model_family': 'poisson_gamma', 'max_iter': 300, 'tol': 1e-5,
})
return configs
def build_model_family_configs():
"""Model-family ablation: balanced across 3 families."""
configs = []
N, M = 200, 200
radii = [1, 2, 3, 4]
num_del = 30
for K in [5, 10]:
for gt in ['bounded_degree', 'erdos_renyi', 'power_law']:
for deg in [5, 15]:
# Poisson-Gamma
for cs_name, cs in [('low', 0.5), ('medium', 1.0), ('high', 3.0)]:
for ps_name, ps in [('strong', {'a0':1.0,'b0':1.0,'c0':1.0,'d0':1.0}),
('weak', {'a0':0.3,'b0':0.3,'c0':0.3,'d0':0.3})]:
configs.append({
'N': N, 'M': M, 'K': K,
'graph_type': gt, 'avg_degree': deg,
'count_scale': cs, 'count_scale_label': cs_name,
'prior_strength': ps_name, 'prior_config': ps,
'num_deletions': num_del, 'radii': radii, 'seed': 42,
'model_family': 'poisson_gamma', 'max_iter': 300, 'tol': 1e-5,
})
# Gaussian-Gaussian
for sx_name, sx in [('high_noise', 2.0), ('medium_noise', 1.0), ('low_noise', 0.3)]:
for sp_name, sp in [('strong_prior', 0.5), ('weak_prior', 3.0)]:
configs.append({
'N': N, 'M': M, 'K': K,
'graph_type': gt, 'avg_degree': deg,
'sigma_x': sx, 'sigma_x_label': sx_name,
'sigma_U': sp, 'sigma_V': sp,
'prior_strength': sp_name,
'num_deletions': num_del, 'radii': radii, 'seed': 42,
'model_family': 'gaussian_gaussian', 'max_iter': 300, 'tol': 1e-5,
})
# Gaussian-Gamma MAP (with fixed optimizer settings)
for sx_name, sx in [('high_noise', 2.0), ('medium_noise', 1.0), ('low_noise', 0.3)]:
for gp_name, gp in [('strong', {'a0':2.0,'b0':2.0,'c0':2.0,'d0':2.0}),
('weak', {'a0':0.3,'b0':0.3,'c0':0.3,'d0':0.3})]:
configs.append({
'N': N, 'M': M, 'K': K,
'graph_type': gt, 'avg_degree': deg,
'sigma_x': sx, 'sigma_x_label': sx_name,
'prior_strength': gp_name, 'prior_config': gp,
'num_deletions': num_del, 'radii': radii, 'seed': 42,
'model_family': 'gaussian_gamma_map',
'lr': 0.05, 'max_iter': 2000, 'tol': 1e-6, 'grad_clip': 10.0,
})
return configs
# ===========================================================
# Run single config
# ===========================================================
def run_config(config):
"""Run one configuration end-to-end with sanity checks."""
model_family = config['model_family']
gt = config['graph_type']
N, M, K = config['N'], config['M'], config['K']
avg_degree = config['avg_degree']
radii = config.get('radii', [1, 2, 3, 4])
num_del = config.get('num_deletions', 50)
seed = config.get('seed', 42)
max_iter = config.get('max_iter', 300)
tol = config.get('tol', 1e-5)
config_id = generate_config_id(config)
run_id = generate_run_id()
prior_cfg = config.get('prior_config', {})
a0 = prior_cfg.get('a0', 0.3)
b0 = prior_cfg.get('b0', 1.0)
c0 = prior_cfg.get('c0', 0.3)
d0 = prior_cfg.get('d0', 1.0)
count_scale = config.get('count_scale', 1.0)
prior_strength = config.get('prior_strength', 'strong')
# Generate data
if model_family == 'poisson_gamma':
edges, U_true, V_true, ge = generate_gamma_poisson_data(
N, M, K, gt, avg_degree, count_scale, a0, b0, c0, d0, seed=seed)
elif model_family == 'gaussian_gaussian':
sigma_U = config.get('sigma_U', 1.0)
sigma_V = config.get('sigma_V', 1.0)
sigma_x = config.get('sigma_x', 1.0)
edges, U_true, V_true, ge = generate_gaussian_gaussian_data(
N, M, K, gt, avg_degree, sigma_U, sigma_V, sigma_x, seed=seed)
elif model_family == 'gaussian_gamma_map':
sigma_x = config.get('sigma_x', 1.0)
edges, U_true, V_true, ge = generate_gaussian_gamma_data(
N, M, K, gt, avg_degree, a0, b0, c0, d0, sigma_x, seed=seed)
if len(edges) < 10:
print(f" SKIP: only {len(edges)} edges")
return []
# Create model
if model_family == 'poisson_gamma':
model = PoissonGammaVI(N, M, K, a0, b0, c0, d0, max_iter=max_iter, tol=tol, seed=seed)
elif model_family == 'gaussian_gaussian':
model = GaussianGaussianVI(N, M, K, sigma_U=config.get('sigma_U', 1.0),
sigma_V=config.get('sigma_V', 1.0),
sigma_x=config.get('sigma_x', 1.0),
max_iter=max_iter, tol=tol, seed=seed)
elif model_family == 'gaussian_gamma_map':
model = GaussianGammaMAP(N, M, K, a0, b0, c0, d0,
sigma_x=config.get('sigma_x', 1.0),
lr=config.get('lr', 0.05),
max_iter=max_iter, tol=tol, seed=seed,
grad_clip=config.get('grad_clip', 10.0))
# Fit
t0 = time.time()
full_result = model.fit_full(edges)
t_full = time.time() - t0
full_params = full_result.params
# Deletions
u2i, i2u, ed = build_adjacency(edges, N, M)
dels = sample_deletions(edges, u2i, i2u, num_del, seed=seed)
records = []
sanity_results = []
for del_idx, (edge_to_del, del_type) in enumerate(dels):
i_del, j_del, x_del = edge_to_del
# Exact
exact_result = model.fit_without_edge(edges, edge_to_del, init_params=full_params)
# Local
local_results = {}
local_params = {}
for R in radii:
lr = model.fit_local(edges, edge_to_del, R, init_params=full_params)
local_results[R] = lr
local_params[R] = lr.params
# Warm-start
ws_result = model.fit_warm_start_global(edges, edge_to_del, init_params=full_params)
# One-step (PG only)
one_step_params = None
one_step_runtime = None
if model_family == 'poisson_gamma':
os_res = one_step_downdate_poisson_gamma(
edges, edge_to_del, full_params, N, M, K, a0, b0, c0, d0)
one_step_params = os_res.params
one_step_runtime = os_res.runtime_sec
# Metrics
model_kwargs = {}
if model_family == 'poisson_gamma':
model_kwargs = {'a0': a0, 'b0': b0, 'c0': c0, 'd0': d0}
else:
model_kwargs = {'sigma_x': config.get('sigma_x', 1.0)}
metrics = compute_all_metrics(
full_params, exact_result.params, local_params,
ws_result.params, one_step_params,
edge_to_del, edges, N, M, K,
model_family, radii=radii, model_kwargs=model_kwargs)
# Sanity checks (first 3 deletions only)
if del_idx < 3:
sanity = run_sanity_checks(
model, edges, full_params, edge_to_del,
exact_result.params, local_params, ws_result.params,
model_family, N, M, K)
sanity_results.append(sanity)
# Build record
record = {
'run_id': run_id, 'config_id': config_id,
'dataset_type': 'synthetic', 'dataset_name': f'synthetic_{model_family}',
'model_family': model_family,
'inference_type': 'vi' if model_family != 'gaussian_gamma_map' else 'map',
'likelihood': 'poisson' if model_family == 'poisson_gamma' else 'gaussian',
'prior': 'gamma' if 'gamma' in model_family else 'gaussian',
'graph_type': gt, 'seed': seed, 'N': N, 'M': M, 'K': K,
'avg_degree': avg_degree,
'count_scale': count_scale if model_family == 'poisson_gamma' else None,
'prior_strength': prior_strength,
'deletion_edge': [int(i_del), int(j_del), float(x_del)],
'deletion_type': del_type, 'deletion_index': del_idx,
'runtime_full': t_full, 'runtime_exact': exact_result.runtime_sec,
'runtime_warm_start': ws_result.runtime_sec,
'runtime_one_step': one_step_runtime,
'exact_converged': exact_result.converged,
'exact_iterations': exact_result.n_iterations,
'full_converged': full_result.converged,
}
for R in radii:
record[f'runtime_local_R{R}'] = local_results[R].runtime_sec
record[f'local_R{R}_converged'] = local_results[R].converged
record[f'local_R{R}_iterations'] = local_results[R].n_iterations
record.update(metrics)
if 'influence_by_distance' in record:
for d_str, val in record['influence_by_distance'].items():
record[f'influence_d{d_str}'] = val
record['regime'] = f"{gt}_{prior_strength}_deg{avg_degree}"
if model_family == 'poisson_gamma':
record['regime'] += f"_cs{count_scale}"
record['a0'] = a0; record['b0'] = b0; record['c0'] = c0; record['d0'] = d0
if model_family in ('gaussian_gaussian', 'gaussian_gamma_map'):
record['sigma_x'] = config.get('sigma_x', 1.0)
if model_family == 'gaussian_gaussian':
record['sigma_U'] = config.get('sigma_U', 1.0)
record['sigma_V'] = config.get('sigma_V', 1.0)
records.append(record)
return records, sanity_results
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default='full_synthetic',
choices=['full_synthetic', 'model_family', 'both'])
parser.add_argument('--max_configs', type=int, default=None)
args = parser.parse_args()
if args.mode in ('full_synthetic', 'both'):
configs = build_full_synthetic_configs()
label = 'full_synthetic'
elif args.mode == 'model_family':
configs = build_model_family_configs()
label = 'model_family_v2'
if args.mode == 'both':
configs += build_model_family_configs()
label = 'all'
if args.max_configs:
configs = configs[:args.max_configs]
print(f"Running {len(configs)} configs ({args.mode})")
output_dir = ensure_dir('results/raw')
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = os.path.join(output_dir, f'{label}_{ts}.jsonl')
sanity_file = os.path.join(output_dir, f'sanity_{label}_{ts}.jsonl')
total_records = 0
all_sanity = []
for idx, config in enumerate(configs):
mf = config['model_family']
gt = config['graph_type']
K = config['K']
print(f"\n[{idx+1}/{len(configs)}] {mf} {gt} K={K} deg={config['avg_degree']} ps={config.get('prior_strength','')}")
try:
records, sanity = run_config(config)
total_records += len(records)
all_sanity.extend(sanity)
save_jsonl(records, output_file)
print(f" -> {len(records)} records (total: {total_records})")
except Exception as e:
print(f" ERROR: {e}")
import traceback; traceback.print_exc()
# Save sanity checks
save_jsonl(all_sanity, sanity_file)
print(f"\n{'='*60}")
print(f"Done. {total_records} records in {output_file}")
print(f"Sanity checks: {len(all_sanity)} in {sanity_file}")
if __name__ == '__main__':
main()