serliezer commited on
Commit
0a0d5dc
·
verified ·
1 Parent(s): 1f8dd95

Add scripts/run_real.py

Browse files
Files changed (1) hide show
  1. scripts/run_real.py +195 -0
scripts/run_real.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run real-data experiments."""
3
+ import os
4
+ import sys
5
+ import json
6
+ import time
7
+ import argparse
8
+ import yaml
9
+ import numpy as np
10
+ from datetime import datetime
11
+
12
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13
+
14
+ from src.data import load_lastfm_data, load_movielens_data, sample_deletions
15
+ from src.model import PoissonGammaVI
16
+ from src.graph_utils import build_adjacency, compute_graph_stats
17
+ from src.metrics import (compute_all_metrics, compute_deletion_influence_by_distance,
18
+ fit_exponential_decay, compute_local_error, compute_chi_poisson_gamma,
19
+ compute_gradient_interference)
20
+ from src.unlearning import one_step_downdate_poisson_gamma
21
+ from src.utils import generate_run_id, generate_config_id, save_jsonl, ensure_dir
22
+
23
+
24
+ def run_real_dataset(dataset_name, edges, N, M, preprocessing, config):
25
+ """Run deletion experiments on a real dataset."""
26
+ K_values = config.get('K_values', [5, 10])
27
+ num_deletions = config.get('num_deletions', 50)
28
+ radii = config.get('radii', [1, 2, 3, 4])
29
+ prior = config.get('prior', {})
30
+ a0 = prior.get('a0', 0.3)
31
+ b0 = prior.get('b0', 1.0)
32
+ c0 = prior.get('c0', 0.3)
33
+ d0 = prior.get('d0', 1.0)
34
+ max_iter = config.get('max_iter', 300)
35
+ tol = config.get('tol', 1e-4)
36
+ seed = config.get('seed', 42)
37
+
38
+ all_records = []
39
+
40
+ for K in K_values:
41
+ print(f"\n K={K}")
42
+ run_id = generate_run_id()
43
+ config_id = generate_config_id({**config, 'K': K, 'dataset': dataset_name})
44
+
45
+ model = PoissonGammaVI(N, M, K, a0, b0, c0, d0, max_iter=max_iter, tol=tol, seed=seed)
46
+
47
+ print(f" Fitting full model...")
48
+ t0 = time.time()
49
+ full_result = model.fit_full(edges)
50
+ t_full = time.time() - t0
51
+ full_params = full_result.params
52
+ print(f" Full fit: {full_result.n_iterations} iters, {t_full:.1f}s")
53
+
54
+ user_to_items, item_to_users, edge_dict = build_adjacency(edges, N, M)
55
+ deletion_samples = sample_deletions(edges, user_to_items, item_to_users, num_deletions, seed=seed)
56
+
57
+ print(f" Running {len(deletion_samples)} deletions...")
58
+
59
+ for del_idx, (edge_to_del, del_type) in enumerate(deletion_samples):
60
+ if del_idx % 10 == 0:
61
+ print(f" Deletion {del_idx+1}/{len(deletion_samples)}")
62
+
63
+ i_del, j_del, x_del = edge_to_del
64
+
65
+ # Exact
66
+ exact_result = model.fit_without_edge(edges, edge_to_del, init_params=full_params)
67
+ exact_params = exact_result.params
68
+
69
+ # Local
70
+ local_results = {}
71
+ local_params = {}
72
+ for R in radii:
73
+ lr = model.fit_local(edges, edge_to_del, R, init_params=full_params)
74
+ local_results[R] = lr
75
+ local_params[R] = lr.params
76
+
77
+ # Warm-start
78
+ ws_result = model.fit_warm_start_global(edges, edge_to_del, init_params=full_params)
79
+
80
+ # One-step
81
+ os_result = one_step_downdate_poisson_gamma(
82
+ edges, edge_to_del, full_params, N, M, K, a0, b0, c0, d0)
83
+
84
+ # Metrics
85
+ model_kwargs = {'a0': a0, 'b0': b0, 'c0': c0, 'd0': d0}
86
+ metrics = compute_all_metrics(
87
+ full_params, exact_params, local_params,
88
+ ws_result.params, os_result.params,
89
+ edge_to_del, edges, N, M, K,
90
+ 'poisson_gamma', model=model, radii=radii,
91
+ model_kwargs=model_kwargs)
92
+
93
+ record = {
94
+ 'run_id': run_id,
95
+ 'config_id': config_id,
96
+ 'dataset_type': 'real',
97
+ 'dataset_name': dataset_name,
98
+ 'model_family': 'poisson_gamma',
99
+ 'inference_type': 'vi',
100
+ 'likelihood': 'poisson',
101
+ 'prior': 'gamma',
102
+ 'N': N, 'M': M, 'K': K,
103
+ 'n_edges': len(edges),
104
+ 'deletion_edge': [int(i_del), int(j_del), float(x_del)],
105
+ 'deletion_type': del_type,
106
+ 'deletion_index': del_idx,
107
+ 'runtime_full': t_full,
108
+ 'runtime_exact': exact_result.runtime_sec,
109
+ 'runtime_warm_start': ws_result.runtime_sec,
110
+ 'runtime_one_step': os_result.runtime_sec,
111
+ 'exact_converged': exact_result.converged,
112
+ 'a0': a0, 'b0': b0, 'c0': c0, 'd0': d0,
113
+ }
114
+
115
+ for R in radii:
116
+ record[f'runtime_local_R{R}'] = local_results[R].runtime_sec
117
+ record[f'local_R{R}_converged'] = local_results[R].converged
118
+
119
+ record.update(metrics)
120
+
121
+ if 'influence_by_distance' in record:
122
+ for d_str, val in record['influence_by_distance'].items():
123
+ record[f'influence_d{d_str}'] = val
124
+
125
+ all_records.append(record)
126
+
127
+ return all_records
128
+
129
+
130
+ def main():
131
+ parser = argparse.ArgumentParser()
132
+ parser.add_argument('--config', type=str, default='config/real_data.yaml')
133
+ parser.add_argument('--datasets', nargs='*', default=None)
134
+ args = parser.parse_args()
135
+
136
+ with open(args.config) as f:
137
+ real_cfg = yaml.safe_load(f)
138
+
139
+ output_dir = ensure_dir('results/raw')
140
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
141
+ output_file = os.path.join(output_dir, f'real_{timestamp}.jsonl')
142
+
143
+ datasets_to_run = args.datasets or list(real_cfg.keys())
144
+
145
+ for ds_name in datasets_to_run:
146
+ if ds_name not in real_cfg:
147
+ print(f"Unknown dataset config: {ds_name}")
148
+ continue
149
+
150
+ ds_cfg = real_cfg[ds_name]
151
+ print(f"\n{'='*60}")
152
+ print(f"Dataset: {ds_name}")
153
+ print(f"{'='*60}")
154
+
155
+ # Load data
156
+ if 'lastfm' in ds_name:
157
+ edges, N, M, preproc = load_lastfm_data(
158
+ max_users=ds_cfg.get('max_users', 1000),
159
+ max_items=ds_cfg.get('max_items', 1000),
160
+ max_edges=ds_cfg.get('max_edges', 50000),
161
+ min_user_degree=ds_cfg.get('min_user_degree', 5),
162
+ min_item_degree=ds_cfg.get('min_item_degree', 5),
163
+ max_count=ds_cfg.get('max_count', 100),
164
+ seed=ds_cfg.get('seed', 42))
165
+ elif 'movielens' in ds_name:
166
+ mode = ds_cfg.get('mode', 'rating_count')
167
+ edges, N, M, preproc = load_movielens_data(
168
+ mode=mode,
169
+ max_users=ds_cfg.get('max_users', 1000),
170
+ max_items=ds_cfg.get('max_items', 1000),
171
+ max_edges=ds_cfg.get('max_edges', 50000),
172
+ min_user_degree=ds_cfg.get('min_user_degree', 5),
173
+ min_item_degree=ds_cfg.get('min_item_degree', 5),
174
+ seed=ds_cfg.get('seed', 42))
175
+ else:
176
+ print(f" Unsupported dataset: {ds_name}")
177
+ continue
178
+
179
+ # Save preprocessing
180
+ preproc_dir = ensure_dir('results/reports')
181
+ with open(os.path.join(preproc_dir, f'dataset_card_{ds_name}.json'), 'w') as f:
182
+ json.dump(preproc, f, indent=2)
183
+
184
+ graph_stats = compute_graph_stats([(e[0], e[1]) for e in edges], N, M)
185
+ print(f" Graph stats: {json.dumps(graph_stats, indent=2)}")
186
+
187
+ records = run_real_dataset(ds_name, edges, N, M, preproc, ds_cfg)
188
+ save_jsonl(records, output_file)
189
+ print(f" Saved {len(records)} records for {ds_name}")
190
+
191
+ print(f"\nOutput: {output_file}")
192
+
193
+
194
+ if __name__ == '__main__':
195
+ main()