KAT-2-RSSM / README.md
prestonpai's picture
Update README.md
d2c82a1 verified
---
language: en
license: apache-2.0
tags:
- world-model
- rssm
- tutoring
- predictive-model
- pytorch
- kat
library_name: pytorch
pipeline_tag: reinforcement-learning
model-index:
- name: kat-2-RSSM
results:
- task:
type: world-modeling
name: Tutoring State Prediction
metrics:
- name: Eval Loss (best)
type: loss
value: 0.3124
- name: Reconstruction Loss
type: loss
value: 0.1389
- name: KL Divergence
type: loss
value: 0.0104
- name: Reward Loss
type: loss
value: 0.082
- name: Done Loss
type: loss
value: 0.064
---
# KAT-2-RSSM
A **Recurrent State-Space Model** trained for tutoring state prediction, part of the **KAT** system by [Progga AI](https://progga.ai).
## Model Description
This is a complete world model for predicting tutoring session dynamics β€” student state transitions, reward signals, and session termination. It uses a DreamerV3-inspired RSSM architecture with VL-JEPA-style EMA target encoding.
### Architecture
```
TutoringRSSM (2,802,838 params)
β”œβ”€β”€ ObservationEncoder: obs_dim(20) β†’ encoder_hidden(256) β†’ latent_dim(128)
β”œβ”€β”€ ActionEmbedding: action_dim(8) β†’ embed_dim(32)
β”œβ”€β”€ DeterministicTransition: GRU(hidden_dim=512)
β”œβ”€β”€ StochasticLatent: Diagonal Gaussian prior/posterior (latent_dim=128)
β”œβ”€β”€ ObservationDecoder: feature_dim(640) β†’ decoder_hidden(256) β†’ obs_dim(20)
β”œβ”€β”€ RewardPredictor: feature_dim(640) β†’ 1
β”œβ”€β”€ DonePredictor: feature_dim(640) β†’ 1
└── EMATargetEncoder: momentum=0.996 (VL-JEPA heritage)
```
**Feature dimension**: `hidden_dim + latent_dim = 512 + 128 = 640`
### Observation Space (20-dim)
The 20-dimensional observation vector encodes tutoring session state:
| Dims | Signal |
|------|--------|
| 0-3 | Mastery estimates (per-topic confidence) |
| 4-7 | Engagement signals (attention, participation) |
| 8-11 | Response quality (accuracy, depth, speed) |
| 12-15 | Emotional state (frustration, confidence, curiosity) |
| 16-19 | Session context (time, hint level, attempt count) |
### Action Space (8 discrete actions)
| Index | Strategy |
|-------|----------|
| 0 | SOCRATIC β€” Guided questioning |
| 1 | SCAFFOLDED β€” Structured support |
| 2 | DIRECT β€” Direct instruction |
| 3 | EXPLORATORY β€” Open exploration |
| 4 | REMEDIAL β€” Error correction |
| 5 | ASSESSMENT β€” Knowledge check |
| 6 | MOTIVATIONAL β€” Encouragement |
| 7 | METACOGNITIVE β€” Reflection |
## Training Details
- **Data**: 100,901 synthetic tutoring trajectories (95,856 train / 5,045 eval)
- **Epochs**: 100 (best at epoch 93)
- **Hardware**: NVIDIA A100-SXM4-40GB
- **Optimizer**: Adam (lr=3e-4)
- **Training time**: ~45 minutes
- **Framework**: PyTorch 2.x
### Training Metrics (Best Checkpoint β€” Epoch 93)
| Metric | Value |
|--------|-------|
| **Total Loss** | 0.3124 |
| Reconstruction Loss | 0.1389 |
| KL Divergence | 0.0104 |
| Reward Loss | 0.0820 |
| Done Loss | 0.0640 |
| Rollout Loss | 0.3294 |
### Training Curve
Training converged smoothly over 100 epochs with consistent eval loss improvement. No catastrophic forgetting or training instability observed.
## Files
| File | Description | Size |
|------|-------------|------|
| `tutoring_rssm_best.pt` | Best checkpoint (epoch 93, eval loss 0.3124) | 11 MB |
| `tutoring_rssm_final.pt` | Final checkpoint (epoch 100) | 11 MB |
| `tutoring_rssm_epoch{N}.pt` | Snapshots every 10 epochs | 11 MB each |
| `v1-backup/` | RSSM v1 checkpoints (smaller model) | ~800 KB each |
| `training_log.txt` | Full training log | ~8 KB |
| `config.json` | Model configuration | <1 KB |
| `architecture.py` | Standalone model definition | ~20 KB |
## Usage
```python
import torch
from architecture import TutoringRSSM, TutoringWorldModelConfig
# Load model
config = TutoringWorldModelConfig(
obs_dim=20, action_dim=8,
latent_dim=128, hidden_dim=512,
encoder_hidden=256, decoder_hidden=256,
)
model = TutoringRSSM(config).cuda()
ckpt = torch.load("tutoring_rssm_best.pt", map_location="cuda")
model.load_state_dict(ckpt["model_state_dict"])
model.eval()
# Initialize state
h, z = model.initial_state(batch_size=1)
# Observe a tutoring step
obs = torch.randn(1, 20).cuda() # Student observation
action = torch.tensor([0]).cuda() # SOCRATIC strategy
result = model.observe_step(h, z, action, obs)
h_new, z_new = result["h"], result["z"]
pred_obs = result["pred_obs"] # Predicted next observation
pred_reward = result["pred_reward"] # Predicted reward
pred_done = result["pred_done"] # Predicted session end
# Imagination (planning without observation)
imagined = model.imagine_step(h_new, z_new, torch.tensor([3]).cuda())
# Returns predicted state without requiring real observation
```
## Evaluation Results (94/94 tests pass)
| Component | Tests | Status |
|-----------|-------|--------|
| Predictive Student Model | 44/44 | ALL PASS |
| Cognition World Model Eval | 2/2 | ALL ACCEPTANCE MET |
| Core PyTorch RSSM | 10/10 | ALL PASS |
| Physics/Causality Micro-Modules | 23/23 | ALL PASS |
| Trained Checkpoint Inference | 7/7 | ALL PASS |
| Advanced Planners (MCTS/Beam) | 8/8 | ALL PASS |
### Acceptance Criteria
- **Prediction accuracy**: 12.08% error at horizon (target <20%) βœ“
- **Planning improvement**: +14.5% vs reactive baseline (target >+10%) βœ“
## Heritage
This model inherits from the **Abigail3 cognitive architecture**, specifically:
- RSSM design from `abigail/core/world_model.py`
- VL-JEPA EMA target encoding from Meta AI's Joint-Embedding Predictive Architecture
- DreamerV3-inspired training with KL balancing and rollout losses
- Governance-first design: generation separated from governance
## Ecosystem
This world model is part of the broader KAT system:
- **23 physics/causality micro-modules** (67M params total) β€” intuitive physics simulation
- **MCTS Planner** β€” Monte Carlo Tree Search for action planning
- **Beam Search Planner** β€” Anytime approximate planning
- **Causal World Model** β€” Structural causal model with do-calculus
- **Predictive Student Model** β€” VL-JEPA/RSSM adapted for tutoring personalization
## License
Apache 2.0
## Author
**Preston Mills** β€” Progga AI
- Built for KAT-2 framework
- Designed by Progga AI
- February 2026