serliezer commited on
Commit
3148dc6
·
verified ·
1 Parent(s): 8078e25

v2: model.py

Browse files
Files changed (1) hide show
  1. src/model.py +123 -57
src/model.py CHANGED
@@ -191,7 +191,14 @@ class PoissonGammaVI:
191
  self.N, self.M, radius)
192
  users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
193
 
194
- I, J, X, n_edges = self._prepare_edges(filtered)
 
 
 
 
 
 
 
195
 
196
  converged = False
197
  for it in range(self.max_iter):
@@ -405,7 +412,10 @@ class GaussianGaussianVI:
405
  self.N, self.M, radius)
406
  users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
407
 
408
- I, J, X, n_edges = self._prepare_edges(filtered)
 
 
 
409
  converged = False
410
  for it in range(self.max_iter):
411
  old_params = {k: v.copy() for k, v in params.items()}
@@ -436,10 +446,14 @@ class GaussianGaussianVI:
436
  # ============================================================
437
 
438
  class GaussianGammaMAP:
439
- """Gaussian likelihood + Gamma prior, MAP via softplus parameterization. Vectorized."""
 
 
 
440
 
441
  def __init__(self, N, M, K, a0=0.3, b0=1.0, c0=0.3, d0=1.0,
442
- sigma_x=1.0, lr=0.01, max_iter=200, tol=1e-5, seed=0):
 
443
  self.N = N
444
  self.M = M
445
  self.K = K
@@ -452,6 +466,9 @@ class GaussianGammaMAP:
452
  self.max_iter = max_iter
453
  self.tol = tol
454
  self.seed = seed
 
 
 
455
 
456
  def _softplus(self, x):
457
  return np.log1p(np.exp(np.clip(x, -20, 20)))
@@ -459,12 +476,28 @@ class GaussianGammaMAP:
459
  def _softplus_grad(self, x):
460
  return 1.0 / (1.0 + np.exp(-np.clip(x, -20, 20)))
461
 
462
- def _init_params(self, rng=None):
 
 
 
 
463
  if rng is None:
464
  rng = np.random.RandomState(self.seed)
 
 
 
 
 
 
 
 
 
 
 
 
465
  return {
466
- 'alpha': rng.randn(self.N, self.K) * 0.5,
467
- 'beta': rng.randn(self.M, self.K) * 0.5,
468
  }
469
 
470
  def _prepare_edges(self, edges):
@@ -487,71 +520,98 @@ class GaussianGammaMAP:
487
  obj += np.sum((self.c0 - 1) * np.log(V + 1e-30) - self.d0 * V)
488
  return float(obj)
489
 
490
- def _gradient_step(self, I, J, X, params, update_users=None, update_items=None):
491
- """One gradient step."""
492
  U = self._softplus(params['alpha'])
493
  V = self._softplus(params['beta'])
494
  prec_x = 1.0 / (self.sigma_x ** 2)
495
 
496
- pred = np.sum(U[I] * V[J], axis=1) # (n_edges,)
497
- residual = X - pred # (n_edges,)
498
 
499
  sp_grad_alpha = self._softplus_grad(params['alpha'])
500
  sp_grad_beta = self._softplus_grad(params['beta'])
501
 
502
- # Gradient for alpha (user params)
503
- # dL/dU[i,k] = prec_x * sum_j residual[e] * V[j,k] + (a0-1)/U[i,k] - b0
504
  grad_U = np.zeros_like(U)
505
  for k in range(self.K):
506
  contrib = prec_x * residual * V[J, k]
507
  np.add.at(grad_U[:, k], I, contrib)
508
 
509
- # Prior gradient
510
- prior_grad_U = (self.a0 - 1) / (U + 1e-30) - self.b0
511
  grad_U += prior_grad_U
512
-
513
- # Chain rule through softplus
514
  grad_alpha = grad_U * sp_grad_alpha
515
 
