ar9avg's picture
Initial submission: SQL Agent OpenEnv for Meta+HF hackathon
3c665d2
"""
LinUCB Contextual Bandit (Li et al., 2010).
Maintains per-action inverse covariance matrices using the
Sherman-Morrison rank-1 update formula for O(d^2) updates.
For each action a in {0..K-1}:
A_inv[a] β€” dΓ—d inverse covariance (starts as I_d)
b[a] β€” d reward-weighted feature accumulator
theta[a] = A_inv[a] @ b[a] (ridge regression estimate)
UCB_a(x) = theta[a] @ x + alpha * sqrt(max(0, x @ A_inv[a] @ x))
Action selection: argmax_a UCB_a(x)
"""
from __future__ import annotations
import json
import os
import random
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
from rl.types import FEATURE_DIM, NUM_ACTIONS, RepairAction, REPAIR_ACTION_NAMES
# Default path β€” can be overridden by DATA_DIR env var
_DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
WEIGHTS_PATH = _DATA_DIR / "rl_weights.json"
class LinUCB:
"""
LinUCB contextual bandit with Sherman-Morrison updates and alpha decay.
Weights are persisted to JSON after every 10 updates.
"""
def __init__(
self,
d: int = FEATURE_DIM,
K: int = NUM_ACTIONS,
alpha: float = 1.5,
) -> None:
self.d = d
self.K = K
self.alpha = alpha
self.total_updates = 0
loaded = self._load_weights()
if loaded is not None:
self.A_inv = loaded["A_inv"]
self.b = loaded["b"]
self.counts = loaded["counts"]
self.total_updates = loaded["total_updates"]
else:
self.A_inv: List[np.ndarray] = [np.eye(d) for _ in range(K)]
self.b: List[np.ndarray] = [np.zeros(d) for _ in range(K)]
self.counts: List[int] = [0] * K
# ─── Core Interface ──────────────────────────────────────────
def select_action(self, x: List[float]) -> Tuple[RepairAction, List[float]]:
"""
Select the action with highest UCB score.
Returns (action, scores_for_all_actions).
"""
xv = np.array(x, dtype=np.float64)
scores = []
for a in range(self.K):
theta = self.A_inv[a] @ self.b[a]
exploit = float(theta @ xv)
quad = float(xv @ self.A_inv[a] @ xv)
explore = self.alpha * float(np.sqrt(max(0.0, quad)))
scores.append(exploit + explore)
# Argmax with random tie-breaking
best_action = 0
best_score = scores[0]
for a in range(1, self.K):
if scores[a] > best_score or (
scores[a] == best_score and random.random() > 0.5
):
best_score = scores[a]
best_action = a
return RepairAction(best_action), scores
def update(self, x: List[float], action: RepairAction, reward: float) -> None:
"""
Update the model after observing a reward.
Uses Sherman-Morrison: (A + xx^T)^{-1} = A^{-1} - (A^{-1}xx^T A^{-1}) / (1 + x^T A^{-1} x)
"""
a = int(action)
xv = np.array(x, dtype=np.float64)
A_inv_x = self.A_inv[a] @ xv # shape (d,)
denom = 1.0 + float(xv @ A_inv_x) # scalar
# Rank-1 downdate
self.A_inv[a] -= np.outer(A_inv_x, A_inv_x) / denom
# Reward-weighted feature accumulation
self.b[a] += reward * xv
self.counts[a] += 1
self.total_updates += 1
if self.total_updates % 10 == 0:
self.save_weights()
def get_estimated_rewards(self, x: List[float]) -> List[float]:
"""
Return theta^T x for each action (no exploration bonus).
Useful for understanding learned policy.
"""
xv = np.array(x, dtype=np.float64)
return [float((self.A_inv[a] @ self.b[a]) @ xv) for a in range(self.K)]
def get_action_counts(self) -> List[int]:
return list(self.counts)
def get_total_updates(self) -> int:
return self.total_updates
def get_alpha(self) -> float:
return self.alpha
def decay_alpha(self, min_alpha: float = 0.3) -> None:
"""Decay exploration coefficient toward exploitation."""
self.alpha = max(min_alpha, self.alpha * 0.995)
def get_action_distribution(self) -> dict:
total = sum(self.counts) or 1
return {
REPAIR_ACTION_NAMES[RepairAction(a)]: self.counts[a] / total
for a in range(self.K)
}
# ─── Persistence ─────────────────────────────────────────────
def save_weights(self) -> None:
try:
WEIGHTS_PATH.parent.mkdir(parents=True, exist_ok=True)
data = {
"A_inv": [m.tolist() for m in self.A_inv],
"b": [v.tolist() for v in self.b],
"counts": self.counts,
"total_updates": self.total_updates,
"alpha": self.alpha,
}
WEIGHTS_PATH.write_text(json.dumps(data))
except Exception:
pass # Non-fatal
def _load_weights(self) -> Optional[dict]:
try:
if not WEIGHTS_PATH.exists():
return None
raw = json.loads(WEIGHTS_PATH.read_text())
A_inv = [np.array(m, dtype=np.float64) for m in raw["A_inv"]]
b = [np.array(v, dtype=np.float64) for v in raw["b"]]
# Validate dimensions
if (
len(A_inv) == self.K
and A_inv[0].shape == (self.d, self.d)
and len(b) == self.K
and b[0].shape == (self.d,)
):
return {
"A_inv": A_inv,
"b": b,
"counts": raw["counts"],
"total_updates": raw["total_updates"],
}
return None
except Exception:
return None
def reset(self) -> None:
self.A_inv = [np.eye(self.d) for _ in range(self.K)]
self.b = [np.zeros(self.d) for _ in range(self.K)]
self.counts = [0] * self.K
self.total_updates = 0
self.alpha = 1.5
try:
WEIGHTS_PATH.unlink(missing_ok=True)
except Exception:
pass