Preston
Upload KAT TutoringRSSM v2 world model β 2.8M params, best eval loss 0.3124 @ epoch 93
76e4ab1 verified | """KAT TutoringRSSM β Standalone Architecture for Inference. | |
| This file contains the complete model architecture for the KAT Tutoring World Model, | |
| a DreamerV3-style Recurrent State-Space Model (RSSM) adapted for tutoring domains. | |
| It can be used to load pretrained checkpoints without the full KAT codebase. | |
| Heritage: Abigail core/world_model.py WorldModel, adapted for KAT's | |
| tutoring-specific dimensions and loss functions. Integrates VL-JEPA | |
| Exponential Moving Average (EMA) target encoding for self-supervised | |
| representation learning. | |
| Architecture Overview: | |
| βββββββββββββββ βββββββββββββββ ββββββββββββββββ | |
| β Observation ββββββΆβ RSSM Core ββββββΆβ Predictions β | |
| β Encoder β β GRU + z β β obs/rew/doneβ | |
| βββββββββββββββ βββββββββββββββ ββββββββββββββββ | |
| β β² | |
| β βββββββ΄ββββββ | |
| β β Action β | |
| β β Embedding β | |
| β βββββββββββββ | |
| βΌ | |
| βββββββββββββββ | |
| β EMA Target β | |
| β Encoder β | |
| βββββββββββββββ | |
| Author: Preston Mills / QRI (Qualia Research Initiative) | |
| License: Apache-2.0 | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| from dataclasses import dataclass, field, asdict | |
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from torch.distributions import Normal | |
| logger = logging.getLogger(__name__) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIGURATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TutoringWorldModelConfig: | |
| """Configuration for the Tutoring RSSM world model. | |
| Heritage: Maps to Abigail's WorldModelConfig with tutoring-specific defaults. | |
| Observation space (20-dim): | |
| - Mastery estimates per topic (8 dims) | |
| - Misconception indicators (4 dims) | |
| - Engagement signals (4 dims) | |
| - Session context (4 dims) | |
| Action space (8 discrete actions): | |
| 0: clarify, 1: hint_l1, 2: hint_l2, 3: hint_l3, | |
| 4: encourage, 5: redirect, 6: assess, 7: summarize | |
| """ | |
| obs_dim: int = 20 | |
| action_dim: int = 8 | |
| latent_dim: int = 128 | |
| hidden_dim: int = 512 | |
| encoder_hidden: int = 256 | |
| decoder_hidden: int = 256 | |
| dropout: float = 0.1 | |
| # EMA target encoder (VL-JEPA heritage) | |
| ema_momentum: float = 0.996 | |
| # Multi-step imagination (DreamerV3 heritage) | |
| rollout_horizon: int = 5 | |
| rollout_weight: float = 0.5 | |
| rollout_discount: float = 0.95 | |
| def from_json(cls, path: str) -> "TutoringWorldModelConfig": | |
| """Load config from a JSON file.""" | |
| with open(path) as f: | |
| data = json.load(f) | |
| # Extract config dict if nested | |
| config_data = data.get("config", data) | |
| # Filter to only known fields | |
| known = {f.name for f in cls.__dataclass_fields__.values()} | |
| filtered = {k: v for k, v in config_data.items() if k in known} | |
| return cls(**filtered) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # COMPONENT MODULES | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ObservationEncoder(nn.Module): | |
| """Encode observations into latent embeddings. | |
| Architecture: Linear β LayerNorm β SiLU β Linear | |
| Heritage: Abigail EncoderNetwork, adapted for tutoring observation space. | |
| """ | |
| def __init__(self, obs_dim: int, latent_dim: int, hidden_dim: int = 256): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(obs_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, latent_dim), | |
| ) | |
| def forward(self, obs: Tensor) -> Tensor: | |
| return self.net(obs) | |
| class ObservationDecoder(nn.Module): | |
| """Decode features back to observation space. | |
| Architecture: Linear β LayerNorm β SiLU β Linear | |
| Heritage: Abigail DecoderNetwork. | |
| """ | |
| def __init__(self, feature_dim: int, obs_dim: int, hidden_dim: int = 256): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(feature_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, obs_dim), | |
| ) | |
| def forward(self, features: Tensor) -> Tensor: | |
| return self.net(features) | |
| class ActionEmbedding(nn.Module): | |
| """Embed discrete tutoring actions into continuous space.""" | |
| def __init__(self, num_actions: int, embed_dim: int): | |
| super().__init__() | |
| self.embed = nn.Embedding(num_actions, embed_dim) | |
| def forward(self, action: Tensor) -> Tensor: | |
| return self.embed(action.long()) | |
| class DeterministicTransition(nn.Module): | |
| """GRU-based deterministic state transition. | |
| Heritage: Abigail RSSM deterministic path. | |
| Projects [z_{t-1}, a_t] to hidden_dim, then feeds through GRU: | |
| x = Linear([z, a]) | |
| h_t = GRU(x, h_{t-1}) | |
| """ | |
| def __init__(self, hidden_dim: int, latent_dim: int, action_embed_dim: int): | |
| super().__init__() | |
| self.pre = nn.Linear(latent_dim + action_embed_dim, hidden_dim) | |
| self.gru = nn.GRUCell( | |
| input_size=hidden_dim, | |
| hidden_size=hidden_dim, | |
| ) | |
| def forward(self, h_prev: Tensor, z_prev: Tensor, a_embed: Tensor) -> Tensor: | |
| x = torch.cat([z_prev, a_embed], dim=-1) | |
| x = self.pre(x) | |
| h = self.gru(x, h_prev) | |
| return h | |
| class StochasticLatent(nn.Module): | |
| """Gaussian stochastic latent variable with prior and posterior. | |
| Heritage: Abigail RSSM stochastic path. | |
| Prior: p(z_t | h_t) β 2-layer MLP (hidden_dim β hidden_dim β 2*latent_dim) | |
| Posterior: q(z_t | h_t, o_t) β 2-layer MLP (hidden_dim+latent_dim β hidden_dim β 2*latent_dim) | |
| """ | |
| def __init__(self, hidden_dim: int, latent_dim: int, obs_embed_dim: int): | |
| super().__init__() | |
| self.prior_net = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, latent_dim * 2), | |
| ) | |
| self.posterior_net = nn.Sequential( | |
| nn.Linear(hidden_dim + obs_embed_dim, hidden_dim), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, latent_dim * 2), | |
| ) | |
| self.min_std = 0.1 | |
| def _split_params(self, params: Tensor) -> tuple[Tensor, Tensor, Normal]: | |
| """Split into mean and std, return distribution.""" | |
| mu, log_std = params.chunk(2, dim=-1) | |
| std = F.softplus(log_std) + self.min_std | |
| return mu, std, Normal(mu, std) | |
| def prior(self, h: Tensor) -> tuple[Tensor, Tensor, Normal]: | |
| return self._split_params(self.prior_net(h)) | |
| def posterior(self, h: Tensor, obs_embed: Tensor) -> tuple[Tensor, Tensor, Normal]: | |
| x = torch.cat([h, obs_embed], dim=-1) | |
| return self._split_params(self.posterior_net(x)) | |
| def kl_divergence(posterior: Normal, prior: Normal) -> Tensor: | |
| """KL(posterior || prior), summed over latent dims.""" | |
| return torch.distributions.kl_divergence(posterior, prior).sum(dim=-1) | |
| class RewardPredictor(nn.Module): | |
| """Predict scalar reward from RSSM features.""" | |
| def __init__(self, feature_dim: int, hidden_dim: int = 64): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(feature_dim, hidden_dim), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, 1), | |
| ) | |
| def forward(self, features: Tensor) -> Tensor: | |
| return self.net(features).squeeze(-1) | |
| class DonePredictor(nn.Module): | |
| """Predict episode termination (logit) from RSSM features.""" | |
| def __init__(self, feature_dim: int, hidden_dim: int = 64): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(feature_dim, hidden_dim), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, 1), | |
| ) | |
| def forward(self, features: Tensor) -> Tensor: | |
| return self.net(features).squeeze(-1) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # COMPLETE RSSM MODEL | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TutoringRSSM(nn.Module): | |
| """Complete RSSM world model for tutoring domain. | |
| Integrates all components: | |
| - Observation encoder/decoder (Linear β LayerNorm β SiLU β Linear) | |
| - Action embedding (nn.Embedding) | |
| - Projection + GRU deterministic transition | |
| - Gaussian stochastic prior/posterior (2-layer MLPs) | |
| - Reward and done predictors (2-layer MLPs) | |
| - EMA target encoder (VL-JEPA heritage) | |
| Heritage: Abigail core/world_model.py WorldModel, adapted for | |
| KAT's tutoring-specific dimensions and loss functions. | |
| """ | |
| def __init__(self, config: TutoringWorldModelConfig): | |
| super().__init__() | |
| self.config = config | |
| # Feature dimension: h + z | |
| self.feature_dim = config.hidden_dim + config.latent_dim | |
| # Action embedding (small enough for direct embedding) | |
| action_embed_dim = min(32, config.action_dim * 4) | |
| self.action_embed = ActionEmbedding(config.action_dim, action_embed_dim) | |
| # Observation encoder | |
| self.obs_encoder = ObservationEncoder( | |
| config.obs_dim, config.latent_dim, config.encoder_hidden, | |
| ) | |
| # RSSM core | |
| self.transition = DeterministicTransition( | |
| config.hidden_dim, config.latent_dim, action_embed_dim, | |
| ) | |
| self.stochastic = StochasticLatent( | |
| config.hidden_dim, config.latent_dim, config.latent_dim, | |
| ) | |
| # Predictors | |
| self.obs_decoder = ObservationDecoder( | |
| self.feature_dim, config.obs_dim, config.decoder_hidden, | |
| ) | |
| self.reward_pred = RewardPredictor(self.feature_dim) | |
| self.done_pred = DonePredictor(self.feature_dim) | |
| # EMA target encoder (VL-JEPA heritage) | |
| self.target_encoder = ObservationEncoder( | |
| config.obs_dim, config.latent_dim, config.encoder_hidden, | |
| ) | |
| # Initialize target encoder from main encoder | |
| self.target_encoder.load_state_dict(self.obs_encoder.state_dict()) | |
| for p in self.target_encoder.parameters(): | |
| p.requires_grad = False | |
| # Dropout | |
| self.dropout = nn.Dropout(config.dropout) | |
| self._param_count = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| def initial_state(self, batch_size: int) -> tuple[Tensor, Tensor]: | |
| """Create initial RSSM state (h_0, z_0).""" | |
| device = next(self.parameters()).device | |
| h = torch.zeros(batch_size, self.config.hidden_dim, device=device) | |
| z = torch.zeros(batch_size, self.config.latent_dim, device=device) | |
| return h, z | |
| def get_features(self, h: Tensor, z: Tensor) -> Tensor: | |
| """Concatenate deterministic and stochastic state.""" | |
| return torch.cat([h, z], dim=-1) | |
| def observe_step( | |
| self, | |
| h_prev: Tensor, | |
| z_prev: Tensor, | |
| action: Tensor, | |
| obs: Tensor, | |
| ) -> dict[str, Any]: | |
| """One observation step: process real observation. | |
| Uses posterior inference for training. | |
| Returns dict with: | |
| h, z, prior_dist, posterior_dist, features, | |
| pred_obs, pred_reward, pred_done | |
| """ | |
| # Embed action | |
| a_embed = self.action_embed(action) | |
| # Deterministic transition | |
| h = self.transition(h_prev, z_prev, a_embed) | |
| # Encode observation | |
| obs_embed = self.obs_encoder(obs) | |
| # Prior and posterior | |
| prior_mu, prior_sigma, prior_dist = self.stochastic.prior(h) | |
| post_mu, post_sigma, posterior_dist = self.stochastic.posterior(h, obs_embed) | |
| # Sample from posterior (training mode) | |
| z = posterior_dist.rsample() | |
| # Predictions from features | |
| features = self.get_features(h, z) | |
| pred_obs = self.obs_decoder(features) | |
| pred_reward = self.reward_pred(features) | |
| pred_done = self.done_pred(features) | |
| return { | |
| "h": h, | |
| "z": z, | |
| "prior_dist": prior_dist, | |
| "posterior_dist": posterior_dist, | |
| "features": features, | |
| "pred_obs": pred_obs, | |
| "pred_reward": pred_reward, | |
| "pred_done": pred_done, | |
| } | |
| def imagine_step( | |
| self, | |
| h_prev: Tensor, | |
| z_prev: Tensor, | |
| action: Tensor, | |
| ) -> dict[str, Any]: | |
| """One imagination step: predict without observation. | |
| Uses prior only (no posterior β for planning/counterfactual). | |
| Returns dict with: | |
| h, z, prior_dist, features, pred_obs, pred_reward, pred_done | |
| """ | |
| a_embed = self.action_embed(action) | |
| h = self.transition(h_prev, z_prev, a_embed) | |
| prior_mu, prior_sigma, prior_dist = self.stochastic.prior(h) | |
| z = prior_dist.rsample() | |
| features = self.get_features(h, z) | |
| pred_obs = self.obs_decoder(features) | |
| pred_reward = self.reward_pred(features) | |
| pred_done = self.done_pred(features) | |
| return { | |
| "h": h, | |
| "z": z, | |
| "prior_dist": prior_dist, | |
| "features": features, | |
| "pred_obs": pred_obs, | |
| "pred_reward": pred_reward, | |
| "pred_done": pred_done, | |
| } | |
| def update_target_encoder(self) -> None: | |
| """EMA update of target encoder (VL-JEPA heritage).""" | |
| m = self.config.ema_momentum | |
| for p_main, p_target in zip( | |
| self.obs_encoder.parameters(), | |
| self.target_encoder.parameters(), | |
| ): | |
| p_target.data.mul_(m).add_(p_main.data, alpha=1.0 - m) | |
| def from_pretrained(cls, checkpoint_path: str, device: str = "cpu") -> "TutoringRSSM": | |
| """Load a pretrained model from a checkpoint file. | |
| Args: | |
| checkpoint_path: Path to .pt checkpoint file. | |
| device: Device to load onto ('cpu', 'cuda', etc.) | |
| Returns: | |
| Loaded TutoringRSSM model in eval mode. | |
| Example: | |
| >>> model = TutoringRSSM.from_pretrained("tutoring_rssm_best.pt") | |
| >>> h, z = model.initial_state(batch_size=1) | |
| >>> obs = torch.randn(1, 20) | |
| >>> action = torch.tensor([2]) # hint_l2 | |
| >>> result = model.observe_step(h, z, action, obs) | |
| """ | |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| # Extract config | |
| config_dict = checkpoint.get("config", {}) | |
| known = {f.name for f in TutoringWorldModelConfig.__dataclass_fields__.values()} | |
| filtered = {k: v for k, v in config_dict.items() if k in known} | |
| config = TutoringWorldModelConfig(**filtered) | |
| # Build model and load weights | |
| model = cls(config) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.to(device) | |
| model.eval() | |
| logger.info( | |
| "Loaded TutoringRSSM from %s (epoch %d, params %d)", | |
| checkpoint_path, | |
| checkpoint.get("epoch", -1), | |
| sum(p.numel() for p in model.parameters()), | |
| ) | |
| return model | |