v2: model.py
Browse files- src/model.py +123 -57
src/model.py
CHANGED
|
@@ -191,7 +191,14 @@ class PoissonGammaVI:
|
|
| 191 |
self.N, self.M, radius)
|
| 192 |
users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
|
| 193 |
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
converged = False
|
| 197 |
for it in range(self.max_iter):
|
|
@@ -405,7 +412,10 @@ class GaussianGaussianVI:
|
|
| 405 |
self.N, self.M, radius)
|
| 406 |
users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
|
| 407 |
|
| 408 |
-
|
|
|
|
|
|
|
|
|
|
| 409 |
converged = False
|
| 410 |
for it in range(self.max_iter):
|
| 411 |
old_params = {k: v.copy() for k, v in params.items()}
|
|
@@ -436,10 +446,14 @@ class GaussianGaussianVI:
|
|
| 436 |
# ============================================================
|
| 437 |
|
| 438 |
class GaussianGammaMAP:
|
| 439 |
-
"""Gaussian likelihood + Gamma prior, MAP via softplus parameterization.
|
|
|
|
|
|
|
|
|
|
| 440 |
|
| 441 |
def __init__(self, N, M, K, a0=0.3, b0=1.0, c0=0.3, d0=1.0,
|
| 442 |
-
sigma_x=1.0, lr=0.01, max_iter=
|
|
|
|
| 443 |
self.N = N
|
| 444 |
self.M = M
|
| 445 |
self.K = K
|
|
@@ -452,6 +466,9 @@ class GaussianGammaMAP:
|
|
| 452 |
self.max_iter = max_iter
|
| 453 |
self.tol = tol
|
| 454 |
self.seed = seed
|
|
|
|
|
|
|
|
|
|
| 455 |
|
| 456 |
def _softplus(self, x):
|
| 457 |
return np.log1p(np.exp(np.clip(x, -20, 20)))
|
|
@@ -459,12 +476,28 @@ class GaussianGammaMAP:
|
|
| 459 |
def _softplus_grad(self, x):
|
| 460 |
return 1.0 / (1.0 + np.exp(-np.clip(x, -20, 20)))
|
| 461 |
|
| 462 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
if rng is None:
|
| 464 |
rng = np.random.RandomState(self.seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
return {
|
| 466 |
-
'alpha':
|
| 467 |
-
'beta':
|
| 468 |
}
|
| 469 |
|
| 470 |
def _prepare_edges(self, edges):
|
|
@@ -487,71 +520,98 @@ class GaussianGammaMAP:
|
|
| 487 |
obj += np.sum((self.c0 - 1) * np.log(V + 1e-30) - self.d0 * V)
|
| 488 |
return float(obj)
|
| 489 |
|
| 490 |
-
def
|
| 491 |
-
"""
|
| 492 |
U = self._softplus(params['alpha'])
|
| 493 |
V = self._softplus(params['beta'])
|
| 494 |
prec_x = 1.0 / (self.sigma_x ** 2)
|
| 495 |
|
| 496 |
-
pred = np.sum(U[I] * V[J], axis=1)
|
| 497 |
-
residual = X - pred
|
| 498 |
|
| 499 |
sp_grad_alpha = self._softplus_grad(params['alpha'])
|
| 500 |
sp_grad_beta = self._softplus_grad(params['beta'])
|
| 501 |
|
| 502 |
-
# Gradient for alpha (user params)
|
| 503 |
-
# dL/dU[i,k] = prec_x * sum_j residual[e] * V[j,k] + (a0-1)/U[i,k] - b0
|
| 504 |
grad_U = np.zeros_like(U)
|
| 505 |
for k in range(self.K):
|
| 506 |
contrib = prec_x * residual * V[J, k]
|
| 507 |
np.add.at(grad_U[:, k], I, contrib)
|
| 508 |
|
| 509 |
-
|
| 510 |
-
prior_grad_U = (self.a0 - 1) / (U + 1e-30) - self.b0
|
| 511 |
grad_U += prior_grad_U
|
| 512 |
-
|
| 513 |
-
# Chain rule through softplus
|
| 514 |
grad_alpha = grad_U * sp_grad_alpha
|
| 515 |
|
| 516 |
-
# Gradient for beta (item params)
|
| 517 |
grad_V = np.zeros_like(V)
|
| 518 |
for k in range(self.K):
|
| 519 |
contrib = prec_x * residual * U[I, k]
|
| 520 |
np.add.at(grad_V[:, k], J, contrib)
|
| 521 |
|
| 522 |
-
prior_grad_V = (self.c0 - 1) / (V + 1e-
|
| 523 |
grad_V += prior_grad_V
|
| 524 |
-
|
| 525 |
grad_beta = grad_V * sp_grad_beta
|
| 526 |
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
|
| 536 |
-
|
| 537 |
-
il = list(update_items)
|
| 538 |
-
beta_new[il] += self.lr * grad_beta[il]
|
| 539 |
-
else:
|
| 540 |
-
beta_new += self.lr * grad_beta
|
| 541 |
-
|
| 542 |
-
return {'alpha': alpha_new, 'beta': beta_new}
|
| 543 |
|
| 544 |
-
def
|
|
|
|
|
|
|
| 545 |
t0 = time.time()
|
|
|
|
|
|
|
| 546 |
I, J, X, n_edges = self._prepare_edges(edges)
|
| 547 |
|
| 548 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
obj_trace = []
|
| 550 |
converged = False
|
| 551 |
|
| 552 |
-
for it in range(
|
| 553 |
old_params = {k: v.copy() for k, v in params.items()}
|
| 554 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
change = relative_param_change(old_params, params)
|
| 556 |
if it % 50 == 0:
|
| 557 |
obj_trace.append(self.compute_objective(edges, params))
|
|
@@ -559,10 +619,20 @@ class GaussianGammaMAP:
|
|
| 559 |
converged = True
|
| 560 |
break
|
| 561 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
return FitResult(
|
| 563 |
params=params, objective_trace=obj_trace,
|
| 564 |
-
n_iterations=
|
| 565 |
-
runtime_sec=
|
| 566 |
model_family='gaussian_gamma_map', inference_type='map',
|
| 567 |
likelihood='gaussian', prior='gamma',
|
| 568 |
)
|
|
@@ -573,7 +643,6 @@ class GaussianGammaMAP:
|
|
| 573 |
return self.fit_full(filtered, config, init_params)
|
| 574 |
|
| 575 |
def fit_local(self, edges, edge_to_remove, radius, config=None, init_params=None):
|
| 576 |
-
t0 = time.time()
|
| 577 |
i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1])
|
| 578 |
filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)]
|
| 579 |
|
|
@@ -586,24 +655,21 @@ class GaussianGammaMAP:
|
|
| 586 |
self.N, self.M, radius)
|
| 587 |
users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
|
| 588 |
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
change = relative_param_change(old_params, params)
|
| 596 |
-
if change < self.tol:
|
| 597 |
-
converged = True
|
| 598 |
-
break
|
| 599 |
|
| 600 |
return FitResult(
|
| 601 |
-
params=params, objective_trace=
|
| 602 |
-
n_iterations=
|
| 603 |
-
runtime_sec=
|
| 604 |
model_family='gaussian_gamma_map', inference_type='map',
|
| 605 |
likelihood='gaussian', prior='gamma',
|
| 606 |
-
diagnostics={'radius': radius
|
|
|
|
| 607 |
)
|
| 608 |
|
| 609 |
def fit_warm_start_global(self, edges, edge_to_remove, config=None, init_params=None):
|
|
@@ -621,7 +687,7 @@ def get_model(model_family, N, M, K, **kwargs):
|
|
| 621 |
valid = {'sigma_U', 'sigma_V', 'sigma_x', 'max_iter', 'tol', 'damping', 'seed'}
|
| 622 |
return GaussianGaussianVI(N, M, K, **{k: v for k, v in kwargs.items() if k in valid})
|
| 623 |
elif model_family == 'gaussian_gamma_map':
|
| 624 |
-
valid = {'a0', 'b0', 'c0', 'd0', 'sigma_x', 'lr', 'max_iter', 'tol', 'seed'}
|
| 625 |
return GaussianGammaMAP(N, M, K, **{k: v for k, v in kwargs.items() if k in valid})
|
| 626 |
else:
|
| 627 |
raise ValueError(f"Unknown model family: {model_family}")
|
|
|
|
| 191 |
self.N, self.M, radius)
|
| 192 |
users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
|
| 193 |
|
| 194 |
+
# KEY OPTIMIZATION: filter edges to only those touching neighborhood
|
| 195 |
+
# For user i update: need all edges (i, j, x) where i in users_in_R
|
| 196 |
+
# For item j update: need all edges (i, j, x) where j in items_in_R
|
| 197 |
+
# Union: edges where i in users_in_R OR j in items_in_R
|
| 198 |
+
local_edges = [(i, j, x) for i, j, x in filtered
|
| 199 |
+
if i in users_in_R or j in items_in_R]
|
| 200 |
+
|
| 201 |
+
I, J, X, n_edges = self._prepare_edges(local_edges)
|
| 202 |
|
| 203 |
converged = False
|
| 204 |
for it in range(self.max_iter):
|
|
|
|
| 412 |
self.N, self.M, radius)
|
| 413 |
users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
|
| 414 |
|
| 415 |
+
# Filter edges to neighborhood
|
| 416 |
+
local_edges = [(i, j, x) for i, j, x in filtered
|
| 417 |
+
if i in users_in_R or j in items_in_R]
|
| 418 |
+
I, J, X, n_edges = self._prepare_edges(local_edges)
|
| 419 |
converged = False
|
| 420 |
for it in range(self.max_iter):
|
| 421 |
old_params = {k: v.copy() for k, v in params.items()}
|
|
|
|
| 446 |
# ============================================================
|
| 447 |
|
| 448 |
class GaussianGammaMAP:
|
| 449 |
+
"""Gaussian likelihood + Gamma prior, MAP via softplus parameterization.
|
| 450 |
+
|
| 451 |
+
Uses Adam optimizer with gradient clipping for stable convergence.
|
| 452 |
+
"""
|
| 453 |
|
| 454 |
def __init__(self, N, M, K, a0=0.3, b0=1.0, c0=0.3, d0=1.0,
|
| 455 |
+
sigma_x=1.0, lr=0.01, max_iter=500, tol=1e-5, seed=0,
|
| 456 |
+
grad_clip=5.0, adam_beta1=0.9, adam_beta2=0.999):
|
| 457 |
self.N = N
|
| 458 |
self.M = M
|
| 459 |
self.K = K
|
|
|
|
| 466 |
self.max_iter = max_iter
|
| 467 |
self.tol = tol
|
| 468 |
self.seed = seed
|
| 469 |
+
self.grad_clip = grad_clip
|
| 470 |
+
self.adam_beta1 = adam_beta1
|
| 471 |
+
self.adam_beta2 = adam_beta2
|
| 472 |
|
| 473 |
def _softplus(self, x):
|
| 474 |
return np.log1p(np.exp(np.clip(x, -20, 20)))
|
|
|
|
| 476 |
def _softplus_grad(self, x):
|
| 477 |
return 1.0 / (1.0 + np.exp(-np.clip(x, -20, 20)))
|
| 478 |
|
| 479 |
+
def _inv_softplus(self, y):
|
| 480 |
+
"""Inverse of softplus: log(exp(y) - 1)."""
|
| 481 |
+
return np.log(np.exp(np.clip(y, 1e-8, 20)) - 1 + 1e-30)
|
| 482 |
+
|
| 483 |
+
def _init_params(self, rng=None, edges=None):
|
| 484 |
if rng is None:
|
| 485 |
rng = np.random.RandomState(self.seed)
|
| 486 |
+
# Data-informed initialization: use NMF-style init from mean values
|
| 487 |
+
if edges is not None:
|
| 488 |
+
I = np.array([e[0] for e in edges], dtype=np.int32)
|
| 489 |
+
J = np.array([e[1] for e in edges], dtype=np.int32)
|
| 490 |
+
X = np.array([e[2] for e in edges], dtype=np.float64)
|
| 491 |
+
# Compute user/item means
|
| 492 |
+
x_mean = np.abs(X).mean()
|
| 493 |
+
init_scale = np.sqrt(np.abs(x_mean) / self.K + 0.1)
|
| 494 |
+
else:
|
| 495 |
+
init_scale = 0.5
|
| 496 |
+
U_init = np.abs(rng.randn(self.N, self.K)) * init_scale + 0.1
|
| 497 |
+
V_init = np.abs(rng.randn(self.M, self.K)) * init_scale + 0.1
|
| 498 |
return {
|
| 499 |
+
'alpha': self._inv_softplus(U_init),
|
| 500 |
+
'beta': self._inv_softplus(V_init),
|
| 501 |
}
|
| 502 |
|
| 503 |
def _prepare_edges(self, edges):
|
|
|
|
| 520 |
obj += np.sum((self.c0 - 1) * np.log(V + 1e-30) - self.d0 * V)
|
| 521 |
return float(obj)
|
| 522 |
|
| 523 |
+
def _compute_gradients(self, I, J, X, params, update_users=None, update_items=None):
|
| 524 |
+
"""Compute gradients with clipping."""
|
| 525 |
U = self._softplus(params['alpha'])
|
| 526 |
V = self._softplus(params['beta'])
|
| 527 |
prec_x = 1.0 / (self.sigma_x ** 2)
|
| 528 |
|
| 529 |
+
pred = np.sum(U[I] * V[J], axis=1)
|
| 530 |
+
residual = X - pred
|
| 531 |
|
| 532 |
sp_grad_alpha = self._softplus_grad(params['alpha'])
|
| 533 |
sp_grad_beta = self._softplus_grad(params['beta'])
|
| 534 |
|
|
|
|
|
|
|
| 535 |
grad_U = np.zeros_like(U)
|
| 536 |
for k in range(self.K):
|
| 537 |
contrib = prec_x * residual * V[J, k]
|
| 538 |
np.add.at(grad_U[:, k], I, contrib)
|
| 539 |
|
| 540 |
+
prior_grad_U = (self.a0 - 1) / (U + 1e-6) - self.b0
|
|
|
|
| 541 |
grad_U += prior_grad_U
|
|
|
|
|
|
|
| 542 |
grad_alpha = grad_U * sp_grad_alpha
|
| 543 |
|
|
|
|
| 544 |
grad_V = np.zeros_like(V)
|
| 545 |
for k in range(self.K):
|
| 546 |
contrib = prec_x * residual * U[I, k]
|
| 547 |
np.add.at(grad_V[:, k], J, contrib)
|
| 548 |
|
| 549 |
+
prior_grad_V = (self.c0 - 1) / (V + 1e-6) - self.d0
|
| 550 |
grad_V += prior_grad_V
|
|
|
|
| 551 |
grad_beta = grad_V * sp_grad_beta
|
| 552 |
|
| 553 |
+
# Gradient clipping
|
| 554 |
+
if self.grad_clip > 0:
|
| 555 |
+
gnorm_a = np.linalg.norm(grad_alpha)
|
| 556 |
+
if gnorm_a > self.grad_clip:
|
| 557 |
+
grad_alpha *= self.grad_clip / gnorm_a
|
| 558 |
+
gnorm_b = np.linalg.norm(grad_beta)
|
| 559 |
+
if gnorm_b > self.grad_clip:
|
| 560 |
+
grad_beta *= self.grad_clip / gnorm_b
|
| 561 |
|
| 562 |
+
return grad_alpha, grad_beta
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
|
| 564 |
+
def _fit_internal(self, edges, params, max_iter=None,
|
| 565 |
+
update_users=None, update_items=None):
|
| 566 |
+
"""Internal fit with Adam optimizer."""
|
| 567 |
t0 = time.time()
|
| 568 |
+
if max_iter is None:
|
| 569 |
+
max_iter = self.max_iter
|
| 570 |
I, J, X, n_edges = self._prepare_edges(edges)
|
| 571 |
|
| 572 |
+
# Adam state
|
| 573 |
+
m_alpha = np.zeros_like(params['alpha'])
|
| 574 |
+
v_alpha = np.zeros_like(params['alpha'])
|
| 575 |
+
m_beta = np.zeros_like(params['beta'])
|
| 576 |
+
v_beta = np.zeros_like(params['beta'])
|
| 577 |
+
eps_adam = 1e-8
|
| 578 |
+
|
| 579 |
obj_trace = []
|
| 580 |
converged = False
|
| 581 |
|
| 582 |
+
for it in range(max_iter):
|
| 583 |
old_params = {k: v.copy() for k, v in params.items()}
|
| 584 |
+
|
| 585 |
+
grad_alpha, grad_beta = self._compute_gradients(
|
| 586 |
+
I, J, X, params, update_users, update_items)
|
| 587 |
+
|
| 588 |
+
# Adam updates
|
| 589 |
+
t_adam = it + 1
|
| 590 |
+
m_alpha = self.adam_beta1 * m_alpha + (1 - self.adam_beta1) * grad_alpha
|
| 591 |
+
v_alpha = self.adam_beta2 * v_alpha + (1 - self.adam_beta2) * grad_alpha**2
|
| 592 |
+
m_hat_a = m_alpha / (1 - self.adam_beta1**t_adam)
|
| 593 |
+
v_hat_a = v_alpha / (1 - self.adam_beta2**t_adam)
|
| 594 |
+
|
| 595 |
+
m_beta = self.adam_beta1 * m_beta + (1 - self.adam_beta1) * grad_beta
|
| 596 |
+
v_beta = self.adam_beta2 * v_beta + (1 - self.adam_beta2) * grad_beta**2
|
| 597 |
+
m_hat_b = m_beta / (1 - self.adam_beta1**t_adam)
|
| 598 |
+
v_hat_b = v_beta / (1 - self.adam_beta2**t_adam)
|
| 599 |
+
|
| 600 |
+
step_alpha = self.lr * m_hat_a / (np.sqrt(v_hat_a) + eps_adam)
|
| 601 |
+
step_beta = self.lr * m_hat_b / (np.sqrt(v_hat_b) + eps_adam)
|
| 602 |
+
|
| 603 |
+
if update_users is not None:
|
| 604 |
+
ul = list(update_users)
|
| 605 |
+
params['alpha'][ul] += step_alpha[ul]
|
| 606 |
+
else:
|
| 607 |
+
params['alpha'] = params['alpha'] + step_alpha
|
| 608 |
+
|
| 609 |
+
if update_items is not None:
|
| 610 |
+
il = list(update_items)
|
| 611 |
+
params['beta'][il] += step_beta[il]
|
| 612 |
+
else:
|
| 613 |
+
params['beta'] = params['beta'] + step_beta
|
| 614 |
+
|
| 615 |
change = relative_param_change(old_params, params)
|
| 616 |
if it % 50 == 0:
|
| 617 |
obj_trace.append(self.compute_objective(edges, params))
|
|
|
|
| 619 |
converged = True
|
| 620 |
break
|
| 621 |
|
| 622 |
+
return params, obj_trace, it + 1, converged, time.time() - t0
|
| 623 |
+
|
| 624 |
+
def fit_full(self, edges, config=None, init_params=None):
|
| 625 |
+
if init_params is not None:
|
| 626 |
+
params = {k: v.copy() for k, v in init_params.items()}
|
| 627 |
+
else:
|
| 628 |
+
params = self._init_params(edges=edges)
|
| 629 |
+
|
| 630 |
+
params, obj_trace, n_iter, converged, runtime = self._fit_internal(edges, params)
|
| 631 |
+
|
| 632 |
return FitResult(
|
| 633 |
params=params, objective_trace=obj_trace,
|
| 634 |
+
n_iterations=n_iter, converged=converged,
|
| 635 |
+
runtime_sec=runtime,
|
| 636 |
model_family='gaussian_gamma_map', inference_type='map',
|
| 637 |
likelihood='gaussian', prior='gamma',
|
| 638 |
)
|
|
|
|
| 643 |
return self.fit_full(filtered, config, init_params)
|
| 644 |
|
| 645 |
def fit_local(self, edges, edge_to_remove, radius, config=None, init_params=None):
|
|
|
|
| 646 |
i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1])
|
| 647 |
filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)]
|
| 648 |
|
|
|
|
| 655 |
self.N, self.M, radius)
|
| 656 |
users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
|
| 657 |
|
| 658 |
+
# Filter edges to neighborhood
|
| 659 |
+
local_edges = [(i, j, x) for i, j, x in filtered
|
| 660 |
+
if i in users_in_R or j in items_in_R]
|
| 661 |
+
|
| 662 |
+
params, obj_trace, n_iter, converged, runtime = self._fit_internal(
|
| 663 |
+
local_edges, params, update_users=users_in_R, update_items=items_in_R)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
|
| 665 |
return FitResult(
|
| 666 |
+
params=params, objective_trace=obj_trace,
|
| 667 |
+
n_iterations=n_iter, converged=converged,
|
| 668 |
+
runtime_sec=runtime,
|
| 669 |
model_family='gaussian_gamma_map', inference_type='map',
|
| 670 |
likelihood='gaussian', prior='gamma',
|
| 671 |
+
diagnostics={'radius': radius, 'n_users_updated': len(users_in_R),
|
| 672 |
+
'n_items_updated': len(items_in_R)}
|
| 673 |
)
|
| 674 |
|
| 675 |
def fit_warm_start_global(self, edges, edge_to_remove, config=None, init_params=None):
|
|
|
|
| 687 |
valid = {'sigma_U', 'sigma_V', 'sigma_x', 'max_iter', 'tol', 'damping', 'seed'}
|
| 688 |
return GaussianGaussianVI(N, M, K, **{k: v for k, v in kwargs.items() if k in valid})
|
| 689 |
elif model_family == 'gaussian_gamma_map':
|
| 690 |
+
valid = {'a0', 'b0', 'c0', 'd0', 'sigma_x', 'lr', 'max_iter', 'tol', 'seed', 'grad_clip'}
|
| 691 |
return GaussianGammaMAP(N, M, K, **{k: v for k, v in kwargs.items() if k in valid})
|
| 692 |
else:
|
| 693 |
raise ValueError(f"Unknown model family: {model_family}")
|