516
- # Gradient for beta (item params)
517
  grad_V = np.zeros_like(V)
518
  for k in range(self.K):
519
  contrib = prec_x * residual * U[I, k]
520
  np.add.at(grad_V[:, k], J, contrib)
521
 
522
- prior_grad_V = (self.c0 - 1) / (V + 1e-30) - self.d0
523
  grad_V += prior_grad_V
524
-
525
  grad_beta = grad_V * sp_grad_beta
526
 
527
- alpha_new = params['alpha'].copy()
528
- beta_new = params['beta'].copy()
529
-
530
- if update_users is not None:
531
- ul = list(update_users)
532
- alpha_new[ul] += self.lr * grad_alpha[ul]
533
- else:
534
- alpha_new += self.lr * grad_alpha
535
 
536
- if update_items is not None:
537
- il = list(update_items)
538
- beta_new[il] += self.lr * grad_beta[il]
539
- else:
540
- beta_new += self.lr * grad_beta
541
-
542
- return {'alpha': alpha_new, 'beta': beta_new}
543
 
544
- def fit_full(self, edges, config=None, init_params=None):
 
 
545
  t0 = time.time()
 
 
546
  I, J, X, n_edges = self._prepare_edges(edges)
547
 
548
- params = {k: v.copy() for k, v in (init_params or self._init_params()).items()}
 
 
 
 
 
 
549
  obj_trace = []
550
  converged = False
551
 
552
- for it in range(self.max_iter):
553
  old_params = {k: v.copy() for k, v in params.items()}
554
- params = self._gradient_step(I, J, X, params)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
555
  change = relative_param_change(old_params, params)
556
  if it % 50 == 0:
557
  obj_trace.append(self.compute_objective(edges, params))
@@ -559,10 +619,20 @@ class GaussianGammaMAP:
559
  converged = True
560
  break
561
 
 
 
 
 
 
 
 
 
 
 
562
  return FitResult(
563
  params=params, objective_trace=obj_trace,
564
- n_iterations=it + 1, converged=converged,
565
- runtime_sec=time.time() - t0,
566
  model_family='gaussian_gamma_map', inference_type='map',
567
  likelihood='gaussian', prior='gamma',
568
  )
@@ -573,7 +643,6 @@ class GaussianGammaMAP:
573
  return self.fit_full(filtered, config, init_params)
574
 
575
  def fit_local(self, edges, edge_to_remove, radius, config=None, init_params=None):
576
- t0 = time.time()
577
  i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1])
578
  filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)]
579
 
@@ -586,24 +655,21 @@ class GaussianGammaMAP:
586
  self.N, self.M, radius)
587
  users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
588
 
589
- I, J, X, n_edges = self._prepare_edges(filtered)
590
- converged = False
591
- for it in range(self.max_iter):
592
- old_params = {k: v.copy() for k, v in params.items()}
593
- params = self._gradient_step(I, J, X, params,
594
- update_users=users_in_R, update_items=items_in_R)
595
- change = relative_param_change(old_params, params)
596
- if change < self.tol:
597
- converged = True
598
- break
599
 
600
  return FitResult(
601
- params=params, objective_trace=[],
602
- n_iterations=it + 1, converged=converged,
603
- runtime_sec=time.time() - t0,
604
  model_family='gaussian_gamma_map', inference_type='map',
605
  likelihood='gaussian', prior='gamma',
606
- diagnostics={'radius': radius}
 
607
  )
608
 
609
  def fit_warm_start_global(self, edges, edge_to_remove, config=None, init_params=None):
@@ -621,7 +687,7 @@ def get_model(model_family, N, M, K, **kwargs):
621
  valid = {'sigma_U', 'sigma_V', 'sigma_x', 'max_iter', 'tol', 'damping', 'seed'}
622
  return GaussianGaussianVI(N, M, K, **{k: v for k, v in kwargs.items() if k in valid})
623
  elif model_family == 'gaussian_gamma_map':
