""" LinUCB Contextual Bandit (Li et al., 2010). Maintains per-action inverse covariance matrices using the Sherman-Morrison rank-1 update formula for O(d^2) updates. For each action a in {0..K-1}: A_inv[a] — d×d inverse covariance (starts as I_d) b[a] — d reward-weighted feature accumulator theta[a] = A_inv[a] @ b[a] (ridge regression estimate) UCB_a(x) = theta[a] @ x + alpha * sqrt(max(0, x @ A_inv[a] @ x)) Action selection: argmax_a UCB_a(x) """ from __future__ import annotations import json import os import random from pathlib import Path from typing import List, Optional, Tuple import numpy as np from rl.types import FEATURE_DIM, NUM_ACTIONS, RepairAction, REPAIR_ACTION_NAMES # Default path — can be overridden by DATA_DIR env var _DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data")) WEIGHTS_PATH = _DATA_DIR / "rl_weights.json" class LinUCB: """ LinUCB contextual bandit with Sherman-Morrison updates and alpha decay. Weights are persisted to JSON after every 10 updates. """ def __init__( self, d: int = FEATURE_DIM, K: int = NUM_ACTIONS, alpha: float = 1.5, ) -> None: self.d = d self.K = K self.alpha = alpha self.total_updates = 0 loaded = self._load_weights() if loaded is not None: self.A_inv = loaded["A_inv"] self.b = loaded["b"] self.counts = loaded["counts"] self.total_updates = loaded["total_updates"] else: self.A_inv: List[np.ndarray] = [np.eye(d) for _ in range(K)] self.b: List[np.ndarray] = [np.zeros(d) for _ in range(K)] self.counts: List[int] = [0] * K # ─── Core Interface ────────────────────────────────────────── def select_action(self, x: List[float]) -> Tuple[RepairAction, List[float]]: """ Select the action with highest UCB score. Returns (action, scores_for_all_actions). """ xv = np.array(x, dtype=np.float64) scores = [] for a in range(self.K): theta = self.A_inv[a] @ self.b[a] exploit = float(theta @ xv) quad = float(xv @ self.A_inv[a] @ xv) explore = self.alpha * float(np.sqrt(max(0.0, quad))) scores.append(exploit + explore) # Argmax with random tie-breaking best_action = 0 best_score = scores[0] for a in range(1, self.K): if scores[a] > best_score or ( scores[a] == best_score and random.random() > 0.5 ): best_score = scores[a] best_action = a return RepairAction(best_action), scores def update(self, x: List[float], action: RepairAction, reward: float) -> None: """ Update the model after observing a reward. Uses Sherman-Morrison: (A + xx^T)^{-1} = A^{-1} - (A^{-1}xx^T A^{-1}) / (1 + x^T A^{-1} x) """ a = int(action) xv = np.array(x, dtype=np.float64) A_inv_x = self.A_inv[a] @ xv # shape (d,) denom = 1.0 + float(xv @ A_inv_x) # scalar # Rank-1 downdate self.A_inv[a] -= np.outer(A_inv_x, A_inv_x) / denom # Reward-weighted feature accumulation self.b[a] += reward * xv self.counts[a] += 1 self.total_updates += 1 if self.total_updates % 10 == 0: self.save_weights() def get_estimated_rewards(self, x: List[float]) -> List[float]: """ Return theta^T x for each action (no exploration bonus). Useful for understanding learned policy. """ xv = np.array(x, dtype=np.float64) return [float((self.A_inv[a] @ self.b[a]) @ xv) for a in range(self.K)] def get_action_counts(self) -> List[int]: return list(self.counts) def get_total_updates(self) -> int: return self.total_updates def get_alpha(self) -> float: return self.alpha def decay_alpha(self, min_alpha: float = 0.3) -> None: """Decay exploration coefficient toward exploitation.""" self.alpha = max(min_alpha, self.alpha * 0.995) def get_action_distribution(self) -> dict: total = sum(self.counts) or 1 return { REPAIR_ACTION_NAMES[RepairAction(a)]: self.counts[a] / total for a in range(self.K) } # ─── Persistence ───────────────────────────────────────────── def save_weights(self) -> None: try: WEIGHTS_PATH.parent.mkdir(parents=True, exist_ok=True) data = { "A_inv": [m.tolist() for m in self.A_inv], "b": [v.tolist() for v in self.b], "counts": self.counts, "total_updates": self.total_updates, "alpha": self.alpha, } WEIGHTS_PATH.write_text(json.dumps(data)) except Exception: pass # Non-fatal def _load_weights(self) -> Optional[dict]: try: if not WEIGHTS_PATH.exists(): return None raw = json.loads(WEIGHTS_PATH.read_text()) A_inv = [np.array(m, dtype=np.float64) for m in raw["A_inv"]] b = [np.array(v, dtype=np.float64) for v in raw["b"]] # Validate dimensions if ( len(A_inv) == self.K and A_inv[0].shape == (self.d, self.d) and len(b) == self.K and b[0].shape == (self.d,) ): return { "A_inv": A_inv, "b": b, "counts": raw["counts"], "total_updates": raw["total_updates"], } return None except Exception: return None def reset(self) -> None: self.A_inv = [np.eye(self.d) for _ in range(self.K)] self.b = [np.zeros(self.d) for _ in range(self.K)] self.counts = [0] * self.K self.total_updates = 0 self.alpha = 1.5 try: WEIGHTS_PATH.unlink(missing_ok=True) except Exception: pass