#!/usr/bin/env python3 """Run synthetic experiments for Dobrushin unlearning theory validation.""" import os import sys import json import time import argparse import yaml import numpy as np from collections import defaultdict from datetime import datetime # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.data import generate_gamma_poisson_data, generate_gaussian_gaussian_data, generate_gaussian_gamma_data, sample_deletions from src.model import PoissonGammaVI, GaussianGaussianVI, GaussianGammaMAP, get_model from src.graph_utils import build_adjacency, compute_graph_stats from src.metrics import (compute_all_metrics, compute_deletion_influence_by_distance, fit_exponential_decay, compute_local_error, compute_gradient_interference, compute_chi_poisson_gamma, compute_chi_gaussian) from src.unlearning import one_step_downdate_poisson_gamma from src.utils import FitResult, generate_run_id, generate_config_id, save_jsonl, ensure_dir def run_single_config(config, mode='pilot'): """Run all deletions for a single configuration.""" graph_type = config['graph_type'] N = config['N'] M = config['M'] K = config['K'] avg_degree = config['avg_degree'] count_scale = config.get('count_scale', 1.0) prior_strength = config.get('prior_strength', 'strong') prior_cfg = config.get('prior_config', {}) a0 = prior_cfg.get('a0', 0.3) b0 = prior_cfg.get('b0', 1.0) c0 = prior_cfg.get('c0', 0.3) d0 = prior_cfg.get('d0', 1.0) num_deletions = config.get('num_deletions', 20) radii = config.get('radii', [1, 2, 3, 4]) max_iter = config.get('max_iter', 500) tol = config.get('tol', 1e-5) seed = config.get('seed', 42) model_family = config.get('model_family', 'poisson_gamma') config_id = generate_config_id(config) run_id = generate_run_id() print(f"\n{'='*60}") print(f"Config: {graph_type}, K={K}, deg={avg_degree}, count={count_scale}, prior={prior_strength}") print(f"Model: {model_family}, Config ID: {config_id}") print(f"{'='*60}") # Generate data if model_family == 'poisson_gamma': edges, U_true, V_true, graph_edges = generate_gamma_poisson_data( N, M, K, graph_type, avg_degree, count_scale, a0, b0, c0, d0, seed=seed) elif model_family == 'gaussian_gaussian': sigma_U = config.get('sigma_U', 1.0) sigma_V = config.get('sigma_V', 1.0) sigma_x = config.get('sigma_x', 1.0) edges, U_true, V_true, graph_edges = generate_gaussian_gaussian_data( N, M, K, graph_type, avg_degree, sigma_U, sigma_V, sigma_x, seed=seed) elif model_family == 'gaussian_gamma_map': sigma_x = config.get('sigma_x', 1.0) edges, U_true, V_true, graph_edges = generate_gaussian_gamma_data( N, M, K, graph_type, avg_degree, a0, b0, c0, d0, sigma_x, seed=seed) if len(edges) < 10: print(f" WARNING: Only {len(edges)} edges generated, skipping config") return [] graph_stats = compute_graph_stats([(e[0], e[1]) for e in edges], N, M) print(f" Graph: {graph_stats['n_edges']} edges, mean user deg={graph_stats['user_degree_mean']:.1f}") # Fit full model print(f" Fitting full model...") t_full_start = time.time() if model_family == 'poisson_gamma': model = PoissonGammaVI(N, M, K, a0, b0, c0, d0, max_iter=max_iter, tol=tol, seed=seed) elif model_family == 'gaussian_gaussian': model = GaussianGaussianVI(N, M, K, sigma_U=config.get('sigma_U', 1.0), sigma_V=config.get('sigma_V', 1.0), sigma_x=config.get('sigma_x', 1.0), max_iter=max_iter, tol=tol, seed=seed) elif model_family == 'gaussian_gamma_map': model = GaussianGammaMAP(N, M, K, a0, b0, c0, d0, sigma_x=config.get('sigma_x', 1.0), lr=config.get('lr', 0.01), max_iter=max_iter, tol=tol, seed=seed) full_result = model.fit_full(edges) full_params = full_result.params t_full = time.time() - t_full_start print(f" Full fit: {full_result.n_iterations} iters, {t_full:.1f}s, converged={full_result.converged}") # Sample deletions user_to_items, item_to_users, edge_dict = build_adjacency(edges, N, M) deletion_samples = sample_deletions(edges, user_to_items, item_to_users, num_deletions, seed=seed) print(f" Running {len(deletion_samples)} deletions...") records = [] for del_idx, (edge_to_del, del_type) in enumerate(deletion_samples): if del_idx % 5 == 0: print(f" Deletion {del_idx+1}/{len(deletion_samples)} ({del_type})") i_del, j_del, x_del = edge_to_del # Exact deletion t0 = time.time() exact_result = model.fit_without_edge(edges, edge_to_del, init_params=full_params) t_exact = time.time() - t0 exact_params = exact_result.params # Local deletions for each radius local_results = {} local_params_by_radius = {} for R in radii: t0 = time.time() local_result = model.fit_local(edges, edge_to_del, R, init_params=full_params) local_results[R] = local_result local_params_by_radius[R] = local_result.params # Warm-start global t0 = time.time() ws_result = model.fit_warm_start_global(edges, edge_to_del, init_params=full_params) # One-step downdate (only for poisson_gamma) one_step_params = None one_step_runtime = None if model_family == 'poisson_gamma': os_result = one_step_downdate_poisson_gamma( edges, edge_to_del, full_params, N, M, K, a0, b0, c0, d0) one_step_params = os_result.params one_step_runtime = os_result.runtime_sec # Compute all metrics model_kwargs = {} if model_family == 'poisson_gamma': model_kwargs = {'a0': a0, 'b0': b0, 'c0': c0, 'd0': d0} else: model_kwargs = {'sigma_x': config.get('sigma_x', 1.0)} metrics = compute_all_metrics( full_params, exact_params, local_params_by_radius, ws_result.params, one_step_params, edge_to_del, edges, N, M, K, model_family, model=model, radii=radii, model_kwargs=model_kwargs) # Build record record = { 'run_id': run_id, 'config_id': config_id, 'dataset_type': 'synthetic', 'dataset_name': f'synthetic_{model_family}', 'model_family': model_family, 'inference_type': 'vi' if model_family != 'gaussian_gamma_map' else 'map', 'likelihood': 'poisson' if model_family == 'poisson_gamma' else 'gaussian', 'prior': 'gamma' if 'gamma' in model_family else 'gaussian', 'graph_type': graph_type, 'seed': seed, 'N': N, 'M': M, 'K': K, 'avg_degree': avg_degree, 'count_scale': count_scale if model_family == 'poisson_gamma' else None, 'prior_strength': prior_strength, 'deletion_edge': [int(i_del), int(j_del), float(x_del)], 'deletion_type': del_type, 'deletion_index': del_idx, # Runtimes 'runtime_full': t_full, 'runtime_exact': t_exact, 'runtime_warm_start': ws_result.runtime_sec, 'runtime_one_step': one_step_runtime, # Exact deletion convergence 'exact_converged': exact_result.converged, 'exact_iterations': exact_result.n_iterations, 'ws_converged': ws_result.converged, 'ws_iterations': ws_result.n_iterations, } # Add local runtimes for R in radii: record[f'runtime_local_R{R}'] = local_results[R].runtime_sec record[f'local_R{R}_converged'] = local_results[R].converged record[f'local_R{R}_iterations'] = local_results[R].n_iterations # Add metrics record.update(metrics) # Flatten influence_by_distance for CSV compatibility if 'influence_by_distance' in record: for d_str, val in record['influence_by_distance'].items(): record[f'influence_d{d_str}'] = val # Add regime label record['regime'] = f"{graph_type}_{prior_strength}_deg{avg_degree}_cs{count_scale}" # Extra config fields if model_family == 'poisson_gamma': record['a0'] = a0 record['b0'] = b0 record['c0'] = c0 record['d0'] = d0 if model_family in ('gaussian_gaussian', 'gaussian_gamma_map'): record['sigma_x'] = config.get('sigma_x', 1.0) if model_family == 'gaussian_gaussian': record['sigma_U'] = config.get('sigma_U', 1.0) record['sigma_V'] = config.get('sigma_V', 1.0) records.append(record) return records def build_config_grid(grid_cfg, mode='pilot'): """Build list of configs from grid specification.""" cfg = grid_cfg[mode] if mode in grid_cfg else grid_cfg.get('pilot', grid_cfg) configs = [] N = cfg['N'] M = cfg['M'] radii = cfg.get('radii', [1, 2, 3, 4]) num_del = cfg.get('num_deletions_per_config', 20) seed = cfg.get('seed', 42) for K in cfg['K_values']: for gt in cfg['graph_types']: for deg_name, deg_val in cfg['avg_degree_levels'].items(): for cs_name, cs_val in cfg['count_scale_levels'].items(): for ps_name, ps_cfg in cfg['prior_strength_configs'].items(): configs.append({ 'N': N, 'M': M, 'K': K, 'graph_type': gt, 'avg_degree': deg_val, 'avg_degree_label': deg_name, 'count_scale': cs_val, 'count_scale_label': cs_name, 'prior_strength': ps_name, 'prior_config': ps_cfg, 'num_deletions': num_del, 'radii': radii, 'seed': seed, 'model_family': 'poisson_gamma', }) return configs def build_model_family_grid(grid_cfg): """Build config grid for model-family ablation.""" cfg = grid_cfg.get('model_family_ablation', {}) if not cfg: return [] configs = [] N = cfg.get('N', 200) M = cfg.get('M', 200) radii = cfg.get('radii', [1, 2, 3, 4]) num_del = cfg.get('num_deletions_per_config', 30) seed = cfg.get('seed', 42) graph_types = cfg.get('graph_types', ['bounded_degree', 'erdos_renyi', 'power_law']) deg_values = cfg.get('avg_degree_values', [5, 15]) K_values = cfg.get('K_values', [5, 10]) # Poisson-Gamma configs pg_cfg = cfg.get('poisson_gamma', {}) for K in K_values: for gt in graph_types: for deg in deg_values: for cs_name, cs_val in pg_cfg.get('count_scale_levels', {'medium': 1.0}).items(): for ps_name, ps_dict in pg_cfg.get('prior_configs', {'strong': {'a0': 1.0, 'b0': 1.0, 'c0': 1.0, 'd0': 1.0}}).items(): configs.append({ 'N': N, 'M': M, 'K': K, 'graph_type': gt, 'avg_degree': deg, 'count_scale': cs_val, 'count_scale_label': cs_name, 'prior_strength': ps_name, 'prior_config': ps_dict, 'num_deletions': num_del, 'radii': radii, 'seed': seed, 'model_family': 'poisson_gamma', }) # Gaussian-Gaussian configs gg_cfg = cfg.get('gaussian_gaussian', {}) for K in K_values: for gt in graph_types: for deg in deg_values: for sx_name, sx_val in gg_cfg.get('sigma_x_values', {'medium_noise': 1.0}).items(): for sp_name, sp_val in gg_cfg.get('sigma_prior_values', {'strong_prior': 0.5}).items(): configs.append({ 'N': N, 'M': M, 'K': K, 'graph_type': gt, 'avg_degree': deg, 'sigma_x': sx_val, 'sigma_x_label': sx_name, 'sigma_U': sp_val, 'sigma_V': sp_val, 'prior_strength': sp_name, 'num_deletions': num_del, 'radii': radii, 'seed': seed, 'model_family': 'gaussian_gaussian', }) # Gaussian-Gamma MAP configs ggm_cfg = cfg.get('gaussian_gamma_map', {}) for K in K_values: for gt in graph_types: for deg in deg_values: for sx_name, sx_val in ggm_cfg.get('sigma_x_values', {'medium_noise': 1.0}).items(): for gp_name, gp_dict in ggm_cfg.get('gamma_prior_strength', {'strong': {'a0': 2.0, 'b0': 2.0, 'c0': 2.0, 'd0': 2.0}}).items(): configs.append({ 'N': N, 'M': M, 'K': K, 'graph_type': gt, 'avg_degree': deg, 'sigma_x': sx_val, 'sigma_x_label': sx_name, 'prior_strength': gp_name, 'prior_config': gp_dict, 'num_deletions': num_del, 'radii': radii, 'seed': seed, 'model_family': 'gaussian_gamma_map', 'lr': 0.005, }) return configs def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='config/synthetic_grid.yaml') parser.add_argument('--mode', type=str, default='pilot', choices=['pilot', 'full', 'model_family']) parser.add_argument('--max_configs', type=int, default=None, help='Limit number of configs for testing') args = parser.parse_args() with open(args.config) as f: grid_cfg = yaml.safe_load(f) if args.mode == 'model_family': configs = build_model_family_grid(grid_cfg) else: configs = build_config_grid(grid_cfg, args.mode) if args.max_configs: configs = configs[:args.max_configs] print(f"Running {len(configs)} configurations in {args.mode} mode") output_dir = ensure_dir('results/raw') timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_file = os.path.join(output_dir, f'synthetic_{args.mode}_{timestamp}.jsonl') all_records = [] for cfg_idx, config in enumerate(configs): print(f"\n>>> Config {cfg_idx+1}/{len(configs)}") try: records = run_single_config(config, mode=args.mode) all_records.extend(records) # Save incrementally save_jsonl(records, output_file) print(f" Saved {len(records)} records (total: {len(all_records)})") except Exception as e: print(f" ERROR in config {cfg_idx+1}: {e}") import traceback traceback.print_exc() # Log to debug debug_entry = { 'config_index': cfg_idx, 'config': {k: str(v) for k, v in config.items()}, 'error': str(e), 'timestamp': datetime.now().isoformat(), } debug_path = 'debug_errors.jsonl' with open(debug_path, 'a') as f: f.write(json.dumps(debug_entry) + '\n') print(f"\n{'='*60}") print(f"Completed. Total records: {len(all_records)}") print(f"Output: {output_file}") return output_file if __name__ == '__main__': main()