624
- valid = {'a0', 'b0', 'c0', 'd0', 'sigma_x', 'lr', 'max_iter', 'tol', 'seed'}
625
  return GaussianGammaMAP(N, M, K, **{k: v for k, v in kwargs.items() if k in valid})
626
  else:
627
  raise ValueError(f"Unknown model family: {model_family}")
 
191
  self.N, self.M, radius)
192
  users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
193
 
194
+ # KEY OPTIMIZATION: filter edges to only those touching neighborhood
195
+ # For user i update: need all edges (i, j, x) where i in users_in_R
196
+ # For item j update: need all edges (i, j, x) where j in items_in_R
197
+ # Union: edges where i in users_in_R OR j in items_in_R
198
+ local_edges = [(i, j, x) for i, j, x in filtered
199
+ if i in users_in_R or j in items_in_R]
200
+
201
+ I, J, X, n_edges = self._prepare_edges(local_edges)
202
 
203
  converged = False
204
  for it in range(self.max_iter):
 
412
  self.N, self.M, radius)
413
  users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
414
 
415
+ # Filter edges to neighborhood
416
+ local_edges = [(i, j, x) for i, j, x in filtered
417
+ if i in users_in_R or j in items_in_R]
418
+ I, J, X, n_edges = self._prepare_edges(local_edges)
419
  converged = False
420
  for it in range(self.max_iter):
421
  old_params = {k: v.copy() for k, v in params.items()}
 
446
  # ============================================================
447
 
448
  class GaussianGammaMAP:
449
+ """Gaussian likelihood + Gamma prior, MAP via softplus parameterization.
450
+
451
+ Uses Adam optimizer with gradient clipping for stable convergence.
452
+ """
453
 
454
  def __init__(self, N, M, K, a0=0.3, b0=1.0, c0=0.3, d0=1.0,
455
+ sigma_x=1.0, lr=0.01, max_iter=500, tol=1e-5, seed=0,
456
+ grad_clip=5.0, adam_beta1=0.9, adam_beta2=0.999):
457
  self.N = N
458
  self.M = M
459
  self.K = K
 
466
  self.max_iter = max_iter
467
  self.tol = tol
468
  self.seed = seed
469
+ self.grad_clip = grad_clip
470
+ self.adam_beta1 = adam_beta1
471
+ self.adam_beta2 = adam_beta2
472
 
473
  def _softplus(self, x):
474
  return np.log1p(np.exp(np.clip(x, -20, 20)))
 
476
  def _softplus_grad(self, x):
477
  return 1.0 / (1.0 + np.exp(-np.clip(x, -20, 20)))
478
 
479
+ def _inv_softplus(self, y):
480
+ """Inverse of softplus: log(exp(y) - 1)."""
481
+ return np.log(np.exp(np.clip(y, 1e-8, 20)) - 1 + 1e-30)
482
+
483
+ def _init_params(self, rng=None, edges=None):
484
  if rng is None:
485
  rng = np.random.RandomState(self.seed)
486
+ # Data-informed initialization: use NMF-style init from mean values
487
+ if edges is not None:
488
+ I = np.array([e[0] for e in edges], dtype=np.int32)
489
+ J = np.array([e[1] for e in edges], dtype=np.int32)
490
+ X = np.array([e[2] for e in edges], dtype=np.float64)
491
+ # Compute user/item means
492
+ x_mean = np.abs(X).mean()
493
+ init_scale = np.sqrt(np.abs(x_mean) / self.K + 0.1)
494
+ else:
495
+ init_scale = 0.5
496
+ U_init = np.abs(rng.randn(self.N, self.K)) * init_scale + 0.1
497
+ V_init = np.abs(rng.randn(self.M, self.K)) * init_scale + 0.1
498
  return {
499
+ 'alpha': self._inv_softplus(U_init),
500
+ 'beta': self._inv_softplus(V_init),
501
  }
