"""Unlearning methods wrapper.""" import numpy as np from typing import Dict, Optional, Tuple from src.graph_utils import build_adjacency, get_deletion_neighborhood, get_user_item_sets_in_radius def one_step_downdate_poisson_gamma(edges, edge_to_remove, full_params, N, M, K, a0, b0, c0, d0): """One-step local downdate: subtract deleted contribution from seed blocks only.""" import time t0 = time.time() i_del, j_del, x_del = edge_to_remove params = {k: v.copy() for k, v in full_params.items()} a, b, c, d = params['a'], params['b'], params['c'], params['d'] if x_del > 0: # Compute responsibility for the deleted edge from scipy.special import digamma from src.utils import stable_softmax log_r = digamma(a[i_del]) - np.log(b[i_del]) + digamma(c[j_del]) - np.log(d[j_del]) r = stable_softmax(log_r) # Subtract contribution from user i a[i_del] = np.maximum(a0, a[i_del] - x_del * r) b[i_del] = np.maximum(b0 * 0.5, b[i_del] - c[j_del] / d[j_del]) # Subtract contribution from item j c[j_del] = np.maximum(c0, c[j_del] - x_del * r) d[j_del] = np.maximum(d0 * 0.5, d[j_del] - a[i_del] / b[i_del]) runtime = time.time() - t0 from src.utils import FitResult return FitResult( params=params, objective_trace=[], n_iterations=1, converged=True, runtime_sec=runtime, model_family='poisson_gamma', inference_type='vi', likelihood='poisson', prior='gamma', diagnostics={'method': 'one_step_downdate'} ) def one_step_downdate_gaussian(edges, edge_to_remove, full_params, N, M, K, sigma_x, sigma_U=None, sigma_V=None, model_family='gaussian_gaussian'): """One-step local downdate for Gaussian models.""" import time t0 = time.time() i_del, j_del, x_del = edge_to_remove params = {k: v.copy() for k, v in full_params.items()} if model_family == 'gaussian_gaussian': m_U, s_U = params['m_U'], params['s_U'] m_V, s_V = params['m_V'], params['s_V'] prec_x = 1.0 / (sigma_x ** 2) # Remove contribution of edge (i,j) from user i for k in range(K): old_prec = 1.0 / s_U[i_del, k] new_prec = old_prec - prec_x * (m_V[j_del, k]**2 + s_V[j_del, k]) new_prec = max(new_prec, 1.0 / (sigma_U**2 if sigma_U else 1.0)) s_U[i_del, k] = 1.0 / new_prec for k in range(K): old_prec = 1.0 / s_V[j_del, k] new_prec = old_prec - prec_x * (m_U[i_del, k]**2 + s_U[i_del, k]) new_prec = max(new_prec, 1.0 / (sigma_V**2 if sigma_V else 1.0)) s_V[j_del, k] = 1.0 / new_prec elif model_family == 'gaussian_gamma_map': # For MAP: one gradient step removing contribution pass runtime = time.time() - t0 from src.utils import FitResult return FitResult( params=params, objective_trace=[], n_iterations=1, converged=True, runtime_sec=runtime, model_family=model_family, inference_type='vi' if 'gaussian_gaussian' == model_family else 'map', likelihood='gaussian', prior='gaussian' if model_family == 'gaussian_gaussian' else 'gamma', diagnostics={'method': 'one_step_downdate'} )