File size: 9,949 Bytes
157b149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""
POMDP Parameter Learning via Dirichlet Concentration Parameters.

Standard Active Inference learning: instead of fixed A/B matrices,
store Dirichlet concentration parameters (pseudo-counts). Each
observation increments counts, and the actual matrices are derived
by normalizing the columns.

This gives the model the ability to learn:
- A-matrices: "how reliable are my observations?" (observation model)
- B-matrices: "how do actions change states?" (transition model)

The learning rate controls how fast new evidence overrides the prior.
High concentration (many pseudo-counts) = slow learning, stable model.
Low concentration = fast learning, plastic model.

Reference: Friston et al. (2016) "Active Inference and Learning"
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple

import numpy as np

from .utils import normalize


class DirichletLearner:
    """
    Learns POMDP matrices via Dirichlet concentration parameter updates.

    The key insight: a matrix column A[:, s] = Dir(alpha[:, s]).
    The expected A-matrix is just normalize(alpha).
    Each observation (obs, state) pair increments alpha[obs, state].

    This is Bayesian parameter learning — the concentration parameters
    encode both the prior structure AND accumulated evidence.

    Usage:
        learner = DirichletLearner(prior_A, learning_rate=1.0)
        # After observing obs=2 when belief peaks at state=3:
        learner.update(obs_idx=2, state_belief=beliefs)
        # Get the learned matrix:
        A_learned = learner.get_matrix()
    """

    def __init__(
        self,
        prior_matrix: np.ndarray,
        prior_strength: float = 10.0,
        learning_rate: float = 1.0,
    ):
        """
        Args:
            prior_matrix: Initial matrix to learn from (columns should sum to 1).
                Shape: [n_obs x n_states] for A, [n_states x n_states] for B.
            prior_strength: How many pseudo-counts the prior is worth.
                Higher = more prior-dominated, slower to change.
                Lower = more plastic, faster adaptation.
            learning_rate: How much each observation contributes.
                1.0 = standard Bayesian update.
                < 1.0 = discounted (for non-stationary environments).
        """
        # Convert prior matrix to concentration parameters
        # alpha = prior_strength * prior_matrix
        self.alpha = prior_matrix.copy() * prior_strength
        self.learning_rate = learning_rate
        self.n_updates = 0

    def update(
        self,
        obs_idx: int,
        state_belief: np.ndarray,
    ) -> None:
        """
        Update concentration parameters given an observation.

        Since we don't know the true state, we weight the update
        by the current belief over states (soft assignment).

        alpha[obs, s] += lr * q(s)  for all s

        This is the standard "expected sufficient statistics" update
        for Dirichlet-Categorical models.

        Args:
            obs_idx: Which observation was seen (row index)
            state_belief: Current belief over states p(s) [n_states]
        """
        self.alpha[obs_idx, :] += self.learning_rate * state_belief
        self.n_updates += 1

    def update_transition(
        self,
        state_belief_before: np.ndarray,
        state_belief_after: np.ndarray,
    ) -> None:
        """
        Update transition model concentration parameters.

        For B-matrices: we observe the state before and after.
        The outer product of beliefs gives the expected transition count.

        alpha[s', s] += lr * q(s') * q(s)  for all s, s'

        Args:
            state_belief_before: Belief over states before transition
            state_belief_after: Belief over states after transition
        """
        # Outer product: expected transition counts
        transition_counts = np.outer(state_belief_after, state_belief_before)
        self.alpha += self.learning_rate * transition_counts
        self.n_updates += 1

    def get_matrix(self) -> np.ndarray:
        """
        Get the current learned matrix by normalizing concentrations.

        A[:, s] = alpha[:, s] / sum(alpha[:, s])
        """
        matrix = self.alpha.copy()
        for col in range(matrix.shape[1]):
            col_sum = matrix[:, col].sum()
            if col_sum > 0:
                matrix[:, col] /= col_sum
            else:
                matrix[:, col] = 1.0 / matrix.shape[0]
        return matrix

    def get_confidence(self) -> float:
        """
        Get confidence in the learned parameters.

        Higher total concentration = more confident.
        Returns the average concentration per column (how many
        effective observations per state).
        """
        avg_concentration = np.mean(np.sum(self.alpha, axis=0))
        return float(avg_concentration)

    def get_learning_progress(self) -> Dict[str, Any]:
        """Get summary of learning state."""
        matrix = self.get_matrix()
        return {
            "n_updates": self.n_updates,
            "avg_concentration": float(np.mean(np.sum(self.alpha, axis=0))),
            "max_diagonal": float(np.max(np.diag(matrix))) if matrix.shape[0] == matrix.shape[1] else None,
            "learned_matrix": matrix.tolist(),
        }

    def reset(self, prior_matrix: np.ndarray, prior_strength: float = 10.0) -> None:
        """Reset to a new prior."""
        self.alpha = prior_matrix.copy() * prior_strength
        self.n_updates = 0


