serliezer's picture
Add src/unlearning.py
a8951a1 verified
"""Unlearning methods wrapper."""
import numpy as np
from typing import Dict, Optional, Tuple
from src.graph_utils import build_adjacency, get_deletion_neighborhood, get_user_item_sets_in_radius
def one_step_downdate_poisson_gamma(edges, edge_to_remove, full_params,
N, M, K, a0, b0, c0, d0):
"""One-step local downdate: subtract deleted contribution from seed blocks only."""
import time
t0 = time.time()
i_del, j_del, x_del = edge_to_remove
params = {k: v.copy() for k, v in full_params.items()}
a, b, c, d = params['a'], params['b'], params['c'], params['d']
if x_del > 0:
# Compute responsibility for the deleted edge
from scipy.special import digamma
from src.utils import stable_softmax
log_r = digamma(a[i_del]) - np.log(b[i_del]) + digamma(c[j_del]) - np.log(d[j_del])
r = stable_softmax(log_r)
# Subtract contribution from user i
a[i_del] = np.maximum(a0, a[i_del] - x_del * r)
b[i_del] = np.maximum(b0 * 0.5, b[i_del] - c[j_del] / d[j_del])
# Subtract contribution from item j
c[j_del] = np.maximum(c0, c[j_del] - x_del * r)
d[j_del] = np.maximum(d0 * 0.5, d[j_del] - a[i_del] / b[i_del])
runtime = time.time() - t0
from src.utils import FitResult
return FitResult(
params=params,
objective_trace=[],
n_iterations=1,
converged=True,
runtime_sec=runtime,
model_family='poisson_gamma',
inference_type='vi',
likelihood='poisson',
prior='gamma',
diagnostics={'method': 'one_step_downdate'}
)
def one_step_downdate_gaussian(edges, edge_to_remove, full_params,
N, M, K, sigma_x, sigma_U=None, sigma_V=None,
model_family='gaussian_gaussian'):
"""One-step local downdate for Gaussian models."""
import time
t0 = time.time()
i_del, j_del, x_del = edge_to_remove
params = {k: v.copy() for k, v in full_params.items()}
if model_family == 'gaussian_gaussian':
m_U, s_U = params['m_U'], params['s_U']
m_V, s_V = params['m_V'], params['s_V']
prec_x = 1.0 / (sigma_x ** 2)
# Remove contribution of edge (i,j) from user i
for k in range(K):
old_prec = 1.0 / s_U[i_del, k]
new_prec = old_prec - prec_x * (m_V[j_del, k]**2 + s_V[j_del, k])
new_prec = max(new_prec, 1.0 / (sigma_U**2 if sigma_U else 1.0))
s_U[i_del, k] = 1.0 / new_prec
for k in range(K):
old_prec = 1.0 / s_V[j_del, k]
new_prec = old_prec - prec_x * (m_U[i_del, k]**2 + s_U[i_del, k])
new_prec = max(new_prec, 1.0 / (sigma_V**2 if sigma_V else 1.0))
s_V[j_del, k] = 1.0 / new_prec
elif model_family == 'gaussian_gamma_map':
# For MAP: one gradient step removing contribution
pass
runtime = time.time() - t0
from src.utils import FitResult
return FitResult(
params=params, objective_trace=[], n_iterations=1,
converged=True, runtime_sec=runtime,
model_family=model_family, inference_type='vi' if 'gaussian_gaussian' == model_family else 'map',
likelihood='gaussian', prior='gaussian' if model_family == 'gaussian_gaussian' else 'gamma',
diagnostics={'method': 'one_step_downdate'}
)