| |
| """Run synthetic experiments for Dobrushin unlearning theory validation.""" |
| import os |
| import sys |
| import json |
| import time |
| import argparse |
| import yaml |
| import numpy as np |
| from collections import defaultdict |
| from datetime import datetime |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from src.data import generate_gamma_poisson_data, generate_gaussian_gaussian_data, generate_gaussian_gamma_data, sample_deletions |
| from src.model import PoissonGammaVI, GaussianGaussianVI, GaussianGammaMAP, get_model |
| from src.graph_utils import build_adjacency, compute_graph_stats |
| from src.metrics import (compute_all_metrics, compute_deletion_influence_by_distance, |
| fit_exponential_decay, compute_local_error, compute_gradient_interference, |
| compute_chi_poisson_gamma, compute_chi_gaussian) |
| from src.unlearning import one_step_downdate_poisson_gamma |
| from src.utils import FitResult, generate_run_id, generate_config_id, save_jsonl, ensure_dir |
|
|
|
|
| def run_single_config(config, mode='pilot'): |
| """Run all deletions for a single configuration.""" |
| graph_type = config['graph_type'] |
| N = config['N'] |
| M = config['M'] |
| K = config['K'] |
| avg_degree = config['avg_degree'] |
| count_scale = config.get('count_scale', 1.0) |
| prior_strength = config.get('prior_strength', 'strong') |
| prior_cfg = config.get('prior_config', {}) |
| a0 = prior_cfg.get('a0', 0.3) |
| b0 = prior_cfg.get('b0', 1.0) |
| c0 = prior_cfg.get('c0', 0.3) |
| d0 = prior_cfg.get('d0', 1.0) |
| num_deletions = config.get('num_deletions', 20) |
| radii = config.get('radii', [1, 2, 3, 4]) |
| max_iter = config.get('max_iter', 500) |
| tol = config.get('tol', 1e-5) |
| seed = config.get('seed', 42) |
| model_family = config.get('model_family', 'poisson_gamma') |
| |
| config_id = generate_config_id(config) |
| run_id = generate_run_id() |
| |
| print(f"\n{'='*60}") |
| print(f"Config: {graph_type}, K={K}, deg={avg_degree}, count={count_scale}, prior={prior_strength}") |
| print(f"Model: {model_family}, Config ID: {config_id}") |
| print(f"{'='*60}") |
| |
| |
| if model_family == 'poisson_gamma': |
| edges, U_true, V_true, graph_edges = generate_gamma_poisson_data( |
| N, M, K, graph_type, avg_degree, count_scale, a0, b0, c0, d0, seed=seed) |
| elif model_family == 'gaussian_gaussian': |
| sigma_U = config.get('sigma_U', 1.0) |
| sigma_V = config.get('sigma_V', 1.0) |
| sigma_x = config.get('sigma_x', 1.0) |
| edges, U_true, V_true, graph_edges = generate_gaussian_gaussian_data( |
| N, M, K, graph_type, avg_degree, sigma_U, sigma_V, sigma_x, seed=seed) |
| elif model_family == 'gaussian_gamma_map': |
| sigma_x = config.get('sigma_x', 1.0) |
| edges, U_true, V_true, graph_edges = generate_gaussian_gamma_data( |
| N, M, K, graph_type, avg_degree, a0, b0, c0, d0, sigma_x, seed=seed) |
| |
| if len(edges) < 10: |
| print(f" WARNING: Only {len(edges)} edges generated, skipping config") |
| return [] |
| |
| graph_stats = compute_graph_stats([(e[0], e[1]) for e in edges], N, M) |
| print(f" Graph: {graph_stats['n_edges']} edges, mean user deg={graph_stats['user_degree_mean']:.1f}") |
| |
| |
| print(f" Fitting full model...") |
| t_full_start = time.time() |
| |
| if model_family == 'poisson_gamma': |
| model = PoissonGammaVI(N, M, K, a0, b0, c0, d0, max_iter=max_iter, tol=tol, seed=seed) |
| elif model_family == 'gaussian_gaussian': |
| model = GaussianGaussianVI(N, M, K, sigma_U=config.get('sigma_U', 1.0), |
| sigma_V=config.get('sigma_V', 1.0), |
| sigma_x=config.get('sigma_x', 1.0), |
| max_iter=max_iter, tol=tol, seed=seed) |
| elif model_family == 'gaussian_gamma_map': |
| model = GaussianGammaMAP(N, M, K, a0, b0, c0, d0, |
| sigma_x=config.get('sigma_x', 1.0), |
| lr=config.get('lr', 0.01), |
| max_iter=max_iter, tol=tol, seed=seed) |
| |
| full_result = model.fit_full(edges) |
| full_params = full_result.params |
| t_full = time.time() - t_full_start |
| print(f" Full fit: {full_result.n_iterations} iters, {t_full:.1f}s, converged={full_result.converged}") |
| |
| |
| user_to_items, item_to_users, edge_dict = build_adjacency(edges, N, M) |
| deletion_samples = sample_deletions(edges, user_to_items, item_to_users, num_deletions, seed=seed) |
| |
| print(f" Running {len(deletion_samples)} deletions...") |
| |
| records = [] |
| for del_idx, (edge_to_del, del_type) in enumerate(deletion_samples): |
| if del_idx % 5 == 0: |
| print(f" Deletion {del_idx+1}/{len(deletion_samples)} ({del_type})") |
| |
| i_del, j_del, x_del = edge_to_del |
| |
| |
| t0 = time.time() |
| exact_result = model.fit_without_edge(edges, edge_to_del, init_params=full_params) |
| t_exact = time.time() - t0 |
| exact_params = exact_result.params |
| |
| |
| local_results = {} |
| local_params_by_radius = {} |
| for R in radii: |
| t0 = time.time() |
| local_result = model.fit_local(edges, edge_to_del, R, init_params=full_params) |
| local_results[R] = local_result |
| local_params_by_radius[R] = local_result.params |
| |
| |
| t0 = time.time() |
| ws_result = model.fit_warm_start_global(edges, edge_to_del, init_params=full_params) |
| |
| |
| one_step_params = None |
| one_step_runtime = None |
| if model_family == 'poisson_gamma': |
| os_result = one_step_downdate_poisson_gamma( |
| edges, edge_to_del, full_params, N, M, K, a0, b0, c0, d0) |
| one_step_params = os_result.params |
| one_step_runtime = os_result.runtime_sec |
| |
| |
| model_kwargs = {} |
| if model_family == 'poisson_gamma': |
| model_kwargs = {'a0': a0, 'b0': b0, 'c0': c0, 'd0': d0} |
| else: |
| model_kwargs = {'sigma_x': config.get('sigma_x', 1.0)} |
| |
| metrics = compute_all_metrics( |
| full_params, exact_params, local_params_by_radius, |
| ws_result.params, one_step_params, |
| edge_to_del, edges, N, M, K, |
| model_family, model=model, radii=radii, |
| model_kwargs=model_kwargs) |
| |
| |
| record = { |
| 'run_id': run_id, |
| 'config_id': config_id, |
| 'dataset_type': 'synthetic', |
| 'dataset_name': f'synthetic_{model_family}', |
| 'model_family': model_family, |
| 'inference_type': 'vi' if model_family != 'gaussian_gamma_map' else 'map', |
| 'likelihood': 'poisson' if model_family == 'poisson_gamma' else 'gaussian', |
| 'prior': 'gamma' if 'gamma' in model_family else 'gaussian', |
| 'graph_type': graph_type, |
| 'seed': seed, |
| 'N': N, |
| 'M': M, |
| 'K': K, |
| 'avg_degree': avg_degree, |
| 'count_scale': count_scale if model_family == 'poisson_gamma' else None, |
| 'prior_strength': prior_strength, |
| 'deletion_edge': [int(i_del), int(j_del), float(x_del)], |
| 'deletion_type': del_type, |
| 'deletion_index': del_idx, |
| |
| 'runtime_full': t_full, |
| 'runtime_exact': t_exact, |
| 'runtime_warm_start': ws_result.runtime_sec, |
| 'runtime_one_step': one_step_runtime, |
| |
| 'exact_converged': exact_result.converged, |
| 'exact_iterations': exact_result.n_iterations, |
| 'ws_converged': ws_result.converged, |
| 'ws_iterations': ws_result.n_iterations, |
| } |
| |
| |
| for R in radii: |
| record[f'runtime_local_R{R}'] = local_results[R].runtime_sec |
| record[f'local_R{R}_converged'] = local_results[R].converged |
| record[f'local_R{R}_iterations'] = local_results[R].n_iterations |
| |
| |
| record.update(metrics) |
| |
| |
| if 'influence_by_distance' in record: |
| for d_str, val in record['influence_by_distance'].items(): |
| record[f'influence_d{d_str}'] = val |
| |
| |
| record['regime'] = f"{graph_type}_{prior_strength}_deg{avg_degree}_cs{count_scale}" |
| |
| |
| if model_family == 'poisson_gamma': |
| record['a0'] = a0 |
| record['b0'] = b0 |
| record['c0'] = c0 |
| record['d0'] = d0 |
| if model_family in ('gaussian_gaussian', 'gaussian_gamma_map'): |
| record['sigma_x'] = config.get('sigma_x', 1.0) |
| if model_family == 'gaussian_gaussian': |
| record['sigma_U'] = config.get('sigma_U', 1.0) |
| record['sigma_V'] = config.get('sigma_V', 1.0) |
| |
| records.append(record) |
| |
| return records |
|
|
|
|
| def build_config_grid(grid_cfg, mode='pilot'): |
| """Build list of configs from grid specification.""" |
| cfg = grid_cfg[mode] if mode in grid_cfg else grid_cfg.get('pilot', grid_cfg) |
| |
| configs = [] |
| N = cfg['N'] |
| M = cfg['M'] |
| radii = cfg.get('radii', [1, 2, 3, 4]) |
| num_del = cfg.get('num_deletions_per_config', 20) |
| seed = cfg.get('seed', 42) |
| |
| for K in cfg['K_values']: |
| for gt in cfg['graph_types']: |
| for deg_name, deg_val in cfg['avg_degree_levels'].items(): |
| for cs_name, cs_val in cfg['count_scale_levels'].items(): |
| for ps_name, ps_cfg in cfg['prior_strength_configs'].items(): |
| configs.append({ |
| 'N': N, 'M': M, 'K': K, |
| 'graph_type': gt, |
| 'avg_degree': deg_val, |
| 'avg_degree_label': deg_name, |
| 'count_scale': cs_val, |
| 'count_scale_label': cs_name, |
| 'prior_strength': ps_name, |
| 'prior_config': ps_cfg, |
| 'num_deletions': num_del, |
| 'radii': radii, |
| 'seed': seed, |
| 'model_family': 'poisson_gamma', |
| }) |
| |
| return configs |
|
|
|
|
| def build_model_family_grid(grid_cfg): |
| """Build config grid for model-family ablation.""" |
| cfg = grid_cfg.get('model_family_ablation', {}) |
| if not cfg: |
| return [] |
| |
| configs = [] |
| N = cfg.get('N', 200) |
| M = cfg.get('M', 200) |
| radii = cfg.get('radii', [1, 2, 3, 4]) |
| num_del = cfg.get('num_deletions_per_config', 30) |
| seed = cfg.get('seed', 42) |
| |
| graph_types = cfg.get('graph_types', ['bounded_degree', 'erdos_renyi', 'power_law']) |
| deg_values = cfg.get('avg_degree_values', [5, 15]) |
| K_values = cfg.get('K_values', [5, 10]) |
| |
| |
| pg_cfg = cfg.get('poisson_gamma', {}) |
| for K in K_values: |
| for gt in graph_types: |
| for deg in deg_values: |
| for cs_name, cs_val in pg_cfg.get('count_scale_levels', {'medium': 1.0}).items(): |
| for ps_name, ps_dict in pg_cfg.get('prior_configs', {'strong': {'a0': 1.0, 'b0': 1.0, 'c0': 1.0, 'd0': 1.0}}).items(): |
| configs.append({ |
| 'N': N, 'M': M, 'K': K, |
| 'graph_type': gt, 'avg_degree': deg, |
| 'count_scale': cs_val, 'count_scale_label': cs_name, |
| 'prior_strength': ps_name, 'prior_config': ps_dict, |
| 'num_deletions': num_del, 'radii': radii, 'seed': seed, |
| 'model_family': 'poisson_gamma', |
| }) |
| |
| |
| gg_cfg = cfg.get('gaussian_gaussian', {}) |
| for K in K_values: |
| for gt in graph_types: |
| for deg in deg_values: |
| for sx_name, sx_val in gg_cfg.get('sigma_x_values', {'medium_noise': 1.0}).items(): |
| for sp_name, sp_val in gg_cfg.get('sigma_prior_values', {'strong_prior': 0.5}).items(): |
| configs.append({ |
| 'N': N, 'M': M, 'K': K, |
| 'graph_type': gt, 'avg_degree': deg, |
| 'sigma_x': sx_val, 'sigma_x_label': sx_name, |
| 'sigma_U': sp_val, 'sigma_V': sp_val, |
| 'prior_strength': sp_name, |
| 'num_deletions': num_del, 'radii': radii, 'seed': seed, |
| 'model_family': 'gaussian_gaussian', |
| }) |
| |
| |
| ggm_cfg = cfg.get('gaussian_gamma_map', {}) |
| for K in K_values: |
| for gt in graph_types: |
| for deg in deg_values: |
| for sx_name, sx_val in ggm_cfg.get('sigma_x_values', {'medium_noise': 1.0}).items(): |
| for gp_name, gp_dict in ggm_cfg.get('gamma_prior_strength', {'strong': {'a0': 2.0, 'b0': 2.0, 'c0': 2.0, 'd0': 2.0}}).items(): |
| configs.append({ |
| 'N': N, 'M': M, 'K': K, |
| 'graph_type': gt, 'avg_degree': deg, |
| 'sigma_x': sx_val, 'sigma_x_label': sx_name, |
| 'prior_strength': gp_name, 'prior_config': gp_dict, |
| 'num_deletions': num_del, 'radii': radii, 'seed': seed, |
| 'model_family': 'gaussian_gamma_map', |
| 'lr': 0.005, |
| }) |
| |
| return configs |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', type=str, default='config/synthetic_grid.yaml') |
| parser.add_argument('--mode', type=str, default='pilot', choices=['pilot', 'full', 'model_family']) |
| parser.add_argument('--max_configs', type=int, default=None, help='Limit number of configs for testing') |
| args = parser.parse_args() |
| |
| with open(args.config) as f: |
| grid_cfg = yaml.safe_load(f) |
| |
| if args.mode == 'model_family': |
| configs = build_model_family_grid(grid_cfg) |
| else: |
| configs = build_config_grid(grid_cfg, args.mode) |
| |
| if args.max_configs: |
| configs = configs[:args.max_configs] |
| |
| print(f"Running {len(configs)} configurations in {args.mode} mode") |
| |
| output_dir = ensure_dir('results/raw') |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| output_file = os.path.join(output_dir, f'synthetic_{args.mode}_{timestamp}.jsonl') |
| |
| all_records = [] |
| for cfg_idx, config in enumerate(configs): |
| print(f"\n>>> Config {cfg_idx+1}/{len(configs)}") |
| try: |
| records = run_single_config(config, mode=args.mode) |
| all_records.extend(records) |
| |
| |
| save_jsonl(records, output_file) |
| print(f" Saved {len(records)} records (total: {len(all_records)})") |
| except Exception as e: |
| print(f" ERROR in config {cfg_idx+1}: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| |
| debug_entry = { |
| 'config_index': cfg_idx, |
| 'config': {k: str(v) for k, v in config.items()}, |
| 'error': str(e), |
| 'timestamp': datetime.now().isoformat(), |
| } |
| debug_path = 'debug_errors.jsonl' |
| with open(debug_path, 'a') as f: |
| f.write(json.dumps(debug_entry) + '\n') |
| |
| print(f"\n{'='*60}") |
| print(f"Completed. Total records: {len(all_records)}") |
| print(f"Output: {output_file}") |
| |
| return output_file |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|