502
 
503
  def _prepare_edges(self, edges):
 
520
  obj += np.sum((self.c0 - 1) * np.log(V + 1e-30) - self.d0 * V)
521
  return float(obj)
522
 
523
+ def _compute_gradients(self, I, J, X, params, update_users=None, update_items=None):
524
+ """Compute gradients with clipping."""
525
  U = self._softplus(params['alpha'])
526
  V = self._softplus(params['beta'])
527
  prec_x = 1.0 / (self.sigma_x ** 2)
528
 
529
+ pred = np.sum(U[I] * V[J], axis=1)
530
+ residual = X - pred
531
 
532
  sp_grad_alpha = self._softplus_grad(params['alpha'])
533
  sp_grad_beta = self._softplus_grad(params['beta'])
534
 
 
 
535
  grad_U = np.zeros_like(U)
536
  for k in range(self.K):
537
  contrib = prec_x * residual * V[J, k]
538
  np.add.at(grad_U[:, k], I, contrib)
539
 
540
+ prior_grad_U = (self.a0 - 1) / (U + 1e-6) - self.b0
 
541
  grad_U += prior_grad_U
 
 
542
  grad_alpha = grad_U * sp_grad_alpha
543
 
 
544
  grad_V = np.zeros_like(V)
545
  for k in range(self.K):
546
  contrib = prec_x * residual * U[I, k]
547
  np.add.at(grad_V[:, k], J, contrib)
548
 
549
+ prior_grad_V = (self.c0 - 1) / (V + 1e-6) - self.d0
550
  grad_V += prior_grad_V
 
551
  grad_beta = grad_V * sp_grad_beta
552
 
553
+ # Gradient clipping
554
+ if self.grad_clip > 0:
555
+ gnorm_a = np.linalg.norm(grad_alpha)
556
+ if gnorm_a > self.grad_clip:
557
+ grad_alpha *= self.grad_clip / gnorm_a
558
+ gnorm_b = np.linalg.norm(grad_beta)
559
+ if gnorm_b > self.grad_clip:
560
+ grad_beta *= self.grad_clip / gnorm_b
561
 
562
+ return grad_alpha, grad_beta
 
 
 
 
 
 
563
 
564
+ def _fit_internal(self, edges, params, max_iter=None,
565
+ update_users=None, update_items=None):
566
+ """Internal fit with Adam optimizer."""
567
  t0 = time.time()
568
+ if max_iter is None:
569
+ max_iter = self.max_iter
570
  I, J, X, n_edges = self._prepare_edges(edges)
571
 
572
+ # Adam state
573
+ m_alpha = np.zeros_like(params['alpha'])
574
+ v_alpha = np.zeros_like(params['alpha'])
575
+ m_beta = np.zeros_like(params['beta'])
576
+ v_beta = np.zeros_like(params['beta'])
577
+ eps_adam = 1e-8
578
+
579
  obj_trace = []
580
  converged = False
581
 
582
+ for it in range(max_iter):
583
  old_params = {k: v.copy() for k, v in params.items()}
