KAT-2-RSSM / architecture.py
Preston
Upload KAT TutoringRSSM v2 world model β€” 2.8M params, best eval loss 0.3124 @ epoch 93
76e4ab1 verified
"""KAT TutoringRSSM β€” Standalone Architecture for Inference.
This file contains the complete model architecture for the KAT Tutoring World Model,
a DreamerV3-style Recurrent State-Space Model (RSSM) adapted for tutoring domains.
It can be used to load pretrained checkpoints without the full KAT codebase.
Heritage: Abigail core/world_model.py WorldModel, adapted for KAT's
tutoring-specific dimensions and loss functions. Integrates VL-JEPA
Exponential Moving Average (EMA) target encoding for self-supervised
representation learning.
Architecture Overview:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Observation │────▢│ RSSM Core │────▢│ Predictions β”‚
β”‚ Encoder β”‚ β”‚ GRU + z β”‚ β”‚ obs/rew/doneβ”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚ β–²
β”‚ β”Œβ”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”
β”‚ β”‚ Action β”‚
β”‚ β”‚ Embedding β”‚
β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ EMA Target β”‚
β”‚ Encoder β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
Author: Preston Mills / QRI (Qualia Research Initiative)
License: Apache-2.0
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field, asdict
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Normal
logger = logging.getLogger(__name__)
# ═══════════════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════════════
@dataclass
class TutoringWorldModelConfig:
"""Configuration for the Tutoring RSSM world model.
Heritage: Maps to Abigail's WorldModelConfig with tutoring-specific defaults.
Observation space (20-dim):
- Mastery estimates per topic (8 dims)
- Misconception indicators (4 dims)
- Engagement signals (4 dims)
- Session context (4 dims)
Action space (8 discrete actions):
0: clarify, 1: hint_l1, 2: hint_l2, 3: hint_l3,
4: encourage, 5: redirect, 6: assess, 7: summarize
"""
obs_dim: int = 20
action_dim: int = 8
latent_dim: int = 128
hidden_dim: int = 512
encoder_hidden: int = 256
decoder_hidden: int = 256
dropout: float = 0.1
# EMA target encoder (VL-JEPA heritage)
ema_momentum: float = 0.996
# Multi-step imagination (DreamerV3 heritage)
rollout_horizon: int = 5
rollout_weight: float = 0.5
rollout_discount: float = 0.95
@classmethod
def from_json(cls, path: str) -> "TutoringWorldModelConfig":
"""Load config from a JSON file."""
with open(path) as f:
data = json.load(f)
# Extract config dict if nested
config_data = data.get("config", data)
# Filter to only known fields
known = {f.name for f in cls.__dataclass_fields__.values()}
filtered = {k: v for k, v in config_data.items() if k in known}
return cls(**filtered)
# ═══════════════════════════════════════════════════════════════════════
# COMPONENT MODULES
# ═══════════════════════════════════════════════════════════════════════
class ObservationEncoder(nn.Module):
"""Encode observations into latent embeddings.
Architecture: Linear β†’ LayerNorm β†’ SiLU β†’ Linear
Heritage: Abigail EncoderNetwork, adapted for tutoring observation space.
"""
def __init__(self, obs_dim: int, latent_dim: int, hidden_dim: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, latent_dim),
)
def forward(self, obs: Tensor) -> Tensor:
return self.net(obs)
class ObservationDecoder(nn.Module):
"""Decode features back to observation space.
Architecture: Linear β†’ LayerNorm β†’ SiLU β†’ Linear
Heritage: Abigail DecoderNetwork.
"""
def __init__(self, feature_dim: int, obs_dim: int, hidden_dim: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, obs_dim),
)
def forward(self, features: Tensor) -> Tensor:
return self.net(features)
class ActionEmbedding(nn.Module):
"""Embed discrete tutoring actions into continuous space."""
def __init__(self, num_actions: int, embed_dim: int):
super().__init__()
self.embed = nn.Embedding(num_actions, embed_dim)
def forward(self, action: Tensor) -> Tensor:
return self.embed(action.long())
class DeterministicTransition(nn.Module):
"""GRU-based deterministic state transition.
Heritage: Abigail RSSM deterministic path.
Projects [z_{t-1}, a_t] to hidden_dim, then feeds through GRU:
x = Linear([z, a])
h_t = GRU(x, h_{t-1})
"""
def __init__(self, hidden_dim: int, latent_dim: int, action_embed_dim: int):
super().__init__()
self.pre = nn.Linear(latent_dim + action_embed_dim, hidden_dim)
self.gru = nn.GRUCell(
input_size=hidden_dim,
hidden_size=hidden_dim,
)
def forward(self, h_prev: Tensor, z_prev: Tensor, a_embed: Tensor) -> Tensor:
x = torch.cat([z_prev, a_embed], dim=-1)
x = self.pre(x)
h = self.gru(x, h_prev)
return h
class StochasticLatent(nn.Module):
"""Gaussian stochastic latent variable with prior and posterior.
Heritage: Abigail RSSM stochastic path.
Prior: p(z_t | h_t) β€” 2-layer MLP (hidden_dim β†’ hidden_dim β†’ 2*latent_dim)
Posterior: q(z_t | h_t, o_t) β€” 2-layer MLP (hidden_dim+latent_dim β†’ hidden_dim β†’ 2*latent_dim)
"""
def __init__(self, hidden_dim: int, latent_dim: int, obs_embed_dim: int):
super().__init__()
self.prior_net = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, latent_dim * 2),
)
self.posterior_net = nn.Sequential(
nn.Linear(hidden_dim + obs_embed_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, latent_dim * 2),
)
self.min_std = 0.1
def _split_params(self, params: Tensor) -> tuple[Tensor, Tensor, Normal]:
"""Split into mean and std, return distribution."""
mu, log_std = params.chunk(2, dim=-1)
std = F.softplus(log_std) + self.min_std
return mu, std, Normal(mu, std)
def prior(self, h: Tensor) -> tuple[Tensor, Tensor, Normal]:
return self._split_params(self.prior_net(h))
def posterior(self, h: Tensor, obs_embed: Tensor) -> tuple[Tensor, Tensor, Normal]:
x = torch.cat([h, obs_embed], dim=-1)
return self._split_params(self.posterior_net(x))
@staticmethod
def kl_divergence(posterior: Normal, prior: Normal) -> Tensor:
"""KL(posterior || prior), summed over latent dims."""
return torch.distributions.kl_divergence(posterior, prior).sum(dim=-1)
class RewardPredictor(nn.Module):
"""Predict scalar reward from RSSM features."""
def __init__(self, feature_dim: int, hidden_dim: int = 64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, 1),
)
def forward(self, features: Tensor) -> Tensor:
return self.net(features).squeeze(-1)
class DonePredictor(nn.Module):
"""Predict episode termination (logit) from RSSM features."""
def __init__(self, feature_dim: int, hidden_dim: int = 64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, 1),
)
def forward(self, features: Tensor) -> Tensor:
return self.net(features).squeeze(-1)
# ═══════════════════════════════════════════════════════════════════════
# COMPLETE RSSM MODEL
# ═══════════════════════════════════════════════════════════════════════
class TutoringRSSM(nn.Module):
"""Complete RSSM world model for tutoring domain.
Integrates all components:
- Observation encoder/decoder (Linear β†’ LayerNorm β†’ SiLU β†’ Linear)
- Action embedding (nn.Embedding)
- Projection + GRU deterministic transition
- Gaussian stochastic prior/posterior (2-layer MLPs)
- Reward and done predictors (2-layer MLPs)
- EMA target encoder (VL-JEPA heritage)
Heritage: Abigail core/world_model.py WorldModel, adapted for
KAT's tutoring-specific dimensions and loss functions.
"""
def __init__(self, config: TutoringWorldModelConfig):
super().__init__()
self.config = config
# Feature dimension: h + z
self.feature_dim = config.hidden_dim + config.latent_dim
# Action embedding (small enough for direct embedding)
action_embed_dim = min(32, config.action_dim * 4)
self.action_embed = ActionEmbedding(config.action_dim, action_embed_dim)
# Observation encoder
self.obs_encoder = ObservationEncoder(
config.obs_dim, config.latent_dim, config.encoder_hidden,
)
# RSSM core
self.transition = DeterministicTransition(
config.hidden_dim, config.latent_dim, action_embed_dim,
)
self.stochastic = StochasticLatent(
config.hidden_dim, config.latent_dim, config.latent_dim,
)
# Predictors
self.obs_decoder = ObservationDecoder(
self.feature_dim, config.obs_dim, config.decoder_hidden,
)
self.reward_pred = RewardPredictor(self.feature_dim)
self.done_pred = DonePredictor(self.feature_dim)
# EMA target encoder (VL-JEPA heritage)
self.target_encoder = ObservationEncoder(
config.obs_dim, config.latent_dim, config.encoder_hidden,
)
# Initialize target encoder from main encoder
self.target_encoder.load_state_dict(self.obs_encoder.state_dict())
for p in self.target_encoder.parameters():
p.requires_grad = False
# Dropout
self.dropout = nn.Dropout(config.dropout)
self._param_count = sum(p.numel() for p in self.parameters() if p.requires_grad)
def initial_state(self, batch_size: int) -> tuple[Tensor, Tensor]:
"""Create initial RSSM state (h_0, z_0)."""
device = next(self.parameters()).device
h = torch.zeros(batch_size, self.config.hidden_dim, device=device)
z = torch.zeros(batch_size, self.config.latent_dim, device=device)
return h, z
def get_features(self, h: Tensor, z: Tensor) -> Tensor:
"""Concatenate deterministic and stochastic state."""
return torch.cat([h, z], dim=-1)
def observe_step(
self,
h_prev: Tensor,
z_prev: Tensor,
action: Tensor,
obs: Tensor,
) -> dict[str, Any]:
"""One observation step: process real observation.
Uses posterior inference for training.
Returns dict with:
h, z, prior_dist, posterior_dist, features,
pred_obs, pred_reward, pred_done
"""
# Embed action
a_embed = self.action_embed(action)
# Deterministic transition
h = self.transition(h_prev, z_prev, a_embed)
# Encode observation
obs_embed = self.obs_encoder(obs)
# Prior and posterior
prior_mu, prior_sigma, prior_dist = self.stochastic.prior(h)
post_mu, post_sigma, posterior_dist = self.stochastic.posterior(h, obs_embed)
# Sample from posterior (training mode)
z = posterior_dist.rsample()
# Predictions from features
features = self.get_features(h, z)
pred_obs = self.obs_decoder(features)
pred_reward = self.reward_pred(features)
pred_done = self.done_pred(features)
return {
"h": h,
"z": z,
"prior_dist": prior_dist,
"posterior_dist": posterior_dist,
"features": features,
"pred_obs": pred_obs,
"pred_reward": pred_reward,
"pred_done": pred_done,
}
def imagine_step(
self,
h_prev: Tensor,
z_prev: Tensor,
action: Tensor,
) -> dict[str, Any]:
"""One imagination step: predict without observation.
Uses prior only (no posterior β€” for planning/counterfactual).
Returns dict with:
h, z, prior_dist, features, pred_obs, pred_reward, pred_done
"""
a_embed = self.action_embed(action)
h = self.transition(h_prev, z_prev, a_embed)
prior_mu, prior_sigma, prior_dist = self.stochastic.prior(h)
z = prior_dist.rsample()
features = self.get_features(h, z)
pred_obs = self.obs_decoder(features)
pred_reward = self.reward_pred(features)
pred_done = self.done_pred(features)
return {
"h": h,
"z": z,
"prior_dist": prior_dist,
"features": features,
"pred_obs": pred_obs,
"pred_reward": pred_reward,
"pred_done": pred_done,
}
@torch.no_grad()
def update_target_encoder(self) -> None:
"""EMA update of target encoder (VL-JEPA heritage)."""
m = self.config.ema_momentum
for p_main, p_target in zip(
self.obs_encoder.parameters(),
self.target_encoder.parameters(),
):
p_target.data.mul_(m).add_(p_main.data, alpha=1.0 - m)
@classmethod
def from_pretrained(cls, checkpoint_path: str, device: str = "cpu") -> "TutoringRSSM":
"""Load a pretrained model from a checkpoint file.
Args:
checkpoint_path: Path to .pt checkpoint file.
device: Device to load onto ('cpu', 'cuda', etc.)
Returns:
Loaded TutoringRSSM model in eval mode.
Example:
>>> model = TutoringRSSM.from_pretrained("tutoring_rssm_best.pt")
>>> h, z = model.initial_state(batch_size=1)
>>> obs = torch.randn(1, 20)
>>> action = torch.tensor([2]) # hint_l2
>>> result = model.observe_step(h, z, action, obs)
"""
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
# Extract config
config_dict = checkpoint.get("config", {})
known = {f.name for f in TutoringWorldModelConfig.__dataclass_fields__.values()}
filtered = {k: v for k, v in config_dict.items() if k in known}
config = TutoringWorldModelConfig(**filtered)
# Build model and load weights
model = cls(config)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
logger.info(
"Loaded TutoringRSSM from %s (epoch %d, params %d)",
checkpoint_path,
checkpoint.get("epoch", -1),
sum(p.numel() for p in model.parameters()),
)
return model