serliezer commited on
Commit
b9f807b
·
verified ·
1 Parent(s): a8951a1

Add src/metrics.py

Browse files
Files changed (1) hide show
  1. src/metrics.py +357 -0
src/metrics.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metrics computation for unlearning experiments."""
2
+ import numpy as np
3
+ from scipy import stats
4
+ from scipy.special import digamma, polygamma
5
+ from typing import Dict, List, Tuple, Optional
6
+ from collections import defaultdict
7
+
8
+ from src.graph_utils import build_adjacency, get_deletion_neighborhood, get_blocks_at_distance
9
+
10
+
11
+ # ============================================================
12
+ # Parameter distance / influence
13
+ # ============================================================
14
+
15
+ def compute_block_param_vector(params: dict, node_type: str, idx: int, model_family: str) -> np.ndarray:
16
+ """Extract flat parameter vector for a single block."""
17
+ if model_family == 'poisson_gamma':
18
+ if node_type == 'user':
19
+ return np.concatenate([params['a'][idx], params['b'][idx]])
20
+ else:
21
+ return np.concatenate([params['c'][idx], params['d'][idx]])
22
+ elif model_family == 'gaussian_gaussian':
23
+ if node_type == 'user':
24
+ return np.concatenate([params['m_U'][idx], params['s_U'][idx]])
25
+ else:
26
+ return np.concatenate([params['m_V'][idx], params['s_V'][idx]])
27
+ elif model_family == 'gaussian_gamma_map':
28
+ if node_type == 'user':
29
+ return params['alpha'][idx].copy()
30
+ else:
31
+ return params['beta'][idx].copy()
32
+ else:
33
+ raise ValueError(f"Unknown model family: {model_family}")
34
+
35
+
36
+ def compute_all_param_vector(params: dict, model_family: str) -> np.ndarray:
37
+ """Flatten all parameters into a single vector."""
38
+ if model_family == 'poisson_gamma':
39
+ return np.concatenate([params['a'].ravel(), params['b'].ravel(),
40
+ params['c'].ravel(), params['d'].ravel()])
41
+ elif model_family == 'gaussian_gaussian':
42
+ return np.concatenate([params['m_U'].ravel(), params['s_U'].ravel(),
43
+ params['m_V'].ravel(), params['s_V'].ravel()])
44
+ elif model_family == 'gaussian_gamma_map':
45
+ return np.concatenate([params['alpha'].ravel(), params['beta'].ravel()])
46
+ else:
47
+ raise ValueError(f"Unknown model family: {model_family}")
48
+
49
+
50
+ def compute_deletion_influence_by_distance(full_params, exact_params, edge_to_remove,
51
+ edges, N, M, model_family, max_radius=6):
52
+ """Compute deletion influence Delta_u(z) = ||lambda_u^* - lambda_u^{\\z}||
53
+ grouped by graph distance from seed set."""
54
+ user_to_items, item_to_users, _ = build_adjacency(edges, N, M)
55
+ distances = get_deletion_neighborhood(edge_to_remove, user_to_items, item_to_users,
56
+ N, M, max_radius)
57
+ blocks_by_dist = get_blocks_at_distance(distances, N)
58
+
59
+ influence_by_dist = {}
60
+ for dist, blocks in blocks_by_dist.items():
61
+ influences = []
62
+ for node_type, idx in blocks:
63
+ v_full = compute_block_param_vector(full_params, node_type, idx, model_family)
64
+ v_exact = compute_block_param_vector(exact_params, node_type, idx, model_family)
65
+ delta = np.linalg.norm(v_full - v_exact)
66
+ influences.append(delta)
67
+ influence_by_dist[dist] = {
68
+ 'mean': float(np.mean(influences)),
69
+ 'std': float(np.std(influences)),
70
+ 'median': float(np.median(influences)),
71
+ 'max': float(np.max(influences)),
72
+ 'n_blocks': len(influences),
73
+ 'values': [float(v) for v in influences],
74
+ }
75
+
76
+ return influence_by_dist
77
+
78
+
79
+ def fit_exponential_decay(influence_by_dist: dict, min_shells: int = 3, eps: float = 1e-12):
80
+ """Fit log(Delta(r)) = alpha - mu * r."""
81
+ distances = sorted(influence_by_dist.keys())
82
+ r_vals = []
83
+ log_vals = []
84
+
85
+ for r in distances:
86
+ mean_inf = influence_by_dist[r]['mean']
87
+ if mean_inf > eps and influence_by_dist[r]['n_blocks'] >= 2:
88
+ r_vals.append(r)
89
+ log_vals.append(np.log(mean_inf + eps))
90
+
91
+ if len(r_vals) < min_shells:
92
+ return {
93
+ 'mu_emp': None,
94
+ 'intercept': None,
95
+ 'r_squared': None,
96
+ 'n_shells': len(r_vals),
97
+ 'valid': False,
98
+ }
99
+
100
+ r_arr = np.array(r_vals, dtype=float)
101
+ log_arr = np.array(log_vals)
102
+
103
+ slope, intercept, r_value, p_value, std_err = stats.linregress(r_arr, log_arr)
104
+
105
+ return {
106
+ 'mu_emp': float(-slope),
107
+ 'intercept': float(intercept),
108
+ 'r_squared': float(r_value ** 2),
109
+ 'n_shells': len(r_vals),
110
+ 'p_value': float(p_value),
111
+ 'std_err': float(std_err),
112
+ 'valid': True,
113
+ }
114
+
115
+
116
+ # ============================================================
117
+ # Local approximation error
118
+ # ============================================================
119
+
120
+ def compute_local_error(local_params, exact_params, model_family):
121
+ """Err_R(z) = ||lambda^(R)_local - lambda^{\\z}||"""
122
+ v_local = compute_all_param_vector(local_params, model_family)
123
+ v_exact = compute_all_param_vector(exact_params, model_family)
124
+
125
+ err = np.linalg.norm(v_local - v_exact)
126
+ rel_err = err / (1.0 + np.linalg.norm(v_exact))
127
+
128
+ return {
129
+ 'param_error': float(err),
130
+ 'relative_error': float(rel_err),
131
+ }
132
+
133
+
134
+ # ============================================================
135
+ # ELBO / objective gap
136
+ # ============================================================
137
+
138
+ def compute_objective_gap(model, edges_without, exact_params, approx_params):
139
+ """Gap_R = L_{\\z}(lambda^{\\z}) - L_{\\z}(lambda^(R))."""
140
+ try:
141
+ obj_exact = model.compute_elbo(edges_without, exact_params) if hasattr(model, 'compute_elbo') \
142
+ else model.compute_objective(edges_without, exact_params)
143
+ obj_approx = model.compute_elbo(edges_without, approx_params) if hasattr(model, 'compute_elbo') \
144
+ else model.compute_objective(edges_without, approx_params)
145
+ return float(obj_exact - obj_approx)
146
+ except Exception as e:
147
+ return None
148
+
149
+
150
+ # ============================================================
151
+ # Weighted interaction statistics (chi)
152
+ # ============================================================
153
+
154
+ def compute_chi_poisson_gamma(edge_to_remove, edges, params, N, M, K,
155
+ a0=0.3, b0=1.0, c0=0.3, d0=1.0):
156
+ """Compute chi statistics for Poisson-Gamma model."""
157
+ a, b, c, d = params['a'], params['b'], params['c'], params['d']
158
+
159
+ # Compute constants
160
+ a_min = np.min(a[a > 0]) if np.any(a > 0) else a0
161
+ b_min = np.min(b[b > 0]) if np.any(b > 0) else b0
162
+ c_min = np.min(c[c > 0]) if np.any(c > 0) else c0
163
+ d_min = np.min(d[d > 0]) if np.any(d > 0) else d0
164
+ a_max = np.max(a)
165
+ c_max = np.max(c)
166
+
167
+ C_x = 0.5 * polygamma(1, max(c_min, 1e-3)) + 0.5 / max(d_min, 1e-6)
168
+ C_0 = 1.0 / max(d_min, 1e-6) + c_max / max(d_min**2, 1e-12)
169
+
170
+ C_tilde_x = 0.5 * polygamma(1, max(a_min, 1e-3)) + 0.5 / max(b_min, 1e-6)
171
+ C_tilde_0 = 1.0 / max(b_min, 1e-6) + a_max / max(b_min**2, 1e-12)
172
+
173
+ # Build adjacency
174
+ user_to_items = defaultdict(list)
175
+ item_to_users = defaultdict(list)
176
+ edge_dict = {}
177
+ for i, j, x in edges:
178
+ user_to_items[i].append(j)
179
+ item_to_users[j].append(i)
180
+ edge_dict[(i, j)] = x
181
+
182
+ i_del, j_del, x_del = edge_to_remove
183
+
184
+ # chi_i
185
+ chi_i = sum(C_x * edge_dict.get((i_del, j), 0) + C_0 for j in user_to_items.get(i_del, []))
186
+
187
+ # chi_tilde_j
188
+ chi_tilde_j = sum(C_tilde_x * edge_dict.get((i, j_del), 0) + C_tilde_0
189
+ for i in item_to_users.get(j_del, []))
190
+
191
+ chi_max = max(chi_i, chi_tilde_j)
192
+ chi_sum = chi_i + chi_tilde_j
193
+
194
+ # Empirical alternatives
195
+ seed_degree = len(user_to_items.get(i_del, [])) + len(item_to_users.get(j_del, []))
196
+ seed_count_sum = sum(edge_dict.get((i_del, j), 0) for j in user_to_items.get(i_del, [])) + \
197
+ sum(edge_dict.get((i, j_del), 0) for i in item_to_users.get(j_del, []))
198
+
199
+ return {
200
+ 'chi_i': float(chi_i),
201
+ 'chi_tilde_j': float(chi_tilde_j),
202
+ 'chi_max': float(chi_max),
203
+ 'chi_sum': float(chi_sum),
204
+ 'seed_degree': int(seed_degree),
205
+ 'seed_count_sum': float(seed_count_sum),
206
+ 'C_x': float(C_x),
207
+ 'C_0': float(C_0),
208
+ 'C_tilde_x': float(C_tilde_x),
209
+ 'C_tilde_0': float(C_tilde_0),
210
+ }
211
+
212
+
213
+ def compute_chi_gaussian(edge_to_remove, edges, params, N, M, K, sigma_x,
214
+ model_family='gaussian_gaussian'):
215
+ """Compute interaction proxy for Gaussian models."""
216
+ user_to_items = defaultdict(list)
217
+ item_to_users = defaultdict(list)
218
+ edge_dict = {}
219
+ for i, j, x in edges:
220
+ user_to_items[i].append(j)
221
+ item_to_users[j].append(i)
222
+ edge_dict[(i, j)] = x
223
+
224
+ i_del, j_del, x_del = edge_to_remove
225
+ prec_x = 1.0 / (sigma_x ** 2)
226
+
227
+ if model_family == 'gaussian_gaussian':
228
+ m_U, s_U = params['m_U'], params['s_U']
229
+ m_V, s_V = params['m_V'], params['s_V']
230
+
231
+ chi_i = sum(prec_x * np.sum(m_V[j]**2 + s_V[j]) for j in user_to_items.get(i_del, []))
232
+ chi_tilde_j = sum(prec_x * np.sum(m_U[i]**2 + s_U[i]) for i in item_to_users.get(j_del, []))
233
+ elif model_family == 'gaussian_gamma_map':
234
+ from src.model import GaussianGammaMAP
235
+ sp = lambda x: np.log1p(np.exp(np.clip(x, -20, 20)))
236
+ U = sp(params['alpha'])
237
+ V = sp(params['beta'])
238
+
239
+ chi_i = sum(prec_x * np.sum(V[j]**2) for j in user_to_items.get(i_del, []))
240
+ chi_tilde_j = sum(prec_x * np.sum(U[i]**2) for i in item_to_users.get(j_del, []))
241
+
242
+ chi_max = max(chi_i, chi_tilde_j)
243
+ chi_sum = chi_i + chi_tilde_j
244
+
245
+ seed_degree = len(user_to_items.get(i_del, [])) + len(item_to_users.get(j_del, []))
246
+
247
+ return {
248
+ 'chi_i': float(chi_i),
249
+ 'chi_tilde_j': float(chi_tilde_j),
250
+ 'chi_max': float(chi_max),
251
+ 'chi_sum': float(chi_sum),
252
+ 'seed_degree': int(seed_degree),
253
+ }
254
+
255
+
256
+ # ============================================================
257
+ # Gradient interference proxy
258
+ # ============================================================
259
+
260
+ def compute_gradient_interference(full_params, exact_params, local_params, model_family):
261
+ """Compute gradient interference proxy.
262
+
263
+ g_del = lambda^* - lambda^{\\z}
264
+ g_ret = lambda^{\\z} - lambda^(R)_local
265
+ I(z) = |sum_u <g_del_u, g_ret_u>|
266
+ """
267
+ v_full = compute_all_param_vector(full_params, model_family)
268
+ v_exact = compute_all_param_vector(exact_params, model_family)
269
+ v_local = compute_all_param_vector(local_params, model_family)
270
+
271
+ g_del = v_full - v_exact
272
+ g_ret = v_exact - v_local
273
+
274
+ raw_interference = float(np.abs(np.dot(g_del, g_ret)))
275
+
276
+ norm_del = np.linalg.norm(g_del)
277
+ norm_ret = np.linalg.norm(g_ret)
278
+ eps = 1e-12
279
+ cosine_interference = raw_interference / (norm_del * norm_ret + eps)
280
+
281
+ return {
282
+ 'interference_raw': float(raw_interference),
283
+ 'interference_cosine': float(cosine_interference),
284
+ 'g_del_norm': float(norm_del),
285
+ 'g_ret_norm': float(norm_ret),
286
+ }
287
+
288
+
289
+ # ============================================================
290
+ # Full metrics for one deletion
291
+ # ============================================================
292
+
293
+ def compute_all_metrics(full_params, exact_params, local_params_by_radius,
294
+ warm_start_params, one_step_params,
295
+ edge_to_remove, edges, N, M, K,
296
+ model_family, model=None, radii=[1, 2, 3, 4],
297
+ model_kwargs=None):
298
+ """Compute all metrics for one deletion."""
299
+ results = {}
300
+
301
+ # Influence by distance
302
+ influence = compute_deletion_influence_by_distance(
303
+ full_params, exact_params, edge_to_remove, edges, N, M, model_family,
304
+ max_radius=max(radii) + 2)
305
+ results['influence_by_distance'] = {str(k): v['mean'] for k, v in influence.items()}
306
+ results['influence_by_distance_full'] = influence
307
+
308
+ # Decay fit
309
+ decay = fit_exponential_decay(influence)
310
+ results['empirical_decay_mu'] = decay['mu_emp']
311
+ results['empirical_decay_r2'] = decay['r_squared']
312
+ results['decay_valid'] = decay['valid']
313
+
314
+ # Chi statistics
315
+ if model_kwargs is None:
316
+ model_kwargs = {}
317
+
318
+ if model_family == 'poisson_gamma':
319
+ chi = compute_chi_poisson_gamma(
320
+ edge_to_remove, edges, full_params, N, M, K,
321
+ a0=model_kwargs.get('a0', 0.3), b0=model_kwargs.get('b0', 1.0),
322
+ c0=model_kwargs.get('c0', 0.3), d0=model_kwargs.get('d0', 1.0))
323
+ else:
324
+ chi = compute_chi_gaussian(
325
+ edge_to_remove, edges, full_params, N, M, K,
326
+ sigma_x=model_kwargs.get('sigma_x', 1.0), model_family=model_family)
327
+
328
+ results['chi_seed_max'] = chi['chi_max']
329
+ results['chi_seed_sum'] = chi['chi_sum']
330
+ results['seed_degree'] = chi['seed_degree']
331
+
332
+ # Local errors by radius
333
+ for R in radii:
334
+ if R in local_params_by_radius:
335
+ err = compute_local_error(local_params_by_radius[R], exact_params, model_family)
336
+ results[f'error_R{R}'] = err['param_error']
337
+ results[f'rel_error_R{R}'] = err['relative_error']
338
+
339
+ # Interference
340
+ interf = compute_gradient_interference(
341
+ full_params, exact_params, local_params_by_radius[R], model_family)
342
+ results[f'interference_raw_R{R}'] = interf['interference_raw']
343
+ results[f'interference_cosine_R{R}'] = interf['interference_cosine']
344
+
345
+ # Warm start error
346
+ if warm_start_params is not None:
347
+ ws_err = compute_local_error(warm_start_params, exact_params, model_family)
348
+ results['error_warm_start'] = ws_err['param_error']
349
+ results['rel_error_warm_start'] = ws_err['relative_error']
350
+
351
+ # One-step error
352
+ if one_step_params is not None:
353
+ os_err = compute_local_error(one_step_params, exact_params, model_family)
354
+ results['error_one_step'] = os_err['param_error']
355
+ results['rel_error_one_step'] = os_err['relative_error']
356
+
357
+ return results