"""Data generation and loading for experiments.""" import numpy as np from typing import List, Tuple, Dict, Optional from collections import defaultdict from src.graph_utils import ( generate_bounded_degree_graph, generate_erdos_renyi_graph, generate_power_law_graph, build_adjacency ) def generate_gamma_poisson_data(N, M, K, graph_type, avg_degree, count_scale, a0, b0, c0, d0, seed=0, keep_zeros=False): """Generate synthetic Gamma-Poisson matrix factorization data.""" rng = np.random.RandomState(seed) U_true = rng.gamma(a0, 1.0 / b0, size=(N, K)) V_true = rng.gamma(c0, 1.0 / d0, size=(M, K)) if graph_type == 'bounded_degree': graph_edges = generate_bounded_degree_graph(N, M, avg_degree, seed) elif graph_type == 'erdos_renyi': graph_edges = generate_erdos_renyi_graph(N, M, avg_degree, seed) elif graph_type == 'power_law': graph_edges = generate_power_law_graph(N, M, avg_degree, seed) else: raise ValueError(f"Unknown graph type: {graph_type}") edges = [] for i, j in graph_edges: rate = count_scale * np.dot(U_true[i], V_true[j]) x = rng.poisson(max(rate, 1e-10)) if x > 0 or keep_zeros: edges.append((i, j, int(x))) return edges, U_true, V_true, graph_edges def generate_gaussian_gaussian_data(N, M, K, graph_type, avg_degree, sigma_U, sigma_V, sigma_x, seed=0): """Generate synthetic Gaussian-Gaussian MF data.""" rng = np.random.RandomState(seed) U_true = rng.normal(0, sigma_U, size=(N, K)) V_true = rng.normal(0, sigma_V, size=(M, K)) if graph_type == 'bounded_degree': graph_edges = generate_bounded_degree_graph(N, M, avg_degree, seed) elif graph_type == 'erdos_renyi': graph_edges = generate_erdos_renyi_graph(N, M, avg_degree, seed) elif graph_type == 'power_law': graph_edges = generate_power_law_graph(N, M, avg_degree, seed) else: raise ValueError(f"Unknown graph type: {graph_type}") edges = [] for i, j in graph_edges: mean = np.dot(U_true[i], V_true[j]) x = rng.normal(mean, sigma_x) edges.append((i, j, float(x))) return edges, U_true, V_true, graph_edges def generate_gaussian_gamma_data(N, M, K, graph_type, avg_degree, a0, b0, c0, d0, sigma_x, seed=0): """Generate synthetic Gaussian likelihood + Gamma prior data.""" rng = np.random.RandomState(seed) U_true = rng.gamma(a0, 1.0 / b0, size=(N, K)) V_true = rng.gamma(c0, 1.0 / d0, size=(M, K)) if graph_type == 'bounded_degree': graph_edges = generate_bounded_degree_graph(N, M, avg_degree, seed) elif graph_type == 'erdos_renyi': graph_edges = generate_erdos_renyi_graph(N, M, avg_degree, seed) elif graph_type == 'power_law': graph_edges = generate_power_law_graph(N, M, avg_degree, seed) else: raise ValueError(f"Unknown graph type: {graph_type}") edges = [] for i, j in graph_edges: mean = np.dot(U_true[i], V_true[j]) x = rng.normal(mean, sigma_x) edges.append((i, j, float(x))) return edges, U_true, V_true, graph_edges def load_lastfm_data(max_users=2000, max_items=2000, max_edges=100000, min_user_degree=5, min_item_degree=5, max_count=100, seed=42): """Load Last.fm user-artist counts from HF dataset.""" from datasets import load_dataset print("Loading Last.fm dataset...") ds = load_dataset("matthewfranglen/lastfm-1k", split="train") user_artist_counts = defaultdict(lambda: defaultdict(int)) for row in ds: uid = row['user_index'] aid = row['artist_index'] user_artist_counts[uid][aid] += 1 user_degrees = {u: len(v) for u, v in user_artist_counts.items()} valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree] item_degree = defaultdict(int) for u in valid_users: for a in user_artist_counts[u]: item_degree[a] += 1 valid_items = set(a for a, d in item_degree.items() if d >= min_item_degree) rng = np.random.RandomState(seed) valid_users = sorted(valid_users) if len(valid_users) > max_users: valid_users = list(rng.choice(valid_users, max_users, replace=False)) valid_users_set = set(valid_users) all_items = set() for u in valid_users: for a in user_artist_counts[u]: if a in valid_items: all_items.add(a) all_items = sorted(all_items) if len(all_items) > max_items: all_items = list(rng.choice(all_items, max_items, replace=False)) valid_items_set = set(all_items) user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))} item_map = {a: idx for idx, a in enumerate(sorted(valid_items_set))} edges = [] for u in valid_users_set: for a, count in user_artist_counts[u].items(): if a in valid_items_set: c = min(count, max_count) if c > 0: edges.append((user_map[u], item_map[a], int(c))) if len(edges) > max_edges: indices = rng.choice(len(edges), max_edges, replace=False) edges = [edges[i] for i in indices] N = len(user_map) M = len(item_map) preprocessing = { 'dataset': 'matthewfranglen/lastfm-1k', 'N': N, 'M': M, 'n_edges': len(edges), 'max_count': max_count, 'seed': seed, } print(f"Last.fm loaded: N={N}, M={M}, edges={len(edges)}") return edges, N, M, preprocessing def load_movielens_data(mode='rating_count', max_users=2000, max_items=2000, max_edges=100000, min_user_degree=5, min_item_degree=5, seed=42): """Load MovieLens ratings from HF dataset.""" from datasets import load_dataset print("Loading MovieLens dataset...") ds = load_dataset("ashraq/movielens_ratings", split="train") rng = np.random.RandomState(seed) user_item_ratings = defaultdict(dict) for row in ds: uid = row['user_id'] mid = row['movie_id'] rating = row['rating'] user_item_ratings[uid][mid] = rating user_degrees = {u: len(v) for u, v in user_item_ratings.items()} valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree] item_degree = defaultdict(int) for u in valid_users: for m in user_item_ratings[u]: item_degree[m] += 1 valid_items = set(m for m, d in item_degree.items() if d >= min_item_degree) if len(valid_users) > max_users: valid_users = list(rng.choice(valid_users, max_users, replace=False)) valid_users_set = set(valid_users) all_items = set() for u in valid_users_set: for m in user_item_ratings[u]: if m in valid_items: all_items.add(m) all_items = sorted(all_items) if len(all_items) > max_items: all_items = list(rng.choice(all_items, max_items, replace=False)) valid_items_set = set(all_items) user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))} item_map = {m: idx for idx, m in enumerate(sorted(valid_items_set))} edges = [] for u in valid_users_set: for m, rating in user_item_ratings[u].items(): if m in valid_items_set: if mode == 'rating_count': x = int(np.ceil(rating)) elif mode == 'binary': x = 1 else: raise ValueError(f"Unknown mode: {mode}") if x > 0: edges.append((user_map[u], item_map[m], x)) if len(edges) > max_edges: indices = rng.choice(len(edges), max_edges, replace=False) edges = [edges[i] for i in indices] N = len(user_map) M = len(item_map) preprocessing = { 'dataset': 'ashraq/movielens_ratings', 'mode': mode, 'N': N, 'M': M, 'n_edges': len(edges), 'seed': seed, } print(f"MovieLens ({mode}) loaded: N={N}, M={M}, edges={len(edges)}") return edges, N, M, preprocessing def sample_deletions(edges, user_to_items, item_to_users, num_deletions, seed=0): """Sample deletions with 25% each: random, high-count, hub-adjacent, low-degree.""" rng = np.random.RandomState(seed) n_per_type = num_deletions // 4 remainder = num_deletions - 4 * n_per_type counts = np.array([e[2] for e in edges], dtype=float) user_degrees = defaultdict(int) item_degrees = defaultdict(int) for i, j, x in edges: user_degrees[i] += 1 item_degrees[j] += 1 hub_scores = np.array([max(user_degrees[e[0]], item_degrees[e[1]]) for e in edges], dtype=float) low_deg_scores = np.array([min(user_degrees[e[0]], item_degrees[e[1]]) for e in edges], dtype=float) sampled = [] used = set() def _sample(scores, n, dtype, high=True): avail = [i for i in range(len(edges)) if i not in used] if not avail or n <= 0: return sc = scores[avail] if high: ranked = np.argsort(-sc) else: ranked = np.argsort(sc) pool = ranked[:min(len(avail), max(n * 3, 20))] chosen = rng.choice(pool, size=min(n, len(pool)), replace=False) for idx in chosen: eidx = avail[idx] used.add(eidx) sampled.append((edges[eidx], dtype)) # Random avail = list(range(len(edges))) rng.shuffle(avail) for idx in avail[:n_per_type + remainder]: used.add(idx) sampled.append((edges[idx], 'random')) _sample(counts, n_per_type, 'high_count', high=True) _sample(hub_scores, n_per_type, 'hub_adjacent', high=True) _sample(low_deg_scores, n_per_type, 'low_degree', high=False) return sampled