""" POMDP Parameter Learning via Dirichlet Concentration Parameters. Standard Active Inference learning: instead of fixed A/B matrices, store Dirichlet concentration parameters (pseudo-counts). Each observation increments counts, and the actual matrices are derived by normalizing the columns. This gives the model the ability to learn: - A-matrices: "how reliable are my observations?" (observation model) - B-matrices: "how do actions change states?" (transition model) The learning rate controls how fast new evidence overrides the prior. High concentration (many pseudo-counts) = slow learning, stable model. Low concentration = fast learning, plastic model. Reference: Friston et al. (2016) "Active Inference and Learning" """ from __future__ import annotations from typing import Any, Dict, List, Optional, Tuple import numpy as np from .utils import normalize class DirichletLearner: """ Learns POMDP matrices via Dirichlet concentration parameter updates. The key insight: a matrix column A[:, s] = Dir(alpha[:, s]). The expected A-matrix is just normalize(alpha). Each observation (obs, state) pair increments alpha[obs, state]. This is Bayesian parameter learning — the concentration parameters encode both the prior structure AND accumulated evidence. Usage: learner = DirichletLearner(prior_A, learning_rate=1.0) # After observing obs=2 when belief peaks at state=3: learner.update(obs_idx=2, state_belief=beliefs) # Get the learned matrix: A_learned = learner.get_matrix() """ def __init__( self, prior_matrix: np.ndarray, prior_strength: float = 10.0, learning_rate: float = 1.0, ): """ Args: prior_matrix: Initial matrix to learn from (columns should sum to 1). Shape: [n_obs x n_states] for A, [n_states x n_states] for B. prior_strength: How many pseudo-counts the prior is worth. Higher = more prior-dominated, slower to change. Lower = more plastic, faster adaptation. learning_rate: How much each observation contributes. 1.0 = standard Bayesian update. < 1.0 = discounted (for non-stationary environments). """ # Convert prior matrix to concentration parameters # alpha = prior_strength * prior_matrix self.alpha = prior_matrix.copy() * prior_strength self.learning_rate = learning_rate self.n_updates = 0 def update( self, obs_idx: int, state_belief: np.ndarray, ) -> None: """ Update concentration parameters given an observation. Since we don't know the true state, we weight the update by the current belief over states (soft assignment). alpha[obs, s] += lr * q(s) for all s This is the standard "expected sufficient statistics" update for Dirichlet-Categorical models. Args: obs_idx: Which observation was seen (row index) state_belief: Current belief over states p(s) [n_states] """ self.alpha[obs_idx, :] += self.learning_rate * state_belief self.n_updates += 1 def update_transition( self, state_belief_before: np.ndarray, state_belief_after: np.ndarray, ) -> None: """ Update transition model concentration parameters. For B-matrices: we observe the state before and after. The outer product of beliefs gives the expected transition count. alpha[s', s] += lr * q(s') * q(s) for all s, s' Args: state_belief_before: Belief over states before transition state_belief_after: Belief over states after transition """ # Outer product: expected transition counts transition_counts = np.outer(state_belief_after, state_belief_before) self.alpha += self.learning_rate * transition_counts self.n_updates += 1 def get_matrix(self) -> np.ndarray: """ Get the current learned matrix by normalizing concentrations. A[:, s] = alpha[:, s] / sum(alpha[:, s]) """ matrix = self.alpha.copy() for col in range(matrix.shape[1]): col_sum = matrix[:, col].sum() if col_sum > 0: matrix[:, col] /= col_sum else: matrix[:, col] = 1.0 / matrix.shape[0] return matrix def get_confidence(self) -> float: """ Get confidence in the learned parameters. Higher total concentration = more confident. Returns the average concentration per column (how many effective observations per state). """ avg_concentration = np.mean(np.sum(self.alpha, axis=0)) return float(avg_concentration) def get_learning_progress(self) -> Dict[str, Any]: """Get summary of learning state.""" matrix = self.get_matrix() return { "n_updates": self.n_updates, "avg_concentration": float(np.mean(np.sum(self.alpha, axis=0))), "max_diagonal": float(np.max(np.diag(matrix))) if matrix.shape[0] == matrix.shape[1] else None, "learned_matrix": matrix.tolist(), } def reset(self, prior_matrix: np.ndarray, prior_strength: float = 10.0) -> None: """Reset to a new prior.""" self.alpha = prior_matrix.copy() * prior_strength self.n_updates = 0 class ModelLearner: """ Manages learning for all POMDP factors in the SphereModel. Wraps DirichletLearner instances for each factor's A-matrix and optionally B-matrices. Provides a clean interface for the CoachingAgent to call on each observation. """ def __init__( self, model, a_learning_rate: float = 1.0, b_learning_rate: float = 0.5, a_prior_strength: float = 10.0, b_prior_strength: float = 20.0, ): """ Args: model: SphereModel instance a_learning_rate: Learning rate for observation models b_learning_rate: Learning rate for transition models a_prior_strength: Prior strength for A-matrices b_prior_strength: Prior strength for B-matrices (higher = more stable) """ self.model = model # A-matrix learners (one per factor) self.a_learners: Dict[str, DirichletLearner] = {} for factor_name, A in model.A.items(): self.a_learners[factor_name] = DirichletLearner( prior_matrix=A, prior_strength=a_prior_strength, learning_rate=a_learning_rate, ) # B-matrix learners (one per factor per action) # Only learn B for friction factors — skill transitions are too # stable within a single session to learn meaningfully self.b_learners: Dict[str, Dict[int, DirichletLearner]] = {} friction_factors = list(model.spec.friction_factors.keys()) for factor_name in friction_factors: if factor_name in model.B: self.b_learners[factor_name] = {} n_actions = model.B[factor_name].shape[0] for action_idx in range(n_actions): self.b_learners[factor_name][action_idx] = DirichletLearner( prior_matrix=model.B[factor_name][action_idx], prior_strength=b_prior_strength, learning_rate=b_learning_rate, ) def learn_from_observation( self, factor_name: str, obs_idx: int, state_belief: np.ndarray, ) -> None: """ Update A-matrix for a factor given an observation. Called after each belief update to refine the observation model. """ if factor_name in self.a_learners: self.a_learners[factor_name].update(obs_idx, state_belief) # Update the model's A-matrix in place self.model.A[factor_name] = self.a_learners[factor_name].get_matrix() def learn_from_transition( self, factor_name: str, action_idx: int, belief_before: np.ndarray, belief_after: np.ndarray, ) -> None: """ Update B-matrix for a factor given a state transition. Called when we observe how the user's state changed after an action. """ if factor_name in self.b_learners and action_idx in self.b_learners[factor_name]: learner = self.b_learners[factor_name][action_idx] learner.update_transition(belief_before, belief_after) # Update the model's B-matrix in place self.model.B[factor_name][action_idx] = learner.get_matrix() def get_learning_summary(self) -> Dict[str, Any]: """Get summary of all learning progress.""" summary = {} for factor_name, learner in self.a_learners.items(): progress = learner.get_learning_progress() summary[f"A_{factor_name}"] = { "n_updates": progress["n_updates"], "avg_concentration": progress["avg_concentration"], } for factor_name, action_learners in self.b_learners.items(): total_updates = sum(l.n_updates for l in action_learners.values()) summary[f"B_{factor_name}"] = { "n_updates": total_updates, } return summary def reset(self) -> None: """Reset all learners to their priors.""" for learner in self.a_learners.values(): learner.reset(learner.alpha / max(learner.get_confidence(), 1.0)) for action_learners in self.b_learners.values(): for learner in action_learners.values(): learner.reset(learner.alpha / max(learner.get_confidence(), 1.0))