#!/usr/bin/env python3 """Run real-data experiments at NeurIPS scale.""" import os, sys, json, time, numpy as np from datetime import datetime from collections import defaultdict sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.model import PoissonGammaVI from src.graph_utils import build_adjacency, compute_graph_stats from src.metrics import compute_all_metrics from src.unlearning import one_step_downdate_poisson_gamma from src.data import sample_deletions from src.utils import save_jsonl, ensure_dir def load_lastfm_fast(max_users=1000, max_items=1000, max_edges=50000, min_user_degree=5, min_item_degree=5, max_count=50, seed=42): """Load Last.fm with efficient random sampling.""" from datasets import load_dataset print("Loading Last.fm...") t0 = time.time() ds = load_dataset("matthewfranglen/lastfm-1k", split="train") print(f" Dataset loaded: {len(ds)} rows in {time.time()-t0:.1f}s") rng = np.random.RandomState(seed) # Sample rows efficiently n_sample = min(2_000_000, len(ds)) indices = sorted(rng.choice(len(ds), n_sample, replace=False).tolist()) print(f" Sampling {n_sample} rows...") user_artist_counts = defaultdict(lambda: defaultdict(int)) for idx in indices: row = ds[idx] user_artist_counts[row['user_index']][row['artist_index']] += 1 print(f" {len(user_artist_counts)} unique users") # Filter user_degrees = {u: len(v) for u, v in user_artist_counts.items()} valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree] if len(valid_users) > max_users: valid_users = list(rng.choice(valid_users, max_users, replace=False)) valid_users_set = set(valid_users) item_degree = defaultdict(int) for u in valid_users: for a in user_artist_counts[u]: item_degree[a] += 1 valid_items = set(a for a, d in item_degree.items() if d >= min_item_degree) all_items = sorted(valid_items) if len(all_items) > max_items: all_items = list(rng.choice(all_items, max_items, replace=False)) valid_items_set = set(all_items) user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))} item_map = {a: idx for idx, a in enumerate(sorted(valid_items_set))} edges = [] for u in valid_users_set: for a, count in user_artist_counts[u].items(): if a in valid_items_set: c = min(count, max_count) if c > 0: edges.append((user_map[u], item_map[a], int(c))) if len(edges) > max_edges: idx = rng.choice(len(edges), max_edges, replace=False) edges = [edges[i] for i in idx] N, M = len(user_map), len(item_map) print(f" Last.fm: N={N}, M={M}, edges={len(edges)}") return edges, N, M def load_movielens_fast(mode='rating_count', max_users=1000, max_items=1000, max_edges=50000, min_user_degree=5, min_item_degree=5, seed=42): """Load MovieLens efficiently.""" from datasets import load_dataset print(f"Loading MovieLens ({mode})...") ds = load_dataset("ashraq/movielens_ratings", split="train") rng = np.random.RandomState(seed) user_item_ratings = defaultdict(dict) for row in ds: user_item_ratings[row['user_id']][row['movie_id']] = row['rating'] user_degrees = {u: len(v) for u, v in user_item_ratings.items()} valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree] if len(valid_users) > max_users: valid_users = list(rng.choice(valid_users, max_users, replace=False)) valid_users_set = set(valid_users) item_degree = defaultdict(int) for u in valid_users: for m in user_item_ratings[u]: item_degree[m] += 1 valid_items = set(m for m, d in item_degree.items() if d >= min_item_degree) all_items = sorted(valid_items) if len(all_items) > max_items: all_items = list(rng.choice(all_items, max_items, replace=False)) valid_items_set = set(all_items) user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))} item_map = {m: idx for idx, m in enumerate(sorted(valid_items_set))} edges = [] for u in valid_users_set: for m, rating in user_item_ratings[u].items(): if m in valid_items_set: x = int(np.ceil(rating)) if mode == 'rating_count' else 1 if x > 0: edges.append((user_map[u], item_map[m], x)) if len(edges) > max_edges: idx = rng.choice(len(edges), max_edges, replace=False) edges = [edges[i] for i in idx] N, M = len(user_map), len(item_map) print(f" MovieLens ({mode}): N={N}, M={M}, edges={len(edges)}") return edges, N, M def run_real_experiment(dataset_name, edges, N, M, K, num_deletions=100, a0=0.3, b0=1.0, c0=0.3, d0=1.0, max_iter=300, tol=1e-4, seed=42): """Run full real-data experiment.""" model = PoissonGammaVI(N, M, K, a0, b0, c0, d0, max_iter=max_iter, tol=tol, seed=seed) print(f" Fitting full model (K={K})...") t0 = time.time() full_result = model.fit_full(edges) t_full = time.time() - t0 print(f" {full_result.n_iterations} iters, {t_full:.1f}s, conv={full_result.converged}") u2i, i2u, ed = build_adjacency(edges, N, M) dels = sample_deletions(edges, u2i, i2u, num_deletions, seed=seed) records = [] print(f" Running {len(dels)} deletions...") for idx, (edge, dtype) in enumerate(dels): if idx % 20 == 0: print(f" {idx+1}/{len(dels)}") exact = model.fit_without_edge(edges, edge, init_params=full_result.params) local_params = {}; local_results = {} for R in [1, 2, 3, 4]: lr = model.fit_local(edges, edge, R, init_params=full_result.params) local_results[R] = lr; local_params[R] = lr.params ws = model.fit_warm_start_global(edges, edge, init_params=full_result.params) os_res = one_step_downdate_poisson_gamma( edges, edge, full_result.params, N, M, K, a0, b0, c0, d0) metrics = compute_all_metrics( full_result.params, exact.params, local_params, ws.params, os_res.params, edge, edges, N, M, K, 'poisson_gamma', radii=[1,2,3,4], model_kwargs={'a0':a0,'b0':b0,'c0':c0,'d0':d0}) record = { 'dataset_type': 'real', 'dataset_name': dataset_name, 'model_family': 'poisson_gamma', 'inference_type': 'vi', 'likelihood': 'poisson', 'prior': 'gamma', 'N': N, 'M': M, 'K': K, 'n_edges': len(edges), 'deletion_type': dtype, 'deletion_index': idx, 'runtime_full': t_full, 'runtime_exact': exact.runtime_sec, 'runtime_warm_start': ws.runtime_sec, 'runtime_one_step': os_res.runtime_sec, 'exact_converged': exact.converged, 'a0': a0, 'b0': b0, 'c0': c0, 'd0': d0, } for R in [1,2,3,4]: record[f'runtime_local_R{R}'] = local_results[R].runtime_sec record[f'local_R{R}_converged'] = local_results[R].converged record.update(metrics) if 'influence_by_distance' in record: for d_str, val in record['influence_by_distance'].items(): record[f'influence_d{d_str}'] = val records.append(record) return records def main(): output_dir = ensure_dir('results/raw') ts = datetime.now().strftime("%Y%m%d_%H%M%S") output_file = os.path.join(output_dir, f'real_scaled_{ts}.jsonl') all_records = [] # Last.fm print("="*60) print("LAST.FM") print("="*60) edges_fm, N_fm, M_fm = load_lastfm_fast( max_users=1000, max_items=1000, max_edges=50000, min_user_degree=5, min_item_degree=5, max_count=50, seed=42) gs = compute_graph_stats([(e[0],e[1]) for e in edges_fm], N_fm, M_fm) print(f" Stats: {json.dumps(gs)}") for K in [5, 10]: records = run_real_experiment('lastfm', edges_fm, N_fm, M_fm, K, num_deletions=100, seed=42) all_records.extend(records) save_jsonl(records, output_file) print(f" K={K}: {len(records)} records") # MovieLens rating count print("="*60) print("MOVIELENS RATING COUNT") print("="*60) edges_ml, N_ml, M_ml = load_movielens_fast( mode='rating_count', max_users=1000, max_items=1000, max_edges=50000, min_user_degree=5, min_item_degree=5, seed=42) gs = compute_graph_stats([(e[0],e[1]) for e in edges_ml], N_ml, M_ml) print(f" Stats: {json.dumps(gs)}") for K in [5, 10]: records = run_real_experiment('movielens_rating_count', edges_ml, N_ml, M_ml, K, num_deletions=100, seed=42) all_records.extend(records) save_jsonl(records, output_file) print(f" K={K}: {len(records)} records") # MovieLens binary print("="*60) print("MOVIELENS BINARY") print("="*60) edges_mb, N_mb, M_mb = load_movielens_fast( mode='binary', max_users=1000, max_items=1000, max_edges=50000, min_user_degree=5, min_item_degree=5, seed=42) for K in [5, 10]: records = run_real_experiment('movielens_binary', edges_mb, N_mb, M_mb, K, num_deletions=100, seed=42) all_records.extend(records) save_jsonl(records, output_file) print(f" K={K}: {len(records)} records") print(f"\nTotal: {len(all_records)} records in {output_file}") if __name__ == '__main__': main()