"""Model implementations: Poisson-Gamma VI, Gaussian-Gaussian VI, Gaussian-Gamma MAP. All implementations use vectorized numpy operations over edge arrays for performance. """ import numpy as np import time from scipy.special import digamma, polygamma from typing import Dict, Optional, Tuple, List from collections import defaultdict from src.utils import FitResult, stable_softmax, relative_param_change from src.graph_utils import build_adjacency, get_deletion_neighborhood, get_user_item_sets_in_radius # ============================================================ # Poisson-Gamma CAVI (main model) - vectorized # ============================================================ class PoissonGammaVI: """Augmented Gamma-Poisson MF with mean-field CAVI. Vectorized.""" def __init__(self, N, M, K, a0=0.3, b0=1.0, c0=0.3, d0=1.0, max_iter=200, tol=1e-5, damping=1.0, seed=0): self.N = N self.M = M self.K = K self.a0 = a0 self.b0 = b0 self.c0 = c0 self.d0 = d0 self.max_iter = max_iter self.tol = tol self.damping = damping self.seed = seed def _init_params(self, rng=None): if rng is None: rng = np.random.RandomState(self.seed) return { 'a': np.full((self.N, self.K), self.a0) + rng.gamma(1, 0.1, (self.N, self.K)), 'b': np.full((self.N, self.K), self.b0) + rng.gamma(1, 0.1, (self.N, self.K)), 'c': np.full((self.M, self.K), self.c0) + rng.gamma(1, 0.1, (self.M, self.K)), 'd': np.full((self.M, self.K), self.d0) + rng.gamma(1, 0.1, (self.M, self.K)), } def _prepare_edges(self, edges): """Convert edge list to vectorized arrays.""" n_edges = len(edges) I = np.array([e[0] for e in edges], dtype=np.int32) J = np.array([e[1] for e in edges], dtype=np.int32) X = np.array([e[2] for e in edges], dtype=np.float64) return I, J, X, n_edges def _cavi_sweep(self, I, J, X, n_edges, params, update_users=None, update_items=None): """One vectorized CAVI sweep.""" a, b, c, d = params['a'], params['b'], params['c'], params['d'] K = self.K # Compute responsibilities for all edges at once # log_r[e, k] = psi(a[I[e],k]) - log(b[I[e],k]) + psi(c[J[e],k]) - log(d[J[e],k]) psi_a = digamma(a) # (N, K) log_b = np.log(b + 1e-30) psi_c = digamma(c) # (M, K) log_d = np.log(d + 1e-30) log_r = psi_a[I] - log_b[I] + psi_c[J] - log_d[J] # (n_edges, K) log_r -= log_r.max(axis=1, keepdims=True) r = np.exp(log_r) r /= r.sum(axis=1, keepdims=True) + 1e-30 # weighted responsibilities: x[e] * r[e, k] xr = X[:, None] * r # (n_edges, K) # E[V_jk] = c[j,k] / d[j,k] EV = c / d # (M, K) EU = a / b # (N, K) # Update user params a_new = np.full_like(a, self.a0) b_new = np.full_like(b, self.b0) # Scatter-add xr to user rows np.add.at(a_new, I, xr) # b_ik = b0 + sum_{j in Omega_i} EV_jk np.add.at(b_new, I, EV[J]) # Update item params c_new = np.full_like(c, self.c0) d_new = np.full_like(d, self.d0) np.add.at(c_new, J, xr) np.add.at(d_new, J, EU[I]) # Apply updates only to specified blocks if update_users is not None: mask_u = np.zeros(self.N, dtype=bool) mask_u[list(update_users)] = True a_out = np.where(mask_u[:, None], a_new, a) b_out = np.where(mask_u[:, None], b_new, b) else: a_out = a_new b_out = b_new if update_items is not None: mask_v = np.zeros(self.M, dtype=bool) mask_v[list(update_items)] = True c_out = np.where(mask_v[:, None], c_new, c) d_out = np.where(mask_v[:, None], d_new, d) else: c_out = c_new d_out = d_new # Damping if self.damping < 1.0: alpha = self.damping a_out = (1 - alpha) * a + alpha * a_out b_out = (1 - alpha) * b + alpha * b_out c_out = (1 - alpha) * c + alpha * c_out d_out = (1 - alpha) * d + alpha * d_out return {'a': a_out, 'b': b_out, 'c': c_out, 'd': d_out} def compute_elbo(self, edges, params): """Approximate ELBO (likelihood term only for speed).""" I, J, X, n_edges = self._prepare_edges(edges) a, b, c, d = params['a'], params['b'], params['c'], params['d'] EU = a / b EV = c / d # E[UV] for each edge pred = np.sum(EU[I] * EV[J], axis=1) # Poisson log-likelihood proxy: sum(x * log(pred) - pred) elbo = np.sum(X * np.log(pred + 1e-30) - pred) return float(elbo) def fit_full(self, edges, config=None, init_params=None): t0 = time.time() I, J, X, n_edges = self._prepare_edges(edges) if init_params is not None: params = {k: v.copy() for k, v in init_params.items()} else: params = self._init_params() elbo_trace = [] converged = False for it in range(self.max_iter): old_params = {k: v.copy() for k, v in params.items()} params = self._cavi_sweep(I, J, X, n_edges, params) change = relative_param_change(old_params, params) if it % 50 == 0: elbo = self.compute_elbo(edges, params) elbo_trace.append(elbo) if change < self.tol: converged = True break return FitResult( params=params, objective_trace=elbo_trace, n_iterations=it + 1, converged=converged, runtime_sec=time.time() - t0, model_family='poisson_gamma', inference_type='vi', likelihood='poisson', prior='gamma', ) def fit_without_edge(self, edges, edge_to_remove, config=None, init_params=None): i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1]) filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)] return self.fit_full(filtered, config, init_params) def fit_local(self, edges, edge_to_remove, radius, config=None, init_params=None): t0 = time.time() i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1]) filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)] if init_params is None: raise ValueError("Local unlearning requires init_params") params = {k: v.copy() for k, v in init_params.items()} # Get neighborhood from ORIGINAL graph u2i_orig, i2u_orig, _ = build_adjacency(edges, self.N, self.M) distances = get_deletion_neighborhood(edge_to_remove, u2i_orig, i2u_orig, self.N, self.M, radius) users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius) # KEY OPTIMIZATION: filter edges to only those touching neighborhood # For user i update: need all edges (i, j, x) where i in users_in_R # For item j update: need all edges (i, j, x) where j in items_in_R # Union: edges where i in users_in_R OR j in items_in_R local_edges = [(i, j, x) for i, j, x in filtered if i in users_in_R or j in items_in_R] I, J, X, n_edges = self._prepare_edges(local_edges) converged = False for it in range(self.max_iter): old_a = params['a'].copy() old_b = params['b'].copy() old_c = params['c'].copy() old_d = params['d'].copy() params = self._cavi_sweep(I, J, X, n_edges, params, update_users=users_in_R, update_items=items_in_R) # Check convergence on updated blocks only max_change = 0.0 if users_in_R: ul = list(users_in_R) max_change = max(max_change, np.max(np.abs(params['a'][ul] - old_a[ul]) / (1 + np.abs(old_a[ul]))), np.max(np.abs(params['b'][ul] - old_b[ul]) / (1 + np.abs(old_b[ul])))) if items_in_R: il = list(items_in_R) max_change = max(max_change, np.max(np.abs(params['c'][il] - old_c[il]) / (1 + np.abs(old_c[il]))), np.max(np.abs(params['d'][il] - old_d[il]) / (1 + np.abs(old_d[il])))) if max_change < self.tol: converged = True break return FitResult( params=params, objective_trace=[], n_iterations=it + 1, converged=converged, runtime_sec=time.time() - t0, model_family='poisson_gamma', inference_type='vi', likelihood='poisson', prior='gamma', diagnostics={'radius': radius, 'n_users_updated': len(users_in_R), 'n_items_updated': len(items_in_R)} ) def fit_warm_start_global(self, edges, edge_to_remove, config=None, init_params=None): i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1]) filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)] return self.fit_full(filtered, config, init_params) # ============================================================ # Gaussian-Gaussian VI - vectorized # ============================================================ class GaussianGaussianVI: """Gaussian-Gaussian MF with mean-field Gaussian VI. Vectorized.""" def __init__(self, N, M, K, sigma_U=1.0, sigma_V=1.0, sigma_x=1.0, max_iter=200, tol=1e-5, damping=1.0, seed=0): self.N = N self.M = M self.K = K self.sigma_U = sigma_U self.sigma_V = sigma_V self.sigma_x = sigma_x self.max_iter = max_iter self.tol = tol self.damping = damping self.seed = seed def _init_params(self, rng=None): if rng is None: rng = np.random.RandomState(self.seed) return { 'm_U': rng.randn(self.N, self.K) * 0.1, 's_U': np.ones((self.N, self.K)) * 0.5, 'm_V': rng.randn(self.M, self.K) * 0.1, 's_V': np.ones((self.M, self.K)) * 0.5, } def _prepare_edges(self, edges): n_edges = len(edges) I = np.array([e[0] for e in edges], dtype=np.int32) J = np.array([e[1] for e in edges], dtype=np.int32) X = np.array([e[2] for e in edges], dtype=np.float64) return I, J, X, n_edges def _cavi_sweep(self, I, J, X, n_edges, params, update_users=None, update_items=None): """Vectorized Gaussian-Gaussian CAVI using coordinate updates.""" m_U = params['m_U'].copy() s_U = params['s_U'].copy() m_V = params['m_V'].copy() s_V = params['s_V'].copy() prec_x = 1.0 / (self.sigma_x ** 2) prec_U = 1.0 / (self.sigma_U ** 2) prec_V = 1.0 / (self.sigma_V ** 2) # For each user i, update all K components # Precision: prec_U + prec_x * sum_{j in Omega_i} (m_V[j,k]^2 + s_V[j,k]) # This requires scatter-add of (m_V[J]^2 + s_V[J]) over I V_sq_plus_var = m_V ** 2 + s_V # (M, K) U_sq_plus_var = m_U ** 2 + s_U # (N, K) # User precision: for each user, sum of V_sq_plus_var over neighbors user_prec_sum = np.zeros((self.N, self.K)) np.add.at(user_prec_sum, I, V_sq_plus_var[J]) s_U_new = 1.0 / (prec_U + prec_x * user_prec_sum) # User mean: for each edge, compute x_ij * m_V[j] contribution # Simplified: for each user i, k: m_U[i,k] = s_U[i,k] * prec_x * sum_j m_V[j,k] * (x_ij - sum_{l!=k} m_U[i,l]*m_V[j,l]) # Approximate: use current m_U for cross-terms # predicted = sum_k m_U[I,k] * m_V[J,k] predicted = np.sum(m_U[I] * m_V[J], axis=1) # (n_edges,) m_U_new = np.zeros((self.N, self.K)) for k in range(self.K): # Residual without component k resid_k = X - predicted + m_U[I, k] * m_V[J, k] contrib = m_V[J, k] * resid_k # (n_edges,) user_sum = np.zeros(self.N) np.add.at(user_sum, I, contrib) m_U_new[:, k] = s_U_new[:, k] * prec_x * user_sum # Item precision item_prec_sum = np.zeros((self.M, self.K)) np.add.at(item_prec_sum, J, U_sq_plus_var[I]) s_V_new = 1.0 / (prec_V + prec_x * item_prec_sum) # Update predicted with new U predicted_new = np.sum(m_U_new[I] * m_V[J], axis=1) m_V_new = np.zeros((self.M, self.K)) for k in range(self.K): resid_k = X - predicted_new + m_U_new[I, k] * m_V[J, k] contrib = m_U_new[I, k] * resid_k item_sum = np.zeros(self.M) np.add.at(item_sum, J, contrib) m_V_new[:, k] = s_V_new[:, k] * prec_x * item_sum # Apply masks if update_users is not None: mask_u = np.zeros(self.N, dtype=bool) mask_u[list(update_users)] = True m_U_new = np.where(mask_u[:, None], m_U_new, m_U) s_U_new = np.where(mask_u[:, None], s_U_new, s_U) if update_items is not None: mask_v = np.zeros(self.M, dtype=bool) mask_v[list(update_items)] = True m_V_new = np.where(mask_v[:, None], m_V_new, m_V) s_V_new = np.where(mask_v[:, None], s_V_new, s_V) if self.damping < 1.0: alpha = self.damping m_U_new = (1 - alpha) * m_U + alpha * m_U_new s_U_new = (1 - alpha) * s_U + alpha * s_U_new m_V_new = (1 - alpha) * m_V + alpha * m_V_new s_V_new = (1 - alpha) * s_V + alpha * s_V_new return {'m_U': m_U_new, 's_U': s_U_new, 'm_V': m_V_new, 's_V': s_V_new} def compute_objective(self, edges, params): """Approximate ELBO (likelihood only).""" I, J, X, _ = self._prepare_edges(edges) m_U, m_V = params['m_U'], params['m_V'] pred = np.sum(m_U[I] * m_V[J], axis=1) mse = np.mean((X - pred) ** 2) return -float(mse) def fit_full(self, edges, config=None, init_params=None): t0 = time.time() I, J, X, n_edges = self._prepare_edges(edges) params = {k: v.copy() for k, v in (init_params or self._init_params()).items()} obj_trace = [] converged = False for it in range(self.max_iter): old_params = {k: v.copy() for k, v in params.items()} params = self._cavi_sweep(I, J, X, n_edges, params) change = relative_param_change(old_params, params) if it % 50 == 0: obj_trace.append(self.compute_objective(edges, params)) if change < self.tol: converged = True break return FitResult( params=params, objective_trace=obj_trace, n_iterations=it + 1, converged=converged, runtime_sec=time.time() - t0, model_family='gaussian_gaussian', inference_type='vi', likelihood='gaussian', prior='gaussian', ) def fit_without_edge(self, edges, edge_to_remove, config=None, init_params=None): i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1]) filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)] return self.fit_full(filtered, config, init_params) def fit_local(self, edges, edge_to_remove, radius, config=None, init_params=None): t0 = time.time() i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1]) filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)] if init_params is None: raise ValueError("Local unlearning requires init_params") params = {k: v.copy() for k, v in init_params.items()} u2i_orig, i2u_orig, _ = build_adjacency(edges, self.N, self.M) distances = get_deletion_neighborhood(edge_to_remove, u2i_orig, i2u_orig, self.N, self.M, radius) users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius) # Filter edges to neighborhood local_edges = [(i, j, x) for i, j, x in filtered if i in users_in_R or j in items_in_R] I, J, X, n_edges = self._prepare_edges(local_edges) converged = False for it in range(self.max_iter): old_params = {k: v.copy() for k, v in params.items()} params = self._cavi_sweep(I, J, X, n_edges, params, update_users=users_in_R, update_items=items_in_R) change = relative_param_change(old_params, params) if change < self.tol: converged = True break return FitResult( params=params, objective_trace=[], n_iterations=it + 1, converged=converged, runtime_sec=time.time() - t0, model_family='gaussian_gaussian', inference_type='vi', likelihood='gaussian', prior='gaussian', diagnostics={'radius': radius} ) def fit_warm_start_global(self, edges, edge_to_remove, config=None, init_params=None): i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1]) filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)] return self.fit_full(filtered, config, init_params) # ============================================================ # Gaussian-Gamma MAP - vectorized # ============================================================ class GaussianGammaMAP: """Gaussian likelihood + Gamma prior, MAP via softplus parameterization. Uses Adam optimizer with gradient clipping for stable convergence. """ def __init__(self, N, M, K, a0=0.3, b0=1.0, c0=0.3, d0=1.0, sigma_x=1.0, lr=0.01, max_iter=500, tol=1e-5, seed=0, grad_clip=5.0, adam_beta1=0.9, adam_beta2=0.999): self.N = N self.M = M self.K = K self.a0 = a0 self.b0 = b0 self.c0 = c0 self.d0 = d0 self.sigma_x = sigma_x self.lr = lr self.max_iter = max_iter self.tol = tol self.seed = seed self.grad_clip = grad_clip self.adam_beta1 = adam_beta1 self.adam_beta2 = adam_beta2 def _softplus(self, x): return np.log1p(np.exp(np.clip(x, -20, 20))) def _softplus_grad(self, x): return 1.0 / (1.0 + np.exp(-np.clip(x, -20, 20))) def _inv_softplus(self, y): """Inverse of softplus: log(exp(y) - 1).""" return np.log(np.exp(np.clip(y, 1e-8, 20)) - 1 + 1e-30) def _init_params(self, rng=None, edges=None): if rng is None: rng = np.random.RandomState(self.seed) # Data-informed initialization: use NMF-style init from mean values if edges is not None: I = np.array([e[0] for e in edges], dtype=np.int32) J = np.array([e[1] for e in edges], dtype=np.int32) X = np.array([e[2] for e in edges], dtype=np.float64) # Compute user/item means x_mean = np.abs(X).mean() init_scale = np.sqrt(np.abs(x_mean) / self.K + 0.1) else: init_scale = 0.5 U_init = np.abs(rng.randn(self.N, self.K)) * init_scale + 0.1 V_init = np.abs(rng.randn(self.M, self.K)) * init_scale + 0.1 return { 'alpha': self._inv_softplus(U_init), 'beta': self._inv_softplus(V_init), } def _prepare_edges(self, edges): n_edges = len(edges) I = np.array([e[0] for e in edges], dtype=np.int32) J = np.array([e[1] for e in edges], dtype=np.int32) X = np.array([e[2] for e in edges], dtype=np.float64) return I, J, X, n_edges def compute_objective(self, edges, params): I, J, X, _ = self._prepare_edges(edges) U = self._softplus(params['alpha']) V = self._softplus(params['beta']) pred = np.sum(U[I] * V[J], axis=1) prec_x = 1.0 / (self.sigma_x ** 2) obj = -0.5 * prec_x * np.sum((X - pred) ** 2) obj += np.sum((self.a0 - 1) * np.log(U + 1e-30) - self.b0 * U) obj += np.sum((self.c0 - 1) * np.log(V + 1e-30) - self.d0 * V) return float(obj) def _compute_gradients(self, I, J, X, params, update_users=None, update_items=None): """Compute gradients with clipping.""" U = self._softplus(params['alpha']) V = self._softplus(params['beta']) prec_x = 1.0 / (self.sigma_x ** 2) pred = np.sum(U[I] * V[J], axis=1) residual = X - pred sp_grad_alpha = self._softplus_grad(params['alpha']) sp_grad_beta = self._softplus_grad(params['beta']) grad_U = np.zeros_like(U) for k in range(self.K): contrib = prec_x * residual * V[J, k] np.add.at(grad_U[:, k], I, contrib) prior_grad_U = (self.a0 - 1) / (U + 1e-6) - self.b0 grad_U += prior_grad_U grad_alpha = grad_U * sp_grad_alpha grad_V = np.zeros_like(V) for k in range(self.K): contrib = prec_x * residual * U[I, k] np.add.at(grad_V[:, k], J, contrib) prior_grad_V = (self.c0 - 1) / (V + 1e-6) - self.d0 grad_V += prior_grad_V grad_beta = grad_V * sp_grad_beta # Gradient clipping if self.grad_clip > 0: gnorm_a = np.linalg.norm(grad_alpha) if gnorm_a > self.grad_clip: grad_alpha *= self.grad_clip / gnorm_a gnorm_b = np.linalg.norm(grad_beta) if gnorm_b > self.grad_clip: grad_beta *= self.grad_clip / gnorm_b return grad_alpha, grad_beta def _fit_internal(self, edges, params, max_iter=None, update_users=None, update_items=None): """Internal fit with Adam optimizer.""" t0 = time.time() if max_iter is None: max_iter = self.max_iter I, J, X, n_edges = self._prepare_edges(edges) # Adam state m_alpha = np.zeros_like(params['alpha']) v_alpha = np.zeros_like(params['alpha']) m_beta = np.zeros_like(params['beta']) v_beta = np.zeros_like(params['beta']) eps_adam = 1e-8 obj_trace = [] converged = False for it in range(max_iter): old_params = {k: v.copy() for k, v in params.items()} grad_alpha, grad_beta = self._compute_gradients( I, J, X, params, update_users, update_items) # Adam updates t_adam = it + 1 m_alpha = self.adam_beta1 * m_alpha + (1 - self.adam_beta1) * grad_alpha v_alpha = self.adam_beta2 * v_alpha + (1 - self.adam_beta2) * grad_alpha**2 m_hat_a = m_alpha / (1 - self.adam_beta1**t_adam) v_hat_a = v_alpha / (1 - self.adam_beta2**t_adam) m_beta = self.adam_beta1 * m_beta + (1 - self.adam_beta1) * grad_beta v_beta = self.adam_beta2 * v_beta + (1 - self.adam_beta2) * grad_beta**2 m_hat_b = m_beta / (1 - self.adam_beta1**t_adam) v_hat_b = v_beta / (1 - self.adam_beta2**t_adam) step_alpha = self.lr * m_hat_a / (np.sqrt(v_hat_a) + eps_adam) step_beta = self.lr * m_hat_b / (np.sqrt(v_hat_b) + eps_adam) if update_users is not None: ul = list(update_users) params['alpha'][ul] += step_alpha[ul] else: params['alpha'] = params['alpha'] + step_alpha if update_items is not None: il = list(update_items) params['beta'][il] += step_beta[il] else: params['beta'] = params['beta'] + step_beta change = relative_param_change(old_params, params) if it % 50 == 0: obj_trace.append(self.compute_objective(edges, params)) if change < self.tol: converged = True break return params, obj_trace, it + 1, converged, time.time() - t0 def fit_full(self, edges, config=None, init_params=None): if init_params is not None: params = {k: v.copy() for k, v in init_params.items()} else: params = self._init_params(edges=edges) params, obj_trace, n_iter, converged, runtime = self._fit_internal(edges, params) return FitResult( params=params, objective_trace=obj_trace, n_iterations=n_iter, converged=converged, runtime_sec=runtime, model_family='gaussian_gamma_map', inference_type='map', likelihood='gaussian', prior='gamma', ) def fit_without_edge(self, edges, edge_to_remove, config=None, init_params=None): i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1]) filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)] return self.fit_full(filtered, config, init_params) def fit_local(self, edges, edge_to_remove, radius, config=None, init_params=None): i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1]) filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)] if init_params is None: raise ValueError("Local unlearning requires init_params") params = {k: v.copy() for k, v in init_params.items()} u2i_orig, i2u_orig, _ = build_adjacency(edges, self.N, self.M) distances = get_deletion_neighborhood(edge_to_remove, u2i_orig, i2u_orig, self.N, self.M, radius) users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius) # Filter edges to neighborhood local_edges = [(i, j, x) for i, j, x in filtered if i in users_in_R or j in items_in_R] params, obj_trace, n_iter, converged, runtime = self._fit_internal( local_edges, params, update_users=users_in_R, update_items=items_in_R) return FitResult( params=params, objective_trace=obj_trace, n_iterations=n_iter, converged=converged, runtime_sec=runtime, model_family='gaussian_gamma_map', inference_type='map', likelihood='gaussian', prior='gamma', diagnostics={'radius': radius, 'n_users_updated': len(users_in_R), 'n_items_updated': len(items_in_R)} ) def fit_warm_start_global(self, edges, edge_to_remove, config=None, init_params=None): i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1]) filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)] return self.fit_full(filtered, config, init_params) def get_model(model_family, N, M, K, **kwargs): """Factory function.""" if model_family == 'poisson_gamma': valid = {'a0', 'b0', 'c0', 'd0', 'max_iter', 'tol', 'damping', 'seed'} return PoissonGammaVI(N, M, K, **{k: v for k, v in kwargs.items() if k in valid}) elif model_family == 'gaussian_gaussian': valid = {'sigma_U', 'sigma_V', 'sigma_x', 'max_iter', 'tol', 'damping', 'seed'} return GaussianGaussianVI(N, M, K, **{k: v for k, v in kwargs.items() if k in valid}) elif model_family == 'gaussian_gamma_map': valid = {'a0', 'b0', 'c0', 'd0', 'sigma_x', 'lr', 'max_iter', 'tol', 'seed', 'grad_clip'} return GaussianGammaMAP(N, M, K, **{k: v for k, v in kwargs.items() if k in valid}) else: raise ValueError(f"Unknown model family: {model_family}")