"""KAT TutoringRSSM — Standalone Architecture for Inference. This file contains the complete model architecture for the KAT Tutoring World Model, a DreamerV3-style Recurrent State-Space Model (RSSM) adapted for tutoring domains. It can be used to load pretrained checkpoints without the full KAT codebase. Heritage: Abigail core/world_model.py WorldModel, adapted for KAT's tutoring-specific dimensions and loss functions. Integrates VL-JEPA Exponential Moving Average (EMA) target encoding for self-supervised representation learning. Architecture Overview: ┌─────────────┐ ┌─────────────┐ ┌──────────────┐ │ Observation │────▶│ RSSM Core │────▶│ Predictions │ │ Encoder │ │ GRU + z │ │ obs/rew/done│ └─────────────┘ └─────────────┘ └──────────────┘ │ ▲ │ ┌─────┴─────┐ │ │ Action │ │ │ Embedding │ │ └───────────┘ ▼ ┌─────────────┐ │ EMA Target │ │ Encoder │ └─────────────┘ Author: Preston Mills / QRI (Qualia Research Initiative) License: Apache-2.0 """ from __future__ import annotations import json import logging from dataclasses import dataclass, field, asdict from typing import Any import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from torch.distributions import Normal logger = logging.getLogger(__name__) # ═══════════════════════════════════════════════════════════════════════ # CONFIGURATION # ═══════════════════════════════════════════════════════════════════════ @dataclass class TutoringWorldModelConfig: """Configuration for the Tutoring RSSM world model. Heritage: Maps to Abigail's WorldModelConfig with tutoring-specific defaults. Observation space (20-dim): - Mastery estimates per topic (8 dims) - Misconception indicators (4 dims) - Engagement signals (4 dims) - Session context (4 dims) Action space (8 discrete actions): 0: clarify, 1: hint_l1, 2: hint_l2, 3: hint_l3, 4: encourage, 5: redirect, 6: assess, 7: summarize """ obs_dim: int = 20 action_dim: int = 8 latent_dim: int = 128 hidden_dim: int = 512 encoder_hidden: int = 256 decoder_hidden: int = 256 dropout: float = 0.1 # EMA target encoder (VL-JEPA heritage) ema_momentum: float = 0.996 # Multi-step imagination (DreamerV3 heritage) rollout_horizon: int = 5 rollout_weight: float = 0.5 rollout_discount: float = 0.95 @classmethod def from_json(cls, path: str) -> "TutoringWorldModelConfig": """Load config from a JSON file.""" with open(path) as f: data = json.load(f) # Extract config dict if nested config_data = data.get("config", data) # Filter to only known fields known = {f.name for f in cls.__dataclass_fields__.values()} filtered = {k: v for k, v in config_data.items() if k in known} return cls(**filtered) # ═══════════════════════════════════════════════════════════════════════ # COMPONENT MODULES # ═══════════════════════════════════════════════════════════════════════ class ObservationEncoder(nn.Module): """Encode observations into latent embeddings. Architecture: Linear → LayerNorm → SiLU → Linear Heritage: Abigail EncoderNetwork, adapted for tutoring observation space. """ def __init__(self, obs_dim: int, latent_dim: int, hidden_dim: int = 256): super().__init__() self.net = nn.Sequential( nn.Linear(obs_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, latent_dim), ) def forward(self, obs: Tensor) -> Tensor: return self.net(obs) class ObservationDecoder(nn.Module): """Decode features back to observation space. Architecture: Linear → LayerNorm → SiLU → Linear Heritage: Abigail DecoderNetwork. """ def __init__(self, feature_dim: int, obs_dim: int, hidden_dim: int = 256): super().__init__() self.net = nn.Sequential( nn.Linear(feature_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, obs_dim), ) def forward(self, features: Tensor) -> Tensor: return self.net(features) class ActionEmbedding(nn.Module): """Embed discrete tutoring actions into continuous space.""" def __init__(self, num_actions: int, embed_dim: int): super().__init__() self.embed = nn.Embedding(num_actions, embed_dim) def forward(self, action: Tensor) -> Tensor: return self.embed(action.long()) class DeterministicTransition(nn.Module): """GRU-based deterministic state transition. Heritage: Abigail RSSM deterministic path. Projects [z_{t-1}, a_t] to hidden_dim, then feeds through GRU: x = Linear([z, a]) h_t = GRU(x, h_{t-1}) """ def __init__(self, hidden_dim: int, latent_dim: int, action_embed_dim: int): super().__init__() self.pre = nn.Linear(latent_dim + action_embed_dim, hidden_dim) self.gru = nn.GRUCell( input_size=hidden_dim, hidden_size=hidden_dim, ) def forward(self, h_prev: Tensor, z_prev: Tensor, a_embed: Tensor) -> Tensor: x = torch.cat([z_prev, a_embed], dim=-1) x = self.pre(x) h = self.gru(x, h_prev) return h class StochasticLatent(nn.Module): """Gaussian stochastic latent variable with prior and posterior. Heritage: Abigail RSSM stochastic path. Prior: p(z_t | h_t) — 2-layer MLP (hidden_dim → hidden_dim → 2*latent_dim) Posterior: q(z_t | h_t, o_t) — 2-layer MLP (hidden_dim+latent_dim → hidden_dim → 2*latent_dim) """ def __init__(self, hidden_dim: int, latent_dim: int, obs_embed_dim: int): super().__init__() self.prior_net = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, latent_dim * 2), ) self.posterior_net = nn.Sequential( nn.Linear(hidden_dim + obs_embed_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, latent_dim * 2), ) self.min_std = 0.1 def _split_params(self, params: Tensor) -> tuple[Tensor, Tensor, Normal]: """Split into mean and std, return distribution.""" mu, log_std = params.chunk(2, dim=-1) std = F.softplus(log_std) + self.min_std return mu, std, Normal(mu, std) def prior(self, h: Tensor) -> tuple[Tensor, Tensor, Normal]: return self._split_params(self.prior_net(h)) def posterior(self, h: Tensor, obs_embed: Tensor) -> tuple[Tensor, Tensor, Normal]: x = torch.cat([h, obs_embed], dim=-1) return self._split_params(self.posterior_net(x)) @staticmethod def kl_divergence(posterior: Normal, prior: Normal) -> Tensor: """KL(posterior || prior), summed over latent dims.""" return torch.distributions.kl_divergence(posterior, prior).sum(dim=-1) class RewardPredictor(nn.Module): """Predict scalar reward from RSSM features.""" def __init__(self, feature_dim: int, hidden_dim: int = 64): super().__init__() self.net = nn.Sequential( nn.Linear(feature_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, 1), ) def forward(self, features: Tensor) -> Tensor: return self.net(features).squeeze(-1) class DonePredictor(nn.Module): """Predict episode termination (logit) from RSSM features.""" def __init__(self, feature_dim: int, hidden_dim: int = 64): super().__init__() self.net = nn.Sequential( nn.Linear(feature_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, 1), ) def forward(self, features: Tensor) -> Tensor: return self.net(features).squeeze(-1) # ═══════════════════════════════════════════════════════════════════════ # COMPLETE RSSM MODEL # ═══════════════════════════════════════════════════════════════════════ class TutoringRSSM(nn.Module): """Complete RSSM world model for tutoring domain. Integrates all components: - Observation encoder/decoder (Linear → LayerNorm → SiLU → Linear) - Action embedding (nn.Embedding) - Projection + GRU deterministic transition - Gaussian stochastic prior/posterior (2-layer MLPs) - Reward and done predictors (2-layer MLPs) - EMA target encoder (VL-JEPA heritage) Heritage: Abigail core/world_model.py WorldModel, adapted for KAT's tutoring-specific dimensions and loss functions. """ def __init__(self, config: TutoringWorldModelConfig): super().__init__() self.config = config # Feature dimension: h + z self.feature_dim = config.hidden_dim + config.latent_dim # Action embedding (small enough for direct embedding) action_embed_dim = min(32, config.action_dim * 4) self.action_embed = ActionEmbedding(config.action_dim, action_embed_dim) # Observation encoder self.obs_encoder = ObservationEncoder( config.obs_dim, config.latent_dim, config.encoder_hidden, ) # RSSM core self.transition = DeterministicTransition( config.hidden_dim, config.latent_dim, action_embed_dim, ) self.stochastic = StochasticLatent( config.hidden_dim, config.latent_dim, config.latent_dim, ) # Predictors self.obs_decoder = ObservationDecoder( self.feature_dim, config.obs_dim, config.decoder_hidden, ) self.reward_pred = RewardPredictor(self.feature_dim) self.done_pred = DonePredictor(self.feature_dim) # EMA target encoder (VL-JEPA heritage) self.target_encoder = ObservationEncoder( config.obs_dim, config.latent_dim, config.encoder_hidden, ) # Initialize target encoder from main encoder self.target_encoder.load_state_dict(self.obs_encoder.state_dict()) for p in self.target_encoder.parameters(): p.requires_grad = False # Dropout self.dropout = nn.Dropout(config.dropout) self._param_count = sum(p.numel() for p in self.parameters() if p.requires_grad) def initial_state(self, batch_size: int) -> tuple[Tensor, Tensor]: """Create initial RSSM state (h_0, z_0).""" device = next(self.parameters()).device h = torch.zeros(batch_size, self.config.hidden_dim, device=device) z = torch.zeros(batch_size, self.config.latent_dim, device=device) return h, z def get_features(self, h: Tensor, z: Tensor) -> Tensor: """Concatenate deterministic and stochastic state.""" return torch.cat([h, z], dim=-1) def observe_step( self, h_prev: Tensor, z_prev: Tensor, action: Tensor, obs: Tensor, ) -> dict[str, Any]: """One observation step: process real observation. Uses posterior inference for training. Returns dict with: h, z, prior_dist, posterior_dist, features, pred_obs, pred_reward, pred_done """ # Embed action a_embed = self.action_embed(action) # Deterministic transition h = self.transition(h_prev, z_prev, a_embed) # Encode observation obs_embed = self.obs_encoder(obs) # Prior and posterior prior_mu, prior_sigma, prior_dist = self.stochastic.prior(h) post_mu, post_sigma, posterior_dist = self.stochastic.posterior(h, obs_embed) # Sample from posterior (training mode) z = posterior_dist.rsample() # Predictions from features features = self.get_features(h, z) pred_obs = self.obs_decoder(features) pred_reward = self.reward_pred(features) pred_done = self.done_pred(features) return { "h": h, "z": z, "prior_dist": prior_dist, "posterior_dist": posterior_dist, "features": features, "pred_obs": pred_obs, "pred_reward": pred_reward, "pred_done": pred_done, } def imagine_step( self, h_prev: Tensor, z_prev: Tensor, action: Tensor, ) -> dict[str, Any]: """One imagination step: predict without observation. Uses prior only (no posterior — for planning/counterfactual). Returns dict with: h, z, prior_dist, features, pred_obs, pred_reward, pred_done """ a_embed = self.action_embed(action) h = self.transition(h_prev, z_prev, a_embed) prior_mu, prior_sigma, prior_dist = self.stochastic.prior(h) z = prior_dist.rsample() features = self.get_features(h, z) pred_obs = self.obs_decoder(features) pred_reward = self.reward_pred(features) pred_done = self.done_pred(features) return { "h": h, "z": z, "prior_dist": prior_dist, "features": features, "pred_obs": pred_obs, "pred_reward": pred_reward, "pred_done": pred_done, } @torch.no_grad() def update_target_encoder(self) -> None: """EMA update of target encoder (VL-JEPA heritage).""" m = self.config.ema_momentum for p_main, p_target in zip( self.obs_encoder.parameters(), self.target_encoder.parameters(), ): p_target.data.mul_(m).add_(p_main.data, alpha=1.0 - m) @classmethod def from_pretrained(cls, checkpoint_path: str, device: str = "cpu") -> "TutoringRSSM": """Load a pretrained model from a checkpoint file. Args: checkpoint_path: Path to .pt checkpoint file. device: Device to load onto ('cpu', 'cuda', etc.) Returns: Loaded TutoringRSSM model in eval mode. Example: >>> model = TutoringRSSM.from_pretrained("tutoring_rssm_best.pt") >>> h, z = model.initial_state(batch_size=1) >>> obs = torch.randn(1, 20) >>> action = torch.tensor([2]) # hint_l2 >>> result = model.observe_step(h, z, action, obs) """ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) # Extract config config_dict = checkpoint.get("config", {}) known = {f.name for f in TutoringWorldModelConfig.__dataclass_fields__.values()} filtered = {k: v for k, v in config_dict.items() if k in known} config = TutoringWorldModelConfig(**filtered) # Build model and load weights model = cls(config) model.load_state_dict(checkpoint["model_state_dict"]) model.to(device) model.eval() logger.info( "Loaded TutoringRSSM from %s (epoch %d, params %d)", checkpoint_path, checkpoint.get("epoch", -1), sum(p.numel() for p in model.parameters()), ) return model