serliezer's picture
Add src/data.py
95fa396 verified
"""Data generation and loading for experiments."""
import numpy as np
from typing import List, Tuple, Dict, Optional
from collections import defaultdict
from src.graph_utils import (
generate_bounded_degree_graph, generate_erdos_renyi_graph,
generate_power_law_graph, build_adjacency
)
def generate_gamma_poisson_data(N, M, K, graph_type, avg_degree,
count_scale, a0, b0, c0, d0,
seed=0, keep_zeros=False):
"""Generate synthetic Gamma-Poisson matrix factorization data."""
rng = np.random.RandomState(seed)
U_true = rng.gamma(a0, 1.0 / b0, size=(N, K))
V_true = rng.gamma(c0, 1.0 / d0, size=(M, K))
if graph_type == 'bounded_degree':
graph_edges = generate_bounded_degree_graph(N, M, avg_degree, seed)
elif graph_type == 'erdos_renyi':
graph_edges = generate_erdos_renyi_graph(N, M, avg_degree, seed)
elif graph_type == 'power_law':
graph_edges = generate_power_law_graph(N, M, avg_degree, seed)
else:
raise ValueError(f"Unknown graph type: {graph_type}")
edges = []
for i, j in graph_edges:
rate = count_scale * np.dot(U_true[i], V_true[j])
x = rng.poisson(max(rate, 1e-10))
if x > 0 or keep_zeros:
edges.append((i, j, int(x)))
return edges, U_true, V_true, graph_edges
def generate_gaussian_gaussian_data(N, M, K, graph_type, avg_degree,
sigma_U, sigma_V, sigma_x, seed=0):
"""Generate synthetic Gaussian-Gaussian MF data."""
rng = np.random.RandomState(seed)
U_true = rng.normal(0, sigma_U, size=(N, K))
V_true = rng.normal(0, sigma_V, size=(M, K))
if graph_type == 'bounded_degree':
graph_edges = generate_bounded_degree_graph(N, M, avg_degree, seed)
elif graph_type == 'erdos_renyi':
graph_edges = generate_erdos_renyi_graph(N, M, avg_degree, seed)
elif graph_type == 'power_law':
graph_edges = generate_power_law_graph(N, M, avg_degree, seed)
else:
raise ValueError(f"Unknown graph type: {graph_type}")
edges = []
for i, j in graph_edges:
mean = np.dot(U_true[i], V_true[j])
x = rng.normal(mean, sigma_x)
edges.append((i, j, float(x)))
return edges, U_true, V_true, graph_edges
def generate_gaussian_gamma_data(N, M, K, graph_type, avg_degree,
a0, b0, c0, d0, sigma_x, seed=0):
"""Generate synthetic Gaussian likelihood + Gamma prior data."""
rng = np.random.RandomState(seed)
U_true = rng.gamma(a0, 1.0 / b0, size=(N, K))
V_true = rng.gamma(c0, 1.0 / d0, size=(M, K))
if graph_type == 'bounded_degree':
graph_edges = generate_bounded_degree_graph(N, M, avg_degree, seed)
elif graph_type == 'erdos_renyi':
graph_edges = generate_erdos_renyi_graph(N, M, avg_degree, seed)
elif graph_type == 'power_law':
graph_edges = generate_power_law_graph(N, M, avg_degree, seed)
else:
raise ValueError(f"Unknown graph type: {graph_type}")
edges = []
for i, j in graph_edges:
mean = np.dot(U_true[i], V_true[j])
x = rng.normal(mean, sigma_x)
edges.append((i, j, float(x)))
return edges, U_true, V_true, graph_edges
def load_lastfm_data(max_users=2000, max_items=2000, max_edges=100000,
min_user_degree=5, min_item_degree=5, max_count=100, seed=42):
"""Load Last.fm user-artist counts from HF dataset."""
from datasets import load_dataset
print("Loading Last.fm dataset...")
ds = load_dataset("matthewfranglen/lastfm-1k", split="train")
user_artist_counts = defaultdict(lambda: defaultdict(int))
for row in ds:
uid = row['user_index']
aid = row['artist_index']
user_artist_counts[uid][aid] += 1
user_degrees = {u: len(v) for u, v in user_artist_counts.items()}
valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree]
item_degree = defaultdict(int)
for u in valid_users:
for a in user_artist_counts[u]:
item_degree[a] += 1
valid_items = set(a for a, d in item_degree.items() if d >= min_item_degree)
rng = np.random.RandomState(seed)
valid_users = sorted(valid_users)
if len(valid_users) > max_users:
valid_users = list(rng.choice(valid_users, max_users, replace=False))
valid_users_set = set(valid_users)
all_items = set()
for u in valid_users:
for a in user_artist_counts[u]:
if a in valid_items:
all_items.add(a)
all_items = sorted(all_items)
if len(all_items) > max_items:
all_items = list(rng.choice(all_items, max_items, replace=False))
valid_items_set = set(all_items)
user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))}
item_map = {a: idx for idx, a in enumerate(sorted(valid_items_set))}
edges = []
for u in valid_users_set:
for a, count in user_artist_counts[u].items():
if a in valid_items_set:
c = min(count, max_count)
if c > 0:
edges.append((user_map[u], item_map[a], int(c)))
if len(edges) > max_edges:
indices = rng.choice(len(edges), max_edges, replace=False)
edges = [edges[i] for i in indices]
N = len(user_map)
M = len(item_map)
preprocessing = {
'dataset': 'matthewfranglen/lastfm-1k', 'N': N, 'M': M,
'n_edges': len(edges), 'max_count': max_count, 'seed': seed,
}
print(f"Last.fm loaded: N={N}, M={M}, edges={len(edges)}")
return edges, N, M, preprocessing
def load_movielens_data(mode='rating_count', max_users=2000, max_items=2000,
max_edges=100000, min_user_degree=5, min_item_degree=5, seed=42):
"""Load MovieLens ratings from HF dataset."""
from datasets import load_dataset
print("Loading MovieLens dataset...")
ds = load_dataset("ashraq/movielens_ratings", split="train")
rng = np.random.RandomState(seed)
user_item_ratings = defaultdict(dict)
for row in ds:
uid = row['user_id']
mid = row['movie_id']
rating = row['rating']
user_item_ratings[uid][mid] = rating
user_degrees = {u: len(v) for u, v in user_item_ratings.items()}
valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree]
item_degree = defaultdict(int)
for u in valid_users:
for m in user_item_ratings[u]:
item_degree[m] += 1
valid_items = set(m for m, d in item_degree.items() if d >= min_item_degree)
if len(valid_users) > max_users:
valid_users = list(rng.choice(valid_users, max_users, replace=False))
valid_users_set = set(valid_users)
all_items = set()
for u in valid_users_set:
for m in user_item_ratings[u]:
if m in valid_items:
all_items.add(m)
all_items = sorted(all_items)
if len(all_items) > max_items:
all_items = list(rng.choice(all_items, max_items, replace=False))
valid_items_set = set(all_items)
user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))}
item_map = {m: idx for idx, m in enumerate(sorted(valid_items_set))}
edges = []
for u in valid_users_set:
for m, rating in user_item_ratings[u].items():
if m in valid_items_set:
if mode == 'rating_count':
x = int(np.ceil(rating))
elif mode == 'binary':
x = 1
else:
raise ValueError(f"Unknown mode: {mode}")
if x > 0:
edges.append((user_map[u], item_map[m], x))
if len(edges) > max_edges:
indices = rng.choice(len(edges), max_edges, replace=False)
edges = [edges[i] for i in indices]
N = len(user_map)
M = len(item_map)
preprocessing = {
'dataset': 'ashraq/movielens_ratings', 'mode': mode,
'N': N, 'M': M, 'n_edges': len(edges), 'seed': seed,
}
print(f"MovieLens ({mode}) loaded: N={N}, M={M}, edges={len(edges)}")
return edges, N, M, preprocessing
def sample_deletions(edges, user_to_items, item_to_users, num_deletions, seed=0):
"""Sample deletions with 25% each: random, high-count, hub-adjacent, low-degree."""
rng = np.random.RandomState(seed)
n_per_type = num_deletions // 4
remainder = num_deletions - 4 * n_per_type
counts = np.array([e[2] for e in edges], dtype=float)
user_degrees = defaultdict(int)
item_degrees = defaultdict(int)
for i, j, x in edges:
user_degrees[i] += 1
item_degrees[j] += 1
hub_scores = np.array([max(user_degrees[e[0]], item_degrees[e[1]]) for e in edges], dtype=float)
low_deg_scores = np.array([min(user_degrees[e[0]], item_degrees[e[1]]) for e in edges], dtype=float)
sampled = []
used = set()
def _sample(scores, n, dtype, high=True):
avail = [i for i in range(len(edges)) if i not in used]
if not avail or n <= 0:
return
sc = scores[avail]
if high:
ranked = np.argsort(-sc)
else:
ranked = np.argsort(sc)
pool = ranked[:min(len(avail), max(n * 3, 20))]
chosen = rng.choice(pool, size=min(n, len(pool)), replace=False)
for idx in chosen:
eidx = avail[idx]
used.add(eidx)
sampled.append((edges[eidx], dtype))
# Random
avail = list(range(len(edges)))
rng.shuffle(avail)
for idx in avail[:n_per_type + remainder]:
used.add(idx)
sampled.append((edges[idx], 'random'))
_sample(counts, n_per_type, 'high_count', high=True)
_sample(hub_scores, n_per_type, 'hub_adjacent', high=True)
_sample(low_deg_scores, n_per_type, 'low_degree', high=False)
return sampled