| |
| """Run real-data experiments at NeurIPS scale.""" |
| import os, sys, json, time, numpy as np |
| from datetime import datetime |
| from collections import defaultdict |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from src.model import PoissonGammaVI |
| from src.graph_utils import build_adjacency, compute_graph_stats |
| from src.metrics import compute_all_metrics |
| from src.unlearning import one_step_downdate_poisson_gamma |
| from src.data import sample_deletions |
| from src.utils import save_jsonl, ensure_dir |
|
|
|
|
| def load_lastfm_fast(max_users=1000, max_items=1000, max_edges=50000, |
| min_user_degree=5, min_item_degree=5, max_count=50, seed=42): |
| """Load Last.fm with efficient random sampling.""" |
| from datasets import load_dataset |
| print("Loading Last.fm...") |
| t0 = time.time() |
| ds = load_dataset("matthewfranglen/lastfm-1k", split="train") |
| print(f" Dataset loaded: {len(ds)} rows in {time.time()-t0:.1f}s") |
| |
| rng = np.random.RandomState(seed) |
| |
| n_sample = min(2_000_000, len(ds)) |
| indices = sorted(rng.choice(len(ds), n_sample, replace=False).tolist()) |
| |
| print(f" Sampling {n_sample} rows...") |
| user_artist_counts = defaultdict(lambda: defaultdict(int)) |
| for idx in indices: |
| row = ds[idx] |
| user_artist_counts[row['user_index']][row['artist_index']] += 1 |
| |
| print(f" {len(user_artist_counts)} unique users") |
| |
| |
| user_degrees = {u: len(v) for u, v in user_artist_counts.items()} |
| valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree] |
| if len(valid_users) > max_users: |
| valid_users = list(rng.choice(valid_users, max_users, replace=False)) |
| valid_users_set = set(valid_users) |
| |
| item_degree = defaultdict(int) |
| for u in valid_users: |
| for a in user_artist_counts[u]: |
| item_degree[a] += 1 |
| valid_items = set(a for a, d in item_degree.items() if d >= min_item_degree) |
| |
| all_items = sorted(valid_items) |
| if len(all_items) > max_items: |
| all_items = list(rng.choice(all_items, max_items, replace=False)) |
| valid_items_set = set(all_items) |
| |
| user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))} |
| item_map = {a: idx for idx, a in enumerate(sorted(valid_items_set))} |
| |
| edges = [] |
| for u in valid_users_set: |
| for a, count in user_artist_counts[u].items(): |
| if a in valid_items_set: |
| c = min(count, max_count) |
| if c > 0: |
| edges.append((user_map[u], item_map[a], int(c))) |
| |
| if len(edges) > max_edges: |
| idx = rng.choice(len(edges), max_edges, replace=False) |
| edges = [edges[i] for i in idx] |
| |
| N, M = len(user_map), len(item_map) |
| print(f" Last.fm: N={N}, M={M}, edges={len(edges)}") |
| return edges, N, M |
|
|
|
|
| def load_movielens_fast(mode='rating_count', max_users=1000, max_items=1000, |
| max_edges=50000, min_user_degree=5, min_item_degree=5, seed=42): |
| """Load MovieLens efficiently.""" |
| from datasets import load_dataset |
| print(f"Loading MovieLens ({mode})...") |
| ds = load_dataset("ashraq/movielens_ratings", split="train") |
| |
| rng = np.random.RandomState(seed) |
| user_item_ratings = defaultdict(dict) |
| for row in ds: |
| user_item_ratings[row['user_id']][row['movie_id']] = row['rating'] |
| |
| user_degrees = {u: len(v) for u, v in user_item_ratings.items()} |
| valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree] |
| if len(valid_users) > max_users: |
| valid_users = list(rng.choice(valid_users, max_users, replace=False)) |
| valid_users_set = set(valid_users) |
| |
| item_degree = defaultdict(int) |
| for u in valid_users: |
| for m in user_item_ratings[u]: |
| item_degree[m] += 1 |
| valid_items = set(m for m, d in item_degree.items() if d >= min_item_degree) |
| |
| all_items = sorted(valid_items) |
| if len(all_items) > max_items: |
| all_items = list(rng.choice(all_items, max_items, replace=False)) |
| valid_items_set = set(all_items) |
| |
| user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))} |
| item_map = {m: idx for idx, m in enumerate(sorted(valid_items_set))} |
| |
| edges = [] |
| for u in valid_users_set: |
| for m, rating in user_item_ratings[u].items(): |
| if m in valid_items_set: |
| x = int(np.ceil(rating)) if mode == 'rating_count' else 1 |
| if x > 0: |
| edges.append((user_map[u], item_map[m], x)) |
| |
| if len(edges) > max_edges: |
| idx = rng.choice(len(edges), max_edges, replace=False) |
| edges = [edges[i] for i in idx] |
| |
| N, M = len(user_map), len(item_map) |
| print(f" MovieLens ({mode}): N={N}, M={M}, edges={len(edges)}") |
| return edges, N, M |
|
|
|
|
| def run_real_experiment(dataset_name, edges, N, M, K, num_deletions=100, |
| a0=0.3, b0=1.0, c0=0.3, d0=1.0, max_iter=300, tol=1e-4, seed=42): |
| """Run full real-data experiment.""" |
| model = PoissonGammaVI(N, M, K, a0, b0, c0, d0, max_iter=max_iter, tol=tol, seed=seed) |
| |
| print(f" Fitting full model (K={K})...") |
| t0 = time.time() |
| full_result = model.fit_full(edges) |
| t_full = time.time() - t0 |
| print(f" {full_result.n_iterations} iters, {t_full:.1f}s, conv={full_result.converged}") |
| |
| u2i, i2u, ed = build_adjacency(edges, N, M) |
| dels = sample_deletions(edges, u2i, i2u, num_deletions, seed=seed) |
| |
| records = [] |
| print(f" Running {len(dels)} deletions...") |
| for idx, (edge, dtype) in enumerate(dels): |
| if idx % 20 == 0: |
| print(f" {idx+1}/{len(dels)}") |
| |
| exact = model.fit_without_edge(edges, edge, init_params=full_result.params) |
| |
| local_params = {}; local_results = {} |
| for R in [1, 2, 3, 4]: |
| lr = model.fit_local(edges, edge, R, init_params=full_result.params) |
| local_results[R] = lr; local_params[R] = lr.params |
| |
| ws = model.fit_warm_start_global(edges, edge, init_params=full_result.params) |
| os_res = one_step_downdate_poisson_gamma( |
| edges, edge, full_result.params, N, M, K, a0, b0, c0, d0) |
| |
| metrics = compute_all_metrics( |
| full_result.params, exact.params, local_params, |
| ws.params, os_res.params, edge, edges, N, M, K, |
| 'poisson_gamma', radii=[1,2,3,4], |
| model_kwargs={'a0':a0,'b0':b0,'c0':c0,'d0':d0}) |
| |
| record = { |
| 'dataset_type': 'real', 'dataset_name': dataset_name, |
| 'model_family': 'poisson_gamma', 'inference_type': 'vi', |
| 'likelihood': 'poisson', 'prior': 'gamma', |
| 'N': N, 'M': M, 'K': K, 'n_edges': len(edges), |
| 'deletion_type': dtype, 'deletion_index': idx, |
| 'runtime_full': t_full, |
| 'runtime_exact': exact.runtime_sec, |
| 'runtime_warm_start': ws.runtime_sec, |
| 'runtime_one_step': os_res.runtime_sec, |
| 'exact_converged': exact.converged, |
| 'a0': a0, 'b0': b0, 'c0': c0, 'd0': d0, |
| } |
| for R in [1,2,3,4]: |
| record[f'runtime_local_R{R}'] = local_results[R].runtime_sec |
| record[f'local_R{R}_converged'] = local_results[R].converged |
| record.update(metrics) |
| if 'influence_by_distance' in record: |
| for d_str, val in record['influence_by_distance'].items(): |
| record[f'influence_d{d_str}'] = val |
| records.append(record) |
| |
| return records |
|
|
|
|
| def main(): |
| output_dir = ensure_dir('results/raw') |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| output_file = os.path.join(output_dir, f'real_scaled_{ts}.jsonl') |
| |
| all_records = [] |
| |
| |
| print("="*60) |
| print("LAST.FM") |
| print("="*60) |
| edges_fm, N_fm, M_fm = load_lastfm_fast( |
| max_users=1000, max_items=1000, max_edges=50000, |
| min_user_degree=5, min_item_degree=5, max_count=50, seed=42) |
| |
| gs = compute_graph_stats([(e[0],e[1]) for e in edges_fm], N_fm, M_fm) |
| print(f" Stats: {json.dumps(gs)}") |
| |
| for K in [5, 10]: |
| records = run_real_experiment('lastfm', edges_fm, N_fm, M_fm, K, |
| num_deletions=100, seed=42) |
| all_records.extend(records) |
| save_jsonl(records, output_file) |
| print(f" K={K}: {len(records)} records") |
| |
| |
| print("="*60) |
| print("MOVIELENS RATING COUNT") |
| print("="*60) |
| edges_ml, N_ml, M_ml = load_movielens_fast( |
| mode='rating_count', max_users=1000, max_items=1000, max_edges=50000, |
| min_user_degree=5, min_item_degree=5, seed=42) |
| |
| gs = compute_graph_stats([(e[0],e[1]) for e in edges_ml], N_ml, M_ml) |
| print(f" Stats: {json.dumps(gs)}") |
| |
| for K in [5, 10]: |
| records = run_real_experiment('movielens_rating_count', edges_ml, N_ml, M_ml, K, |
| num_deletions=100, seed=42) |
| all_records.extend(records) |
| save_jsonl(records, output_file) |
| print(f" K={K}: {len(records)} records") |
| |
| |
| print("="*60) |
| print("MOVIELENS BINARY") |
| print("="*60) |
| edges_mb, N_mb, M_mb = load_movielens_fast( |
| mode='binary', max_users=1000, max_items=1000, max_edges=50000, |
| min_user_degree=5, min_item_degree=5, seed=42) |
| |
| for K in [5, 10]: |
| records = run_real_experiment('movielens_binary', edges_mb, N_mb, M_mb, K, |
| num_deletions=100, seed=42) |
| all_records.extend(records) |
| save_jsonl(records, output_file) |
| print(f" K={K}: {len(records)} records") |
| |
| print(f"\nTotal: {len(all_records)} records in {output_file}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|