KAT-2-RSSM

A Recurrent State-Space Model trained for tutoring state prediction, part of the KAT system by Progga AI.

Model Description

This is a complete world model for predicting tutoring session dynamics β€” student state transitions, reward signals, and session termination. It uses a DreamerV3-inspired RSSM architecture with VL-JEPA-style EMA target encoding.

Architecture

TutoringRSSM (2,802,838 params)
β”œβ”€β”€ ObservationEncoder: obs_dim(20) β†’ encoder_hidden(256) β†’ latent_dim(128)
β”œβ”€β”€ ActionEmbedding: action_dim(8) β†’ embed_dim(32)
β”œβ”€β”€ DeterministicTransition: GRU(hidden_dim=512)
β”œβ”€β”€ StochasticLatent: Diagonal Gaussian prior/posterior (latent_dim=128)
β”œβ”€β”€ ObservationDecoder: feature_dim(640) β†’ decoder_hidden(256) β†’ obs_dim(20)
β”œβ”€β”€ RewardPredictor: feature_dim(640) β†’ 1
β”œβ”€β”€ DonePredictor: feature_dim(640) β†’ 1
└── EMATargetEncoder: momentum=0.996 (VL-JEPA heritage)

Feature dimension: hidden_dim + latent_dim = 512 + 128 = 640

Observation Space (20-dim)

The 20-dimensional observation vector encodes tutoring session state:

Dims Signal
0-3 Mastery estimates (per-topic confidence)
4-7 Engagement signals (attention, participation)
8-11 Response quality (accuracy, depth, speed)
12-15 Emotional state (frustration, confidence, curiosity)
16-19 Session context (time, hint level, attempt count)

Action Space (8 discrete actions)

Index Strategy
0 SOCRATIC β€” Guided questioning
1 SCAFFOLDED β€” Structured support
2 DIRECT β€” Direct instruction
3 EXPLORATORY β€” Open exploration
4 REMEDIAL β€” Error correction
5 ASSESSMENT β€” Knowledge check
6 MOTIVATIONAL β€” Encouragement
7 METACOGNITIVE β€” Reflection

Training Details

  • Data: 100,901 synthetic tutoring trajectories (95,856 train / 5,045 eval)
  • Epochs: 100 (best at epoch 93)
  • Hardware: NVIDIA A100-SXM4-40GB
  • Optimizer: Adam (lr=3e-4)
  • Training time: ~45 minutes
  • Framework: PyTorch 2.x

Training Metrics (Best Checkpoint β€” Epoch 93)

Metric Value
Total Loss 0.3124
Reconstruction Loss 0.1389
KL Divergence 0.0104
Reward Loss 0.0820
Done Loss 0.0640
Rollout Loss 0.3294

Training Curve

Training converged smoothly over 100 epochs with consistent eval loss improvement. No catastrophic forgetting or training instability observed.

Files

File Description Size
tutoring_rssm_best.pt Best checkpoint (epoch 93, eval loss 0.3124) 11 MB
tutoring_rssm_final.pt Final checkpoint (epoch 100) 11 MB
tutoring_rssm_epoch{N}.pt Snapshots every 10 epochs 11 MB each
v1-backup/ RSSM v1 checkpoints (smaller model) ~800 KB each
training_log.txt Full training log ~8 KB
config.json Model configuration <1 KB
architecture.py Standalone model definition ~20 KB

Usage

import torch
from architecture import TutoringRSSM, TutoringWorldModelConfig

# Load model
config = TutoringWorldModelConfig(
    obs_dim=20, action_dim=8,
    latent_dim=128, hidden_dim=512,
    encoder_hidden=256, decoder_hidden=256,
)
model = TutoringRSSM(config).cuda()

ckpt = torch.load("tutoring_rssm_best.pt", map_location="cuda")
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

# Initialize state
h, z = model.initial_state(batch_size=1)

# Observe a tutoring step
obs = torch.randn(1, 20).cuda()  # Student observation
action = torch.tensor([0]).cuda()  # SOCRATIC strategy
result = model.observe_step(h, z, action, obs)

h_new, z_new = result["h"], result["z"]
pred_obs = result["pred_obs"]       # Predicted next observation
pred_reward = result["pred_reward"]  # Predicted reward
pred_done = result["pred_done"]      # Predicted session end

# Imagination (planning without observation)
imagined = model.imagine_step(h_new, z_new, torch.tensor([3]).cuda())
# Returns predicted state without requiring real observation

Evaluation Results (94/94 tests pass)

Component Tests Status
Predictive Student Model 44/44 ALL PASS
Cognition World Model Eval 2/2 ALL ACCEPTANCE MET
Core PyTorch RSSM 10/10 ALL PASS
Physics/Causality Micro-Modules 23/23 ALL PASS
Trained Checkpoint Inference 7/7 ALL PASS
Advanced Planners (MCTS/Beam) 8/8 ALL PASS

Acceptance Criteria

  • Prediction accuracy: 12.08% error at horizon (target <20%) βœ“
  • Planning improvement: +14.5% vs reactive baseline (target >+10%) βœ“

Heritage

This model inherits from the Abigail3 cognitive architecture, specifically:

  • RSSM design from abigail/core/world_model.py
  • VL-JEPA EMA target encoding from Meta AI's Joint-Embedding Predictive Architecture
  • DreamerV3-inspired training with KL balancing and rollout losses
  • Governance-first design: generation separated from governance

Ecosystem

This world model is part of the broader KAT system:

  • 23 physics/causality micro-modules (67M params total) β€” intuitive physics simulation
  • MCTS Planner β€” Monte Carlo Tree Search for action planning
  • Beam Search Planner β€” Anytime approximate planning
  • Causal World Model β€” Structural causal model with do-calculus
  • Predictive Student Model β€” VL-JEPA/RSSM adapted for tutoring personalization

License

Apache 2.0

Author

Preston Mills β€” Progga AI

  • Built for KAT-2 framework
  • Designed by Progga AI
  • February 2026
Downloads last month
13
Video Preview
loading

Evaluation results