serliezer's picture
v2: run_real_scaled.py
2aad1a6 verified
#!/usr/bin/env python3
"""Run real-data experiments at NeurIPS scale."""
import os, sys, json, time, numpy as np
from datetime import datetime
from collections import defaultdict
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.model import PoissonGammaVI
from src.graph_utils import build_adjacency, compute_graph_stats
from src.metrics import compute_all_metrics
from src.unlearning import one_step_downdate_poisson_gamma
from src.data import sample_deletions
from src.utils import save_jsonl, ensure_dir
def load_lastfm_fast(max_users=1000, max_items=1000, max_edges=50000,
min_user_degree=5, min_item_degree=5, max_count=50, seed=42):
"""Load Last.fm with efficient random sampling."""
from datasets import load_dataset
print("Loading Last.fm...")
t0 = time.time()
ds = load_dataset("matthewfranglen/lastfm-1k", split="train")
print(f" Dataset loaded: {len(ds)} rows in {time.time()-t0:.1f}s")
rng = np.random.RandomState(seed)
# Sample rows efficiently
n_sample = min(2_000_000, len(ds))
indices = sorted(rng.choice(len(ds), n_sample, replace=False).tolist())
print(f" Sampling {n_sample} rows...")
user_artist_counts = defaultdict(lambda: defaultdict(int))
for idx in indices:
row = ds[idx]
user_artist_counts[row['user_index']][row['artist_index']] += 1
print(f" {len(user_artist_counts)} unique users")
# Filter
user_degrees = {u: len(v) for u, v in user_artist_counts.items()}
valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree]
if len(valid_users) > max_users:
valid_users = list(rng.choice(valid_users, max_users, replace=False))
valid_users_set = set(valid_users)
item_degree = defaultdict(int)
for u in valid_users:
for a in user_artist_counts[u]:
item_degree[a] += 1
valid_items = set(a for a, d in item_degree.items() if d >= min_item_degree)
all_items = sorted(valid_items)
if len(all_items) > max_items:
all_items = list(rng.choice(all_items, max_items, replace=False))
valid_items_set = set(all_items)
user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))}
item_map = {a: idx for idx, a in enumerate(sorted(valid_items_set))}
edges = []
for u in valid_users_set:
for a, count in user_artist_counts[u].items():
if a in valid_items_set:
c = min(count, max_count)
if c > 0:
edges.append((user_map[u], item_map[a], int(c)))
if len(edges) > max_edges:
idx = rng.choice(len(edges), max_edges, replace=False)
edges = [edges[i] for i in idx]
N, M = len(user_map), len(item_map)
print(f" Last.fm: N={N}, M={M}, edges={len(edges)}")
return edges, N, M
def load_movielens_fast(mode='rating_count', max_users=1000, max_items=1000,
max_edges=50000, min_user_degree=5, min_item_degree=5, seed=42):
"""Load MovieLens efficiently."""
from datasets import load_dataset
print(f"Loading MovieLens ({mode})...")
ds = load_dataset("ashraq/movielens_ratings", split="train")
rng = np.random.RandomState(seed)
user_item_ratings = defaultdict(dict)
for row in ds:
user_item_ratings[row['user_id']][row['movie_id']] = row['rating']
user_degrees = {u: len(v) for u, v in user_item_ratings.items()}
valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree]
if len(valid_users) > max_users:
valid_users = list(rng.choice(valid_users, max_users, replace=False))
valid_users_set = set(valid_users)
item_degree = defaultdict(int)
for u in valid_users:
for m in user_item_ratings[u]:
item_degree[m] += 1
valid_items = set(m for m, d in item_degree.items() if d >= min_item_degree)
all_items = sorted(valid_items)
if len(all_items) > max_items:
all_items = list(rng.choice(all_items, max_items, replace=False))
valid_items_set = set(all_items)
user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))}
item_map = {m: idx for idx, m in enumerate(sorted(valid_items_set))}
edges = []
for u in valid_users_set:
for m, rating in user_item_ratings[u].items():
if m in valid_items_set:
x = int(np.ceil(rating)) if mode == 'rating_count' else 1
if x > 0:
edges.append((user_map[u], item_map[m], x))
if len(edges) > max_edges:
idx = rng.choice(len(edges), max_edges, replace=False)
edges = [edges[i] for i in idx]
N, M = len(user_map), len(item_map)
print(f" MovieLens ({mode}): N={N}, M={M}, edges={len(edges)}")
return edges, N, M
def run_real_experiment(dataset_name, edges, N, M, K, num_deletions=100,
a0=0.3, b0=1.0, c0=0.3, d0=1.0, max_iter=300, tol=1e-4, seed=42):
"""Run full real-data experiment."""
model = PoissonGammaVI(N, M, K, a0, b0, c0, d0, max_iter=max_iter, tol=tol, seed=seed)
print(f" Fitting full model (K={K})...")
t0 = time.time()
full_result = model.fit_full(edges)
t_full = time.time() - t0
print(f" {full_result.n_iterations} iters, {t_full:.1f}s, conv={full_result.converged}")
u2i, i2u, ed = build_adjacency(edges, N, M)
dels = sample_deletions(edges, u2i, i2u, num_deletions, seed=seed)
records = []
print(f" Running {len(dels)} deletions...")
for idx, (edge, dtype) in enumerate(dels):
if idx % 20 == 0:
print(f" {idx+1}/{len(dels)}")
exact = model.fit_without_edge(edges, edge, init_params=full_result.params)
local_params = {}; local_results = {}
for R in [1, 2, 3, 4]:
lr = model.fit_local(edges, edge, R, init_params=full_result.params)
local_results[R] = lr; local_params[R] = lr.params
ws = model.fit_warm_start_global(edges, edge, init_params=full_result.params)
os_res = one_step_downdate_poisson_gamma(
edges, edge, full_result.params, N, M, K, a0, b0, c0, d0)
metrics = compute_all_metrics(
full_result.params, exact.params, local_params,
ws.params, os_res.params, edge, edges, N, M, K,
'poisson_gamma', radii=[1,2,3,4],
model_kwargs={'a0':a0,'b0':b0,'c0':c0,'d0':d0})
record = {
'dataset_type': 'real', 'dataset_name': dataset_name,
'model_family': 'poisson_gamma', 'inference_type': 'vi',
'likelihood': 'poisson', 'prior': 'gamma',
'N': N, 'M': M, 'K': K, 'n_edges': len(edges),
'deletion_type': dtype, 'deletion_index': idx,
'runtime_full': t_full,
'runtime_exact': exact.runtime_sec,
'runtime_warm_start': ws.runtime_sec,
'runtime_one_step': os_res.runtime_sec,
'exact_converged': exact.converged,
'a0': a0, 'b0': b0, 'c0': c0, 'd0': d0,
}
for R in [1,2,3,4]:
record[f'runtime_local_R{R}'] = local_results[R].runtime_sec
record[f'local_R{R}_converged'] = local_results[R].converged
record.update(metrics)
if 'influence_by_distance' in record:
for d_str, val in record['influence_by_distance'].items():
record[f'influence_d{d_str}'] = val
records.append(record)
return records
def main():
output_dir = ensure_dir('results/raw')
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = os.path.join(output_dir, f'real_scaled_{ts}.jsonl')
all_records = []
# Last.fm
print("="*60)
print("LAST.FM")
print("="*60)
edges_fm, N_fm, M_fm = load_lastfm_fast(
max_users=1000, max_items=1000, max_edges=50000,
min_user_degree=5, min_item_degree=5, max_count=50, seed=42)
gs = compute_graph_stats([(e[0],e[1]) for e in edges_fm], N_fm, M_fm)
print(f" Stats: {json.dumps(gs)}")
for K in [5, 10]:
records = run_real_experiment('lastfm', edges_fm, N_fm, M_fm, K,
num_deletions=100, seed=42)
all_records.extend(records)
save_jsonl(records, output_file)
print(f" K={K}: {len(records)} records")
# MovieLens rating count
print("="*60)
print("MOVIELENS RATING COUNT")
print("="*60)
edges_ml, N_ml, M_ml = load_movielens_fast(
mode='rating_count', max_users=1000, max_items=1000, max_edges=50000,
min_user_degree=5, min_item_degree=5, seed=42)
gs = compute_graph_stats([(e[0],e[1]) for e in edges_ml], N_ml, M_ml)
print(f" Stats: {json.dumps(gs)}")
for K in [5, 10]:
records = run_real_experiment('movielens_rating_count', edges_ml, N_ml, M_ml, K,
num_deletions=100, seed=42)
all_records.extend(records)
save_jsonl(records, output_file)
print(f" K={K}: {len(records)} records")
# MovieLens binary
print("="*60)
print("MOVIELENS BINARY")
print("="*60)
edges_mb, N_mb, M_mb = load_movielens_fast(
mode='binary', max_users=1000, max_items=1000, max_edges=50000,
min_user_degree=5, min_item_degree=5, seed=42)
for K in [5, 10]:
records = run_real_experiment('movielens_binary', edges_mb, N_mb, M_mb, K,
num_deletions=100, seed=42)
all_records.extend(records)
save_jsonl(records, output_file)
print(f" K={K}: {len(records)} records")
print(f"\nTotal: {len(all_records)} records in {output_file}")
if __name__ == '__main__':
main()