class ModelLearner:
    """
    Manages learning for all POMDP factors in the SphereModel.

    Wraps DirichletLearner instances for each factor's A-matrix
    and optionally B-matrices. Provides a clean interface for
    the CoachingAgent to call on each observation.
    """

    def __init__(
        self,
        model,
        a_learning_rate: float = 1.0,
        b_learning_rate: float = 0.5,
        a_prior_strength: float = 10.0,
        b_prior_strength: float = 20.0,
    ):
        """
        Args:
            model: SphereModel instance
            a_learning_rate: Learning rate for observation models
            b_learning_rate: Learning rate for transition models
            a_prior_strength: Prior strength for A-matrices
            b_prior_strength: Prior strength for B-matrices (higher = more stable)
        """
        self.model = model

        # A-matrix learners (one per factor)
        self.a_learners: Dict[str, DirichletLearner] = {}
        for factor_name, A in model.A.items():
            self.a_learners[factor_name] = DirichletLearner(
                prior_matrix=A,
                prior_strength=a_prior_strength,
                learning_rate=a_learning_rate,
            )

        # B-matrix learners (one per factor per action)
        # Only learn B for friction factors — skill transitions are too
        # stable within a single session to learn meaningfully
        self.b_learners: Dict[str, Dict[int, DirichletLearner]] = {}
        friction_factors = list(model.spec.friction_factors.keys())
        for factor_name in friction_factors:
            if factor_name in model.B:
                self.b_learners[factor_name] = {}
                n_actions = model.B[factor_name].shape[0]
                for action_idx in range(n_actions):
                    self.b_learners[factor_name][action_idx] = DirichletLearner(
                        prior_matrix=model.B[factor_name][action_idx],
                        prior_strength=b_prior_strength,
                        learning_rate=b_learning_rate,
                    )

    def learn_from_observation(
        self,
        factor_name: str,
        obs_idx: int,
        state_belief: np.ndarray,
    ) -> None:
        """
        Update A-matrix for a factor given an observation.

        Called after each belief update to refine the observation model.
        """
        if factor_name in self.a_learners:
            self.a_learners[factor_name].update(obs_idx, state_belief)
            # Update the model's A-matrix in place
            self.model.A[factor_name] = self.a_learners[factor_name].get_matrix()

    def learn_from_transition(
        self,
        factor_name: str,
        action_idx: int,
        belief_before: np.ndarray,
        belief_after: np.ndarray,
    ) -> None:
        """
        Update B-matrix for a factor given a state transition.

        Called when we observe how the user's state changed after an action.
        """
        if factor_name in self.b_learners and action_idx in self.b_learners[factor_name]:
            learner = self.b_learners[factor_name][action_idx]
            learner.update_transition(belief_before, belief_after)
            # Update the model's B-matrix in place
            self.model.B[factor_name][action_idx] = learner.get_matrix()

    def get_learning_summary(self) -> Dict[str, Any]:
        """Get summary of all learning progress."""
        summary = {}
        for factor_name, learner in self.a_learners.items():
            progress = learner.get_learning_progress()
            summary[f"A_{factor_name}"] = {
                "n_updates": progress["n_updates"],
                "avg_concentration": progress["avg_concentration"],
            }
        for factor_name, action_learners in self.b_learners.items():
            total_updates = sum(l.n_updates for l in action_learners.values())
            summary[f"B_{factor_name}"] = {
                "n_updates": total_updates,
            }
        return summary

    def reset(self) -> None:
        """Reset all learners to their priors."""
        for learner in self.a_learners.values():
            learner.reset(learner.alpha / max(learner.get_confidence(), 1.0))
        for action_learners in self.b_learners.values():
            for learner in action_learners.values():
                learner.reset(learner.alpha / max(learner.get_confidence(), 1.0))