Spaces:
Running
Running
| """Bayesian engine for selecting and updating response strategies. | |
| Implements a lightweight contextual bandit (Thompson sampling style) over | |
| presentation strategies. Supports: | |
| - soft preferences (bias, not lock) | |
| - optional hard lock (user can disable exploration) | |
| - Option B corrective exploration after negative feedback: | |
| * temperature boost | |
| * strong repeat penalty | |
| * posterior damping (reduce confidence in last chosen arm) | |
| """ | |
| import numpy as np | |
| import random | |
| from . import config | |
| from .utils import sigmoid, mean_uncertainty | |
| class BayesianEngine: | |
| def __init__(self): | |
| # Keep posterior state for all known strategies (enabled + disabled). | |
| # Selection/visualization should still use `config.STRATEGY_NAMES` (enabled only). | |
| initial_ids = set(getattr(config, "STRATEGY_ITEMS", {}) or {}).union(set(config.STRATEGY_NAMES)) | |
| self.global_mu = {k: np.zeros(config.D) for k in initial_ids} | |
| self.global_sinv = {k: np.eye(config.D) * 0.1 for k in initial_ids} | |
| self.users = {} | |
| self.global_n = 0 | |
| def _all_strategy_ids(self) -> set[str]: | |
| items = getattr(config, "STRATEGY_ITEMS", {}) or {} | |
| return set(items.keys()) if items else set(config.STRATEGY_NAMES) | |
| def reconcile_strategies(self) -> None: | |
| """ | |
| Reconcile engine posterior dictionaries with the current strategy store. | |
| - If new strategies appear: add priors (zero mean, low precision identity). | |
| - If strategies are hard-deleted: remove their posterior state. | |
| - If strategies are disabled: keep state but they won't be selected/visualized | |
| because selection/summary loops use `config.STRATEGY_NAMES`. | |
| """ | |
| desired_ids = self._all_strategy_ids() | |
| # Add missing strategy keys. | |
| for sid in desired_ids: | |
| if sid not in self.global_mu: | |
| self.global_mu[sid] = np.zeros(config.D) | |
| self.global_sinv[sid] = np.eye(config.D) * 0.1 | |
| # Remove hard-deleted strategies. | |
| current_ids = set(self.global_mu.keys()) | |
| to_remove = current_ids - desired_ids | |
| for sid in to_remove: | |
| self.global_mu.pop(sid, None) | |
| self.global_sinv.pop(sid, None) | |
| for user in self.users.values(): | |
| user.get("mu", {}).pop(sid, None) | |
| user.get("sigma_inv", {}).pop(sid, None) | |
| prefs = user.get("prefs") | |
| if isinstance(prefs, set) and sid in prefs: | |
| prefs.discard(sid) | |
| if user.get("locked_strategy") == sid: | |
| user["locked_strategy"] = None | |
| if user.get("pending_strategy") == sid: | |
| user["pending_strategy"] = None | |
| # Ensure all existing users have keys for every known strategy. | |
| for user in self.users.values(): | |
| mu = user.get("mu") or {} | |
| sinv = user.get("sigma_inv") or {} | |
| for sid in desired_ids: | |
| if sid not in mu: | |
| mu[sid] = self.global_mu[sid].copy() | |
| if sid not in sinv: | |
| sinv[sid] = self.global_sinv[sid].copy() | |
| user["mu"] = mu | |
| user["sigma_inv"] = sinv | |
| def _new_user(self): | |
| all_ids = self._all_strategy_ids() | |
| return { | |
| "mu": {k: self.global_mu[k].copy() for k in all_ids}, | |
| "sigma_inv": {k: self.global_sinv[k].copy() for k in all_ids}, | |
| "history": [], | |
| "reward_log": [], | |
| "last_message": "", | |
| "last_response": "", | |
| "last_strategy": None, | |
| "last_x": None, | |
| "msg_count": 0, | |
| "prefs": set(), | |
| "locked_strategy": None, | |
| "pending_strategy": None, | |
| } | |
| def get_user(self, uid: str): | |
| if uid not in self.users: | |
| self.users[uid] = self._new_user() | |
| return self.users[uid] | |
| def featurize(self, message: str, user: dict) -> np.ndarray: | |
| words = message.split() | |
| msg_len = min(len(message) / 500, 1.0) | |
| word_ct = min(len(words) / 100, 1.0) | |
| has_q = 1.0 if "?" in message else 0.0 | |
| is_long = 1.0 if len(words) > 40 else 0.0 | |
| informal = {"lol","gonna","wanna","yo","omg","idk","wtf","lmao"} | |
| formal = 0.0 if any(w in message.lower() for w in informal) else 1.0 | |
| rl = user["reward_log"] | |
| avg_r = sum(r for _, r in rl[-5:]) / max(len(rl[-5:]), 1) if rl else 0.5 | |
| msg_num = min(user["msg_count"] / 20.0, 1.0) | |
| # Map last strategy into [0,1] by its rank among enabled strategies. | |
| # If the strategy is currently disabled, fall back to 0.5. | |
| if user.get("last_strategy") and user["last_strategy"] in config.STRATEGY_NAMES and config.K > 1: | |
| si = config.STRATEGY_NAMES.index(user["last_strategy"]) / (config.K - 1) | |
| else: | |
| si = 0.5 | |
| trend = 0.0 | |
| if len(rl) >= 3: | |
| ys = [r for _, r in rl[-5:]] | |
| xs = list(range(len(ys))) | |
| mx, my = sum(xs)/len(xs), sum(ys)/len(ys) | |
| num = sum((xi-mx)*(yi-my) for xi, yi in zip(xs, ys)) | |
| den = sum((xi-mx)**2 for xi in xs) or 1e-8 | |
| trend = float(np.clip(num/den, -1, 1)) | |
| # last dim is just a small noise to break ties | |
| return np.array([msg_len, word_ct, has_q, is_long, | |
| formal, avg_r, msg_num, si, trend, random.random()]) | |
| def _damp_posterior(self, user: dict, strategy: str, strength: float): | |
| """Option B: reduce confidence in the last chosen strategy.""" | |
| if strategy not in user.get("mu", {}): | |
| return | |
| s = float(np.clip(strength, 0.0, 1.0)) | |
| if s <= 0: | |
| return | |
| # shrink mean magnitude | |
| mu = user["mu"][strategy] | |
| user["mu"][strategy] = mu * (1.0 - config.NEG_MU_SHRINK * s) | |
| # shrink precision -> increases covariance (more uncertainty) | |
| fac = max(0.15, 1.0 - config.NEG_SINV_SHRINK * s) | |
| user["sigma_inv"][strategy] = user["sigma_inv"][strategy] * fac | |
| def select(self, uid: str, message: str, *, | |
| force_explore: bool = False, | |
| neg_strength: float = 0.0, | |
| explicit_strategy: str | None = None): | |
| """Return (chosen, scores, x, prev_strategy).""" | |
| user = self.get_user(uid) | |
| prev = user.get("last_strategy") | |
| x = self.featurize(message, user) | |
| # One-time override: honor the upfront user-selected format once, | |
| # then return to adaptive Thompson Sampling on later turns. | |
| pending = user.get("pending_strategy") | |
| if pending in self.global_mu: | |
| user["pending_strategy"] = None | |
| return pending, {k: 0.0 for k in config.STRATEGY_NAMES}, x, prev | |
| # If user explicitly asked for a format, obey immediately. | |
| # Hard-lock disables exploration, but should NOT block an explicit request like | |
| # "compare X vs Y" or "put it in a table". | |
| locked = user.get("locked_strategy") | |
| if locked in self.global_mu: | |
| force_explore = False | |
| if explicit_strategy is None: | |
| explicit_strategy = locked | |
| # Corrective exploration: damp posterior on prev to avoid getting stuck. | |
| if force_explore and prev: | |
| self._damp_posterior(user, prev, neg_strength) | |
| # Build TS scores | |
| temp = config.TS_TEMPERATURE * (config.EXPLORE_TEMP_BOOST if force_explore else 1.0) | |
| scores = {} | |
| for k in config.STRATEGY_NAMES: | |
| sigma = np.linalg.inv(user["sigma_inv"][k]) | |
| beta = np.random.multivariate_normal(user["mu"][k], sigma * temp) | |
| scores[k] = float(sigmoid(x @ beta)) | |
| # Soft preference nudges (does NOT lock) | |
| pref_boost = 0.06 | |
| for s in user.get("prefs", set()): | |
| if s in scores: | |
| scores[s] = float(min(0.999, scores[s] + pref_boost)) | |
| # Strong anti-repeat penalty when exploring (unless explicit override) | |
| if force_explore and prev and prev in scores and explicit_strategy is None and len(config.STRATEGY_NAMES) > 1: | |
| scores[prev] = float(scores[prev] - config.EXPLORE_SCORE_PENALTY) | |
| # Choose | |
| # Explicit user request (e.g., "compare", "table", "graph") should win even if the user | |
| # previously hard-locked a default style. The lock only disables exploration. | |
| if explicit_strategy in config.STRATEGY_NAMES: | |
| chosen = explicit_strategy | |
| elif locked in self.global_mu and locked in config.STRATEGY_NAMES: | |
| chosen = locked | |
| else: | |
| chosen = max(scores, key=scores.get) | |
| return chosen, scores, x, prev | |
| def update(self, uid: str, strategy: str, x: np.ndarray, reward: float): | |
| user = self.get_user(uid) | |
| if strategy not in user.get("mu", {}) or strategy not in user.get("sigma_inv", {}): | |
| # Strategy might have been hard-deleted after selection; ignore update safely. | |
| return | |
| mu_old = user["mu"][strategy] | |
| si_old = user["sigma_inv"][strategy] | |
| r_hat = sigmoid(float(x @ mu_old)) | |
| w = r_hat * (1 - r_hat) | |
| si_new = config.GAMMA * si_old + np.outer(x, x) * w + config.LAMBDA * np.eye(config.D) | |
| s_new = np.linalg.inv(si_new) | |
| user["mu"][strategy] = mu_old + s_new @ x * (reward - r_hat) | |
| user["sigma_inv"][strategy] = si_new | |
| user["reward_log"].append((strategy, reward)) | |
| # Global update | |
| gmu = self.global_mu[strategy] | |
| gsi = self.global_sinv[strategy] | |
| gr_hat = sigmoid(float(x @ gmu)) | |
| gsi_n = config.GAMMA*gsi + np.outer(x,x)*gr_hat*(1-gr_hat)*config.ALPHA_G + config.LAMBDA*np.eye(config.D) | |
| gs_n = np.linalg.inv(gsi_n) | |
| self.global_mu[strategy] = gmu + gs_n @ x * (reward - gr_hat) * config.ALPHA_G | |
| self.global_sinv[strategy] = gsi_n | |
| self.global_n += 1 | |
| def apply_preferences(self, uid: str, strategy_names, *, lock: bool = False): | |
| """Apply soft preferences and optional hard lock. | |
| Behavior: | |
| - if lock=True and exactly one strategy is chosen: always use that strategy | |
| - if lock=False and exactly one strategy is chosen: force it ONCE on the next turn, | |
| then return to adaptive TS with preference bias | |
| """ | |
| user = self.get_user(uid) | |
| chosen = [s for s in (strategy_names or []) if s in config.STRATEGY_NAMES] | |
| user["prefs"] = set(chosen) | |
| # hard lock only if explicitly requested | |
| user["locked_strategy"] = chosen[0] if (lock and len(chosen) == 1) else None | |
| # one-turn override for the very next assistant response | |
| user["pending_strategy"] = chosen[0] if (not lock and len(chosen) == 1) else None | |
| # warm-start: nudge mean for preferred arms | |
| for s in user["prefs"]: | |
| user["mu"][s] = user["mu"][s] + 0.5 | |
| def posterior_summary(self, mu_dict, sinv_dict, x=None): | |
| if x is None: | |
| x = np.ones(config.D) * 0.5 | |
| return {k: {"r": round(float(sigmoid(x @ mu_dict[k])), 4), | |
| "u": round(mean_uncertainty(sinv_dict[k]), 4)} | |
| for k in config.STRATEGY_NAMES} | |
| def user_posterior(self, uid: str, x=None): | |
| u = self.get_user(uid) | |
| return self.posterior_summary(u["mu"], u["sigma_inv"], x) | |
| def global_posterior(self, x=None): | |
| return self.posterior_summary(self.global_mu, self.global_sinv, x) | |
| def reset_user(self, uid: str): | |
| self.users.pop(uid, None) | |
| def clear_conversation_thread(self, uid: str) -> None: | |
| """Drop chat history and last-turn pointers without resetting posteriors or reward log.""" | |
| u = self.get_user(uid) | |
| u["history"] = [] | |
| u["last_message"] = "" | |
| u["last_response"] = "" | |
| u["last_strategy"] = None | |
| u["last_x"] = None | |
| engine = BayesianEngine() | |
| USERB_ID = "__user_b__" | |
| engine.get_user(USERB_ID) | |