serliezer commited on
Commit
95fa396
·
verified ·
1 Parent(s): 1c8033c

Add src/data.py

Browse files
Files changed (1) hide show
  1. src/data.py +270 -0
src/data.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data generation and loading for experiments."""
2
+ import numpy as np
3
+ from typing import List, Tuple, Dict, Optional
4
+ from collections import defaultdict
5
+ from src.graph_utils import (
6
+ generate_bounded_degree_graph, generate_erdos_renyi_graph,
7
+ generate_power_law_graph, build_adjacency
8
+ )
9
+
10
+
11
+ def generate_gamma_poisson_data(N, M, K, graph_type, avg_degree,
12
+ count_scale, a0, b0, c0, d0,
13
+ seed=0, keep_zeros=False):
14
+ """Generate synthetic Gamma-Poisson matrix factorization data."""
15
+ rng = np.random.RandomState(seed)
16
+ U_true = rng.gamma(a0, 1.0 / b0, size=(N, K))
17
+ V_true = rng.gamma(c0, 1.0 / d0, size=(M, K))
18
+
19
+ if graph_type == 'bounded_degree':
20
+ graph_edges = generate_bounded_degree_graph(N, M, avg_degree, seed)
21
+ elif graph_type == 'erdos_renyi':
22
+ graph_edges = generate_erdos_renyi_graph(N, M, avg_degree, seed)
23
+ elif graph_type == 'power_law':
24
+ graph_edges = generate_power_law_graph(N, M, avg_degree, seed)
25
+ else:
26
+ raise ValueError(f"Unknown graph type: {graph_type}")
27
+
28
+ edges = []
29
+ for i, j in graph_edges:
30
+ rate = count_scale * np.dot(U_true[i], V_true[j])
31
+ x = rng.poisson(max(rate, 1e-10))
32
+ if x > 0 or keep_zeros:
33
+ edges.append((i, j, int(x)))
34
+
35
+ return edges, U_true, V_true, graph_edges
36
+
37
+
38
+ def generate_gaussian_gaussian_data(N, M, K, graph_type, avg_degree,
39
+ sigma_U, sigma_V, sigma_x, seed=0):
40
+ """Generate synthetic Gaussian-Gaussian MF data."""
41
+ rng = np.random.RandomState(seed)
42
+ U_true = rng.normal(0, sigma_U, size=(N, K))
43
+ V_true = rng.normal(0, sigma_V, size=(M, K))
44
+
45
+ if graph_type == 'bounded_degree':
46
+ graph_edges = generate_bounded_degree_graph(N, M, avg_degree, seed)
47
+ elif graph_type == 'erdos_renyi':
48
+ graph_edges = generate_erdos_renyi_graph(N, M, avg_degree, seed)
49
+ elif graph_type == 'power_law':
50
+ graph_edges = generate_power_law_graph(N, M, avg_degree, seed)
51
+ else:
52
+ raise ValueError(f"Unknown graph type: {graph_type}")
53
+
54
+ edges = []
55
+ for i, j in graph_edges:
56
+ mean = np.dot(U_true[i], V_true[j])
57
+ x = rng.normal(mean, sigma_x)
58
+ edges.append((i, j, float(x)))
59
+
60
+ return edges, U_true, V_true, graph_edges
61
+
62
+
63
+ def generate_gaussian_gamma_data(N, M, K, graph_type, avg_degree,
64
+ a0, b0, c0, d0, sigma_x, seed=0):
65
+ """Generate synthetic Gaussian likelihood + Gamma prior data."""
66
+ rng = np.random.RandomState(seed)
67
+ U_true = rng.gamma(a0, 1.0 / b0, size=(N, K))
68
+ V_true = rng.gamma(c0, 1.0 / d0, size=(M, K))
69
+
70
+ if graph_type == 'bounded_degree':
71
+ graph_edges = generate_bounded_degree_graph(N, M, avg_degree, seed)
72
+ elif graph_type == 'erdos_renyi':
73
+ graph_edges = generate_erdos_renyi_graph(N, M, avg_degree, seed)
74
+ elif graph_type == 'power_law':
75
+ graph_edges = generate_power_law_graph(N, M, avg_degree, seed)
76
+ else:
77
+ raise ValueError(f"Unknown graph type: {graph_type}")
78
+
79
+ edges = []
80
+ for i, j in graph_edges:
81
+ mean = np.dot(U_true[i], V_true[j])
82
+ x = rng.normal(mean, sigma_x)
83
+ edges.append((i, j, float(x)))
84
+
85
+ return edges, U_true, V_true, graph_edges
86
+
87
+
88
+ def load_lastfm_data(max_users=2000, max_items=2000, max_edges=100000,
89
+ min_user_degree=5, min_item_degree=5, max_count=100, seed=42):
90
+ """Load Last.fm user-artist counts from HF dataset."""
91
+ from datasets import load_dataset
92
+
93
+ print("Loading Last.fm dataset...")
94
+ ds = load_dataset("matthewfranglen/lastfm-1k", split="train")
95
+
96
+ user_artist_counts = defaultdict(lambda: defaultdict(int))
97
+ for row in ds:
98
+ uid = row['user_index']
99
+ aid = row['artist_index']
100
+ user_artist_counts[uid][aid] += 1
101
+
102
+ user_degrees = {u: len(v) for u, v in user_artist_counts.items()}
103
+ valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree]
104
+
105
+ item_degree = defaultdict(int)
106
+ for u in valid_users:
107
+ for a in user_artist_counts[u]:
108
+ item_degree[a] += 1
109
+ valid_items = set(a for a, d in item_degree.items() if d >= min_item_degree)
110
+
111
+ rng = np.random.RandomState(seed)
112
+ valid_users = sorted(valid_users)
113
+ if len(valid_users) > max_users:
114
+ valid_users = list(rng.choice(valid_users, max_users, replace=False))
115
+ valid_users_set = set(valid_users)
116
+
117
+ all_items = set()
118
+ for u in valid_users:
119
+ for a in user_artist_counts[u]:
120
+ if a in valid_items:
121
+ all_items.add(a)
122
+ all_items = sorted(all_items)
123
+ if len(all_items) > max_items:
124
+ all_items = list(rng.choice(all_items, max_items, replace=False))
125
+ valid_items_set = set(all_items)
126
+
127
+ user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))}
128
+ item_map = {a: idx for idx, a in enumerate(sorted(valid_items_set))}
129
+
130
+ edges = []
131
+ for u in valid_users_set:
132
+ for a, count in user_artist_counts[u].items():
133
+ if a in valid_items_set:
134
+ c = min(count, max_count)
135
+ if c > 0:
136
+ edges.append((user_map[u], item_map[a], int(c)))
137
+
138
+ if len(edges) > max_edges:
139
+ indices = rng.choice(len(edges), max_edges, replace=False)
140
+ edges = [edges[i] for i in indices]
141
+
142
+ N = len(user_map)
143
+ M = len(item_map)
144
+ preprocessing = {
145
+ 'dataset': 'matthewfranglen/lastfm-1k', 'N': N, 'M': M,
146
+ 'n_edges': len(edges), 'max_count': max_count, 'seed': seed,
147
+ }
148
+
149
+ print(f"Last.fm loaded: N={N}, M={M}, edges={len(edges)}")
150
+ return edges, N, M, preprocessing
151
+
152
+
153
+ def load_movielens_data(mode='rating_count', max_users=2000, max_items=2000,
154
+ max_edges=100000, min_user_degree=5, min_item_degree=5, seed=42):
155
+ """Load MovieLens ratings from HF dataset."""
156
+ from datasets import load_dataset
157
+
158
+ print("Loading MovieLens dataset...")
159
+ ds = load_dataset("ashraq/movielens_ratings", split="train")
160
+
161
+ rng = np.random.RandomState(seed)
162
+ user_item_ratings = defaultdict(dict)
163
+ for row in ds:
164
+ uid = row['user_id']
165
+ mid = row['movie_id']
166
+ rating = row['rating']
167
+ user_item_ratings[uid][mid] = rating
168
+
169
+ user_degrees = {u: len(v) for u, v in user_item_ratings.items()}
170
+ valid_users = [u for u, d in user_degrees.items() if d >= min_user_degree]
171
+
172
+ item_degree = defaultdict(int)
173
+ for u in valid_users:
174
+ for m in user_item_ratings[u]:
175
+ item_degree[m] += 1
176
+ valid_items = set(m for m, d in item_degree.items() if d >= min_item_degree)
177
+
178
+ if len(valid_users) > max_users:
179
+ valid_users = list(rng.choice(valid_users, max_users, replace=False))
180
+ valid_users_set = set(valid_users)
181
+
182
+ all_items = set()
183
+ for u in valid_users_set:
184
+ for m in user_item_ratings[u]:
185
+ if m in valid_items:
186
+ all_items.add(m)
187
+ all_items = sorted(all_items)
188
+ if len(all_items) > max_items:
189
+ all_items = list(rng.choice(all_items, max_items, replace=False))
190
+ valid_items_set = set(all_items)
191
+
192
+ user_map = {u: idx for idx, u in enumerate(sorted(valid_users_set))}
193
+ item_map = {m: idx for idx, m in enumerate(sorted(valid_items_set))}
194
+
195
+ edges = []
196
+ for u in valid_users_set:
197
+ for m, rating in user_item_ratings[u].items():
198
+ if m in valid_items_set:
199
+ if mode == 'rating_count':
200
+ x = int(np.ceil(rating))
201
+ elif mode == 'binary':
202
+ x = 1
203
+ else:
204
+ raise ValueError(f"Unknown mode: {mode}")
205
+ if x > 0:
206
+ edges.append((user_map[u], item_map[m], x))
207
+
208
+ if len(edges) > max_edges:
209
+ indices = rng.choice(len(edges), max_edges, replace=False)
210
+ edges = [edges[i] for i in indices]
211
+
212
+ N = len(user_map)
213
+ M = len(item_map)
214
+ preprocessing = {
215
+ 'dataset': 'ashraq/movielens_ratings', 'mode': mode,
216
+ 'N': N, 'M': M, 'n_edges': len(edges), 'seed': seed,
217
+ }
218
+
219
+ print(f"MovieLens ({mode}) loaded: N={N}, M={M}, edges={len(edges)}")
220
+ return edges, N, M, preprocessing
221
+
222
+
223
+ def sample_deletions(edges, user_to_items, item_to_users, num_deletions, seed=0):
224
+ """Sample deletions with 25% each: random, high-count, hub-adjacent, low-degree."""
225
+ rng = np.random.RandomState(seed)
226
+ n_per_type = num_deletions // 4
227
+ remainder = num_deletions - 4 * n_per_type
228
+
229
+ counts = np.array([e[2] for e in edges], dtype=float)
230
+
231
+ user_degrees = defaultdict(int)
232
+ item_degrees = defaultdict(int)
233
+ for i, j, x in edges:
234
+ user_degrees[i] += 1
235
+ item_degrees[j] += 1
236
+
237
+ hub_scores = np.array([max(user_degrees[e[0]], item_degrees[e[1]]) for e in edges], dtype=float)
238
+ low_deg_scores = np.array([min(user_degrees[e[0]], item_degrees[e[1]]) for e in edges], dtype=float)
239
+
240
+ sampled = []
241
+ used = set()
242
+
243
+ def _sample(scores, n, dtype, high=True):
244
+ avail = [i for i in range(len(edges)) if i not in used]
245
+ if not avail or n <= 0:
246
+ return
247
+ sc = scores[avail]
248
+ if high:
249
+ ranked = np.argsort(-sc)
250
+ else:
251
+ ranked = np.argsort(sc)
252
+ pool = ranked[:min(len(avail), max(n * 3, 20))]
253
+ chosen = rng.choice(pool, size=min(n, len(pool)), replace=False)
254
+ for idx in chosen:
255
+ eidx = avail[idx]
256
+ used.add(eidx)
257
+ sampled.append((edges[eidx], dtype))
258
+
259
+ # Random
260
+ avail = list(range(len(edges)))
261
+ rng.shuffle(avail)
262
+ for idx in avail[:n_per_type + remainder]:
263
+ used.add(idx)
264
+ sampled.append((edges[idx], 'random'))
265
+
266
+ _sample(counts, n_per_type, 'high_count', high=True)
267
+ _sample(hub_scores, n_per_type, 'hub_adjacent', high=True)
268
+ _sample(low_deg_scores, n_per_type, 'low_degree', high=False)
269
+
270
+ return sampled