584
+
585
+ grad_alpha, grad_beta = self._compute_gradients(
586
+ I, J, X, params, update_users, update_items)
587
+
588
+ # Adam updates
589
+ t_adam = it + 1
590
+ m_alpha = self.adam_beta1 * m_alpha + (1 - self.adam_beta1) * grad_alpha
591
+ v_alpha = self.adam_beta2 * v_alpha + (1 - self.adam_beta2) * grad_alpha**2
592
+ m_hat_a = m_alpha / (1 - self.adam_beta1**t_adam)
593
+ v_hat_a = v_alpha / (1 - self.adam_beta2**t_adam)
594
+
595
+ m_beta = self.adam_beta1 * m_beta + (1 - self.adam_beta1) * grad_beta
596
+ v_beta = self.adam_beta2 * v_beta + (1 - self.adam_beta2) * grad_beta**2
597
+ m_hat_b = m_beta / (1 - self.adam_beta1**t_adam)
598
+ v_hat_b = v_beta / (1 - self.adam_beta2**t_adam)
599
+
600
+ step_alpha = self.lr * m_hat_a / (np.sqrt(v_hat_a) + eps_adam)
601
+ step_beta = self.lr * m_hat_b / (np.sqrt(v_hat_b) + eps_adam)
602
+
603
+ if update_users is not None:
604
+ ul = list(update_users)
605
+ params['alpha'][ul] += step_alpha[ul]
606
+ else:
607
+ params['alpha'] = params['alpha'] + step_alpha
608
+
609
+ if update_items is not None:
610
+ il = list(update_items)
611
+ params['beta'][il] += step_beta[il]
612
+ else:
613
+ params['beta'] = params['beta'] + step_beta
614
+
615
  change = relative_param_change(old_params, params)
616
  if it % 50 == 0:
617
  obj_trace.append(self.compute_objective(edges, params))
 
619
  converged = True
620
  break
621
 
622
+ return params, obj_trace, it + 1, converged, time.time() - t0
623
+
624
+ def fit_full(self, edges, config=None, init_params=None):
625
+ if init_params is not None:
626
+ params = {k: v.copy() for k, v in init_params.items()}
627
+ else:
628
+ params = self._init_params(edges=edges)
629
+
630
+ params, obj_trace, n_iter, converged, runtime = self._fit_internal(edges, params)
631
+
632
  return FitResult(
633
  params=params, objective_trace=obj_trace,
634
+ n_iterations=n_iter, converged=converged,
635
+ runtime_sec=runtime,
636
  model_family='gaussian_gamma_map', inference_type='map',
637
  likelihood='gaussian', prior='gamma',
638
  )
 
643
  return self.fit_full(filtered, config, init_params)
644
 
645
  def fit_local(self, edges, edge_to_remove, radius, config=None, init_params=None):
 
646
  i_del, j_del = int(edge_to_remove[0]), int(edge_to_remove[1])
647
  filtered = [(i, j, x) for i, j, x in edges if not (i == i_del and j == j_del)]
648
 
 
655
  self.N, self.M, radius)
656
  users_in_R, items_in_R = get_user_item_sets_in_radius(distances, self.N, radius)
657
 
658
+ # Filter edges to neighborhood
659
+ local_edges = [(i, j, x) for i, j, x in filtered
660
+ if i in users_in_R or j in items_in_R]
661
+
662
+ params, obj_trace, n_iter, converged, runtime = self._fit_internal(
663
+ local_edges, params, update_users=users_in_R, update_items=items_in_R)
 
 
 
 
664
 
665
  return FitResult(
666
+ params=params, objective_trace=obj_trace,
667
+ n_iterations=n_iter, converged=converged,
668
+ runtime_sec=runtime,
669
  model_family='gaussian_gamma_map', inference_type='map',
670
  likelihood='gaussian', prior='gamma',
671
+ diagnostics={'radius': radius, 'n_users_updated': len(users_in_R),
672
+ 'n_items_updated': len(items_in_R)}
673
  )
674
 
675
  def fit_warm_start_global(self, edges, edge_to_remove, config=None, init_params=None):
 
687
  valid = {'sigma_U', 'sigma_V', 'sigma_x', 'max_iter', 'tol', 'damping', 'seed'}
688
  return GaussianGaussianVI(N, M, K, **{k: v for k, v in kwargs.items() if k in valid})
689
  elif model_family == 'gaussian_gamma_map':
690
+ valid = {'a0', 'b0', 'c0', 'd0', 'sigma_x', 'lr', 'max_iter', 'tol', 'seed', 'grad_clip'}
691
  return GaussianGammaMAP(N, M, K, **{k: v for k, v in kwargs.items() if k in valid})
692
  else:
693
  raise ValueError(f"Unknown model family: {model_family}")