""" A2C agent loading, inference, and training utilities. """ from __future__ import annotations import os import threading from dataclasses import dataclass, field import numpy as np from stable_baselines3 import A2C from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.callbacks import BaseCallback from core.environment import AITutorEnv MODEL_PATH = "tutor_model" @dataclass class TrainingState: running: bool = False timestep: int = 0 total: int = 0 ep_rewards: list[float] = field(default_factory=list) status: str = "idle" model_ready: bool = False class _LiveCallback(BaseCallback): def __init__(self, state: TrainingState, log_every: int = 200): super().__init__() self._s = state self._le = log_every def _on_step(self) -> bool: if not self._s.running: return False self._s.timestep = self.num_timesteps for info in self.locals.get("infos", []): if "episode" in info: self._s.ep_rewards.append(float(info["episode"]["r"])) pct = self._s.timestep / max(self._s.total, 1) * 100 n_ep = len(self._s.ep_rewards) roll = float(np.mean(self._s.ep_rewards[-20:])) if self._s.ep_rewards else 0 self._s.status = ( f"{pct:.0f}% | Step {self._s.timestep:,}/{self._s.total:,} " f"| Episodes: {n_ep} | Rolling reward: {roll:.3f}" ) return True def _on_training_end(self) -> None: self._s.running = False self._s.model_ready = True self._s.status = f"Training complete — {self._s.timestep:,} steps." def load_model(path: str = MODEL_PATH) -> A2C: env = AITutorEnv() try: m = A2C.load(path, env=env) return m except Exception: m = A2C("MlpPolicy", env, verbose=0) m.learn(total_timesteps=10_000) m.save(path) return m def get_policy_probs(model: A2C, proficiency_pct: list[float]) -> np.ndarray: """Return action probability distribution for a given state.""" obs = np.array(proficiency_pct, dtype=np.float32) / 100.0 tensor = model.policy.obs_to_tensor(obs.reshape(1, -1))[0] dist = model.policy.get_distribution(tensor) return dist.distribution.probs.detach().cpu().numpy()[0] def simulate_path( model: A2C, start_pct: list[float], n_steps: int = 20, deterministic: bool = True, ) -> list[dict]: """ Roll out the policy for n_steps. Returns a list of step-dicts with state, action, reward, probs. """ env = AITutorEnv() env.set_state(start_pct) history = [] for step in range(n_steps): obs = env.state.copy() probs = get_policy_probs(model, (obs * 100).tolist()) action, _ = model.predict(obs, deterministic=deterministic) next_obs, reward, done, trunc, info = env.step(int(action)) history.append({ "step": step + 1, "action": int(action), "state": (obs * 100).tolist(), "reward": reward, "probs": probs.tolist(), "done": done or trunc, }) if done or trunc: break return history def start_training( total_steps: int, state: TrainingState, save_path: str = MODEL_PATH, ) -> threading.Thread: def _run(): state.running = True state.total = total_steps env = Monitor(AITutorEnv()) model = A2C("MlpPolicy", env, verbose=0, learning_rate=7e-4, gamma=0.99, n_steps=5, ent_coef=0.01) cb = _LiveCallback(state) model.learn(total_timesteps=total_steps, callback=cb, progress_bar=False) model.save(save_path) env.close() t = threading.Thread(target=_run, daemon=True) t.start() return t