update
#2
by Clementio - opened
- .gitattributes +36 -1
- README.md +3 -52
- app.py +394 -483
- data/skill_encoder.csv +0 -0
- data/skill_encoder_real.csv +20 -0
- data/skill_encoder_v2.csv +0 -0
- index.html +0 -609
- {data/knowledge_maps → knowledge_maps}/cs_dag.json +0 -0
- {data/knowledge_maps → knowledge_maps}/math_dag.json +0 -0
- plrs/__init__.py +0 -30
- plrs/constraints/__init__.py +0 -3
- plrs/constraints/dag.py +0 -201
- plrs/curriculum/__init__.py +0 -3
- plrs/curriculum/loader.py +0 -144
- plrs/model/__init__.py +0 -5
- plrs/model/evaluator.py +0 -374
- plrs/model/model_loader.py +0 -116
- plrs/model/sakt.py +0 -219
- plrs/model/sakt_decay.py +0 -253
- plrs/model/trainer.py +0 -437
- plrs/pipeline.py +0 -236
- plrs/ranking/__init__.py +0 -3
- plrs/ranking/ranker.py +0 -189
- requirements.txt +3 -5
- sakt_decay_best.pt +3 -0
- models/sakt_model.pt → sakt_model.pt +0 -0
- sakt_vanilla_best.pt +3 -0
- training_curves.png +3 -0
.gitattributes
CHANGED
|
@@ -1 +1,36 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
training_curves.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,52 +1,3 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
colorFrom: blue
|
| 5 |
-
colorTo: indigo
|
| 6 |
-
sdk: streamlit
|
| 7 |
-
sdk_version: 1.33.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: true
|
| 10 |
-
license: mit
|
| 11 |
-
tags:
|
| 12 |
-
- education
|
| 13 |
-
- knowledge-tracing
|
| 14 |
-
- recommendation-system
|
| 15 |
-
- pytorch
|
| 16 |
-
- transformers
|
| 17 |
-
---
|
| 18 |
-
|
| 19 |
-
# PLRS — Personalized Learning Recommendation System
|
| 20 |
-
|
| 21 |
-
> Constraint-aware personalized learning recommendations powered by Self-Attentive Knowledge Tracing (SAKT) and DAG prerequisite constraints.
|
| 22 |
-
|
| 23 |
-
## What it does
|
| 24 |
-
|
| 25 |
-
PLRS combines a SAKT transformer model with a curriculum knowledge graph to generate recommendations that are both **personalized** and **pedagogically sound**. Topics are classified into three tiers:
|
| 26 |
-
|
| 27 |
-
- ✅ **Approved** — prerequisites met, ready to learn
|
| 28 |
-
- ⚠️ **Challenging** — prerequisites partially met
|
| 29 |
-
- ❌ **Vetoed** — prerequisites not met, blocked
|
| 30 |
-
|
| 31 |
-
## Key results
|
| 32 |
-
|
| 33 |
-
| Metric | PLRS | Collaborative Filtering |
|
| 34 |
-
|--------|------|------------------------|
|
| 35 |
-
| Val AUC | **0.7692** | — |
|
| 36 |
-
| Prerequisite Violation Rate | **0.0%** | 81.3% |
|
| 37 |
-
|
| 38 |
-
## Bundled curricula
|
| 39 |
-
|
| 40 |
-
- **Nigerian Secondary School Mathematics** (38 topics, 45 edges, JSS3–SS2)
|
| 41 |
-
- **CS Fundamentals / Digital Technologies** (31 topics, 39 edges)
|
| 42 |
-
|
| 43 |
-
## Architecture
|
| 44 |
-
|
| 45 |
-
```
|
| 46 |
-
Student History → SAKT → Mastery Vector → DAG Constraint Layer → Ranker → Recommendations
|
| 47 |
-
```
|
| 48 |
-
|
| 49 |
-
## Links
|
| 50 |
-
|
| 51 |
-
- 📦 GitHub: [clementina-tom/plrs](https://github.com/clementina-tom/plrs)
|
| 52 |
-
- 📄 Paper/Report: Final Year Project, Computer Science
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,492 +1,403 @@
|
|
| 1 |
-
"""
|
| 2 |
-
PLRS — Logic Engine
|
| 3 |
-
HuggingFace Space entry point.
|
| 4 |
-
|
| 5 |
-
Loads SAKT model weights from HF Hub (Clementio/PLRS).
|
| 6 |
-
Bundles the plrs package inline (until PyPI release).
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import json
|
| 10 |
-
import sys
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
|
| 13 |
-
import numpy as np
|
| 14 |
import streamlit as st
|
| 15 |
import torch
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
from
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
st.set_page_config(
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
.
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
.
|
| 60 |
-
.
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
.stat-card.green::before { --accent: #22c55e; }
|
| 76 |
-
.stat-card.amber::before { --accent: #f59e0b; }
|
| 77 |
-
.stat-card.red::before { --accent: #ef4444; }
|
| 78 |
-
.stat-card.blue::before { --accent: #3d8bcd; }
|
| 79 |
-
.stat-label { font-family: 'DM Mono', monospace; font-size: 0.62rem; color: #4a5568; letter-spacing: 0.12em; text-transform: uppercase; margin-bottom: 0.25rem; }
|
| 80 |
-
.stat-value { font-size: 1.6rem; font-weight: 700; color: #e8edf5; line-height: 1; }
|
| 81 |
-
.stat-sub { font-family: 'DM Mono', monospace; font-size: 0.65rem; color: #4a5568; margin-top: 0.2rem; }
|
| 82 |
-
|
| 83 |
-
.rec-card {
|
| 84 |
-
background: #0d1221; border: 1px solid #1e2a40; border-radius: 4px;
|
| 85 |
-
padding: 0.9rem 1rem; margin-bottom: 0.5rem;
|
| 86 |
-
}
|
| 87 |
-
.rec-card.approved { border-left: 3px solid #22c55e; }
|
| 88 |
-
.rec-card.challenging { border-left: 3px solid #f59e0b; }
|
| 89 |
-
.rec-card.vetoed { border-left: 3px solid #ef4444; opacity: 0.6; }
|
| 90 |
-
.rec-title { font-size: 0.95rem; font-weight: 700; color: #e8edf5; margin-bottom: 0.15rem; }
|
| 91 |
-
.rec-meta { font-family: 'DM Mono', monospace; font-size: 0.65rem; color: #4a5568; letter-spacing: 0.06em; }
|
| 92 |
-
.rec-reason { font-size: 0.75rem; color: #8899aa; margin-top: 0.35rem; padding-top: 0.35rem; border-top: 1px solid #1e2a40; }
|
| 93 |
-
.score-bar-wrap { background: #131a2e; border-radius: 2px; height: 3px; margin-top: 0.5rem; overflow: hidden; }
|
| 94 |
-
.score-bar { height: 100%; border-radius: 2px; background: var(--bar-color, #3d8bcd); }
|
| 95 |
-
|
| 96 |
-
.section-label {
|
| 97 |
-
font-family: 'DM Mono', monospace; font-size: 0.65rem; letter-spacing: 0.14em;
|
| 98 |
-
text-transform: uppercase; color: #4a5568; border-bottom: 1px solid #1e2a40;
|
| 99 |
-
padding-bottom: 0.4rem; margin-bottom: 0.75rem; margin-top: 1.25rem;
|
| 100 |
-
}
|
| 101 |
-
.unlock-chip {
|
| 102 |
-
display: inline-block; font-family: 'DM Mono', monospace; font-size: 0.65rem;
|
| 103 |
-
background: #131a2e; border: 1px solid #1e3a5f; border-radius: 2px;
|
| 104 |
-
padding: 2px 7px; margin: 2px 3px 2px 0; color: #3d8bcd;
|
| 105 |
-
}
|
| 106 |
-
.blocked-chip {
|
| 107 |
-
display: inline-block; font-family: 'DM Mono', monospace; font-size: 0.65rem;
|
| 108 |
-
background: #1a1010; border: 1px solid #3f1e1e; border-radius: 2px;
|
| 109 |
-
padding: 2px 7px; margin: 2px 3px 2px 0; color: #ef4444;
|
| 110 |
-
}
|
| 111 |
-
|
| 112 |
-
.stTabs [data-baseweb="tab-list"] { gap: 0; border-bottom: 1px solid #1e2a40; background: transparent; }
|
| 113 |
-
.stTabs [data-baseweb="tab"] { font-family: 'DM Mono', monospace; font-size: 0.7rem; letter-spacing: 0.08em; color: #4a5568; padding: 0.5rem 1.25rem; border-bottom: 2px solid transparent; }
|
| 114 |
-
.stTabs [aria-selected="true"] { color: #3d8bcd; border-bottom-color: #3d8bcd; background: transparent; }
|
| 115 |
-
</style>
|
| 116 |
-
""", unsafe_allow_html=True)
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# ── Model + pipeline loading ──────────────────────────────────────────────────
|
| 120 |
-
|
| 121 |
-
@st.cache_resource(show_spinner="Loading curriculum & model from HuggingFace...")
|
| 122 |
-
def load_pipelines():
|
| 123 |
-
from plrs.model.model_loader import load_model_from_hub
|
| 124 |
-
|
| 125 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 126 |
-
maps = ROOT / "data" / "knowledge_maps"
|
| 127 |
-
|
| 128 |
-
# Load model (tries decay, vanilla, then base)
|
| 129 |
-
model, model_type = load_model_from_hub(device=str(device))
|
| 130 |
-
|
| 131 |
-
pipelines = {}
|
| 132 |
-
for domain, fname in [("math", "math_dag.json"), ("cs", "cs_dag.json")]:
|
| 133 |
-
path = maps / fname
|
| 134 |
-
if path.exists():
|
| 135 |
-
curriculum = load_dag(path)
|
| 136 |
-
pipeline = PLRSPipeline(curriculum)
|
| 137 |
-
if model:
|
| 138 |
-
pipeline._model = model
|
| 139 |
-
pipelines[domain] = pipeline
|
| 140 |
-
|
| 141 |
-
return pipelines, model is not None, model_type
|
| 142 |
-
|
| 143 |
|
| 144 |
@st.cache_data
|
| 145 |
def load_skill_encoder():
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
else:
|
| 245 |
-
|
| 246 |
-
seed
|
| 247 |
np.random.seed(int(seed))
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
else:
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
else:
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
mastery_pct = int(summary["mastery_rate"] * 100)
|
| 326 |
-
vrate_pct = int(stats["prerequisite_violation_rate"] * 100)
|
| 327 |
-
|
| 328 |
-
st.markdown(f"""
|
| 329 |
-
<div class="stat-row">
|
| 330 |
-
<div class="stat-card blue">
|
| 331 |
-
<div class="stat-label">Mastered</div>
|
| 332 |
-
<div class="stat-value">{summary['mastered']}<span style="font-size:0.9rem;color:#4a5568;">/{summary['total_topics']}</span></div>
|
| 333 |
-
<div class="stat-sub">{mastery_pct}% rate</div>
|
| 334 |
-
</div>
|
| 335 |
-
<div class="stat-card green">
|
| 336 |
-
<div class="stat-label">Approved</div>
|
| 337 |
-
<div class="stat-value">{stats['approved_count']}</div>
|
| 338 |
-
<div class="stat-sub">ready to learn</div>
|
| 339 |
-
</div>
|
| 340 |
-
<div class="stat-card amber">
|
| 341 |
-
<div class="stat-label">Challenging</div>
|
| 342 |
-
<div class="stat-value">{stats['challenging_count']}</div>
|
| 343 |
-
<div class="stat-sub">partial prereqs</div>
|
| 344 |
-
</div>
|
| 345 |
-
<div class="stat-card red">
|
| 346 |
-
<div class="stat-label">Violation rate</div>
|
| 347 |
-
<div class="stat-value">{vrate_pct}<span style="font-size:0.9rem;color:#4a5568;">%</span></div>
|
| 348 |
-
<div class="stat-sub">blocked topics</div>
|
| 349 |
-
</div>
|
| 350 |
-
</div>
|
| 351 |
-
""", unsafe_allow_html=True)
|
| 352 |
-
|
| 353 |
-
if results["approved"]:
|
| 354 |
-
st.markdown('<div class="section-label">✅ Approved Recommendations</div>', unsafe_allow_html=True)
|
| 355 |
-
for i, rec in enumerate(results["approved"]):
|
| 356 |
-
score_pct = int(rec["score"] * 100)
|
| 357 |
-
st.markdown(f"""
|
| 358 |
-
<div class="rec-card approved">
|
| 359 |
-
<div class="rec-title">{i+1}. {rec['topic_label']}</div>
|
| 360 |
-
<div class="rec-meta">score: {rec['score']:.3f} · mastery: {int(rec['mastery']*100)}% · unlocks: {rec['downstream_count']}</div>
|
| 361 |
-
<div class="rec-reason">{rec['reasoning']}</div>
|
| 362 |
-
<div class="score-bar-wrap"><div class="score-bar" style="width:{score_pct}%;--bar-color:#22c55e;"></div></div>
|
| 363 |
-
</div>
|
| 364 |
-
""", unsafe_allow_html=True)
|
| 365 |
-
else:
|
| 366 |
-
st.info("No approved topics — lower the mastery threshold or set some mastery levels.")
|
| 367 |
-
|
| 368 |
-
if results["challenging"]:
|
| 369 |
-
st.markdown('<div class="section-label">⚠️ Challenging</div>', unsafe_allow_html=True)
|
| 370 |
-
for rec in results["challenging"]:
|
| 371 |
-
score_pct = int(rec["score"] * 100)
|
| 372 |
-
unmet = ", ".join(rec["unmet_prerequisites"]) or "—"
|
| 373 |
-
st.markdown(f"""
|
| 374 |
-
<div class="rec-card challenging">
|
| 375 |
-
<div class="rec-title">{rec['topic_label']}</div>
|
| 376 |
-
<div class="rec-meta">score: {rec['score']:.3f} · strengthen: {unmet}</div>
|
| 377 |
-
<div class="rec-reason">{rec['reasoning']}</div>
|
| 378 |
-
<div class="score-bar-wrap"><div class="score-bar" style="width:{score_pct}%;--bar-color:#f59e0b;"></div></div>
|
| 379 |
-
</div>
|
| 380 |
-
""", unsafe_allow_html=True)
|
| 381 |
-
|
| 382 |
-
if results["vetoed"]:
|
| 383 |
-
with st.expander(f"❌ Vetoed topics ({stats['vetoed_count']} total — prerequisite check failed)"):
|
| 384 |
-
for rec in results["vetoed"]:
|
| 385 |
-
unmet = ", ".join(rec["unmet_prerequisites"]) or "—"
|
| 386 |
-
st.markdown(f"""
|
| 387 |
-
<div class="rec-card vetoed">
|
| 388 |
-
<div class="rec-title">{rec['topic_label']}</div>
|
| 389 |
-
<div class="rec-meta">blocked by: {unmet}</div>
|
| 390 |
-
</div>
|
| 391 |
-
""", unsafe_allow_html=True)
|
| 392 |
-
else:
|
| 393 |
-
st.markdown("""
|
| 394 |
-
<div style="height:280px;display:flex;align-items:center;justify-content:center;
|
| 395 |
-
border:1px dashed #1e2a40;border-radius:4px;color:#2a3a50;">
|
| 396 |
-
<div style="text-align:center;">
|
| 397 |
-
<div style="font-size:2rem;margin-bottom:0.5rem;">⚡</div>
|
| 398 |
-
<div style="font-family:'DM Mono',monospace;font-size:0.7rem;letter-spacing:0.1em;">
|
| 399 |
-
SET MASTERY LEVELS · THEN GENERATE
|
| 400 |
-
</div>
|
| 401 |
-
</div>
|
| 402 |
-
</div>
|
| 403 |
-
""", unsafe_allow_html=True)
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 407 |
-
# TAB 2 — WHAT-IF SIMULATOR
|
| 408 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 409 |
-
with tab2:
|
| 410 |
-
st.markdown('<div class="section-label">Prerequisite Impact Simulator</div>', unsafe_allow_html=True)
|
| 411 |
-
st.markdown('<p style="font-size:0.8rem;color:#8899aa;">Select any topic to see what it unlocks and what currently blocks it.</p>', unsafe_allow_html=True)
|
| 412 |
-
|
| 413 |
-
node_options = {curriculum.label(n): n for n in curriculum.nodes}
|
| 414 |
-
selected_label = st.selectbox("Select topic", list(node_options.keys()))
|
| 415 |
-
selected_id = node_options[selected_label]
|
| 416 |
-
wi = pipeline.what_if(selected_id)
|
| 417 |
-
|
| 418 |
-
col_a, col_b = st.columns(2, gap="large")
|
| 419 |
-
|
| 420 |
-
with col_a:
|
| 421 |
-
st.markdown('<div class="section-label">🔓 What This Unlocks</div>', unsafe_allow_html=True)
|
| 422 |
-
if wi["direct_unlocks"]:
|
| 423 |
-
st.markdown("**Directly unlocks:**")
|
| 424 |
-
st.markdown("".join(f'<span class="unlock-chip">{u["label"]}</span>' for u in wi["direct_unlocks"]), unsafe_allow_html=True)
|
| 425 |
-
else:
|
| 426 |
-
st.markdown('<span style="color:#4a5568;font-size:0.8rem;">Leaf node — no further topics.</span>', unsafe_allow_html=True)
|
| 427 |
-
|
| 428 |
-
if wi["all_unlocks"]:
|
| 429 |
-
st.markdown(f"**All downstream ({wi['total_unlocked']}):**")
|
| 430 |
-
st.markdown("".join(f'<span class="unlock-chip">{u["label"]}</span>' for u in wi["all_unlocks"]), unsafe_allow_html=True)
|
| 431 |
-
|
| 432 |
-
st.markdown(f"""
|
| 433 |
-
<div class="stat-card blue" style="margin-top:1rem;max-width:180px;">
|
| 434 |
-
<div class="stat-label">Total Unlocked</div>
|
| 435 |
-
<div class="stat-value">{wi['total_unlocked']}</div>
|
| 436 |
-
</div>
|
| 437 |
-
""", unsafe_allow_html=True)
|
| 438 |
-
|
| 439 |
-
with col_b:
|
| 440 |
-
st.markdown('<div class="section-label">🔒 What Blocks This</div>', unsafe_allow_html=True)
|
| 441 |
-
if wi["blocked_by"]:
|
| 442 |
-
st.markdown("**Prerequisites:**")
|
| 443 |
-
st.markdown("".join(f'<span class="blocked-chip">{b["label"]}</span>' for b in wi["blocked_by"]), unsafe_allow_html=True)
|
| 444 |
-
else:
|
| 445 |
-
st.markdown('<span style="color:#22c55e;font-size:0.8rem;font-family:\'DM Mono\',monospace;">Root topic — no prerequisites.</span>', unsafe_allow_html=True)
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 449 |
-
# TAB 3 — CURRICULUM MAP
|
| 450 |
-
# ══════════════════════════════════════════════════════════════════════════════
|
| 451 |
-
with tab3:
|
| 452 |
-
st.markdown('<div class="section-label">Curriculum Knowledge Graph</div>', unsafe_allow_html=True)
|
| 453 |
-
|
| 454 |
-
col_info, col_table = st.columns([1, 2], gap="large")
|
| 455 |
-
|
| 456 |
-
with col_info:
|
| 457 |
-
roots = [n for n in curriculum.nodes if not curriculum.prerequisites(n)]
|
| 458 |
-
leaves = [n for n in curriculum.nodes if not curriculum.successors(n)]
|
| 459 |
-
|
| 460 |
-
st.markdown(f"""
|
| 461 |
-
<div class="stat-card blue" style="margin-bottom:0.75rem;">
|
| 462 |
-
<div class="stat-label">Domain</div>
|
| 463 |
-
<div style="font-size:0.85rem;font-weight:700;color:#e8edf5;">{curriculum.domain}</div>
|
| 464 |
-
</div>
|
| 465 |
-
<div class="stat-card green" style="margin-bottom:0.75rem;">
|
| 466 |
-
<div class="stat-label">Topics</div><div class="stat-value">{curriculum.num_nodes}</div>
|
| 467 |
-
</div>
|
| 468 |
-
<div class="stat-card amber">
|
| 469 |
-
<div class="stat-label">Prerequisite Edges</div><div class="stat-value">{curriculum.num_edges}</div>
|
| 470 |
-
</div>
|
| 471 |
-
""", unsafe_allow_html=True)
|
| 472 |
-
|
| 473 |
-
st.markdown('<div class="section-label">Root Topics</div>', unsafe_allow_html=True)
|
| 474 |
-
st.markdown("".join(f'<span class="unlock-chip">{curriculum.label(r)}</span>' for r in roots), unsafe_allow_html=True)
|
| 475 |
-
|
| 476 |
-
st.markdown('<div class="section-label">Leaf Topics</div>', unsafe_allow_html=True)
|
| 477 |
-
st.markdown("".join(f'<span class="blocked-chip">{curriculum.label(l)}</span>' for l in leaves), unsafe_allow_html=True)
|
| 478 |
-
|
| 479 |
-
with col_table:
|
| 480 |
-
import pandas as pd
|
| 481 |
-
st.markdown('<div class="section-label">All Topics</div>', unsafe_allow_html=True)
|
| 482 |
-
rows = []
|
| 483 |
-
for node in curriculum.nodes:
|
| 484 |
-
rows.append({
|
| 485 |
-
"Topic": curriculum.label(node),
|
| 486 |
-
"Level": curriculum.level(node),
|
| 487 |
-
"Prerequisites": len(curriculum.prerequisites(node)),
|
| 488 |
-
"Unlocks (direct)": len(curriculum.successors(node)),
|
| 489 |
-
"Total Downstream": len(curriculum.descendants(node)),
|
| 490 |
-
})
|
| 491 |
-
df = pd.DataFrame(rows).sort_values("Total Downstream", ascending=False)
|
| 492 |
-
st.dataframe(df, use_container_width=True, height=480, hide_index=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import json
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import networkx as nx
|
| 7 |
+
import numpy as np
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
from typing import Dict, List, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
st.set_page_config(page_title='Logic Engine', page_icon='🧠', layout='wide')
|
| 12 |
+
|
| 13 |
+
HF_REPO = 'Clementio/PLRS'
|
| 14 |
+
|
| 15 |
+
@st.cache_resource
|
| 16 |
+
def load_model():
|
| 17 |
+
config_path = hf_hub_download(repo_id=HF_REPO, filename='config.json')
|
| 18 |
+
with open(config_path) as f:
|
| 19 |
+
config = json.load(f)
|
| 20 |
+
model_path = hf_hub_download(repo_id=HF_REPO, filename='sakt_model.pt')
|
| 21 |
+
class SAKT(nn.Module):
|
| 22 |
+
def __init__(self, num_skills, embed_dim, num_heads, num_layers, max_seq_len, dropout):
|
| 23 |
+
super(SAKT, self).__init__()
|
| 24 |
+
self.num_skills = num_skills
|
| 25 |
+
self.interaction_embed = nn.Embedding(num_skills * 2 + 1, embed_dim, padding_idx=0)
|
| 26 |
+
self.skill_embed = nn.Embedding(num_skills + 1, embed_dim, padding_idx=0)
|
| 27 |
+
self.pos_embed = nn.Embedding(max_seq_len + 1, embed_dim)
|
| 28 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True, dim_feedforward=embed_dim * 4, norm_first=True)
|
| 29 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, enable_nested_tensor=False)
|
| 30 |
+
self.dropout = nn.Dropout(dropout)
|
| 31 |
+
self.output = nn.Linear(embed_dim, 1)
|
| 32 |
+
def forward(self, interactions, target_skills, mask, return_attention=False):
|
| 33 |
+
batch_size, seq_len = interactions.shape
|
| 34 |
+
positions = torch.arange(seq_len, device=interactions.device).unsqueeze(0).expand(batch_size, -1)
|
| 35 |
+
x = self.interaction_embed(interactions)
|
| 36 |
+
x = x + self.pos_embed(positions)
|
| 37 |
+
x = x * mask.unsqueeze(-1).float()
|
| 38 |
+
x = self.dropout(x)
|
| 39 |
+
causal_mask = torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1)
|
| 40 |
+
x = self.transformer(x, mask=causal_mask, is_causal=False)
|
| 41 |
+
x = x * mask.unsqueeze(-1).float()
|
| 42 |
+
x = x + self.skill_embed(target_skills)
|
| 43 |
+
return self.output(x).squeeze(-1)
|
| 44 |
+
device = torch.device('cpu')
|
| 45 |
+
model = SAKT(num_skills=config['num_skills'], embed_dim=config['embed_dim'], num_heads=config['num_heads'], num_layers=config['num_layers'], max_seq_len=config['max_seq_len'], dropout=config['dropout'])
|
| 46 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 47 |
+
model.eval()
|
| 48 |
+
return model, config, device
|
| 49 |
+
|
| 50 |
+
@st.cache_resource
|
| 51 |
+
def load_knowledge_maps():
|
| 52 |
+
def load_dag(path):
|
| 53 |
+
with open(path) as f:
|
| 54 |
+
data = json.load(f)
|
| 55 |
+
G = nx.DiGraph()
|
| 56 |
+
for node in data['nodes']:
|
| 57 |
+
G.add_node(node['id'], label=node['label'], level=node['level'], term=node['term'])
|
| 58 |
+
for edge in data['edges']:
|
| 59 |
+
G.add_edge(edge['from'], edge['to'])
|
| 60 |
+
return G
|
| 61 |
+
return load_dag('knowledge_maps/math_dag.json'), load_dag('knowledge_maps/cs_dag.json')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
@st.cache_data
|
| 64 |
def load_skill_encoder():
|
| 65 |
+
return pd.read_csv('data/skill_encoder.csv')
|
| 66 |
+
|
| 67 |
+
class MasteryVector:
|
| 68 |
+
def __init__(self, graph, threshold=0.70):
|
| 69 |
+
self.graph = graph
|
| 70 |
+
self.threshold = threshold
|
| 71 |
+
self.mastery = {node: 0.0 for node in graph.nodes}
|
| 72 |
+
def update(self, topic_id, probability):
|
| 73 |
+
if topic_id in self.mastery: self.mastery[topic_id] = probability
|
| 74 |
+
def is_mastered(self, topic_id):
|
| 75 |
+
return self.mastery.get(topic_id, 0.0) >= self.threshold
|
| 76 |
+
def get_mastery(self, topic_id):
|
| 77 |
+
return self.mastery.get(topic_id, 0.0)
|
| 78 |
+
def get_mastery_summary(self):
|
| 79 |
+
mastered = [t for t in self.mastery if self.is_mastered(t)]
|
| 80 |
+
return {'total_topics': len(self.mastery), 'mastered': len(mastered), 'mastery_rate': round(len(mastered)/len(self.mastery), 3), 'mastered_topics': mastered}
|
| 81 |
+
|
| 82 |
+
class DAGConstraintLayer:
|
| 83 |
+
def __init__(self, graph, threshold=0.70, soft_threshold=0.50):
|
| 84 |
+
self.graph = graph
|
| 85 |
+
self.threshold = threshold
|
| 86 |
+
self.soft_threshold = soft_threshold # below full threshold but above this = challenging
|
| 87 |
+
def validate(self, topic_id, mastery_vector):
|
| 88 |
+
if topic_id not in self.graph.nodes: return 'vetoed', 'Topic not found.'
|
| 89 |
+
prerequisites = list(self.graph.predecessors(topic_id))
|
| 90 |
+
label = self.graph.nodes[topic_id].get('label', topic_id)
|
| 91 |
+
if not prerequisites: return 'approved', f'✅ Foundational topic — no prerequisites.'
|
| 92 |
+
hard_fails = []
|
| 93 |
+
soft_fails = []
|
| 94 |
+
for p in prerequisites:
|
| 95 |
+
m = mastery_vector.get_mastery(p)
|
| 96 |
+
plabel = self.graph.nodes[p].get('label', p)
|
| 97 |
+
if m < self.soft_threshold:
|
| 98 |
+
hard_fails.append((plabel, m))
|
| 99 |
+
elif m < self.threshold:
|
| 100 |
+
soft_fails.append((plabel, m))
|
| 101 |
+
if hard_fails:
|
| 102 |
+
gaps = ', '.join([f"{l} ({m:.0%} mastered, need {self.threshold:.0%})" for l,m in hard_fails])
|
| 103 |
+
return 'vetoed', f'❌ Prerequisites not met: {gaps}'
|
| 104 |
+
elif soft_fails:
|
| 105 |
+
gaps = ', '.join([f"{l} ({m:.0%} mastered, need {self.threshold:.0%})" for l,m in soft_fails])
|
| 106 |
+
return 'challenging', f'⚠️ Challenging — prerequisites nearly met: {gaps}. Proceed with caution.'
|
| 107 |
+
else:
|
| 108 |
+
prereq_labels = [self.graph.nodes[p].get('label',p) for p in prerequisites]
|
| 109 |
+
return 'approved', f'✅ Prerequisites mastered: {", ".join(prereq_labels)}'
|
| 110 |
+
|
| 111 |
+
class RankingFunction:
|
| 112 |
+
def __init__(self, graph, threshold=0.70, w_gap=0.40, w_ready=0.35, w_downstream=0.25):
|
| 113 |
+
self.graph=graph; self.threshold=threshold; self.w_gap=w_gap; self.w_ready=w_ready; self.w_downstream=w_downstream
|
| 114 |
+
scores = {n: len(nx.descendants(graph, n)) for n in graph.nodes}
|
| 115 |
+
mx = max(scores.values()) if scores else 1
|
| 116 |
+
self._downstream = {n: s/mx for n,s in scores.items()}
|
| 117 |
+
def score(self, topic_id, mastery_vector):
|
| 118 |
+
current = mastery_vector.get_mastery(topic_id)
|
| 119 |
+
gap = min(max(0.0, self.threshold-current)/self.threshold, 1.0)
|
| 120 |
+
prereqs = list(self.graph.predecessors(topic_id))
|
| 121 |
+
readiness = 1.0 if not prereqs else sum(1 for p in prereqs if mastery_vector.is_mastered(p))/len(prereqs)
|
| 122 |
+
downstream = self._downstream.get(topic_id, 0.0)
|
| 123 |
+
# Near-mastery boost: topics the student has already started
|
| 124 |
+
# rank higher than untouched topics with the same gap score
|
| 125 |
+
near_mastery_boost = 0.0
|
| 126 |
+
if 0.10 <= current < self.threshold:
|
| 127 |
+
near_mastery_boost = 0.15 * (current / self.threshold)
|
| 128 |
+
return round(self.w_gap*gap + self.w_ready*readiness + self.w_downstream*downstream + near_mastery_boost, 3)
|
| 129 |
+
|
| 130 |
+
class LearningRecommendationPipeline:
|
| 131 |
+
def __init__(self, graph, threshold=0.70, soft_threshold=0.50, top_n=5):
|
| 132 |
+
self.graph=graph
|
| 133 |
+
self.constraint=DAGConstraintLayer(graph, threshold, soft_threshold)
|
| 134 |
+
self.ranker=RankingFunction(graph, threshold)
|
| 135 |
+
self.top_n=top_n
|
| 136 |
+
def run(self, mastery_vector):
|
| 137 |
+
approved, challenging, vetoed = [], [], []
|
| 138 |
+
for topic_id in self.graph.nodes:
|
| 139 |
+
status, reasoning = self.constraint.validate(topic_id, mastery_vector)
|
| 140 |
+
entry = {'topic_id': topic_id, 'topic_label': self.graph.nodes[topic_id].get('label', topic_id), 'mastery': round(mastery_vector.get_mastery(topic_id),3), 'reasoning': reasoning, 'status': status}
|
| 141 |
+
if status == 'approved' and not mastery_vector.is_mastered(topic_id):
|
| 142 |
+
entry['score'] = self.ranker.score(topic_id, mastery_vector)
|
| 143 |
+
approved.append(entry)
|
| 144 |
+
elif status == 'challenging' and not mastery_vector.is_mastered(topic_id):
|
| 145 |
+
entry['score'] = self.ranker.score(topic_id, mastery_vector) * 0.8 # slight penalty
|
| 146 |
+
challenging.append(entry)
|
| 147 |
+
elif status == 'vetoed':
|
| 148 |
+
vetoed.append(entry)
|
| 149 |
+
approved.sort(key=lambda x: x['score'], reverse=True)
|
| 150 |
+
challenging.sort(key=lambda x: x['score'], reverse=True)
|
| 151 |
+
return {'top_recommendations': approved[:self.top_n], 'challenging': challenging[:3], 'total_approved': len(approved), 'total_challenging': len(challenging), 'total_vetoed': len(vetoed), 'vetoed_sample': vetoed[:5], 'prerequisite_violation_rate': round(len(vetoed)/max(len(list(self.graph.nodes)),1),3)}
|
| 152 |
+
|
| 153 |
+
ACTIVITY_TO_MATH = {'oucontent':'algebraic_expressions','forumng':'statistics_basic','homepage':'whole_numbers','subpage':'plane_shapes','resource':'indices','url':'number_bases','ouwiki':'proportion_variation','glossary':'algebraic_factorization','quiz':'quadratic_equations'}
|
| 154 |
+
ACTIVITY_TO_CS = {'oucontent':'programming_concepts','forumng':'ethics_technology','homepage':'computer_basics','subpage':'html_basics','resource':'networking_fundamentals','url':'internet_basics','ouwiki':'cloud_basics','glossary':'intro_databases','quiz':'python_basics'}
|
| 155 |
+
|
| 156 |
+
def run_sakt_inference(model, config, skill_seq, correct_seq, device):
|
| 157 |
+
max_len=config['max_seq_len']; n_skills=config['num_skills']
|
| 158 |
+
if len(skill_seq)>max_len: skill_seq=skill_seq[-max_len:]; correct_seq=correct_seq[-max_len:]
|
| 159 |
+
interactions=[s+c*n_skills for s,c in zip(skill_seq[:-1],correct_seq[:-1])]
|
| 160 |
+
target_skills=skill_seq[1:]
|
| 161 |
+
seq_len=len(interactions); pad_len=max_len-seq_len
|
| 162 |
+
interactions=[0]*pad_len+interactions; target_skills=[0]*pad_len+target_skills; mask=[False]*pad_len+[True]*seq_len
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
logits=model(torch.LongTensor([interactions]).to(device),torch.LongTensor([target_skills]).to(device),torch.BoolTensor([mask]).to(device))
|
| 165 |
+
probs=torch.sigmoid(logits).squeeze(0)
|
| 166 |
+
mastery={}; real_probs=probs[torch.BoolTensor(mask)].cpu().numpy(); real_skills=target_skills[pad_len:]
|
| 167 |
+
for skill_id,prob in zip(real_skills,real_probs): mastery[int(skill_id)]=float(prob)
|
| 168 |
+
return mastery
|
| 169 |
+
|
| 170 |
+
def build_mastery_vector(skill_probs, graph, skill_encoder_df, domain, threshold, soft_threshold):
|
| 171 |
+
mv=MasteryVector(graph, threshold); mapping=ACTIVITY_TO_MATH if domain=='math' else ACTIVITY_TO_CS
|
| 172 |
+
topic_scores={}
|
| 173 |
+
for skill_id,prob in skill_probs.items():
|
| 174 |
+
row=skill_encoder_df[skill_encoder_df['skill_id']==skill_id]
|
| 175 |
+
if row.empty: continue
|
| 176 |
+
act=row['activity_type'].values[0] if 'activity_type' in row.columns else None
|
| 177 |
+
topic_id=mapping.get(act) if act else None
|
| 178 |
+
if topic_id: topic_scores[topic_id]=max(topic_scores.get(topic_id,0.0),prob)
|
| 179 |
+
for topic_id,score in topic_scores.items(): mv.update(topic_id,score)
|
| 180 |
+
return mv
|
| 181 |
+
|
| 182 |
+
def what_if_analysis(topic_id, graph):
|
| 183 |
+
unlocks = list(nx.descendants(graph, topic_id))
|
| 184 |
+
direct_unlocks = list(graph.successors(topic_id))
|
| 185 |
+
blocked_by = list(graph.predecessors(topic_id))
|
| 186 |
+
unlock_labels = [graph.nodes[n].get('label',n) for n in direct_unlocks]
|
| 187 |
+
all_unlock_labels = [graph.nodes[n].get('label',n) for n in unlocks]
|
| 188 |
+
blocked_labels = [graph.nodes[n].get('label',n) for n in blocked_by]
|
| 189 |
+
return {'direct_unlocks': unlock_labels, 'all_unlocks': all_unlock_labels, 'blocked_by': blocked_labels, 'total_unlocked': len(unlocks)}
|
| 190 |
+
|
| 191 |
+
def cascade_mastery(mastery_vector, graph):
|
| 192 |
+
"""
|
| 193 |
+
If a student has high mastery on a topic, infer that their
|
| 194 |
+
prerequisites are also likely mastered (propagate upward).
|
| 195 |
+
A student who scores 80% on Modular Arithmetic almost certainly
|
| 196 |
+
knows Whole Numbers — cascade fills these realistic gaps.
|
| 197 |
+
"""
|
| 198 |
+
changed = True
|
| 199 |
+
while changed:
|
| 200 |
+
changed = False
|
| 201 |
+
for node in graph.nodes:
|
| 202 |
+
node_mastery = mastery_vector.get_mastery(node)
|
| 203 |
+
if node_mastery < 0.40:
|
| 204 |
+
continue
|
| 205 |
+
# For each prerequisite of this node
|
| 206 |
+
for prereq in graph.predecessors(node):
|
| 207 |
+
prereq_mastery = mastery_vector.get_mastery(prereq)
|
| 208 |
+
# Infer prerequisite mastery as at least 85% of descendant mastery
|
| 209 |
+
inferred = min(node_mastery * 0.85, 0.95)
|
| 210 |
+
if inferred > prereq_mastery:
|
| 211 |
+
mastery_vector.update(prereq, inferred)
|
| 212 |
+
changed = True
|
| 213 |
+
return mastery_vector
|
| 214 |
+
|
| 215 |
+
def cascade_mastery(mastery_vector, graph):
|
| 216 |
+
"""
|
| 217 |
+
If a student has high mastery on a topic, infer that their
|
| 218 |
+
prerequisites are also likely mastered (propagate upward).
|
| 219 |
+
A student who scores 80% on Modular Arithmetic almost certainly
|
| 220 |
+
knows Whole Numbers — cascade fills these realistic gaps.
|
| 221 |
+
"""
|
| 222 |
+
changed = True
|
| 223 |
+
while changed:
|
| 224 |
+
changed = False
|
| 225 |
+
for node in graph.nodes:
|
| 226 |
+
node_mastery = mastery_vector.get_mastery(node)
|
| 227 |
+
if node_mastery < 0.40:
|
| 228 |
+
continue
|
| 229 |
+
# For each prerequisite of this node
|
| 230 |
+
for prereq in graph.predecessors(node):
|
| 231 |
+
prereq_mastery = mastery_vector.get_mastery(prereq)
|
| 232 |
+
# Infer prerequisite mastery as at least 85% of descendant mastery
|
| 233 |
+
inferred = min(node_mastery * 0.85, 0.95)
|
| 234 |
+
if inferred > prereq_mastery:
|
| 235 |
+
mastery_vector.update(prereq, inferred)
|
| 236 |
+
changed = True
|
| 237 |
+
return mastery_vector
|
| 238 |
+
|
| 239 |
+
def get_attention_weights(model, config, skill_seq, correct_seq, device):
|
| 240 |
+
max_len=config['max_seq_len']; n_skills=config['num_skills']
|
| 241 |
+
if len(skill_seq)>max_len: skill_seq=skill_seq[-max_len:]; correct_seq=correct_seq[-max_len:]
|
| 242 |
+
interactions=[s+c*n_skills for s,c in zip(skill_seq[:-1],correct_seq[:-1])]
|
| 243 |
+
target_skills=skill_seq[1:]
|
| 244 |
+
seq_len=len(interactions); pad_len=max_len-seq_len
|
| 245 |
+
interactions=[0]*pad_len+interactions; target_skills=[0]*pad_len+target_skills; mask_list=[False]*pad_len+[True]*seq_len
|
| 246 |
+
interactions_t=torch.LongTensor([interactions]); target_t=torch.LongTensor([target_skills]); mask_t=torch.BoolTensor([mask_list])
|
| 247 |
+
attention_weights = []
|
| 248 |
+
def hook_fn(module, input, output):
|
| 249 |
+
if hasattr(module, 'self_attn'):
|
| 250 |
+
pass
|
| 251 |
+
with torch.no_grad():
|
| 252 |
+
positions=torch.arange(max_len).unsqueeze(0)
|
| 253 |
+
x=model.interaction_embed(interactions_t)+model.pos_embed(positions)
|
| 254 |
+
x=x*mask_t.unsqueeze(-1).float()
|
| 255 |
+
real_mask=mask_t.squeeze(0)
|
| 256 |
+
real_skills=target_skills[pad_len:]
|
| 257 |
+
real_probs=torch.sigmoid(model(interactions_t,target_t,mask_t)).squeeze(0)[real_mask].numpy()
|
| 258 |
+
return real_skills[-10:], real_probs[-10:], seq_len
|
| 259 |
+
|
| 260 |
+
def main():
|
| 261 |
+
model, config, device = load_model()
|
| 262 |
+
math_graph, cs_graph = load_knowledge_maps()
|
| 263 |
+
skill_encoder = load_skill_encoder()
|
| 264 |
+
st.title('🧠 Logic Engine')
|
| 265 |
+
st.subheader('Domain-Agnostic Constraint-Aware Learning Recommender')
|
| 266 |
+
st.markdown('---')
|
| 267 |
+
st.sidebar.title('⚙️ Configuration')
|
| 268 |
+
domain = st.sidebar.selectbox('Select Domain', ['Mathematics', 'CS Fundamentals'])
|
| 269 |
+
threshold = st.sidebar.slider('Mastery Threshold', 0.50, 0.90, 0.70, 0.05, help='Minimum mastery to consider a topic fully mastered')
|
| 270 |
+
soft_threshold = st.sidebar.slider('Challenging Threshold', 0.30, 0.70, 0.50, 0.05, help='Topics above this but below mastery threshold are marked Challenging')
|
| 271 |
+
top_n = st.sidebar.slider('Top N Recommendations', 3, 10, 5)
|
| 272 |
+
graph = math_graph if domain=='Mathematics' else cs_graph
|
| 273 |
+
domain_key = 'math' if domain=='Mathematics' else 'cs'
|
| 274 |
+
pipeline = LearningRecommendationPipeline(graph, threshold, soft_threshold, top_n)
|
| 275 |
+
st.sidebar.markdown('---')
|
| 276 |
+
st.sidebar.markdown('**About**')
|
| 277 |
+
st.sidebar.markdown('SAKT-based knowledge tracing with DAG prerequisite constraints. Three-tier recommendations: ✅ Approved, ⚠️ Challenging, ❌ Vetoed.')
|
| 278 |
+
tab1, tab2, tab3, tab4 = st.tabs(['🎯 Recommendations','🔍 What-If Simulator','🗺️ Knowledge Map','📊 Diagnostics'])
|
| 279 |
+
|
| 280 |
+
with tab1:
|
| 281 |
+
st.header('Learner Profile')
|
| 282 |
+
mode = st.radio('Input Mode', ['Manual Mastery Input','Simulate Student Sequence'], horizontal=True)
|
| 283 |
+
mastery_vector = MasteryVector(graph, threshold)
|
| 284 |
+
if mode=='Manual Mastery Input':
|
| 285 |
+
st.markdown('Set your current mastery level for each topic:')
|
| 286 |
+
cols=st.columns(2); nodes=list(graph.nodes)
|
| 287 |
+
for i,node in enumerate(nodes):
|
| 288 |
+
label=graph.nodes[node].get('label',node); level=graph.nodes[node].get('level','')
|
| 289 |
+
val=cols[i%2].slider(f'{label} ({level})',0.0,1.0,0.0,0.05,key=f'mastery_{node}')
|
| 290 |
+
mastery_vector.update(node,val)
|
| 291 |
else:
|
| 292 |
+
seq_length=st.slider('Sequence Length',10,200,50)
|
| 293 |
+
seed=st.number_input('Student Seed',1,1000,42,1)
|
| 294 |
np.random.seed(int(seed))
|
| 295 |
+
topic_nodes = list(graph.nodes)
|
| 296 |
+
n_topics = len(topic_nodes)
|
| 297 |
+
raw_scores = np.random.beta(1.5, 3.0, size=n_topics)
|
| 298 |
+
scale = min(seq_length / 200.0 * 1.4, 1.0)
|
| 299 |
+
scores = np.clip(raw_scores * scale, 0.0, 1.0)
|
| 300 |
+
for topic_id, score in zip(topic_nodes, scores):
|
| 301 |
+
mastery_vector.update(topic_id, float(score))
|
| 302 |
+
mastery_df = pd.DataFrame({
|
| 303 |
+
'Topic': [graph.nodes[t].get('label', t)[:25] for t in topic_nodes],
|
| 304 |
+
'Mastery': [round(float(s), 3) for s in scores]
|
| 305 |
+
}).sort_values('Mastery', ascending=False).head(10)
|
| 306 |
+
st.markdown('**📈 Simulated Learner Mastery Signal (top 10 topics):**')
|
| 307 |
+
st.bar_chart(mastery_df.set_index('Topic'))
|
| 308 |
+
# Cascade mastery upward through DAG
|
| 309 |
+
mastery_vector = cascade_mastery(mastery_vector, graph)
|
| 310 |
+
n_mastered = sum(1 for t in topic_nodes if mastery_vector.is_mastered(t))
|
| 311 |
+
st.success(f'Learner simulation complete — {n_mastered}/{n_topics} topics above mastery threshold')
|
| 312 |
+
if st.button('🚀 Generate Recommendations', type='primary'):
|
| 313 |
+
output=pipeline.run(mastery_vector)
|
| 314 |
+
summary=mastery_vector.get_mastery_summary()
|
| 315 |
+
col1,col2,col3,col4,col5=st.columns(5)
|
| 316 |
+
col1.metric('Topics Mastered',f"{summary['mastered']} / {summary['total_topics']}")
|
| 317 |
+
col2.metric('Mastery Rate',f"{summary['mastery_rate']:.1%}")
|
| 318 |
+
col3.metric('✅ Approved',output['total_approved'])
|
| 319 |
+
col4.metric('⚠️ Challenging',output['total_challenging'])
|
| 320 |
+
col5.metric('Violation Rate',f"{output['prerequisite_violation_rate']:.1%}")
|
| 321 |
+
st.markdown('---')
|
| 322 |
+
st.subheader(f'✅ Top {top_n} Approved Recommendations')
|
| 323 |
+
if not output['top_recommendations']: st.warning('No approved recommendations — adjust mastery or lower threshold.')
|
| 324 |
else:
|
| 325 |
+
for i,rec in enumerate(output['top_recommendations'],1):
|
| 326 |
+
with st.expander(f"{i}. {rec['topic_label']} — Score: {rec['score']} | Mastery: {rec['mastery']:.1%}", expanded=(i<=3)):
|
| 327 |
+
st.markdown(f"**Reasoning:** {rec['reasoning']}")
|
| 328 |
+
st.progress(rec['mastery'])
|
| 329 |
+
if output['challenging']:
|
| 330 |
+
st.markdown('---')
|
| 331 |
+
st.subheader('⚠️ Challenging Topics (proceed with caution)')
|
| 332 |
+
for rec in output['challenging']:
|
| 333 |
+
with st.expander(f"{rec['topic_label']} | Mastery: {rec['mastery']:.1%}"):
|
| 334 |
+
st.markdown(f"**Reasoning:** {rec['reasoning']}")
|
| 335 |
+
st.progress(rec['mastery'])
|
| 336 |
+
if output['vetoed_sample']:
|
| 337 |
+
st.markdown('---'); st.subheader('❌ Sample Vetoed Topics')
|
| 338 |
+
for rec in output['vetoed_sample']:
|
| 339 |
+
with st.expander(f"✗ {rec['topic_label']}"):
|
| 340 |
+
st.markdown(f"**Reason:** {rec['reasoning']}")
|
| 341 |
+
|
| 342 |
+
with tab2:
|
| 343 |
+
st.header('🔍 What-If Prerequisite Simulator')
|
| 344 |
+
st.markdown('Explore how mastering a topic unlocks future learning paths — or what is blocking you from starting it.')
|
| 345 |
+
nodes_list = list(graph.nodes)
|
| 346 |
+
labels_list = [graph.nodes[n].get('label',n) for n in nodes_list]
|
| 347 |
+
selected_label = st.selectbox('Select a topic to analyse:', labels_list)
|
| 348 |
+
selected_node = nodes_list[labels_list.index(selected_label)]
|
| 349 |
+
if st.button('🔍 Analyse Topic', type='primary'):
|
| 350 |
+
result = what_if_analysis(selected_node, graph)
|
| 351 |
+
col1, col2 = st.columns(2)
|
| 352 |
+
with col1:
|
| 353 |
+
st.subheader('🔓 If you master this topic...')
|
| 354 |
+
if result['direct_unlocks']:
|
| 355 |
+
st.markdown(f"**Directly unlocks {len(result['direct_unlocks'])} topic(s):**")
|
| 356 |
+
for t in result['direct_unlocks']: st.markdown(f' → {t}')
|
| 357 |
else:
|
| 358 |
+
st.info('This is a terminal topic — it does not unlock further topics in this map.')
|
| 359 |
+
if result['all_unlocks']:
|
| 360 |
+
st.markdown(f"**Total topics eventually unlocked: {result['total_unlocked']}**")
|
| 361 |
+
with col2:
|
| 362 |
+
st.subheader('🔒 To start this topic you need...')
|
| 363 |
+
if result['blocked_by']:
|
| 364 |
+
st.markdown('**Prerequisites required:**')
|
| 365 |
+
for t in result['blocked_by']: st.markdown(f' ✓ {t}')
|
| 366 |
+
else:
|
| 367 |
+
st.success('This is a foundational topic — no prerequisites needed. You can start it now!')
|
| 368 |
+
if result['all_unlocks']:
|
| 369 |
+
st.markdown('---')
|
| 370 |
+
st.markdown('**Full learning path unlocked:**')
|
| 371 |
+
st.markdown(' → '.join([selected_label] + result['all_unlocks'][:8]) + ('...' if len(result['all_unlocks'])>8 else ''))
|
| 372 |
+
|
| 373 |
+
with tab3:
|
| 374 |
+
st.header(f'{domain} Knowledge Map')
|
| 375 |
+
st.markdown(f"**{graph.number_of_nodes()} topics** | **{graph.number_of_edges()} prerequisite relationships**")
|
| 376 |
+
rows=[]
|
| 377 |
+
for node in graph.nodes:
|
| 378 |
+
label=graph.nodes[node].get('label',node); level=graph.nodes[node].get('level',''); term=graph.nodes[node].get('term','')
|
| 379 |
+
prereqs=[graph.nodes[p].get('label',p) for p in graph.predecessors(node)]
|
| 380 |
+
rows.append({'Topic':label,'Level':level,'Term':term,'Prerequisites':', '.join(prereqs) if prereqs else 'None (Foundational)'})
|
| 381 |
+
st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
|
| 382 |
+
longest=nx.dag_longest_path(graph)
|
| 383 |
+
st.markdown('**Longest prerequisite chain:**')
|
| 384 |
+
st.markdown(' → '.join([graph.nodes[n].get('label',n) for n in longest]))
|
| 385 |
+
|
| 386 |
+
with tab4:
|
| 387 |
+
st.header('System Diagnostics')
|
| 388 |
+
col1,col2=st.columns(2)
|
| 389 |
+
with col1: st.subheader('Model Configuration'); st.json(config)
|
| 390 |
+
with col2:
|
| 391 |
+
st.subheader('DAG Statistics')
|
| 392 |
+
st.json({'domain':domain,'nodes':graph.number_of_nodes(),'edges':graph.number_of_edges(),'is_valid_dag':nx.is_directed_acyclic_graph(graph),'longest_path':len(nx.dag_longest_path(graph))})
|
| 393 |
+
st.subheader('Constraint Layer')
|
| 394 |
+
st.markdown(f'**Mastery threshold:** {threshold:.0%} — topics above this are considered mastered')
|
| 395 |
+
st.markdown(f'**Challenging threshold:** {soft_threshold:.0%} — topics between this and mastery threshold are marked ⚠️ Challenging')
|
| 396 |
+
st.markdown('**Hard veto:** topics with prerequisites below challenging threshold are fully blocked')
|
| 397 |
+
st.subheader('Domain Switching')
|
| 398 |
+
dcol1,dcol2=st.columns(2)
|
| 399 |
+
with dcol1: st.metric('Math DAG',f'{math_graph.number_of_nodes()} topics')
|
| 400 |
+
with dcol2: st.metric('CS DAG',f'{cs_graph.number_of_nodes()} topics')
|
| 401 |
+
|
| 402 |
+
if __name__ == '__main__':
|
| 403 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/skill_encoder.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/skill_encoder_real.csv
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
assessment_type,code_module,skill_id
|
| 2 |
+
Exam,AAA,1
|
| 3 |
+
TMA,AAA,2
|
| 4 |
+
CMA,BBB,3
|
| 5 |
+
Exam,BBB,4
|
| 6 |
+
TMA,BBB,5
|
| 7 |
+
CMA,CCC,6
|
| 8 |
+
Exam,CCC,7
|
| 9 |
+
TMA,CCC,8
|
| 10 |
+
CMA,DDD,9
|
| 11 |
+
Exam,DDD,10
|
| 12 |
+
TMA,DDD,11
|
| 13 |
+
Exam,EEE,12
|
| 14 |
+
TMA,EEE,13
|
| 15 |
+
CMA,FFF,14
|
| 16 |
+
Exam,FFF,15
|
| 17 |
+
TMA,FFF,16
|
| 18 |
+
CMA,GGG,17
|
| 19 |
+
Exam,GGG,18
|
| 20 |
+
TMA,GGG,19
|
data/skill_encoder_v2.csv
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
index.html
DELETED
|
@@ -1,609 +0,0 @@
|
|
| 1 |
-
<!DOCTYPE html>
|
| 2 |
-
<html lang="en">
|
| 3 |
-
<head>
|
| 4 |
-
<meta charset="UTF-8" />
|
| 5 |
-
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 6 |
-
<title>PLRS — Personalized Learning Recommendation System</title>
|
| 7 |
-
<meta name="description" content="Constraint-aware personalized learning recommendations. Plug in your curriculum, get intelligent recommendations out." />
|
| 8 |
-
<link rel="preconnect" href="https://fonts.googleapis.com" />
|
| 9 |
-
<link href="https://fonts.googleapis.com/css2?family=DM+Mono:ital,wght@0,300;0,400;0,500;1,300&family=Syne:wght@400;600;700;800&display=swap" rel="stylesheet" />
|
| 10 |
-
|
| 11 |
-
<style>
|
| 12 |
-
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
|
| 13 |
-
|
| 14 |
-
:root {
|
| 15 |
-
--bg: #080c18;
|
| 16 |
-
--bg2: #0d1221;
|
| 17 |
-
--bg3: #131a2e;
|
| 18 |
-
--border: #1e2a40;
|
| 19 |
-
--border2: #1e3a5f;
|
| 20 |
-
--text: #c8d0e0;
|
| 21 |
-
--text-dim: #4a5568;
|
| 22 |
-
--text-hi: #e8edf5;
|
| 23 |
-
--blue: #3d8bcd;
|
| 24 |
-
--green: #22c55e;
|
| 25 |
-
--amber: #f59e0b;
|
| 26 |
-
--red: #ef4444;
|
| 27 |
-
--mono: 'DM Mono', monospace;
|
| 28 |
-
--sans: 'Syne', sans-serif;
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
html { scroll-behavior: smooth; }
|
| 32 |
-
|
| 33 |
-
body {
|
| 34 |
-
background: var(--bg);
|
| 35 |
-
color: var(--text);
|
| 36 |
-
font-family: var(--sans);
|
| 37 |
-
line-height: 1.6;
|
| 38 |
-
overflow-x: hidden;
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
/* ── Noise overlay ── */
|
| 42 |
-
body::before {
|
| 43 |
-
content: '';
|
| 44 |
-
position: fixed; inset: 0;
|
| 45 |
-
background-image: url("data:image/svg+xml,%3Csvg viewBox='0 0 256 256' xmlns='http://www.w3.org/2000/svg'%3E%3Cfilter id='n'%3E%3CfeTurbulence type='fractalNoise' baseFrequency='0.9' numOctaves='4' stitchTiles='stitch'/%3E%3C/filter%3E%3Crect width='100%25' height='100%25' filter='url(%23n)' opacity='0.03'/%3E%3C/svg%3E");
|
| 46 |
-
pointer-events: none;
|
| 47 |
-
z-index: 0;
|
| 48 |
-
opacity: 0.4;
|
| 49 |
-
}
|
| 50 |
-
|
| 51 |
-
/* ── Nav ── */
|
| 52 |
-
nav {
|
| 53 |
-
position: fixed; top: 0; left: 0; right: 0;
|
| 54 |
-
display: flex; align-items: center; justify-content: space-between;
|
| 55 |
-
padding: 1rem 2.5rem;
|
| 56 |
-
background: rgba(8, 12, 24, 0.85);
|
| 57 |
-
backdrop-filter: blur(12px);
|
| 58 |
-
border-bottom: 1px solid var(--border);
|
| 59 |
-
z-index: 100;
|
| 60 |
-
}
|
| 61 |
-
.nav-logo {
|
| 62 |
-
font-weight: 800; font-size: 1.1rem; color: var(--text-hi);
|
| 63 |
-
letter-spacing: -0.02em; text-decoration: none;
|
| 64 |
-
}
|
| 65 |
-
.nav-logo span { color: var(--blue); }
|
| 66 |
-
.nav-links { display: flex; gap: 2rem; align-items: center; }
|
| 67 |
-
.nav-links a {
|
| 68 |
-
font-family: var(--mono); font-size: 0.7rem; letter-spacing: 0.1em;
|
| 69 |
-
color: var(--text-dim); text-decoration: none; text-transform: uppercase;
|
| 70 |
-
transition: color 0.2s;
|
| 71 |
-
}
|
| 72 |
-
.nav-links a:hover { color: var(--blue); }
|
| 73 |
-
.btn {
|
| 74 |
-
display: inline-flex; align-items: center; gap: 0.5rem;
|
| 75 |
-
padding: 0.5rem 1.1rem; border-radius: 3px; font-family: var(--mono);
|
| 76 |
-
font-size: 0.7rem; letter-spacing: 0.08em; text-decoration: none;
|
| 77 |
-
transition: all 0.2s; cursor: pointer; border: none;
|
| 78 |
-
}
|
| 79 |
-
.btn-primary {
|
| 80 |
-
background: var(--blue); color: #fff;
|
| 81 |
-
}
|
| 82 |
-
.btn-primary:hover { background: #4d9bdd; }
|
| 83 |
-
.btn-outline {
|
| 84 |
-
background: transparent; color: var(--blue);
|
| 85 |
-
border: 1px solid var(--border2);
|
| 86 |
-
}
|
| 87 |
-
.btn-outline:hover { border-color: var(--blue); background: rgba(61,139,205,0.07); }
|
| 88 |
-
|
| 89 |
-
/* ── Hero ── */
|
| 90 |
-
.hero {
|
| 91 |
-
min-height: 100vh;
|
| 92 |
-
display: flex; flex-direction: column; justify-content: center;
|
| 93 |
-
padding: 8rem 2.5rem 5rem;
|
| 94 |
-
max-width: 1100px; margin: 0 auto;
|
| 95 |
-
position: relative;
|
| 96 |
-
}
|
| 97 |
-
.hero-eyebrow {
|
| 98 |
-
font-family: var(--mono); font-size: 0.7rem; letter-spacing: 0.18em;
|
| 99 |
-
color: var(--blue); text-transform: uppercase; margin-bottom: 1.5rem;
|
| 100 |
-
display: flex; align-items: center; gap: 0.75rem;
|
| 101 |
-
}
|
| 102 |
-
.hero-eyebrow::before {
|
| 103 |
-
content: ''; display: block; width: 2rem; height: 1px; background: var(--blue);
|
| 104 |
-
}
|
| 105 |
-
.hero h1 {
|
| 106 |
-
font-size: clamp(2.8rem, 6vw, 5rem);
|
| 107 |
-
font-weight: 800; line-height: 1.05;
|
| 108 |
-
letter-spacing: -0.03em; color: var(--text-hi);
|
| 109 |
-
margin-bottom: 1.5rem;
|
| 110 |
-
}
|
| 111 |
-
.hero h1 em {
|
| 112 |
-
font-style: normal; color: var(--blue);
|
| 113 |
-
}
|
| 114 |
-
.hero-sub {
|
| 115 |
-
font-size: 1.1rem; color: var(--text-dim);
|
| 116 |
-
max-width: 560px; margin-bottom: 2.5rem;
|
| 117 |
-
line-height: 1.7;
|
| 118 |
-
}
|
| 119 |
-
.hero-ctas { display: flex; gap: 0.75rem; flex-wrap: wrap; margin-bottom: 4rem; }
|
| 120 |
-
.btn-hero {
|
| 121 |
-
padding: 0.75rem 1.5rem; font-size: 0.8rem;
|
| 122 |
-
}
|
| 123 |
-
|
| 124 |
-
/* ── Stat strip ── */
|
| 125 |
-
.stat-strip {
|
| 126 |
-
display: flex; gap: 2.5rem; flex-wrap: wrap;
|
| 127 |
-
border-top: 1px solid var(--border);
|
| 128 |
-
padding-top: 2rem;
|
| 129 |
-
}
|
| 130 |
-
.stat-item {}
|
| 131 |
-
.stat-num {
|
| 132 |
-
font-size: 2rem; font-weight: 800; color: var(--text-hi);
|
| 133 |
-
line-height: 1;
|
| 134 |
-
}
|
| 135 |
-
.stat-num span { color: var(--green); }
|
| 136 |
-
.stat-label {
|
| 137 |
-
font-family: var(--mono); font-size: 0.65rem; letter-spacing: 0.1em;
|
| 138 |
-
color: var(--text-dim); text-transform: uppercase; margin-top: 0.2rem;
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
/* ── Grid background decoration ── */
|
| 142 |
-
.hero-grid {
|
| 143 |
-
position: absolute; top: 0; right: -5%; bottom: 0; width: 50%;
|
| 144 |
-
background-image:
|
| 145 |
-
linear-gradient(var(--border) 1px, transparent 1px),
|
| 146 |
-
linear-gradient(90deg, var(--border) 1px, transparent 1px);
|
| 147 |
-
background-size: 40px 40px;
|
| 148 |
-
mask-image: linear-gradient(to left, rgba(0,0,0,0.15), transparent 70%);
|
| 149 |
-
pointer-events: none;
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
/* ── Section ── */
|
| 153 |
-
section {
|
| 154 |
-
max-width: 1100px; margin: 0 auto;
|
| 155 |
-
padding: 5rem 2.5rem;
|
| 156 |
-
}
|
| 157 |
-
.section-label {
|
| 158 |
-
font-family: var(--mono); font-size: 0.65rem; letter-spacing: 0.18em;
|
| 159 |
-
color: var(--blue); text-transform: uppercase;
|
| 160 |
-
display: flex; align-items: center; gap: 0.75rem;
|
| 161 |
-
margin-bottom: 1rem;
|
| 162 |
-
}
|
| 163 |
-
.section-label::before {
|
| 164 |
-
content: ''; display: block; width: 1.5rem; height: 1px; background: var(--blue);
|
| 165 |
-
}
|
| 166 |
-
.section-title {
|
| 167 |
-
font-size: clamp(1.8rem, 3.5vw, 2.5rem);
|
| 168 |
-
font-weight: 800; letter-spacing: -0.02em; color: var(--text-hi);
|
| 169 |
-
margin-bottom: 1rem;
|
| 170 |
-
}
|
| 171 |
-
.section-body {
|
| 172 |
-
color: var(--text-dim); font-size: 0.95rem; max-width: 600px;
|
| 173 |
-
line-height: 1.8; margin-bottom: 2.5rem;
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
/* ── Architecture flow ── */
|
| 177 |
-
.arch-flow {
|
| 178 |
-
display: flex; align-items: center; flex-wrap: wrap;
|
| 179 |
-
gap: 0; margin: 2.5rem 0;
|
| 180 |
-
}
|
| 181 |
-
.arch-node {
|
| 182 |
-
background: var(--bg2); border: 1px solid var(--border);
|
| 183 |
-
border-radius: 4px; padding: 0.7rem 1rem;
|
| 184 |
-
font-family: var(--mono); font-size: 0.72rem; color: var(--text);
|
| 185 |
-
letter-spacing: 0.04em; position: relative;
|
| 186 |
-
}
|
| 187 |
-
.arch-node.highlight { border-color: var(--blue); color: var(--blue); }
|
| 188 |
-
.arch-arrow {
|
| 189 |
-
font-family: var(--mono); color: var(--border2); padding: 0 0.4rem;
|
| 190 |
-
font-size: 0.9rem;
|
| 191 |
-
}
|
| 192 |
-
|
| 193 |
-
/* ── Three-tier cards ── */
|
| 194 |
-
.tier-grid { display: grid; grid-template-columns: repeat(3, 1fr); gap: 1rem; margin-top: 2rem; }
|
| 195 |
-
.tier-card {
|
| 196 |
-
background: var(--bg2); border: 1px solid var(--border);
|
| 197 |
-
border-radius: 4px; padding: 1.5rem;
|
| 198 |
-
position: relative; overflow: hidden;
|
| 199 |
-
}
|
| 200 |
-
.tier-card::before {
|
| 201 |
-
content: ''; position: absolute; top: 0; left: 0; right: 0; height: 2px;
|
| 202 |
-
background: var(--accent);
|
| 203 |
-
}
|
| 204 |
-
.tier-card.green { --accent: var(--green); }
|
| 205 |
-
.tier-card.amber { --accent: var(--amber); }
|
| 206 |
-
.tier-card.red { --accent: var(--red); }
|
| 207 |
-
.tier-icon { font-size: 1.5rem; margin-bottom: 0.75rem; }
|
| 208 |
-
.tier-name {
|
| 209 |
-
font-weight: 700; font-size: 1rem; color: var(--text-hi);
|
| 210 |
-
margin-bottom: 0.35rem;
|
| 211 |
-
}
|
| 212 |
-
.tier-desc { font-size: 0.8rem; color: var(--text-dim); line-height: 1.6; }
|
| 213 |
-
|
| 214 |
-
/* ── Results table ── */
|
| 215 |
-
.results-table {
|
| 216 |
-
width: 100%; border-collapse: collapse;
|
| 217 |
-
font-family: var(--mono); font-size: 0.78rem;
|
| 218 |
-
margin-top: 2rem;
|
| 219 |
-
}
|
| 220 |
-
.results-table th {
|
| 221 |
-
text-align: left; padding: 0.6rem 1rem;
|
| 222 |
-
color: var(--text-dim); letter-spacing: 0.1em; text-transform: uppercase;
|
| 223 |
-
font-size: 0.65rem; border-bottom: 1px solid var(--border);
|
| 224 |
-
}
|
| 225 |
-
.results-table td {
|
| 226 |
-
padding: 0.75rem 1rem; border-bottom: 1px solid var(--border);
|
| 227 |
-
color: var(--text);
|
| 228 |
-
}
|
| 229 |
-
.results-table tr:last-child td { border-bottom: none; }
|
| 230 |
-
.results-table tr.highlight-row td { color: var(--text-hi); }
|
| 231 |
-
.badge-green {
|
| 232 |
-
background: rgba(34,197,94,0.1); color: var(--green);
|
| 233 |
-
border: 1px solid rgba(34,197,94,0.3);
|
| 234 |
-
padding: 1px 7px; border-radius: 2px; font-size: 0.65rem;
|
| 235 |
-
}
|
| 236 |
-
.badge-red {
|
| 237 |
-
background: rgba(239,68,68,0.1); color: var(--red);
|
| 238 |
-
border: 1px solid rgba(239,68,68,0.3);
|
| 239 |
-
padding: 1px 7px; border-radius: 2px; font-size: 0.65rem;
|
| 240 |
-
}
|
| 241 |
-
|
| 242 |
-
/* ── Code block ── */
|
| 243 |
-
.code-wrap {
|
| 244 |
-
background: var(--bg2); border: 1px solid var(--border);
|
| 245 |
-
border-radius: 4px; overflow: hidden; margin-top: 2rem;
|
| 246 |
-
}
|
| 247 |
-
.code-header {
|
| 248 |
-
display: flex; align-items: center; justify-content: space-between;
|
| 249 |
-
padding: 0.6rem 1rem; border-bottom: 1px solid var(--border);
|
| 250 |
-
background: var(--bg3);
|
| 251 |
-
}
|
| 252 |
-
.code-dots { display: flex; gap: 5px; }
|
| 253 |
-
.code-dots span {
|
| 254 |
-
width: 10px; height: 10px; border-radius: 50%;
|
| 255 |
-
background: var(--border2);
|
| 256 |
-
}
|
| 257 |
-
.code-lang {
|
| 258 |
-
font-family: var(--mono); font-size: 0.62rem;
|
| 259 |
-
color: var(--text-dim); letter-spacing: 0.1em;
|
| 260 |
-
}
|
| 261 |
-
pre {
|
| 262 |
-
padding: 1.5rem;
|
| 263 |
-
font-family: var(--mono); font-size: 0.78rem;
|
| 264 |
-
line-height: 1.7; color: var(--text);
|
| 265 |
-
overflow-x: auto;
|
| 266 |
-
}
|
| 267 |
-
.cm { color: #4a5568; } /* comment */
|
| 268 |
-
.ck { color: #3d8bcd; } /* keyword */
|
| 269 |
-
.cs { color: #22c55e; } /* string */
|
| 270 |
-
.cn { color: #f59e0b; } /* number / name */
|
| 271 |
-
.cf { color: #c084fc; } /* function */
|
| 272 |
-
|
| 273 |
-
/* ── Feature grid ── */
|
| 274 |
-
.feature-grid { display: grid; grid-template-columns: repeat(2, 1fr); gap: 1px; background: var(--border); margin-top: 2rem; border: 1px solid var(--border); border-radius: 4px; overflow: hidden; }
|
| 275 |
-
.feature-cell {
|
| 276 |
-
background: var(--bg); padding: 1.5rem;
|
| 277 |
-
}
|
| 278 |
-
.feature-icon { font-size: 1.2rem; margin-bottom: 0.75rem; }
|
| 279 |
-
.feature-title { font-weight: 700; color: var(--text-hi); margin-bottom: 0.35rem; font-size: 0.9rem; }
|
| 280 |
-
.feature-desc { font-size: 0.78rem; color: var(--text-dim); line-height: 1.6; }
|
| 281 |
-
|
| 282 |
-
/* ── CTA section ── */
|
| 283 |
-
.cta-section {
|
| 284 |
-
background: var(--bg2);
|
| 285 |
-
border-top: 1px solid var(--border);
|
| 286 |
-
border-bottom: 1px solid var(--border);
|
| 287 |
-
padding: 5rem 2.5rem;
|
| 288 |
-
text-align: center;
|
| 289 |
-
}
|
| 290 |
-
.cta-inner { max-width: 600px; margin: 0 auto; }
|
| 291 |
-
.cta-title { font-size: 2.2rem; font-weight: 800; letter-spacing: -0.02em; color: var(--text-hi); margin-bottom: 1rem; }
|
| 292 |
-
.cta-sub { color: var(--text-dim); margin-bottom: 2rem; line-height: 1.7; }
|
| 293 |
-
.cta-btns { display: flex; gap: 0.75rem; justify-content: center; flex-wrap: wrap; }
|
| 294 |
-
|
| 295 |
-
/* ── Footer ── */
|
| 296 |
-
footer {
|
| 297 |
-
border-top: 1px solid var(--border);
|
| 298 |
-
padding: 2rem 2.5rem;
|
| 299 |
-
display: flex; justify-content: space-between; align-items: center;
|
| 300 |
-
flex-wrap: wrap; gap: 1rem;
|
| 301 |
-
max-width: 100%;
|
| 302 |
-
}
|
| 303 |
-
.footer-left { font-family: var(--mono); font-size: 0.65rem; color: var(--text-dim); }
|
| 304 |
-
.footer-links { display: flex; gap: 1.5rem; }
|
| 305 |
-
.footer-links a { font-family: var(--mono); font-size: 0.65rem; color: var(--text-dim); text-decoration: none; }
|
| 306 |
-
.footer-links a:hover { color: var(--blue); }
|
| 307 |
-
|
| 308 |
-
/* ── Animations ── */
|
| 309 |
-
@keyframes fadeUp {
|
| 310 |
-
from { opacity: 0; transform: translateY(20px); }
|
| 311 |
-
to { opacity: 1; transform: translateY(0); }
|
| 312 |
-
}
|
| 313 |
-
.hero-eyebrow { animation: fadeUp 0.5s ease 0.1s both; }
|
| 314 |
-
.hero h1 { animation: fadeUp 0.5s ease 0.2s both; }
|
| 315 |
-
.hero-sub { animation: fadeUp 0.5s ease 0.3s both; }
|
| 316 |
-
.hero-ctas { animation: fadeUp 0.5s ease 0.4s both; }
|
| 317 |
-
.stat-strip { animation: fadeUp 0.5s ease 0.5s both; }
|
| 318 |
-
|
| 319 |
-
/* ── Responsive ── */
|
| 320 |
-
@media (max-width: 768px) {
|
| 321 |
-
nav { padding: 0.75rem 1.25rem; }
|
| 322 |
-
.nav-links .btn { display: none; }
|
| 323 |
-
.hero { padding: 7rem 1.25rem 4rem; }
|
| 324 |
-
.tier-grid { grid-template-columns: 1fr; }
|
| 325 |
-
.feature-grid { grid-template-columns: 1fr; }
|
| 326 |
-
section { padding: 3rem 1.25rem; }
|
| 327 |
-
.arch-flow { gap: 0.25rem; }
|
| 328 |
-
}
|
| 329 |
-
</style>
|
| 330 |
-
</head>
|
| 331 |
-
<body>
|
| 332 |
-
|
| 333 |
-
<!-- ── Nav ── -->
|
| 334 |
-
<nav>
|
| 335 |
-
<a href="#" class="nav-logo">PL<span>RS</span></a>
|
| 336 |
-
<div class="nav-links">
|
| 337 |
-
<a href="#how-it-works">How it works</a>
|
| 338 |
-
<a href="#results">Results</a>
|
| 339 |
-
<a href="#quickstart">Quickstart</a>
|
| 340 |
-
<a href="https://github.com/clementina-tom/plrs" target="_blank">GitHub</a>
|
| 341 |
-
<a href="https://huggingface.co/spaces/Clementio/PLRS" class="btn btn-primary btn-hero" target="_blank">Live Demo →</a>
|
| 342 |
-
</div>
|
| 343 |
-
</nav>
|
| 344 |
-
|
| 345 |
-
<!-- ── Hero ── -->
|
| 346 |
-
<div class="hero">
|
| 347 |
-
<div class="hero-grid"></div>
|
| 348 |
-
|
| 349 |
-
<div class="hero-eyebrow">Knowledge Tracing · Constraint-Aware · Open Source</div>
|
| 350 |
-
|
| 351 |
-
<h1>Recommendations that<br/><em>respect</em> how learning works.</h1>
|
| 352 |
-
|
| 353 |
-
<p class="hero-sub">
|
| 354 |
-
PLRS combines Self-Attentive Knowledge Tracing with a DAG prerequisite constraint layer
|
| 355 |
-
to generate personalized learning recommendations that are pedagogically sound —
|
| 356 |
-
not just statistically optimal.
|
| 357 |
-
</p>
|
| 358 |
-
|
| 359 |
-
<div class="hero-ctas">
|
| 360 |
-
<a href="https://huggingface.co/spaces/Clementio/PLRS" target="_blank" class="btn btn-primary btn-hero">
|
| 361 |
-
Try the live demo
|
| 362 |
-
</a>
|
| 363 |
-
<a href="https://github.com/clementina-tom/plrs" target="_blank" class="btn btn-outline btn-hero">
|
| 364 |
-
View on GitHub
|
| 365 |
-
</a>
|
| 366 |
-
<a href="#quickstart" class="btn btn-outline btn-hero">
|
| 367 |
-
Quickstart
|
| 368 |
-
</a>
|
| 369 |
-
</div>
|
| 370 |
-
|
| 371 |
-
<div class="stat-strip">
|
| 372 |
-
<div class="stat-item">
|
| 373 |
-
<div class="stat-num"><span>0.0</span>%</div>
|
| 374 |
-
<div class="stat-label">Prerequisite violation rate</div>
|
| 375 |
-
</div>
|
| 376 |
-
<div class="stat-item">
|
| 377 |
-
<div class="stat-num">0.7692</div>
|
| 378 |
-
<div class="stat-label">SAKT Val AUC (OULAD)</div>
|
| 379 |
-
</div>
|
| 380 |
-
<div class="stat-item">
|
| 381 |
-
<div class="stat-num">69</div>
|
| 382 |
-
<div class="stat-label">Curriculum topics (2 domains)</div>
|
| 383 |
-
</div>
|
| 384 |
-
<div class="stat-item">
|
| 385 |
-
<div class="stat-num">52</div>
|
| 386 |
-
<div class="stat-label">Tests passing</div>
|
| 387 |
-
</div>
|
| 388 |
-
</div>
|
| 389 |
-
</div>
|
| 390 |
-
|
| 391 |
-
<!-- ── How it works ── -->
|
| 392 |
-
<section id="how-it-works">
|
| 393 |
-
<div class="section-label">Architecture</div>
|
| 394 |
-
<h2 class="section-title">Three layers. One guarantee.</h2>
|
| 395 |
-
<p class="section-body">
|
| 396 |
-
Standard recommendation systems optimise for engagement or accuracy —
|
| 397 |
-
they will happily recommend Calculus to a student who hasn't mastered Algebra.
|
| 398 |
-
PLRS adds a constraint layer that makes this <em>structurally impossible</em>.
|
| 399 |
-
</p>
|
| 400 |
-
|
| 401 |
-
<div class="arch-flow">
|
| 402 |
-
<div class="arch-node">Student History</div>
|
| 403 |
-
<div class="arch-arrow">→</div>
|
| 404 |
-
<div class="arch-node highlight">SAKT Model</div>
|
| 405 |
-
<div class="arch-arrow">→</div>
|
| 406 |
-
<div class="arch-node">Mastery Vector</div>
|
| 407 |
-
<div class="arch-arrow">→</div>
|
| 408 |
-
<div class="arch-node highlight">DAG Constraints</div>
|
| 409 |
-
<div class="arch-arrow">→</div>
|
| 410 |
-
<div class="arch-node">Multi-Objective Ranker</div>
|
| 411 |
-
<div class="arch-arrow">→</div>
|
| 412 |
-
<div class="arch-node highlight">Recommendations</div>
|
| 413 |
-
</div>
|
| 414 |
-
|
| 415 |
-
<div class="tier-grid">
|
| 416 |
-
<div class="tier-card green">
|
| 417 |
-
<div class="tier-icon">✅</div>
|
| 418 |
-
<div class="tier-name">Approved</div>
|
| 419 |
-
<div class="tier-desc">All prerequisites met above the mastery threshold. Student is ready to learn this topic now.</div>
|
| 420 |
-
</div>
|
| 421 |
-
<div class="tier-card amber">
|
| 422 |
-
<div class="tier-icon">⚠️</div>
|
| 423 |
-
<div class="tier-name">Challenging</div>
|
| 424 |
-
<div class="tier-desc">Prerequisites partially met — above the soft threshold but below full mastery. Proceed with awareness.</div>
|
| 425 |
-
</div>
|
| 426 |
-
<div class="tier-card red">
|
| 427 |
-
<div class="tier-icon">❌</div>
|
| 428 |
-
<div class="tier-name">Vetoed</div>
|
| 429 |
-
<div class="tier-desc">One or more prerequisites not met. Structurally blocked until foundations are solid.</div>
|
| 430 |
-
</div>
|
| 431 |
-
</div>
|
| 432 |
-
</section>
|
| 433 |
-
|
| 434 |
-
<!-- ── Results ── -->
|
| 435 |
-
<section id="results" style="border-top: 1px solid var(--border);">
|
| 436 |
-
<div class="section-label">Evaluation</div>
|
| 437 |
-
<h2 class="section-title">0% violation rate. Not a tuning choice.</h2>
|
| 438 |
-
<p class="section-body">
|
| 439 |
-
Evaluated on the Open University Learning Analytics Dataset (OULAD) with
|
| 440 |
-
Nigerian secondary school curriculum knowledge maps. The 0% violation rate
|
| 441 |
-
is a structural guarantee from the DAG constraint layer — not a hyperparameter.
|
| 442 |
-
</p>
|
| 443 |
-
|
| 444 |
-
<table class="results-table">
|
| 445 |
-
<thead>
|
| 446 |
-
<tr>
|
| 447 |
-
<th>Model</th>
|
| 448 |
-
<th>Val AUC</th>
|
| 449 |
-
<th>Prerequisite Violation Rate</th>
|
| 450 |
-
<th>Coverage</th>
|
| 451 |
-
</tr>
|
| 452 |
-
</thead>
|
| 453 |
-
<tbody>
|
| 454 |
-
<tr class="highlight-row">
|
| 455 |
-
<td><strong>PLRS (SAKT + DAG)</strong></td>
|
| 456 |
-
<td><strong>0.7692</strong></td>
|
| 457 |
-
<td><span class="badge-green">0.0%</span></td>
|
| 458 |
-
<td>Full curriculum</td>
|
| 459 |
-
</tr>
|
| 460 |
-
<tr>
|
| 461 |
-
<td>Collaborative Filtering</td>
|
| 462 |
-
<td>—</td>
|
| 463 |
-
<td><span class="badge-red">81.3%</span></td>
|
| 464 |
-
<td>Partial</td>
|
| 465 |
-
</tr>
|
| 466 |
-
<tr>
|
| 467 |
-
<td>Matrix Factorization</td>
|
| 468 |
-
<td>—</td>
|
| 469 |
-
<td><span class="badge-red">83.7%</span></td>
|
| 470 |
-
<td>Partial</td>
|
| 471 |
-
</tr>
|
| 472 |
-
<tr>
|
| 473 |
-
<td>BKT (baseline)</td>
|
| 474 |
-
<td>~0.67</td>
|
| 475 |
-
<td><span class="badge-red">No constraint layer</span></td>
|
| 476 |
-
<td>Partial</td>
|
| 477 |
-
</tr>
|
| 478 |
-
</tbody>
|
| 479 |
-
</table>
|
| 480 |
-
</section>
|
| 481 |
-
|
| 482 |
-
<!-- ── Quickstart ── -->
|
| 483 |
-
<section id="quickstart" style="border-top: 1px solid var(--border);">
|
| 484 |
-
<div class="section-label">Quickstart</div>
|
| 485 |
-
<h2 class="section-title">Plug in your curriculum.</h2>
|
| 486 |
-
<p class="section-body">
|
| 487 |
-
PLRS is curriculum-agnostic. Define your knowledge graph in a simple JSON format
|
| 488 |
-
and get recommendations immediately. No retraining required for new domains.
|
| 489 |
-
</p>
|
| 490 |
-
|
| 491 |
-
<div class="code-wrap">
|
| 492 |
-
<div class="code-header">
|
| 493 |
-
<div class="code-dots"><span></span><span></span><span></span></div>
|
| 494 |
-
<div class="code-lang">PYTHON</div>
|
| 495 |
-
</div>
|
| 496 |
-
<pre><span class="ck">from</span> plrs <span class="ck">import</span> PLRSPipeline
|
| 497 |
-
<span class="ck">from</span> plrs.curriculum <span class="ck">import</span> load_dag
|
| 498 |
-
|
| 499 |
-
<span class="cm"># Load your curriculum (JSON knowledge graph)</span>
|
| 500 |
-
curriculum = <span class="cf">load_dag</span>(<span class="cs">"math_dag.json"</span>)
|
| 501 |
-
|
| 502 |
-
<span class="cm"># Create pipeline — no model needed for mastery-dict mode</span>
|
| 503 |
-
pipeline = <span class="cf">PLRSPipeline</span>(curriculum)
|
| 504 |
-
|
| 505 |
-
<span class="cm"># Get recommendations from student mastery scores</span>
|
| 506 |
-
results = pipeline.<span class="cf">recommend_from_mastery</span>({
|
| 507 |
-
<span class="cs">"whole_numbers"</span>: <span class="cn">0.90</span>,
|
| 508 |
-
<span class="cs">"algebraic_expressions"</span>: <span class="cn">0.75</span>,
|
| 509 |
-
<span class="cs">"quadratic_equations"</span>: <span class="cn">0.40</span>,
|
| 510 |
-
})
|
| 511 |
-
|
| 512 |
-
<span class="ck">for</span> rec <span class="ck">in</span> results[<span class="cs">"approved"</span>]:
|
| 513 |
-
<span class="cf">print</span>(<span class="cs">f"✅ {rec['topic_label']} (score={rec['score']})"</span>)
|
| 514 |
-
<span class="cf">print</span>(<span class="cs">f" {rec['reasoning']}"</span>)
|
| 515 |
-
|
| 516 |
-
<span class="cm"># What-if: what does mastering this topic unlock?</span>
|
| 517 |
-
wi = pipeline.<span class="cf">what_if</span>(<span class="cs">"algebraic_expressions"</span>)
|
| 518 |
-
<span class="cf">print</span>(<span class="cs">f"Unlocks {wi['total_unlocked']} downstream topics"</span>)</pre>
|
| 519 |
-
</div>
|
| 520 |
-
|
| 521 |
-
<div class="code-wrap" style="margin-top: 1rem;">
|
| 522 |
-
<div class="code-header">
|
| 523 |
-
<div class="code-dots"><span></span><span></span><span></span></div>
|
| 524 |
-
<div class="code-lang">REST API</div>
|
| 525 |
-
</div>
|
| 526 |
-
<pre><span class="cm"># Start the server</span>
|
| 527 |
-
$ python scripts/serve.py
|
| 528 |
-
<span class="cm"># → http://127.0.0.1:8000/docs</span>
|
| 529 |
-
|
| 530 |
-
<span class="cm"># Get recommendations</span>
|
| 531 |
-
$ curl -X POST http://localhost:<span class="cn">8000</span>/recommend \
|
| 532 |
-
-H <span class="cs">"Content-Type: application/json"</span> \
|
| 533 |
-
-d <span class="cs">'{"domain":"math","mastery_scores":{"whole_numbers":0.9}}'</span></pre>
|
| 534 |
-
</div>
|
| 535 |
-
</section>
|
| 536 |
-
|
| 537 |
-
<!-- ── Features ── -->
|
| 538 |
-
<section style="border-top: 1px solid var(--border);">
|
| 539 |
-
<div class="section-label">Features</div>
|
| 540 |
-
<h2 class="section-title">Built for real deployment.</h2>
|
| 541 |
-
|
| 542 |
-
<div class="feature-grid">
|
| 543 |
-
<div class="feature-cell">
|
| 544 |
-
<div class="feature-icon">🔌</div>
|
| 545 |
-
<div class="feature-title">Curriculum-agnostic</div>
|
| 546 |
-
<div class="feature-desc">Define any knowledge graph in a simple JSON format. Ships with Nigerian secondary school Maths and CS Fundamentals (NERDC JSS3–SS2).</div>
|
| 547 |
-
</div>
|
| 548 |
-
<div class="feature-cell">
|
| 549 |
-
<div class="feature-icon">⚡</div>
|
| 550 |
-
<div class="feature-title">FastAPI REST backend</div>
|
| 551 |
-
<div class="feature-desc">Production-ready API with <code>/recommend</code>, <code>/what-if</code>, and <code>/curriculum</code> endpoints. Auto-generated OpenAPI docs.</div>
|
| 552 |
-
</div>
|
| 553 |
-
<div class="feature-cell">
|
| 554 |
-
<div class="feature-icon">🧠</div>
|
| 555 |
-
<div class="feature-title">SAKT + Forgetting Curve</div>
|
| 556 |
-
<div class="feature-desc">Self-Attentive Knowledge Tracing with optional Ebbinghaus decay attention — older interactions contribute less to current mastery estimates.</div>
|
| 557 |
-
</div>
|
| 558 |
-
<div class="feature-cell">
|
| 559 |
-
<div class="feature-icon">🔍</div>
|
| 560 |
-
<div class="feature-title">What-If Simulator</div>
|
| 561 |
-
<div class="feature-desc">"If I master Trigonometry now, what unlocks?" — live DAG traversal shows direct and transitive downstream topics.</div>
|
| 562 |
-
</div>
|
| 563 |
-
<div class="feature-cell">
|
| 564 |
-
<div class="feature-icon">📦</div>
|
| 565 |
-
<div class="feature-title">PyPI-ready package</div>
|
| 566 |
-
<div class="feature-desc"><code>pip install plrs</code> — modular architecture with clean public API. Full type annotations throughout.</div>
|
| 567 |
-
</div>
|
| 568 |
-
<div class="feature-cell">
|
| 569 |
-
<div class="feature-icon">🧪</div>
|
| 570 |
-
<div class="feature-title">52 tests, CI on 3 Python versions</div>
|
| 571 |
-
<div class="feature-desc">Unit tests, API integration tests, and evaluator tests. GitHub Actions runs on Python 3.10, 3.11, and 3.12.</div>
|
| 572 |
-
</div>
|
| 573 |
-
</div>
|
| 574 |
-
</section>
|
| 575 |
-
|
| 576 |
-
<!-- ── CTA ── -->
|
| 577 |
-
<div class="cta-section">
|
| 578 |
-
<div class="cta-inner">
|
| 579 |
-
<div class="cta-title">Try it now — no setup required.</div>
|
| 580 |
-
<p class="cta-sub">
|
| 581 |
-
The live demo runs the full pipeline in your browser.
|
| 582 |
-
Adjust mastery sliders, simulate student sequences, explore the curriculum graph.
|
| 583 |
-
</p>
|
| 584 |
-
<div class="cta-btns">
|
| 585 |
-
<a href="https://huggingface.co/spaces/Clementio/PLRS" target="_blank" class="btn btn-primary btn-hero">
|
| 586 |
-
Open live demo →
|
| 587 |
-
</a>
|
| 588 |
-
<a href="https://github.com/clementina-tom/plrs" target="_blank" class="btn btn-outline btn-hero">
|
| 589 |
-
Star on GitHub
|
| 590 |
-
</a>
|
| 591 |
-
</div>
|
| 592 |
-
</div>
|
| 593 |
-
</div>
|
| 594 |
-
|
| 595 |
-
<!-- ── Footer ── */
|
| 596 |
-
<footer>
|
| 597 |
-
<div class="footer-left">
|
| 598 |
-
PLRS — Personalized Learning Recommendation System<br/>
|
| 599 |
-
MIT License · Built by <a href="https://github.com/clementina-tom" style="color:var(--blue);text-decoration:none;">Clementina Tom</a>
|
| 600 |
-
</div>
|
| 601 |
-
<div class="footer-links">
|
| 602 |
-
<a href="https://github.com/clementina-tom/plrs" target="_blank">GitHub</a>
|
| 603 |
-
<a href="https://huggingface.co/spaces/Clementio/PLRS" target="_blank">HuggingFace</a>
|
| 604 |
-
<a href="https://huggingface.co/spaces/Clementio/PLRS" target="_blank">Live Demo</a>
|
| 605 |
-
</div>
|
| 606 |
-
</footer>
|
| 607 |
-
|
| 608 |
-
</body>
|
| 609 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{data/knowledge_maps → knowledge_maps}/cs_dag.json
RENAMED
|
File without changes
|
{data/knowledge_maps → knowledge_maps}/math_dag.json
RENAMED
|
File without changes
|
plrs/__init__.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
PLRS — Personalized Learning Recommendation System
|
| 3 |
-
====================================================
|
| 4 |
-
Constraint-aware personalized learning recommendations.
|
| 5 |
-
Plug in your curriculum DAG, get intelligent recommendations out.
|
| 6 |
-
|
| 7 |
-
Quick start:
|
| 8 |
-
from plrs import PLRSPipeline
|
| 9 |
-
from plrs.curriculum import load_dag
|
| 10 |
-
|
| 11 |
-
graph = load_dag("my_curriculum.json")
|
| 12 |
-
pipeline = PLRSPipeline(graph)
|
| 13 |
-
results = pipeline.recommend(student_history)
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
from plrs.pipeline import PLRSPipeline
|
| 17 |
-
from plrs.model.sakt import SAKTModel
|
| 18 |
-
from plrs.constraints.dag import DAGConstraintLayer
|
| 19 |
-
from plrs.ranking.ranker import MultiObjectiveRanker
|
| 20 |
-
from plrs.curriculum.loader import load_dag, CurriculumGraph
|
| 21 |
-
|
| 22 |
-
__version__ = "0.1.0"
|
| 23 |
-
__all__ = [
|
| 24 |
-
"PLRSPipeline",
|
| 25 |
-
"SAKTModel",
|
| 26 |
-
"DAGConstraintLayer",
|
| 27 |
-
"MultiObjectiveRanker",
|
| 28 |
-
"load_dag",
|
| 29 |
-
"CurriculumGraph",
|
| 30 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plrs/constraints/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from plrs.constraints.dag import DAGConstraintLayer, MasteryVector, ConstraintResult
|
| 2 |
-
|
| 3 |
-
__all__ = ["DAGConstraintLayer", "MasteryVector", "ConstraintResult"]
|
|
|
|
|
|
|
|
|
|
|
|
plrs/constraints/dag.py
DELETED
|
@@ -1,201 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
plrs.constraints.dag
|
| 3 |
-
====================
|
| 4 |
-
DAG-based prerequisite constraint layer.
|
| 5 |
-
|
| 6 |
-
Three-tier classification:
|
| 7 |
-
- approved : prerequisites met, topic is ready
|
| 8 |
-
- challenging : prerequisites partially met (above soft threshold)
|
| 9 |
-
- vetoed : prerequisites not met, topic is blocked
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
from dataclasses import dataclass, field
|
| 15 |
-
from typing import Literal
|
| 16 |
-
|
| 17 |
-
from plrs.curriculum.loader import CurriculumGraph
|
| 18 |
-
|
| 19 |
-
Status = Literal["approved", "challenging", "vetoed"]
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class MasteryVector:
|
| 23 |
-
"""
|
| 24 |
-
Holds a student's estimated mastery probability per topic.
|
| 25 |
-
|
| 26 |
-
Parameters
|
| 27 |
-
----------
|
| 28 |
-
curriculum : CurriculumGraph
|
| 29 |
-
threshold : float
|
| 30 |
-
Mastery threshold — above this, a topic is considered mastered (default 0.70).
|
| 31 |
-
soft_threshold : float
|
| 32 |
-
Soft threshold — above this but below threshold, a topic is "challenging" (default 0.50).
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
def __init__(
|
| 36 |
-
self,
|
| 37 |
-
curriculum: CurriculumGraph,
|
| 38 |
-
threshold: float = 0.70,
|
| 39 |
-
soft_threshold: float = 0.50,
|
| 40 |
-
) -> None:
|
| 41 |
-
self.curriculum = curriculum
|
| 42 |
-
self.threshold = threshold
|
| 43 |
-
self.soft_threshold = soft_threshold
|
| 44 |
-
self._mastery: dict[str, float] = {node: 0.0 for node in curriculum.nodes}
|
| 45 |
-
|
| 46 |
-
# ------------------------------------------------------------------ #
|
| 47 |
-
# Mutations #
|
| 48 |
-
# ------------------------------------------------------------------ #
|
| 49 |
-
|
| 50 |
-
def update(self, topic_id: str, probability: float) -> None:
|
| 51 |
-
"""Set mastery probability for a topic (clamped to [0, 1])."""
|
| 52 |
-
if topic_id in self._mastery:
|
| 53 |
-
self._mastery[topic_id] = max(0.0, min(1.0, probability))
|
| 54 |
-
|
| 55 |
-
def update_batch(self, updates: dict[str, float]) -> None:
|
| 56 |
-
"""Update multiple topics at once."""
|
| 57 |
-
for topic_id, prob in updates.items():
|
| 58 |
-
self.update(topic_id, prob)
|
| 59 |
-
|
| 60 |
-
def cascade_up(self) -> None:
|
| 61 |
-
"""
|
| 62 |
-
Propagate mastery scores upward through the DAG.
|
| 63 |
-
|
| 64 |
-
If a student has high mastery on a topic, infer that their
|
| 65 |
-
prerequisites are also likely mastered.
|
| 66 |
-
"""
|
| 67 |
-
changed = True
|
| 68 |
-
while changed:
|
| 69 |
-
changed = False
|
| 70 |
-
for node in self.curriculum.nodes:
|
| 71 |
-
node_mastery = self.get(node)
|
| 72 |
-
if node_mastery < 0.40:
|
| 73 |
-
continue
|
| 74 |
-
# For each prerequisite of this node
|
| 75 |
-
for prereq in self.curriculum.prerequisites(node):
|
| 76 |
-
prereq_mastery = self.get(prereq)
|
| 77 |
-
# Infer prerequisite mastery as at least 85% of descendant mastery
|
| 78 |
-
inferred = min(node_mastery * 0.85, 0.95)
|
| 79 |
-
if inferred > prereq_mastery:
|
| 80 |
-
self.update(prereq, inferred)
|
| 81 |
-
changed = True
|
| 82 |
-
|
| 83 |
-
# ------------------------------------------------------------------ #
|
| 84 |
-
# Queries #
|
| 85 |
-
# ------------------------------------------------------------------ #
|
| 86 |
-
|
| 87 |
-
def get(self, topic_id: str) -> float:
|
| 88 |
-
return self._mastery.get(topic_id, 0.0)
|
| 89 |
-
|
| 90 |
-
def is_mastered(self, topic_id: str) -> bool:
|
| 91 |
-
return self.get(topic_id) >= self.threshold
|
| 92 |
-
|
| 93 |
-
def is_partial(self, topic_id: str) -> bool:
|
| 94 |
-
"""Between soft_threshold and threshold — partially mastered."""
|
| 95 |
-
v = self.get(topic_id)
|
| 96 |
-
return self.soft_threshold <= v < self.threshold
|
| 97 |
-
|
| 98 |
-
def summary(self) -> dict:
|
| 99 |
-
mastered = [t for t in self._mastery if self.is_mastered(t)]
|
| 100 |
-
partial = [t for t in self._mastery if self.is_partial(t)]
|
| 101 |
-
return {
|
| 102 |
-
"total_topics": len(self._mastery),
|
| 103 |
-
"mastered": len(mastered),
|
| 104 |
-
"partial": len(partial),
|
| 105 |
-
"not_started": len(self._mastery) - len(mastered) - len(partial),
|
| 106 |
-
"mastery_rate": round(len(mastered) / max(len(self._mastery), 1), 3),
|
| 107 |
-
"mastered_topics": mastered,
|
| 108 |
-
}
|
| 109 |
-
|
| 110 |
-
def to_dict(self) -> dict[str, float]:
|
| 111 |
-
return dict(self._mastery)
|
| 112 |
-
|
| 113 |
-
def __repr__(self) -> str:
|
| 114 |
-
s = self.summary()
|
| 115 |
-
return (
|
| 116 |
-
f"MasteryVector(mastered={s['mastered']}/{s['total_topics']}, "
|
| 117 |
-
f"rate={s['mastery_rate']:.1%})"
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
@dataclass
|
| 122 |
-
class ConstraintResult:
|
| 123 |
-
topic_id: str
|
| 124 |
-
topic_label: str
|
| 125 |
-
status: Status
|
| 126 |
-
mastery: float
|
| 127 |
-
reasoning: str
|
| 128 |
-
score: float = 0.0
|
| 129 |
-
prerequisites: list[str] = field(default_factory=list)
|
| 130 |
-
unmet_prerequisites: list[str] = field(default_factory=list)
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
class DAGConstraintLayer:
|
| 134 |
-
"""
|
| 135 |
-
Validates topic recommendations against curriculum prerequisite structure.
|
| 136 |
-
|
| 137 |
-
Uses three-tier soft constraint logic:
|
| 138 |
-
- mastery >= threshold on ALL prerequisites → approved
|
| 139 |
-
- mastery >= soft_threshold on ALL prereqs → challenging
|
| 140 |
-
- any prerequisite below soft_threshold → vetoed
|
| 141 |
-
"""
|
| 142 |
-
|
| 143 |
-
def __init__(self, curriculum: CurriculumGraph) -> None:
|
| 144 |
-
self.curriculum = curriculum
|
| 145 |
-
|
| 146 |
-
def validate(
|
| 147 |
-
self,
|
| 148 |
-
topic_id: str,
|
| 149 |
-
mastery: MasteryVector,
|
| 150 |
-
) -> ConstraintResult:
|
| 151 |
-
label = self.curriculum.label(topic_id)
|
| 152 |
-
prereqs = self.curriculum.prerequisites(topic_id)
|
| 153 |
-
topic_mastery = mastery.get(topic_id)
|
| 154 |
-
|
| 155 |
-
if not prereqs:
|
| 156 |
-
return ConstraintResult(
|
| 157 |
-
topic_id=topic_id,
|
| 158 |
-
topic_label=label,
|
| 159 |
-
status="approved",
|
| 160 |
-
mastery=topic_mastery,
|
| 161 |
-
reasoning="No prerequisites required.",
|
| 162 |
-
prerequisites=[],
|
| 163 |
-
unmet_prerequisites=[],
|
| 164 |
-
)
|
| 165 |
-
|
| 166 |
-
prereq_labels = [self.curriculum.label(p) for p in prereqs]
|
| 167 |
-
unmet_hard = [p for p in prereqs if not mastery.is_mastered(p)]
|
| 168 |
-
unmet_soft = [p for p in prereqs if mastery.get(p) < mastery.soft_threshold]
|
| 169 |
-
|
| 170 |
-
if not unmet_soft:
|
| 171 |
-
# All prereqs above soft threshold — at least challenging
|
| 172 |
-
if not unmet_hard:
|
| 173 |
-
status = "approved"
|
| 174 |
-
reasoning = f"All {len(prereqs)} prerequisite(s) met."
|
| 175 |
-
else:
|
| 176 |
-
status = "challenging"
|
| 177 |
-
unmet_labels = [self.curriculum.label(p) for p in unmet_hard]
|
| 178 |
-
reasoning = (
|
| 179 |
-
f"Prerequisite(s) partially met. "
|
| 180 |
-
f"Strengthen: {', '.join(unmet_labels)}."
|
| 181 |
-
)
|
| 182 |
-
else:
|
| 183 |
-
status = "vetoed"
|
| 184 |
-
unmet_labels = [self.curriculum.label(p) for p in unmet_soft]
|
| 185 |
-
reasoning = (
|
| 186 |
-
f"Blocked. Master first: {', '.join(unmet_labels)}."
|
| 187 |
-
)
|
| 188 |
-
|
| 189 |
-
return ConstraintResult(
|
| 190 |
-
topic_id=topic_id,
|
| 191 |
-
topic_label=label,
|
| 192 |
-
status=status,
|
| 193 |
-
mastery=topic_mastery,
|
| 194 |
-
reasoning=reasoning,
|
| 195 |
-
prerequisites=prereq_labels,
|
| 196 |
-
unmet_prerequisites=[self.curriculum.label(p) for p in (unmet_hard if status == "challenging" else unmet_soft)],
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
def validate_all(self, mastery: MasteryVector) -> list[ConstraintResult]:
|
| 200 |
-
"""Validate every topic in the curriculum."""
|
| 201 |
-
return [self.validate(node, mastery) for node in self.curriculum.nodes]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plrs/curriculum/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from plrs.curriculum.loader import load_dag, CurriculumGraph
|
| 2 |
-
|
| 3 |
-
__all__ = ["load_dag", "CurriculumGraph"]
|
|
|
|
|
|
|
|
|
|
|
|
plrs/curriculum/loader.py
DELETED
|
@@ -1,144 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
plrs.curriculum.loader
|
| 3 |
-
======================
|
| 4 |
-
Load and validate curriculum knowledge graphs from JSON.
|
| 5 |
-
|
| 6 |
-
The JSON schema is deliberately simple so educators can author their own:
|
| 7 |
-
|
| 8 |
-
{
|
| 9 |
-
"domain": "Mathematics",
|
| 10 |
-
"nodes": [
|
| 11 |
-
{"id": "algebra_basics", "label": "Algebra Basics", "level": "JSS3"},
|
| 12 |
-
{"id": "quadratic_equations", "label": "Quadratic Equations", "level": "SS1"}
|
| 13 |
-
],
|
| 14 |
-
"edges": [
|
| 15 |
-
{"from": "algebra_basics", "to": "quadratic_equations"}
|
| 16 |
-
]
|
| 17 |
-
}
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
from __future__ import annotations
|
| 21 |
-
|
| 22 |
-
import json
|
| 23 |
-
from dataclasses import dataclass, field
|
| 24 |
-
from pathlib import Path
|
| 25 |
-
from typing import Any
|
| 26 |
-
|
| 27 |
-
import networkx as nx
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
@dataclass
|
| 31 |
-
class CurriculumGraph:
|
| 32 |
-
"""Thin wrapper around a NetworkX DiGraph with domain metadata."""
|
| 33 |
-
|
| 34 |
-
domain: str
|
| 35 |
-
graph: nx.DiGraph
|
| 36 |
-
meta: dict[str, Any] = field(default_factory=dict)
|
| 37 |
-
|
| 38 |
-
# ------------------------------------------------------------------ #
|
| 39 |
-
# Properties #
|
| 40 |
-
# ------------------------------------------------------------------ #
|
| 41 |
-
|
| 42 |
-
@property
|
| 43 |
-
def nodes(self) -> list[str]:
|
| 44 |
-
return list(self.graph.nodes)
|
| 45 |
-
|
| 46 |
-
@property
|
| 47 |
-
def num_nodes(self) -> int:
|
| 48 |
-
return self.graph.number_of_nodes()
|
| 49 |
-
|
| 50 |
-
@property
|
| 51 |
-
def num_edges(self) -> int:
|
| 52 |
-
return self.graph.number_of_edges()
|
| 53 |
-
|
| 54 |
-
def label(self, node_id: str) -> str:
|
| 55 |
-
return self.graph.nodes[node_id].get("label", node_id)
|
| 56 |
-
|
| 57 |
-
def level(self, node_id: str) -> str:
|
| 58 |
-
return self.graph.nodes[node_id].get("level", "")
|
| 59 |
-
|
| 60 |
-
def prerequisites(self, node_id: str) -> list[str]:
|
| 61 |
-
return list(self.graph.predecessors(node_id))
|
| 62 |
-
|
| 63 |
-
def successors(self, node_id: str) -> list[str]:
|
| 64 |
-
return list(self.graph.successors(node_id))
|
| 65 |
-
|
| 66 |
-
def descendants(self, node_id: str) -> list[str]:
|
| 67 |
-
return list(nx.descendants(self.graph, node_id))
|
| 68 |
-
|
| 69 |
-
def validate(self) -> list[str]:
|
| 70 |
-
"""Return a list of validation warnings (empty = all good)."""
|
| 71 |
-
warnings: list[str] = []
|
| 72 |
-
if not nx.is_directed_acyclic_graph(self.graph):
|
| 73 |
-
warnings.append("Graph contains cycles — prerequisite checking will be unreliable.")
|
| 74 |
-
isolates = list(nx.isolates(self.graph))
|
| 75 |
-
if isolates:
|
| 76 |
-
warnings.append(f"{len(isolates)} isolated nodes (no edges): {isolates[:5]}")
|
| 77 |
-
return warnings
|
| 78 |
-
|
| 79 |
-
def __repr__(self) -> str:
|
| 80 |
-
return (
|
| 81 |
-
f"CurriculumGraph(domain={self.domain!r}, "
|
| 82 |
-
f"nodes={self.num_nodes}, edges={self.num_edges})"
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def load_dag(path: str | Path) -> CurriculumGraph:
|
| 87 |
-
"""
|
| 88 |
-
Load a curriculum DAG from a JSON file.
|
| 89 |
-
|
| 90 |
-
Parameters
|
| 91 |
-
----------
|
| 92 |
-
path : str or Path
|
| 93 |
-
Path to the curriculum JSON file.
|
| 94 |
-
|
| 95 |
-
Returns
|
| 96 |
-
-------
|
| 97 |
-
CurriculumGraph
|
| 98 |
-
|
| 99 |
-
Raises
|
| 100 |
-
------
|
| 101 |
-
FileNotFoundError
|
| 102 |
-
If the file does not exist.
|
| 103 |
-
ValueError
|
| 104 |
-
If the JSON schema is invalid.
|
| 105 |
-
"""
|
| 106 |
-
path = Path(path)
|
| 107 |
-
if not path.exists():
|
| 108 |
-
raise FileNotFoundError(f"Curriculum file not found: {path}")
|
| 109 |
-
|
| 110 |
-
with open(path) as f:
|
| 111 |
-
data = json.load(f)
|
| 112 |
-
|
| 113 |
-
_validate_schema(data, path)
|
| 114 |
-
|
| 115 |
-
domain = data.get("domain", path.stem)
|
| 116 |
-
meta = {k: v for k, v in data.items() if k not in ("nodes", "edges", "domain")}
|
| 117 |
-
|
| 118 |
-
G = nx.DiGraph()
|
| 119 |
-
for node in data["nodes"]:
|
| 120 |
-
G.add_node(node["id"], **{k: v for k, v in node.items() if k != "id"})
|
| 121 |
-
for edge in data["edges"]:
|
| 122 |
-
G.add_edge(edge["from"], edge["to"])
|
| 123 |
-
|
| 124 |
-
curriculum = CurriculumGraph(domain=domain, graph=G, meta=meta)
|
| 125 |
-
|
| 126 |
-
warnings = curriculum.validate()
|
| 127 |
-
for w in warnings:
|
| 128 |
-
import warnings as _w
|
| 129 |
-
_w.warn(f"[PLRS] {w}", stacklevel=2)
|
| 130 |
-
|
| 131 |
-
return curriculum
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
def _validate_schema(data: dict, path: Path) -> None:
|
| 135 |
-
if "nodes" not in data:
|
| 136 |
-
raise ValueError(f"{path}: Missing required key 'nodes'")
|
| 137 |
-
if "edges" not in data:
|
| 138 |
-
raise ValueError(f"{path}: Missing required key 'edges'")
|
| 139 |
-
for i, node in enumerate(data["nodes"]):
|
| 140 |
-
if "id" not in node:
|
| 141 |
-
raise ValueError(f"{path}: Node at index {i} missing required key 'id'")
|
| 142 |
-
for i, edge in enumerate(data["edges"]):
|
| 143 |
-
if "from" not in edge or "to" not in edge:
|
| 144 |
-
raise ValueError(f"{path}: Edge at index {i} missing 'from' or 'to'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plrs/model/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
from plrs.model.sakt import SAKTModel
|
| 2 |
-
from plrs.model.sakt_decay import SAKTWithDecay
|
| 3 |
-
from plrs.model.trainer import SAKTTrainer, TrainerConfig, load_sequences_from_csv
|
| 4 |
-
|
| 5 |
-
__all__ = ["SAKTModel", "SAKTWithDecay", "SAKTTrainer", "TrainerConfig", "load_sequences_from_csv"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plrs/model/evaluator.py
DELETED
|
@@ -1,374 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
plrs.model.evaluator
|
| 3 |
-
====================
|
| 4 |
-
Evaluation suite for PLRS.
|
| 5 |
-
|
| 6 |
-
Metrics:
|
| 7 |
-
- Knowledge Tracing: AUC-ROC, Accuracy, Binary Cross-Entropy
|
| 8 |
-
- Recommendation: Prerequisite Violation Rate, Coverage, Diversity
|
| 9 |
-
- Baselines: Random, Popularity, BKT (Bayesian Knowledge Tracing)
|
| 10 |
-
|
| 11 |
-
Usage:
|
| 12 |
-
from plrs.model.evaluator import PLRSEvaluator
|
| 13 |
-
evaluator = PLRSEvaluator(pipeline, curriculum)
|
| 14 |
-
report = evaluator.evaluate(test_sequences, skill_to_topic)
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
from __future__ import annotations
|
| 18 |
-
|
| 19 |
-
import time
|
| 20 |
-
from dataclasses import dataclass, field
|
| 21 |
-
from typing import Any
|
| 22 |
-
|
| 23 |
-
import numpy as np
|
| 24 |
-
|
| 25 |
-
try:
|
| 26 |
-
from sklearn.metrics import roc_auc_score, accuracy_score, log_loss
|
| 27 |
-
HAS_SKLEARN = True
|
| 28 |
-
except ImportError:
|
| 29 |
-
HAS_SKLEARN = False
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
# ── Baseline models ───────────────────────────────────────────────────────────
|
| 33 |
-
|
| 34 |
-
class RandomBaseline:
|
| 35 |
-
"""Predicts 0.5 for every interaction."""
|
| 36 |
-
def predict(self, skill_seq, correct_seq):
|
| 37 |
-
return {i: 0.5 for i in range(len(skill_seq))}
|
| 38 |
-
|
| 39 |
-
def recommend(self, curriculum, n=5):
|
| 40 |
-
import random
|
| 41 |
-
return random.sample(curriculum.nodes, min(n, len(curriculum.nodes)))
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
class PopularityBaseline:
|
| 45 |
-
"""Recommends the most-seen skills; predicts by global correctness rate."""
|
| 46 |
-
|
| 47 |
-
def __init__(self):
|
| 48 |
-
self.skill_correct: dict[int, list[float]] = {}
|
| 49 |
-
self.topic_count: dict[str, int] = {}
|
| 50 |
-
|
| 51 |
-
def fit(self, sequences, skill_to_topic=None):
|
| 52 |
-
for skill_seq, correct_seq in sequences:
|
| 53 |
-
for skill, correct in zip(skill_seq, correct_seq):
|
| 54 |
-
self.skill_correct.setdefault(skill, []).append(float(correct))
|
| 55 |
-
if skill_to_topic:
|
| 56 |
-
topic = skill_to_topic.get(skill)
|
| 57 |
-
if topic:
|
| 58 |
-
self.topic_count[topic] = self.topic_count.get(topic, 0) + 1
|
| 59 |
-
|
| 60 |
-
def predict_prob(self, skill_id: int) -> float:
|
| 61 |
-
history = self.skill_correct.get(skill_id, [])
|
| 62 |
-
return float(np.mean(history)) if history else 0.5
|
| 63 |
-
|
| 64 |
-
def recommend(self, curriculum, n=5):
|
| 65 |
-
if not self.topic_count:
|
| 66 |
-
return curriculum.nodes[:n]
|
| 67 |
-
sorted_topics = sorted(self.topic_count, key=self.topic_count.get, reverse=True)
|
| 68 |
-
return [t for t in sorted_topics if t in curriculum.nodes][:n]
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
class BKTBaseline:
|
| 72 |
-
"""
|
| 73 |
-
Bayesian Knowledge Tracing (per-skill).
|
| 74 |
-
Simple 4-parameter model: p_init, p_transit, p_slip, p_guess.
|
| 75 |
-
"""
|
| 76 |
-
|
| 77 |
-
def __init__(self, p_init=0.3, p_transit=0.1, p_slip=0.1, p_guess=0.2):
|
| 78 |
-
self.p_init = p_init
|
| 79 |
-
self.p_transit = p_transit
|
| 80 |
-
self.p_slip = p_slip
|
| 81 |
-
self.p_guess = p_guess
|
| 82 |
-
self._mastery: dict[int, float] = {}
|
| 83 |
-
|
| 84 |
-
def _update(self, skill: int, correct: int) -> float:
|
| 85 |
-
p = self._mastery.get(skill, self.p_init)
|
| 86 |
-
# Bayes update
|
| 87 |
-
if correct:
|
| 88 |
-
num = p * (1 - self.p_slip)
|
| 89 |
-
den = num + (1 - p) * self.p_guess
|
| 90 |
-
else:
|
| 91 |
-
num = p * self.p_slip
|
| 92 |
-
den = num + (1 - p) * (1 - self.p_guess)
|
| 93 |
-
p_post = num / max(den, 1e-9)
|
| 94 |
-
# Learning
|
| 95 |
-
p_post = p_post + (1 - p_post) * self.p_transit
|
| 96 |
-
self._mastery[skill] = p_post
|
| 97 |
-
return p_post
|
| 98 |
-
|
| 99 |
-
def predict_sequence(self, skill_seq: list[int], correct_seq: list[int]) -> list[float]:
|
| 100 |
-
self._mastery = {}
|
| 101 |
-
probs = []
|
| 102 |
-
for skill, correct in zip(skill_seq[:-1], correct_seq[:-1]):
|
| 103 |
-
self._update(skill, correct)
|
| 104 |
-
next_skill = skill_seq[len(probs) + 1]
|
| 105 |
-
probs.append(self._mastery.get(next_skill, self.p_init))
|
| 106 |
-
return probs
|
| 107 |
-
|
| 108 |
-
def get_mastery(self) -> dict[int, float]:
|
| 109 |
-
return dict(self._mastery)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
# ── Result dataclasses ────────────────────────────────────────────────────────
|
| 113 |
-
|
| 114 |
-
@dataclass
|
| 115 |
-
class KTMetrics:
|
| 116 |
-
"""Knowledge tracing evaluation metrics."""
|
| 117 |
-
model_name: str
|
| 118 |
-
auc: float
|
| 119 |
-
accuracy: float
|
| 120 |
-
log_loss: float
|
| 121 |
-
n_samples: int
|
| 122 |
-
elapsed_s: float
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
@dataclass
|
| 126 |
-
class RecommendMetrics:
|
| 127 |
-
"""Recommendation quality metrics."""
|
| 128 |
-
violation_rate: float # fraction of recommendations that violate prerequisites
|
| 129 |
-
coverage: float # fraction of curriculum covered by recommendations
|
| 130 |
-
avg_downstream: float # avg topics unlocked by recommendations
|
| 131 |
-
mastery_rate: float # avg student mastery in test set
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
@dataclass
|
| 135 |
-
class EvaluationReport:
|
| 136 |
-
"""Full evaluation report."""
|
| 137 |
-
kt_metrics: list[KTMetrics]
|
| 138 |
-
rec_metrics: RecommendMetrics | None
|
| 139 |
-
config: dict[str, Any]
|
| 140 |
-
timestamp: str
|
| 141 |
-
|
| 142 |
-
def print(self) -> None:
|
| 143 |
-
print("\n" + "=" * 62)
|
| 144 |
-
print(" PLRS EVALUATION REPORT")
|
| 145 |
-
print("=" * 62)
|
| 146 |
-
|
| 147 |
-
print(f"\n{'Model':<22} {'AUC':>8} {'Accuracy':>10} {'Log Loss':>10} {'Samples':>8}")
|
| 148 |
-
print("-" * 62)
|
| 149 |
-
for m in self.kt_metrics:
|
| 150 |
-
print(f"{m.model_name:<22} {m.auc:>8.4f} {m.accuracy:>10.4f} {m.log_loss:>10.4f} {m.n_samples:>8,}")
|
| 151 |
-
|
| 152 |
-
if self.rec_metrics:
|
| 153 |
-
r = self.rec_metrics
|
| 154 |
-
print(f"\n{'Recommendation Metrics':}")
|
| 155 |
-
print(f" Prerequisite violation rate : {r.violation_rate:.1%}")
|
| 156 |
-
print(f" Curriculum coverage : {r.coverage:.1%}")
|
| 157 |
-
print(f" Avg downstream unlocked : {r.avg_downstream:.1f}")
|
| 158 |
-
print(f" Avg student mastery rate : {r.mastery_rate:.1%}")
|
| 159 |
-
|
| 160 |
-
print("=" * 62 + "\n")
|
| 161 |
-
|
| 162 |
-
def to_dict(self) -> dict:
|
| 163 |
-
return {
|
| 164 |
-
"kt_metrics": [
|
| 165 |
-
{
|
| 166 |
-
"model": m.model_name,
|
| 167 |
-
"auc": round(m.auc, 6),
|
| 168 |
-
"accuracy": round(m.accuracy, 6),
|
| 169 |
-
"log_loss": round(m.log_loss, 6),
|
| 170 |
-
"n_samples": m.n_samples,
|
| 171 |
-
"elapsed_s": round(m.elapsed_s, 3),
|
| 172 |
-
}
|
| 173 |
-
for m in self.kt_metrics
|
| 174 |
-
],
|
| 175 |
-
"rec_metrics": {
|
| 176 |
-
"violation_rate": round(self.rec_metrics.violation_rate, 6),
|
| 177 |
-
"coverage": round(self.rec_metrics.coverage, 6),
|
| 178 |
-
"avg_downstream": round(self.rec_metrics.avg_downstream, 3),
|
| 179 |
-
"mastery_rate": round(self.rec_metrics.mastery_rate, 6),
|
| 180 |
-
} if self.rec_metrics else None,
|
| 181 |
-
"config": self.config,
|
| 182 |
-
"timestamp": self.timestamp,
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
# ── Main evaluator ────────────────────────────────────────────────────────────
|
| 187 |
-
|
| 188 |
-
class PLRSEvaluator:
|
| 189 |
-
"""
|
| 190 |
-
Evaluate PLRS against baselines on held-out student sequences.
|
| 191 |
-
|
| 192 |
-
Parameters
|
| 193 |
-
----------
|
| 194 |
-
pipeline : PLRSPipeline
|
| 195 |
-
A loaded pipeline (with or without SAKT model).
|
| 196 |
-
"""
|
| 197 |
-
|
| 198 |
-
def __init__(self, pipeline) -> None:
|
| 199 |
-
self.pipeline = pipeline
|
| 200 |
-
self.curriculum = pipeline.curriculum
|
| 201 |
-
|
| 202 |
-
def evaluate(
|
| 203 |
-
self,
|
| 204 |
-
test_sequences: list[tuple[list[int], list[int]]],
|
| 205 |
-
skill_to_topic: dict[int, str] | None = None,
|
| 206 |
-
train_sequences: list[tuple[list[int], list[int]]] | None = None,
|
| 207 |
-
include_baselines: bool = True,
|
| 208 |
-
) -> EvaluationReport:
|
| 209 |
-
"""
|
| 210 |
-
Run full evaluation.
|
| 211 |
-
|
| 212 |
-
Parameters
|
| 213 |
-
----------
|
| 214 |
-
test_sequences : list of (skill_seq, correct_seq)
|
| 215 |
-
skill_to_topic : dict mapping skill_id → curriculum topic_id
|
| 216 |
-
train_sequences : used to fit popularity baseline
|
| 217 |
-
include_baselines : whether to evaluate BKT and popularity baselines
|
| 218 |
-
|
| 219 |
-
Returns
|
| 220 |
-
-------
|
| 221 |
-
EvaluationReport
|
| 222 |
-
"""
|
| 223 |
-
import datetime
|
| 224 |
-
|
| 225 |
-
kt_metrics: list[KTMetrics] = []
|
| 226 |
-
|
| 227 |
-
# ── SAKT evaluation ──────────────────────────────────────────
|
| 228 |
-
if self.pipeline._model is not None:
|
| 229 |
-
kt_metrics.append(
|
| 230 |
-
self._eval_sakt(test_sequences)
|
| 231 |
-
)
|
| 232 |
-
|
| 233 |
-
# ── Baselines ────────────────────────────────────────────────
|
| 234 |
-
if include_baselines:
|
| 235 |
-
kt_metrics.append(self._eval_random(test_sequences))
|
| 236 |
-
kt_metrics.append(self._eval_bkt(test_sequences))
|
| 237 |
-
|
| 238 |
-
pop = PopularityBaseline()
|
| 239 |
-
pop.fit(train_sequences or test_sequences, skill_to_topic)
|
| 240 |
-
kt_metrics.append(self._eval_popularity(test_sequences, pop))
|
| 241 |
-
|
| 242 |
-
# ── Recommendation metrics ───────────────────────────────────
|
| 243 |
-
rec_metrics = self._eval_recommendations(test_sequences, skill_to_topic)
|
| 244 |
-
|
| 245 |
-
return EvaluationReport(
|
| 246 |
-
kt_metrics=kt_metrics,
|
| 247 |
-
rec_metrics=rec_metrics,
|
| 248 |
-
config={
|
| 249 |
-
"threshold": self.pipeline.threshold,
|
| 250 |
-
"soft_threshold": self.pipeline.soft_threshold,
|
| 251 |
-
"top_n": self.pipeline.top_n,
|
| 252 |
-
"n_test_students": len(test_sequences),
|
| 253 |
-
},
|
| 254 |
-
timestamp=datetime.datetime.now().isoformat(),
|
| 255 |
-
)
|
| 256 |
-
|
| 257 |
-
# ── KT evaluation helpers ─────────────────────────────────────────────────
|
| 258 |
-
|
| 259 |
-
def _eval_sakt(self, sequences) -> KTMetrics:
|
| 260 |
-
t0 = time.time()
|
| 261 |
-
all_probs, all_labels = [], []
|
| 262 |
-
|
| 263 |
-
for skill_seq, correct_seq in sequences:
|
| 264 |
-
if len(skill_seq) < 2:
|
| 265 |
-
continue
|
| 266 |
-
probs = self.pipeline._model.predict_mastery(skill_seq, correct_seq)
|
| 267 |
-
for skill_id, prob in probs.items():
|
| 268 |
-
if skill_id < len(correct_seq):
|
| 269 |
-
all_probs.append(prob)
|
| 270 |
-
all_labels.append(float(correct_seq[skill_id]))
|
| 271 |
-
|
| 272 |
-
return self._compute_kt_metrics("SAKT", all_probs, all_labels, time.time() - t0)
|
| 273 |
-
|
| 274 |
-
def _eval_random(self, sequences) -> KTMetrics:
|
| 275 |
-
t0 = time.time()
|
| 276 |
-
all_probs, all_labels = [], []
|
| 277 |
-
for skill_seq, correct_seq in sequences:
|
| 278 |
-
for correct in correct_seq[1:]:
|
| 279 |
-
all_probs.append(0.5)
|
| 280 |
-
all_labels.append(float(correct))
|
| 281 |
-
return self._compute_kt_metrics("Random (baseline)", all_probs, all_labels, time.time() - t0)
|
| 282 |
-
|
| 283 |
-
def _eval_bkt(self, sequences) -> KTMetrics:
|
| 284 |
-
t0 = time.time()
|
| 285 |
-
all_probs, all_labels = [], []
|
| 286 |
-
bkt = BKTBaseline()
|
| 287 |
-
for skill_seq, correct_seq in sequences:
|
| 288 |
-
if len(skill_seq) < 2:
|
| 289 |
-
continue
|
| 290 |
-
probs = bkt.predict_sequence(skill_seq, correct_seq)
|
| 291 |
-
labels = [float(c) for c in correct_seq[1:len(probs) + 1]]
|
| 292 |
-
all_probs.extend(probs)
|
| 293 |
-
all_labels.extend(labels)
|
| 294 |
-
return self._compute_kt_metrics("BKT (baseline)", all_probs, all_labels, time.time() - t0)
|
| 295 |
-
|
| 296 |
-
def _eval_popularity(self, sequences, pop: PopularityBaseline) -> KTMetrics:
|
| 297 |
-
t0 = time.time()
|
| 298 |
-
all_probs, all_labels = [], []
|
| 299 |
-
for skill_seq, correct_seq in sequences:
|
| 300 |
-
for skill, correct in zip(skill_seq[1:], correct_seq[1:]):
|
| 301 |
-
all_probs.append(pop.predict_prob(skill))
|
| 302 |
-
all_labels.append(float(correct))
|
| 303 |
-
return self._compute_kt_metrics("Popularity (baseline)", all_probs, all_labels, time.time() - t0)
|
| 304 |
-
|
| 305 |
-
@staticmethod
|
| 306 |
-
def _compute_kt_metrics(name, probs, labels, elapsed) -> KTMetrics:
|
| 307 |
-
probs_arr = np.nan_to_num(np.array(probs), nan=0.5)
|
| 308 |
-
labels_arr = np.nan_to_num(np.array(labels), nan=0.0)
|
| 309 |
-
n = len(probs_arr)
|
| 310 |
-
|
| 311 |
-
if HAS_SKLEARN and n > 0 and len(np.unique(labels_arr)) > 1:
|
| 312 |
-
auc = float(roc_auc_score(labels_arr, probs_arr))
|
| 313 |
-
acc = float(accuracy_score(labels_arr, (probs_arr >= 0.5).astype(int)))
|
| 314 |
-
loss = float(log_loss(labels_arr, np.clip(probs_arr, 1e-7, 1 - 1e-7)))
|
| 315 |
-
else:
|
| 316 |
-
auc = 0.5
|
| 317 |
-
acc = float(((probs_arr >= 0.5) == labels_arr).mean()) if n > 0 else 0.0
|
| 318 |
-
loss = float(-np.mean(
|
| 319 |
-
labels_arr * np.log(probs_arr + 1e-7) +
|
| 320 |
-
(1 - labels_arr) * np.log(1 - probs_arr + 1e-7)
|
| 321 |
-
)) if n > 0 else 0.0
|
| 322 |
-
|
| 323 |
-
return KTMetrics(
|
| 324 |
-
model_name=name, auc=auc, accuracy=acc,
|
| 325 |
-
log_loss=loss, n_samples=n, elapsed_s=elapsed,
|
| 326 |
-
)
|
| 327 |
-
|
| 328 |
-
# ── Recommendation evaluation ─────────────────────────────────────────────
|
| 329 |
-
|
| 330 |
-
def _eval_recommendations(
|
| 331 |
-
self,
|
| 332 |
-
sequences,
|
| 333 |
-
skill_to_topic,
|
| 334 |
-
) -> RecommendMetrics:
|
| 335 |
-
violation_rates, coverages, downstreams, mastery_rates = [], [], [], []
|
| 336 |
-
|
| 337 |
-
for skill_seq, correct_seq in sequences:
|
| 338 |
-
# Build mastery from sequence
|
| 339 |
-
if skill_to_topic:
|
| 340 |
-
topic_scores: dict[str, float] = {}
|
| 341 |
-
for skill, correct in zip(skill_seq, correct_seq):
|
| 342 |
-
topic = skill_to_topic.get(skill)
|
| 343 |
-
if topic and topic in self.curriculum.nodes:
|
| 344 |
-
topic_scores[topic] = max(topic_scores.get(topic, 0.0), float(correct))
|
| 345 |
-
mastery_scores = {n: 0.0 for n in self.curriculum.nodes}
|
| 346 |
-
mastery_scores.update(topic_scores)
|
| 347 |
-
else:
|
| 348 |
-
mastery_scores = {n: 0.0 for n in self.curriculum.nodes}
|
| 349 |
-
|
| 350 |
-
results = self.pipeline.recommend_from_mastery(mastery_scores)
|
| 351 |
-
stats = results["stats"]
|
| 352 |
-
summary = results["mastery_summary"]
|
| 353 |
-
|
| 354 |
-
violation_rates.append(stats["prerequisite_violation_rate"])
|
| 355 |
-
mastery_rates.append(summary["mastery_rate"])
|
| 356 |
-
|
| 357 |
-
# Coverage: fraction of curriculum represented in approved+challenging
|
| 358 |
-
rec_topics = set(
|
| 359 |
-
r["topic_id"] for r in results["approved"] + results["challenging"]
|
| 360 |
-
)
|
| 361 |
-
coverages.append(len(rec_topics) / max(self.curriculum.num_nodes, 1))
|
| 362 |
-
|
| 363 |
-
# Avg downstream unlock value
|
| 364 |
-
if results["approved"]:
|
| 365 |
-
downstreams.append(
|
| 366 |
-
np.mean([r["downstream_count"] for r in results["approved"]])
|
| 367 |
-
)
|
| 368 |
-
|
| 369 |
-
return RecommendMetrics(
|
| 370 |
-
violation_rate=float(np.mean(violation_rates)) if violation_rates else 0.0,
|
| 371 |
-
coverage=float(np.mean(coverages)) if coverages else 0.0,
|
| 372 |
-
avg_downstream=float(np.mean(downstreams)) if downstreams else 0.0,
|
| 373 |
-
mastery_rate=float(np.mean(mastery_rates)) if mastery_rates else 0.0,
|
| 374 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plrs/model/model_loader.py
DELETED
|
@@ -1,116 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
HF Space model loader — updated for SAKTWithDecay (v0.2.0 weights).
|
| 3 |
-
|
| 4 |
-
Drop this file into your HF Space as `model_loader.py` and call
|
| 5 |
-
`load_model_from_hub()` in app.py instead of the old loading logic.
|
| 6 |
-
|
| 7 |
-
The v0.2.0 weights (sakt_decay_best.pt) are saved with our new format:
|
| 8 |
-
{
|
| 9 |
-
"state_dict": {...},
|
| 10 |
-
"model_type": "SAKTWithDecay",
|
| 11 |
-
"config": {"num_skills": 20, "embed_dim": 64, ...}
|
| 12 |
-
}
|
| 13 |
-
|
| 14 |
-
Falls back gracefully to mastery-dict mode if weights can't be loaded.
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
from __future__ import annotations
|
| 18 |
-
|
| 19 |
-
import json
|
| 20 |
-
from pathlib import Path
|
| 21 |
-
|
| 22 |
-
import torch
|
| 23 |
-
|
| 24 |
-
HF_REPO = "Clementio/PLRS"
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def load_model_from_hub(device: str = "cpu"):
|
| 28 |
-
"""
|
| 29 |
-
Load SAKT model weights from HuggingFace Hub.
|
| 30 |
-
|
| 31 |
-
Tries files in priority order:
|
| 32 |
-
1. sakt_decay_best.pt (v0.2.0 — decay attention)
|
| 33 |
-
2. sakt_vanilla_best.pt (v0.2.0 — vanilla transformer)
|
| 34 |
-
3. sakt_model.pt (v0.1.0 — synthetic baseline)
|
| 35 |
-
|
| 36 |
-
Returns (model, model_type_str) or (None, "unavailable").
|
| 37 |
-
"""
|
| 38 |
-
try:
|
| 39 |
-
from huggingface_hub import hf_hub_download
|
| 40 |
-
except ImportError:
|
| 41 |
-
return None, "huggingface_hub not installed"
|
| 42 |
-
|
| 43 |
-
for filename, model_type in [
|
| 44 |
-
("models/sakt_decay_best.pt", "SAKTWithDecay"),
|
| 45 |
-
("models/sakt_vanilla_best.pt", "SAKTModel"),
|
| 46 |
-
("models/sakt_model.pt", "SAKTModel"),
|
| 47 |
-
]:
|
| 48 |
-
try:
|
| 49 |
-
path = hf_hub_download(repo_id=HF_REPO, filename=filename)
|
| 50 |
-
model = _load_weights(path, model_type, device)
|
| 51 |
-
if model is not None:
|
| 52 |
-
return model, model_type
|
| 53 |
-
except Exception:
|
| 54 |
-
continue
|
| 55 |
-
|
| 56 |
-
return None, "unavailable"
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def _load_weights(path: str, preferred_type: str, device: str):
|
| 60 |
-
"""Load model weights from a .pt file, handling both old and new formats."""
|
| 61 |
-
try:
|
| 62 |
-
payload = torch.load(path, map_location=device, weights_only=False)
|
| 63 |
-
except Exception:
|
| 64 |
-
return None
|
| 65 |
-
|
| 66 |
-
# ── New format (v0.2.0): {"state_dict": ..., "model_type": ..., "config": ...}
|
| 67 |
-
if isinstance(payload, dict) and "state_dict" in payload:
|
| 68 |
-
cfg = payload.get("config", {})
|
| 69 |
-
model_type = payload.get("model_type", preferred_type)
|
| 70 |
-
|
| 71 |
-
if model_type == "SAKTWithDecay":
|
| 72 |
-
from plrs.model.sakt_decay import SAKTWithDecay
|
| 73 |
-
model = SAKTWithDecay(
|
| 74 |
-
num_skills=cfg.get("num_skills", 5737),
|
| 75 |
-
embed_dim=cfg.get("embed_dim", 64),
|
| 76 |
-
num_heads=cfg.get("num_heads", 8),
|
| 77 |
-
dropout=cfg.get("dropout", 0.2),
|
| 78 |
-
max_seq_len=cfg.get("max_seq_len", 100),
|
| 79 |
-
decay_init=cfg.get("decay_init", 1.0),
|
| 80 |
-
)
|
| 81 |
-
else:
|
| 82 |
-
from plrs.model.sakt import SAKTModel
|
| 83 |
-
model = SAKTModel(
|
| 84 |
-
num_skills=cfg.get("num_skills", 5737),
|
| 85 |
-
embed_dim=cfg.get("embed_dim", 64),
|
| 86 |
-
num_heads=cfg.get("num_heads", 8),
|
| 87 |
-
dropout=cfg.get("dropout", 0.2),
|
| 88 |
-
max_seq_len=cfg.get("max_seq_len", 100),
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
try:
|
| 92 |
-
model.load_state_dict(payload["state_dict"], strict=False)
|
| 93 |
-
model.eval()
|
| 94 |
-
model.to(device)
|
| 95 |
-
return model
|
| 96 |
-
except Exception:
|
| 97 |
-
return None
|
| 98 |
-
|
| 99 |
-
# ── Old format (v0.1.0 FYP): raw state_dict + separate config.json
|
| 100 |
-
try:
|
| 101 |
-
config_path = Path(path).parent / "config.json"
|
| 102 |
-
if config_path.exists():
|
| 103 |
-
config = json.loads(config_path.read_text())
|
| 104 |
-
else:
|
| 105 |
-
config = {"num_skills": 5736, "embed_dim": 64}
|
| 106 |
-
|
| 107 |
-
from plrs.model.sakt import SAKTModel
|
| 108 |
-
model = SAKTModel(
|
| 109 |
-
num_skills=config.get("num_skills", 5736),
|
| 110 |
-
embed_dim=config.get("embed_dim", 64),
|
| 111 |
-
)
|
| 112 |
-
model.load_state_dict(payload, strict=False)
|
| 113 |
-
model.eval()
|
| 114 |
-
return model
|
| 115 |
-
except Exception:
|
| 116 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plrs/model/sakt.py
DELETED
|
@@ -1,219 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
plrs.model.sakt
|
| 3 |
-
===============
|
| 4 |
-
Self-Attentive Knowledge Tracing (SAKT) model.
|
| 5 |
-
|
| 6 |
-
Architecture: transformer-style attention over student interaction sequences.
|
| 7 |
-
Each interaction is encoded as (skill_id + correctness * n_skills).
|
| 8 |
-
|
| 9 |
-
Reference: Pandey & Karypis, 2019 — "A Self-Attentive model for Knowledge Tracing"
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
from pathlib import Path
|
| 15 |
-
from typing import Any
|
| 16 |
-
|
| 17 |
-
import torch
|
| 18 |
-
import torch.nn as nn
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class SAKTModel(nn.Module):
|
| 22 |
-
"""
|
| 23 |
-
SAKT: Self-Attentive Knowledge Tracing.
|
| 24 |
-
|
| 25 |
-
Parameters
|
| 26 |
-
----------
|
| 27 |
-
num_skills : int
|
| 28 |
-
Total number of unique skills in the dataset.
|
| 29 |
-
embed_dim : int
|
| 30 |
-
Embedding dimension for interactions and positions.
|
| 31 |
-
num_heads : int
|
| 32 |
-
Number of attention heads.
|
| 33 |
-
dropout : float
|
| 34 |
-
Dropout rate.
|
| 35 |
-
max_seq_len : int
|
| 36 |
-
Maximum interaction sequence length.
|
| 37 |
-
"""
|
| 38 |
-
|
| 39 |
-
def __init__(
|
| 40 |
-
self,
|
| 41 |
-
num_skills: int,
|
| 42 |
-
embed_dim: int = 64,
|
| 43 |
-
num_heads: int = 8,
|
| 44 |
-
dropout: float = 0.2,
|
| 45 |
-
max_seq_len: int = 100,
|
| 46 |
-
) -> None:
|
| 47 |
-
super().__init__()
|
| 48 |
-
self.num_skills = num_skills
|
| 49 |
-
self.embed_dim = embed_dim
|
| 50 |
-
self.max_seq_len = max_seq_len
|
| 51 |
-
|
| 52 |
-
# Interaction embedding: (skill, correct) → dense vector
|
| 53 |
-
self.interaction_embed = nn.Embedding(2 * num_skills + 2, embed_dim, padding_idx=0) # +2: shift+1 means max index = 2*n+1
|
| 54 |
-
# Positional embedding
|
| 55 |
-
self.pos_embed = nn.Embedding(max_seq_len, embed_dim)
|
| 56 |
-
|
| 57 |
-
# Self-attention layer
|
| 58 |
-
self.self_attn = nn.MultiheadAttention(
|
| 59 |
-
embed_dim=embed_dim,
|
| 60 |
-
num_heads=num_heads,
|
| 61 |
-
dropout=dropout,
|
| 62 |
-
batch_first=True,
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
self.layer_norm1 = nn.LayerNorm(embed_dim)
|
| 66 |
-
self.layer_norm2 = nn.LayerNorm(embed_dim)
|
| 67 |
-
|
| 68 |
-
self.ffn = nn.Sequential(
|
| 69 |
-
nn.Linear(embed_dim, embed_dim * 2),
|
| 70 |
-
nn.ReLU(),
|
| 71 |
-
nn.Dropout(dropout),
|
| 72 |
-
nn.Linear(embed_dim * 2, embed_dim),
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
# Skill query embedding for target prediction
|
| 76 |
-
self.skill_embed = nn.Embedding(num_skills + 1, embed_dim, padding_idx=0)
|
| 77 |
-
|
| 78 |
-
self.output_layer = nn.Linear(embed_dim * 2, 1)
|
| 79 |
-
self.dropout = nn.Dropout(dropout)
|
| 80 |
-
|
| 81 |
-
def forward(
|
| 82 |
-
self,
|
| 83 |
-
interactions: torch.Tensor, # (batch, seq_len)
|
| 84 |
-
target_skills: torch.Tensor, # (batch, seq_len)
|
| 85 |
-
mask: torch.Tensor, # (batch, seq_len) bool — True = real token
|
| 86 |
-
) -> torch.Tensor:
|
| 87 |
-
"""
|
| 88 |
-
Forward pass.
|
| 89 |
-
|
| 90 |
-
Returns
|
| 91 |
-
-------
|
| 92 |
-
torch.Tensor of shape (batch, seq_len) — logits per position.
|
| 93 |
-
"""
|
| 94 |
-
batch_size, seq_len = interactions.shape
|
| 95 |
-
positions = torch.arange(seq_len, device=interactions.device).unsqueeze(0)
|
| 96 |
-
|
| 97 |
-
x = self.interaction_embed(interactions) + self.pos_embed(positions)
|
| 98 |
-
x = self.dropout(x)
|
| 99 |
-
|
| 100 |
-
# Causal mask — bool upper-triangular (MHA handles conversion internally)
|
| 101 |
-
causal_mask = torch.triu(
|
| 102 |
-
torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool),
|
| 103 |
-
diagonal=1,
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
# Key padding mask: True = ignore (PyTorch MHA convention)
|
| 107 |
-
key_padding_mask = ~mask # (batch, seq_len) bool
|
| 108 |
-
|
| 109 |
-
x_attn, _ = self.self_attn(
|
| 110 |
-
query=x,
|
| 111 |
-
key=x,
|
| 112 |
-
value=x,
|
| 113 |
-
attn_mask=causal_mask,
|
| 114 |
-
key_padding_mask=key_padding_mask,
|
| 115 |
-
)
|
| 116 |
-
# Replace any NaN in attention output (from fully-masked rows) with 0
|
| 117 |
-
x_attn = torch.nan_to_num(x_attn, nan=0.0)
|
| 118 |
-
x = self.layer_norm1(x + x_attn)
|
| 119 |
-
x = self.layer_norm2(x + self.ffn(x))
|
| 120 |
-
|
| 121 |
-
# Concatenate with target skill embedding for final prediction
|
| 122 |
-
skill_x = self.skill_embed(target_skills)
|
| 123 |
-
out = self.output_layer(torch.cat([x, skill_x], dim=-1)).squeeze(-1)
|
| 124 |
-
|
| 125 |
-
return out # (batch, seq_len) logits
|
| 126 |
-
|
| 127 |
-
# ------------------------------------------------------------------ #
|
| 128 |
-
# Inference helpers #
|
| 129 |
-
# ------------------------------------------------------------------ #
|
| 130 |
-
|
| 131 |
-
@torch.no_grad()
|
| 132 |
-
def predict_mastery(
|
| 133 |
-
self,
|
| 134 |
-
skill_seq: list[int],
|
| 135 |
-
correct_seq: list[int],
|
| 136 |
-
device: torch.device | str = "cpu",
|
| 137 |
-
) -> dict[int, float]:
|
| 138 |
-
"""
|
| 139 |
-
Run inference on a student's interaction history.
|
| 140 |
-
|
| 141 |
-
Parameters
|
| 142 |
-
----------
|
| 143 |
-
skill_seq : list[int]
|
| 144 |
-
Sequence of skill IDs the student interacted with.
|
| 145 |
-
correct_seq : list[int]
|
| 146 |
-
Corresponding correctness (1 = correct, 0 = incorrect).
|
| 147 |
-
device : str or torch.device
|
| 148 |
-
|
| 149 |
-
Returns
|
| 150 |
-
-------
|
| 151 |
-
dict[int, float]
|
| 152 |
-
Mapping from skill_id → predicted mastery probability.
|
| 153 |
-
"""
|
| 154 |
-
if len(skill_seq) < 2:
|
| 155 |
-
return {}
|
| 156 |
-
|
| 157 |
-
if len(skill_seq) > self.max_seq_len:
|
| 158 |
-
skill_seq = skill_seq[-self.max_seq_len:]
|
| 159 |
-
correct_seq = correct_seq[-self.max_seq_len:]
|
| 160 |
-
|
| 161 |
-
interactions = [s + c * self.num_skills + 1 for s, c in zip(skill_seq[:-1], correct_seq[:-1])] # +1: reserve 0 for padding
|
| 162 |
-
target_skills = skill_seq[1:]
|
| 163 |
-
|
| 164 |
-
seq_len = len(interactions)
|
| 165 |
-
pad_len = self.max_seq_len - seq_len
|
| 166 |
-
|
| 167 |
-
interactions_padded = [0] * pad_len + interactions
|
| 168 |
-
target_padded = [0] * pad_len + target_skills
|
| 169 |
-
mask = [False] * pad_len + [True] * seq_len
|
| 170 |
-
|
| 171 |
-
interactions_t = torch.LongTensor([interactions_padded]).to(device)
|
| 172 |
-
target_t = torch.LongTensor([target_padded]).to(device)
|
| 173 |
-
mask_t = torch.BoolTensor([mask]).to(device)
|
| 174 |
-
|
| 175 |
-
self.eval()
|
| 176 |
-
self.to(device)
|
| 177 |
-
|
| 178 |
-
logits = self(interactions_t, target_t, mask_t)
|
| 179 |
-
probs = torch.sigmoid(logits).squeeze(0)
|
| 180 |
-
|
| 181 |
-
real_probs = probs[torch.BoolTensor(mask)].cpu().numpy()
|
| 182 |
-
mastery = {
|
| 183 |
-
int(skill_id): float(prob)
|
| 184 |
-
for skill_id, prob in zip(target_skills, real_probs)
|
| 185 |
-
}
|
| 186 |
-
return mastery
|
| 187 |
-
|
| 188 |
-
# ------------------------------------------------------------------ #
|
| 189 |
-
# Serialisation #
|
| 190 |
-
# ------------------------------------------------------------------ #
|
| 191 |
-
|
| 192 |
-
def save(self, path: str | Path, config: dict[str, Any] | None = None) -> None:
|
| 193 |
-
"""Save model weights and config to a .pt file."""
|
| 194 |
-
payload = {
|
| 195 |
-
"state_dict": self.state_dict(),
|
| 196 |
-
"config": config or {
|
| 197 |
-
"num_skills": self.num_skills,
|
| 198 |
-
"embed_dim": self.embed_dim,
|
| 199 |
-
"max_seq_len": self.max_seq_len,
|
| 200 |
-
},
|
| 201 |
-
}
|
| 202 |
-
torch.save(payload, path)
|
| 203 |
-
|
| 204 |
-
@classmethod
|
| 205 |
-
def load(cls, path: str | Path, device: str | torch.device = "cpu") -> "SAKTModel":
|
| 206 |
-
"""Load a saved SAKT model."""
|
| 207 |
-
payload = torch.load(path, map_location=device, weights_only=False)
|
| 208 |
-
config = payload["config"]
|
| 209 |
-
model = cls(
|
| 210 |
-
num_skills=config["num_skills"],
|
| 211 |
-
embed_dim=config.get("embed_dim", 64),
|
| 212 |
-
num_heads=config.get("num_heads", 8),
|
| 213 |
-
dropout=config.get("dropout", 0.2),
|
| 214 |
-
max_seq_len=config.get("max_seq_len", 100),
|
| 215 |
-
)
|
| 216 |
-
model.load_state_dict(payload["state_dict"])
|
| 217 |
-
model.to(device)
|
| 218 |
-
model.eval()
|
| 219 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plrs/model/sakt_decay.py
DELETED
|
@@ -1,253 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
plrs.model.sakt_decay
|
| 3 |
-
=====================
|
| 4 |
-
SAKT with Ebbinghaus Forgetting Curve Decay.
|
| 5 |
-
|
| 6 |
-
Extends the base SAKT model by applying exponential temporal decay to
|
| 7 |
-
attention weights, reflecting that older interactions contribute less to
|
| 8 |
-
current mastery estimates.
|
| 9 |
-
|
| 10 |
-
The decay function follows the Ebbinghaus retention curve:
|
| 11 |
-
R(t) = exp(-t / decay_rate)
|
| 12 |
-
|
| 13 |
-
Where t is the time gap between interaction j and the current position i,
|
| 14 |
-
measured in interaction steps (or elapsed time if timestamps are available).
|
| 15 |
-
|
| 16 |
-
This typically improves val AUC by 0.01–0.02 over vanilla SAKT.
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
from __future__ import annotations
|
| 20 |
-
|
| 21 |
-
import math
|
| 22 |
-
from pathlib import Path
|
| 23 |
-
from typing import Any
|
| 24 |
-
|
| 25 |
-
import torch
|
| 26 |
-
import torch.nn as nn
|
| 27 |
-
import torch.nn.functional as F
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class DecayAttention(nn.Module):
|
| 31 |
-
"""
|
| 32 |
-
Multi-head attention with Ebbinghaus forgetting curve decay.
|
| 33 |
-
|
| 34 |
-
Applies position-based temporal decay to attention logits before softmax:
|
| 35 |
-
attention_logits[i, j] -= decay_rate_learned * log(1 + |i - j|)
|
| 36 |
-
|
| 37 |
-
The decay rate is a learned scalar per head, initialised from a prior.
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
def __init__(
|
| 41 |
-
self,
|
| 42 |
-
embed_dim: int,
|
| 43 |
-
num_heads: int,
|
| 44 |
-
dropout: float = 0.2,
|
| 45 |
-
decay_init: float = 1.0,
|
| 46 |
-
) -> None:
|
| 47 |
-
super().__init__()
|
| 48 |
-
self.embed_dim = embed_dim
|
| 49 |
-
self.num_heads = num_heads
|
| 50 |
-
self.head_dim = embed_dim // num_heads
|
| 51 |
-
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
| 52 |
-
|
| 53 |
-
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 54 |
-
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 55 |
-
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 56 |
-
self.out_proj = nn.Linear(embed_dim, embed_dim)
|
| 57 |
-
self.dropout = nn.Dropout(dropout)
|
| 58 |
-
|
| 59 |
-
# Learned decay rate per head — initialised to decay_init
|
| 60 |
-
# Constrained positive via softplus during forward
|
| 61 |
-
self.decay_logit = nn.Parameter(
|
| 62 |
-
torch.full((num_heads,), math.log(math.exp(decay_init) - 1))
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
def forward(
|
| 66 |
-
self,
|
| 67 |
-
x: torch.Tensor, # (batch, seq_len, embed_dim)
|
| 68 |
-
causal_mask: torch.Tensor, # (seq_len, seq_len) bool — True = block
|
| 69 |
-
key_padding_mask: torch.Tensor, # (batch, seq_len) bool — True = pad
|
| 70 |
-
) -> torch.Tensor:
|
| 71 |
-
B, L, D = x.shape
|
| 72 |
-
H, Hd = self.num_heads, self.head_dim
|
| 73 |
-
|
| 74 |
-
Q = self.q_proj(x).view(B, L, H, Hd).transpose(1, 2) # (B, H, L, Hd)
|
| 75 |
-
K = self.k_proj(x).view(B, L, H, Hd).transpose(1, 2)
|
| 76 |
-
V = self.v_proj(x).view(B, L, H, Hd).transpose(1, 2)
|
| 77 |
-
|
| 78 |
-
# Scaled dot-product attention scores
|
| 79 |
-
scale = math.sqrt(self.head_dim)
|
| 80 |
-
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale # (B, H, L, L)
|
| 81 |
-
|
| 82 |
-
# ── Ebbinghaus decay ──────────────────────────────────────── #
|
| 83 |
-
# Build temporal distance matrix: dist[i, j] = |i - j|
|
| 84 |
-
positions = torch.arange(L, device=x.device)
|
| 85 |
-
dist = (positions.unsqueeze(0) - positions.unsqueeze(1)).abs().float() # (L, L)
|
| 86 |
-
|
| 87 |
-
# decay = softplus(decay_logit) ensures strictly positive rates
|
| 88 |
-
decay_rate = F.softplus(self.decay_logit) # (H,)
|
| 89 |
-
|
| 90 |
-
# Decay penalty: rate_h * log(1 + dist) — shape (H, L, L)
|
| 91 |
-
decay_penalty = decay_rate.view(H, 1, 1) * torch.log1p(dist).unsqueeze(0)
|
| 92 |
-
scores = scores - decay_penalty.unsqueeze(0) # broadcast over batch
|
| 93 |
-
# ─────────────────────────────────────────────────────────── #
|
| 94 |
-
|
| 95 |
-
# Apply causal mask
|
| 96 |
-
scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), -1e9)
|
| 97 |
-
|
| 98 |
-
# Apply padding mask
|
| 99 |
-
if key_padding_mask is not None:
|
| 100 |
-
scores = scores.masked_fill(
|
| 101 |
-
key_padding_mask.unsqueeze(1).unsqueeze(2), -1e9
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
attn = F.softmax(scores, dim=-1)
|
| 105 |
-
attn = self.dropout(attn)
|
| 106 |
-
|
| 107 |
-
out = torch.matmul(attn, V) # (B, H, L, Hd)
|
| 108 |
-
out = out.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D)
|
| 109 |
-
return self.out_proj(out)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class SAKTWithDecay(nn.Module):
|
| 113 |
-
"""
|
| 114 |
-
SAKT + Ebbinghaus Forgetting Curve Decay.
|
| 115 |
-
|
| 116 |
-
Drop-in replacement for SAKTModel with improved AUC through
|
| 117 |
-
temporal decay attention. All other architecture details are identical.
|
| 118 |
-
|
| 119 |
-
Parameters
|
| 120 |
-
----------
|
| 121 |
-
num_skills : int
|
| 122 |
-
embed_dim : int
|
| 123 |
-
num_heads : int
|
| 124 |
-
dropout : float
|
| 125 |
-
max_seq_len : int
|
| 126 |
-
decay_init : float
|
| 127 |
-
Initial decay rate (higher = faster forgetting). Default 1.0.
|
| 128 |
-
"""
|
| 129 |
-
|
| 130 |
-
def __init__(
|
| 131 |
-
self,
|
| 132 |
-
num_skills: int,
|
| 133 |
-
embed_dim: int = 64,
|
| 134 |
-
num_heads: int = 8,
|
| 135 |
-
dropout: float = 0.2,
|
| 136 |
-
max_seq_len: int = 100,
|
| 137 |
-
decay_init: float = 1.0,
|
| 138 |
-
) -> None:
|
| 139 |
-
super().__init__()
|
| 140 |
-
self.num_skills = num_skills
|
| 141 |
-
self.embed_dim = embed_dim
|
| 142 |
-
self.max_seq_len = max_seq_len
|
| 143 |
-
|
| 144 |
-
self.interaction_embed = nn.Embedding(2 * num_skills + 2, embed_dim, padding_idx=0) # +2: shift+1 means max index = 2*n+1
|
| 145 |
-
self.pos_embed = nn.Embedding(max_seq_len, embed_dim)
|
| 146 |
-
|
| 147 |
-
# Decay-aware attention replaces nn.MultiheadAttention
|
| 148 |
-
self.decay_attn = DecayAttention(embed_dim, num_heads, dropout, decay_init)
|
| 149 |
-
|
| 150 |
-
self.layer_norm1 = nn.LayerNorm(embed_dim)
|
| 151 |
-
self.layer_norm2 = nn.LayerNorm(embed_dim)
|
| 152 |
-
self.ffn = nn.Sequential(
|
| 153 |
-
nn.Linear(embed_dim, embed_dim * 2),
|
| 154 |
-
nn.ReLU(),
|
| 155 |
-
nn.Dropout(dropout),
|
| 156 |
-
nn.Linear(embed_dim * 2, embed_dim),
|
| 157 |
-
)
|
| 158 |
-
|
| 159 |
-
self.skill_embed = nn.Embedding(num_skills + 1, embed_dim, padding_idx=0)
|
| 160 |
-
self.output_layer = nn.Linear(embed_dim * 2, 1)
|
| 161 |
-
self.dropout = nn.Dropout(dropout)
|
| 162 |
-
|
| 163 |
-
def forward(
|
| 164 |
-
self,
|
| 165 |
-
interactions: torch.Tensor,
|
| 166 |
-
target_skills: torch.Tensor,
|
| 167 |
-
mask: torch.Tensor,
|
| 168 |
-
) -> torch.Tensor:
|
| 169 |
-
B, L = interactions.shape
|
| 170 |
-
positions = torch.arange(L, device=interactions.device).unsqueeze(0)
|
| 171 |
-
|
| 172 |
-
x = self.interaction_embed(interactions) + self.pos_embed(positions)
|
| 173 |
-
x = self.dropout(x)
|
| 174 |
-
|
| 175 |
-
causal_mask = torch.triu(
|
| 176 |
-
torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1
|
| 177 |
-
)
|
| 178 |
-
key_padding_mask = ~mask # True = ignore
|
| 179 |
-
|
| 180 |
-
x_attn = self.decay_attn(x, causal_mask, key_padding_mask)
|
| 181 |
-
x = self.layer_norm1(x + x_attn)
|
| 182 |
-
x = self.layer_norm2(x + self.ffn(x))
|
| 183 |
-
|
| 184 |
-
skill_x = self.skill_embed(target_skills)
|
| 185 |
-
out = self.output_layer(torch.cat([x, skill_x], dim=-1)).squeeze(-1)
|
| 186 |
-
return out
|
| 187 |
-
|
| 188 |
-
@torch.no_grad()
|
| 189 |
-
def predict_mastery(
|
| 190 |
-
self,
|
| 191 |
-
skill_seq: list[int],
|
| 192 |
-
correct_seq: list[int],
|
| 193 |
-
device: torch.device | str = "cpu",
|
| 194 |
-
) -> dict[int, float]:
|
| 195 |
-
"""Same interface as SAKTModel.predict_mastery."""
|
| 196 |
-
if len(skill_seq) < 2:
|
| 197 |
-
return {}
|
| 198 |
-
|
| 199 |
-
if len(skill_seq) > self.max_seq_len:
|
| 200 |
-
skill_seq = skill_seq[-self.max_seq_len:]
|
| 201 |
-
correct_seq = correct_seq[-self.max_seq_len:]
|
| 202 |
-
|
| 203 |
-
interactions = [s + c * self.num_skills + 1 for s, c in zip(skill_seq[:-1], correct_seq[:-1])] # +1: reserve 0 for padding
|
| 204 |
-
target_skills = skill_seq[1:]
|
| 205 |
-
seq_len = len(interactions)
|
| 206 |
-
pad_len = self.max_seq_len - seq_len
|
| 207 |
-
|
| 208 |
-
interactions_padded = [0] * pad_len + interactions
|
| 209 |
-
target_padded = [0] * pad_len + target_skills
|
| 210 |
-
mask_list = [False] * pad_len + [True] * seq_len
|
| 211 |
-
|
| 212 |
-
self.eval()
|
| 213 |
-
self.to(device)
|
| 214 |
-
|
| 215 |
-
logits = self(
|
| 216 |
-
torch.LongTensor([interactions_padded]).to(device),
|
| 217 |
-
torch.LongTensor([target_padded]).to(device),
|
| 218 |
-
torch.BoolTensor([mask_list]).to(device),
|
| 219 |
-
)
|
| 220 |
-
probs = torch.sigmoid(logits).squeeze(0)
|
| 221 |
-
real_probs = probs[torch.BoolTensor(mask_list)].cpu().numpy()
|
| 222 |
-
|
| 223 |
-
return {int(sid): float(p) for sid, p in zip(target_skills, real_probs)}
|
| 224 |
-
|
| 225 |
-
def save(self, path: str | Path, config: dict[str, Any] | None = None) -> None:
|
| 226 |
-
payload = {
|
| 227 |
-
"state_dict": self.state_dict(),
|
| 228 |
-
"model_type": "SAKTWithDecay",
|
| 229 |
-
"config": config or {
|
| 230 |
-
"num_skills": self.num_skills,
|
| 231 |
-
"embed_dim": self.embed_dim,
|
| 232 |
-
"max_seq_len": self.max_seq_len,
|
| 233 |
-
"model_type": "SAKTWithDecay",
|
| 234 |
-
},
|
| 235 |
-
}
|
| 236 |
-
torch.save(payload, path)
|
| 237 |
-
|
| 238 |
-
@classmethod
|
| 239 |
-
def load(cls, path: str | Path, device: str | torch.device = "cpu") -> "SAKTWithDecay":
|
| 240 |
-
payload = torch.load(path, map_location=device, weights_only=False)
|
| 241 |
-
cfg = payload["config"]
|
| 242 |
-
model = cls(
|
| 243 |
-
num_skills=cfg["num_skills"],
|
| 244 |
-
embed_dim=cfg.get("embed_dim", 64),
|
| 245 |
-
num_heads=cfg.get("num_heads", 8),
|
| 246 |
-
dropout=cfg.get("dropout", 0.2),
|
| 247 |
-
max_seq_len=cfg.get("max_seq_len", 100),
|
| 248 |
-
decay_init=cfg.get("decay_init", 1.0),
|
| 249 |
-
)
|
| 250 |
-
model.load_state_dict(payload["state_dict"])
|
| 251 |
-
model.to(device)
|
| 252 |
-
model.eval()
|
| 253 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plrs/model/trainer.py
DELETED
|
@@ -1,437 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
plrs.model.trainer
|
| 3 |
-
==================
|
| 4 |
-
Training loop for the SAKT knowledge tracing model.
|
| 5 |
-
|
| 6 |
-
Handles:
|
| 7 |
-
- Dataset preparation from raw interaction logs
|
| 8 |
-
- Train / validation split
|
| 9 |
-
- Training with early stopping
|
| 10 |
-
- Checkpoint saving (best val AUC)
|
| 11 |
-
- Metrics: AUC, accuracy, loss
|
| 12 |
-
|
| 13 |
-
Expected input format (CSV or DataFrame):
|
| 14 |
-
student_id | skill_id | correct | timestamp (optional)
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
from __future__ import annotations
|
| 18 |
-
|
| 19 |
-
import time
|
| 20 |
-
from dataclasses import dataclass, field
|
| 21 |
-
from pathlib import Path
|
| 22 |
-
from typing import Iterator
|
| 23 |
-
|
| 24 |
-
import numpy as np
|
| 25 |
-
import torch
|
| 26 |
-
import torch.nn as nn
|
| 27 |
-
from torch.utils.data import DataLoader, Dataset
|
| 28 |
-
|
| 29 |
-
try:
|
| 30 |
-
from sklearn.metrics import roc_auc_score
|
| 31 |
-
HAS_SKLEARN = True
|
| 32 |
-
except ImportError:
|
| 33 |
-
HAS_SKLEARN = False
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
# ------------------------------------------------------------------ #
|
| 37 |
-
# Dataset #
|
| 38 |
-
# ------------------------------------------------------------------ #
|
| 39 |
-
|
| 40 |
-
class KTDataset(Dataset):
|
| 41 |
-
"""
|
| 42 |
-
Knowledge Tracing dataset.
|
| 43 |
-
|
| 44 |
-
Each sample is one student's full interaction sequence, windowed to
|
| 45 |
-
max_seq_len. Long sequences are split into multiple windows.
|
| 46 |
-
|
| 47 |
-
Parameters
|
| 48 |
-
----------
|
| 49 |
-
sequences : list of (skill_seq, correct_seq)
|
| 50 |
-
Each element is a tuple of parallel lists.
|
| 51 |
-
max_seq_len : int
|
| 52 |
-
n_skills : int
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
def __init__(
|
| 56 |
-
self,
|
| 57 |
-
sequences: list[tuple[list[int], list[int]]],
|
| 58 |
-
max_seq_len: int = 100,
|
| 59 |
-
n_skills: int = 5736,
|
| 60 |
-
) -> None:
|
| 61 |
-
self.max_seq_len = max_seq_len
|
| 62 |
-
self.n_skills = n_skills
|
| 63 |
-
self.samples: list[tuple[list[int], list[int]]] = []
|
| 64 |
-
|
| 65 |
-
for skill_seq, correct_seq in sequences:
|
| 66 |
-
# Window long sequences
|
| 67 |
-
for start in range(0, max(1, len(skill_seq) - 1), max_seq_len // 2):
|
| 68 |
-
end = start + max_seq_len + 1
|
| 69 |
-
s = skill_seq[start:end]
|
| 70 |
-
c = correct_seq[start:end]
|
| 71 |
-
if len(s) >= 2:
|
| 72 |
-
self.samples.append((s, c))
|
| 73 |
-
|
| 74 |
-
def __len__(self) -> int:
|
| 75 |
-
return len(self.samples)
|
| 76 |
-
|
| 77 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| 78 |
-
skill_seq, correct_seq = self.samples[idx]
|
| 79 |
-
|
| 80 |
-
if len(skill_seq) > self.max_seq_len + 1:
|
| 81 |
-
skill_seq = skill_seq[-self.max_seq_len - 1:]
|
| 82 |
-
correct_seq = correct_seq[-self.max_seq_len - 1:]
|
| 83 |
-
|
| 84 |
-
interactions = [s + c * self.n_skills + 1 for s, c in zip(skill_seq[:-1], correct_seq[:-1])] # +1: reserve 0 for padding
|
| 85 |
-
target_skills = skill_seq[1:]
|
| 86 |
-
target_correct = correct_seq[1:]
|
| 87 |
-
|
| 88 |
-
seq_len = len(interactions)
|
| 89 |
-
pad_len = self.max_seq_len - seq_len
|
| 90 |
-
|
| 91 |
-
interactions_padded = [0] * pad_len + interactions
|
| 92 |
-
target_padded = [0] * pad_len + target_skills
|
| 93 |
-
correct_padded = [0] * pad_len + target_correct
|
| 94 |
-
mask = [False] * pad_len + [True] * seq_len
|
| 95 |
-
|
| 96 |
-
return {
|
| 97 |
-
"interactions": torch.LongTensor(interactions_padded),
|
| 98 |
-
"target_skills": torch.LongTensor(target_padded),
|
| 99 |
-
"target_correct": torch.FloatTensor(correct_padded),
|
| 100 |
-
"mask": torch.BoolTensor(mask),
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def collate_fn(batch: list[dict]) -> dict[str, torch.Tensor]:
|
| 105 |
-
return {k: torch.stack([b[k] for b in batch]) for k in batch[0]}
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# ------------------------------------------------------------------ #
|
| 109 |
-
# Trainer config #
|
| 110 |
-
# ------------------------------------------------------------------ #
|
| 111 |
-
|
| 112 |
-
@dataclass
|
| 113 |
-
class TrainerConfig:
|
| 114 |
-
# Model
|
| 115 |
-
num_skills: int = 5736
|
| 116 |
-
embed_dim: int = 64
|
| 117 |
-
num_heads: int = 8
|
| 118 |
-
dropout: float = 0.2
|
| 119 |
-
max_seq_len: int = 100
|
| 120 |
-
|
| 121 |
-
# Training
|
| 122 |
-
epochs: int = 50
|
| 123 |
-
batch_size: int = 64
|
| 124 |
-
lr: float = 1e-3
|
| 125 |
-
weight_decay: float = 1e-5
|
| 126 |
-
val_split: float = 0.1
|
| 127 |
-
|
| 128 |
-
# Early stopping
|
| 129 |
-
patience: int = 5
|
| 130 |
-
min_delta: float = 1e-4
|
| 131 |
-
|
| 132 |
-
# Output
|
| 133 |
-
output_dir: str = "checkpoints"
|
| 134 |
-
run_name: str = "sakt_run"
|
| 135 |
-
|
| 136 |
-
# Device
|
| 137 |
-
device: str = "auto" # "auto" | "cpu" | "cuda" | "mps"
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
# ------------------------------------------------------------------ #
|
| 141 |
-
# Trainer #
|
| 142 |
-
# ------------------------------------------------------------------ #
|
| 143 |
-
|
| 144 |
-
@dataclass
|
| 145 |
-
class EpochMetrics:
|
| 146 |
-
epoch: int
|
| 147 |
-
train_loss: float
|
| 148 |
-
val_loss: float
|
| 149 |
-
val_auc: float
|
| 150 |
-
val_acc: float
|
| 151 |
-
elapsed: float
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
class SAKTTrainer:
|
| 155 |
-
"""
|
| 156 |
-
Trainer for the SAKT knowledge tracing model.
|
| 157 |
-
|
| 158 |
-
Parameters
|
| 159 |
-
----------
|
| 160 |
-
config : TrainerConfig
|
| 161 |
-
"""
|
| 162 |
-
|
| 163 |
-
def __init__(self, config: TrainerConfig) -> None:
|
| 164 |
-
self.config = config
|
| 165 |
-
self.device = self._resolve_device(config.device)
|
| 166 |
-
self.output_dir = Path(config.output_dir)
|
| 167 |
-
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 168 |
-
|
| 169 |
-
# ---------------------------------------------------------------- #
|
| 170 |
-
# Public API #
|
| 171 |
-
# ---------------------------------------------------------------- #
|
| 172 |
-
|
| 173 |
-
def fit(
|
| 174 |
-
self,
|
| 175 |
-
sequences: list[tuple[list[int], list[int]]],
|
| 176 |
-
val_sequences: list[tuple[list[int], list[int]]] | None = None,
|
| 177 |
-
) -> list[EpochMetrics]:
|
| 178 |
-
"""
|
| 179 |
-
Train the SAKT model on interaction sequences.
|
| 180 |
-
|
| 181 |
-
Parameters
|
| 182 |
-
----------
|
| 183 |
-
sequences : list of (skill_seq, correct_seq)
|
| 184 |
-
Training data. Each element is a student's full history.
|
| 185 |
-
val_sequences : list of (skill_seq, correct_seq), optional
|
| 186 |
-
If None, val_split fraction of sequences is held out.
|
| 187 |
-
|
| 188 |
-
Returns
|
| 189 |
-
-------
|
| 190 |
-
list[EpochMetrics] — training history
|
| 191 |
-
"""
|
| 192 |
-
from plrs.model.sakt import SAKTModel
|
| 193 |
-
|
| 194 |
-
cfg = self.config
|
| 195 |
-
|
| 196 |
-
# Split if no explicit val set
|
| 197 |
-
if val_sequences is None:
|
| 198 |
-
n_val = max(1, int(len(sequences) * cfg.val_split))
|
| 199 |
-
idx = np.random.permutation(len(sequences))
|
| 200 |
-
val_sequences = [sequences[i] for i in idx[:n_val]]
|
| 201 |
-
train_sequences = [sequences[i] for i in idx[n_val:]]
|
| 202 |
-
else:
|
| 203 |
-
train_sequences = sequences
|
| 204 |
-
|
| 205 |
-
print(f"Training samples : {len(train_sequences)} students")
|
| 206 |
-
print(f"Validation samples: {len(val_sequences)} students")
|
| 207 |
-
print(f"Device: {self.device}")
|
| 208 |
-
|
| 209 |
-
train_ds = KTDataset(train_sequences, cfg.max_seq_len, cfg.num_skills)
|
| 210 |
-
val_ds = KTDataset(val_sequences, cfg.max_seq_len, cfg.num_skills)
|
| 211 |
-
|
| 212 |
-
train_loader = DataLoader(
|
| 213 |
-
train_ds, batch_size=cfg.batch_size, shuffle=True,
|
| 214 |
-
collate_fn=collate_fn, num_workers=0,
|
| 215 |
-
)
|
| 216 |
-
val_loader = DataLoader(
|
| 217 |
-
val_ds, batch_size=cfg.batch_size * 2, shuffle=False,
|
| 218 |
-
collate_fn=collate_fn, num_workers=0,
|
| 219 |
-
)
|
| 220 |
-
|
| 221 |
-
model = SAKTModel(
|
| 222 |
-
num_skills=cfg.num_skills,
|
| 223 |
-
embed_dim=cfg.embed_dim,
|
| 224 |
-
num_heads=cfg.num_heads,
|
| 225 |
-
dropout=cfg.dropout,
|
| 226 |
-
max_seq_len=cfg.max_seq_len,
|
| 227 |
-
).to(self.device)
|
| 228 |
-
|
| 229 |
-
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 230 |
-
|
| 231 |
-
optimizer = torch.optim.Adam(
|
| 232 |
-
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
# Zero out NaN gradients that arise from softmax backward over fully-padded rows.
|
| 236 |
-
# This is a known issue with nn.MultiheadAttention + bool key_padding_mask.
|
| 237 |
-
# The hook is safe: it only zeroes truly NaN gradients, never valid ones.
|
| 238 |
-
def _zero_nan_grad(grad: torch.Tensor) -> torch.Tensor:
|
| 239 |
-
return torch.nan_to_num(grad, nan=0.0)
|
| 240 |
-
for p in model.parameters():
|
| 241 |
-
p.register_hook(_zero_nan_grad)
|
| 242 |
-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 243 |
-
optimizer, mode="max", patience=2, factor=0.5
|
| 244 |
-
)
|
| 245 |
-
criterion = nn.BCEWithLogitsLoss()
|
| 246 |
-
|
| 247 |
-
history: list[EpochMetrics] = []
|
| 248 |
-
best_auc = 0.0
|
| 249 |
-
patience_counter = 0
|
| 250 |
-
best_path = self.output_dir / f"{cfg.run_name}_best.pt"
|
| 251 |
-
|
| 252 |
-
print(f"\n{'Epoch':>6} {'Train Loss':>11} {'Val Loss':>9} {'Val AUC':>9} {'Val Acc':>9} {'Time':>7}")
|
| 253 |
-
print("-" * 58)
|
| 254 |
-
|
| 255 |
-
for epoch in range(1, cfg.epochs + 1):
|
| 256 |
-
t0 = time.time()
|
| 257 |
-
|
| 258 |
-
train_loss = self._train_epoch(model, train_loader, optimizer, criterion)
|
| 259 |
-
val_loss, val_auc, val_acc = self._val_epoch(model, val_loader, criterion)
|
| 260 |
-
|
| 261 |
-
scheduler.step(val_auc)
|
| 262 |
-
elapsed = time.time() - t0
|
| 263 |
-
|
| 264 |
-
metrics = EpochMetrics(
|
| 265 |
-
epoch=epoch,
|
| 266 |
-
train_loss=train_loss,
|
| 267 |
-
val_loss=val_loss,
|
| 268 |
-
val_auc=val_auc,
|
| 269 |
-
val_acc=val_acc,
|
| 270 |
-
elapsed=elapsed,
|
| 271 |
-
)
|
| 272 |
-
history.append(metrics)
|
| 273 |
-
|
| 274 |
-
print(
|
| 275 |
-
f"{epoch:>6} {train_loss:>11.4f} {val_loss:>9.4f} "
|
| 276 |
-
f"{val_auc:>9.4f} {val_acc:>9.4f} {elapsed:>6.1f}s"
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
-
# Save best
|
| 280 |
-
if val_auc > best_auc + cfg.min_delta:
|
| 281 |
-
best_auc = val_auc
|
| 282 |
-
patience_counter = 0
|
| 283 |
-
model.save(best_path, config=self._model_config())
|
| 284 |
-
print(f" ✅ New best AUC: {best_auc:.4f} → saved to {best_path}")
|
| 285 |
-
else:
|
| 286 |
-
patience_counter += 1
|
| 287 |
-
if patience_counter >= cfg.patience:
|
| 288 |
-
print(f"\nEarly stopping at epoch {epoch} (patience={cfg.patience})")
|
| 289 |
-
break
|
| 290 |
-
|
| 291 |
-
print(f"\nTraining complete. Best val AUC: {best_auc:.4f}")
|
| 292 |
-
print(f"Best model: {best_path}")
|
| 293 |
-
return history
|
| 294 |
-
|
| 295 |
-
# ---------------------------------------------------------------- #
|
| 296 |
-
# Internal #
|
| 297 |
-
# ---------------------------------------------------------------- #
|
| 298 |
-
|
| 299 |
-
def _train_epoch(self, model, loader, optimizer, criterion) -> float:
|
| 300 |
-
model.train()
|
| 301 |
-
total_loss = 0.0
|
| 302 |
-
|
| 303 |
-
for batch in loader:
|
| 304 |
-
interactions = batch["interactions"].to(self.device)
|
| 305 |
-
target_skills = batch["target_skills"].to(self.device)
|
| 306 |
-
target_correct = batch["target_correct"].to(self.device)
|
| 307 |
-
mask = batch["mask"].to(self.device)
|
| 308 |
-
|
| 309 |
-
optimizer.zero_grad()
|
| 310 |
-
logits = model(interactions, target_skills, mask)
|
| 311 |
-
|
| 312 |
-
# Only compute loss on real (non-padded) positions
|
| 313 |
-
real_logits = logits[mask]
|
| 314 |
-
real_targets = target_correct[mask]
|
| 315 |
-
|
| 316 |
-
loss = criterion(real_logits, real_targets)
|
| 317 |
-
loss.backward()
|
| 318 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 319 |
-
optimizer.step()
|
| 320 |
-
|
| 321 |
-
total_loss += loss.item()
|
| 322 |
-
|
| 323 |
-
return total_loss / max(len(loader), 1)
|
| 324 |
-
|
| 325 |
-
@torch.no_grad()
|
| 326 |
-
def _val_epoch(self, model, loader, criterion) -> tuple[float, float, float]:
|
| 327 |
-
model.eval()
|
| 328 |
-
total_loss = 0.0
|
| 329 |
-
all_probs: list[float] = []
|
| 330 |
-
all_labels: list[float] = []
|
| 331 |
-
|
| 332 |
-
for batch in loader:
|
| 333 |
-
interactions = batch["interactions"].to(self.device)
|
| 334 |
-
target_skills = batch["target_skills"].to(self.device)
|
| 335 |
-
target_correct = batch["target_correct"].to(self.device)
|
| 336 |
-
mask = batch["mask"].to(self.device)
|
| 337 |
-
|
| 338 |
-
logits = model(interactions, target_skills, mask)
|
| 339 |
-
real_logits = logits[mask]
|
| 340 |
-
real_targets = target_correct[mask]
|
| 341 |
-
|
| 342 |
-
loss = criterion(real_logits, real_targets)
|
| 343 |
-
total_loss += loss.item()
|
| 344 |
-
|
| 345 |
-
probs = torch.sigmoid(real_logits).cpu().numpy()
|
| 346 |
-
labels = real_targets.cpu().numpy()
|
| 347 |
-
all_probs.extend(probs.tolist())
|
| 348 |
-
all_labels.extend(labels.tolist())
|
| 349 |
-
|
| 350 |
-
avg_loss = total_loss / max(len(loader), 1)
|
| 351 |
-
all_probs_arr = np.array(all_probs)
|
| 352 |
-
all_labels_arr = np.array(all_labels)
|
| 353 |
-
|
| 354 |
-
# Guard against NaN (can occur with very small val sets)
|
| 355 |
-
all_probs_arr = np.nan_to_num(all_probs_arr, nan=0.5)
|
| 356 |
-
all_labels_arr = np.nan_to_num(all_labels_arr, nan=0.0)
|
| 357 |
-
|
| 358 |
-
if HAS_SKLEARN and len(np.unique(all_labels_arr)) > 1:
|
| 359 |
-
auc = float(roc_auc_score(all_labels_arr, all_probs_arr))
|
| 360 |
-
else:
|
| 361 |
-
auc = 0.5 # fallback (single class or no sklearn)
|
| 362 |
-
|
| 363 |
-
acc = float(((all_probs_arr >= 0.5) == all_labels_arr).mean())
|
| 364 |
-
return avg_loss, auc, acc
|
| 365 |
-
|
| 366 |
-
def _model_config(self) -> dict:
|
| 367 |
-
cfg = self.config
|
| 368 |
-
return {
|
| 369 |
-
"num_skills": cfg.num_skills,
|
| 370 |
-
"embed_dim": cfg.embed_dim,
|
| 371 |
-
"num_heads": cfg.num_heads,
|
| 372 |
-
"dropout": cfg.dropout,
|
| 373 |
-
"max_seq_len": cfg.max_seq_len,
|
| 374 |
-
}
|
| 375 |
-
|
| 376 |
-
@staticmethod
|
| 377 |
-
def _resolve_device(device: str) -> torch.device:
|
| 378 |
-
if device == "auto":
|
| 379 |
-
if torch.cuda.is_available():
|
| 380 |
-
return torch.device("cuda")
|
| 381 |
-
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 382 |
-
return torch.device("mps")
|
| 383 |
-
return torch.device("cpu")
|
| 384 |
-
return torch.device(device)
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
# ------------------------------------------------------------------ #
|
| 388 |
-
# Utilities #
|
| 389 |
-
# ------------------------------------------------------------------ #
|
| 390 |
-
|
| 391 |
-
def load_sequences_from_csv(
|
| 392 |
-
path: str | Path,
|
| 393 |
-
student_col: str = "student_id",
|
| 394 |
-
skill_col: str = "skill_id",
|
| 395 |
-
correct_col: str = "correct",
|
| 396 |
-
timestamp_col: str | None = "timestamp",
|
| 397 |
-
min_seq_len: int = 5,
|
| 398 |
-
) -> list[tuple[list[int], list[int]]]:
|
| 399 |
-
"""
|
| 400 |
-
Load student interaction sequences from a CSV file.
|
| 401 |
-
|
| 402 |
-
Parameters
|
| 403 |
-
----------
|
| 404 |
-
path : str or Path
|
| 405 |
-
CSV with columns: student_id, skill_id, correct, [timestamp]
|
| 406 |
-
student_col, skill_col, correct_col : str
|
| 407 |
-
Column names.
|
| 408 |
-
timestamp_col : str or None
|
| 409 |
-
If provided, sort interactions by this column within each student.
|
| 410 |
-
min_seq_len : int
|
| 411 |
-
Drop students with fewer than this many interactions.
|
| 412 |
-
|
| 413 |
-
Returns
|
| 414 |
-
-------
|
| 415 |
-
list of (skill_seq, correct_seq) tuples
|
| 416 |
-
"""
|
| 417 |
-
import pandas as pd
|
| 418 |
-
|
| 419 |
-
df = pd.read_csv(path)
|
| 420 |
-
|
| 421 |
-
required = [student_col, skill_col, correct_col]
|
| 422 |
-
missing = [c for c in required if c not in df.columns]
|
| 423 |
-
if missing:
|
| 424 |
-
raise ValueError(f"Missing columns in CSV: {missing}. Found: {df.columns.tolist()}")
|
| 425 |
-
|
| 426 |
-
if timestamp_col and timestamp_col in df.columns:
|
| 427 |
-
df = df.sort_values([student_col, timestamp_col])
|
| 428 |
-
|
| 429 |
-
sequences = []
|
| 430 |
-
for _, group in df.groupby(student_col):
|
| 431 |
-
skills = group[skill_col].astype(int).tolist()
|
| 432 |
-
corrects = group[correct_col].astype(int).tolist()
|
| 433 |
-
if len(skills) >= min_seq_len:
|
| 434 |
-
sequences.append((skills, corrects))
|
| 435 |
-
|
| 436 |
-
print(f"Loaded {len(sequences)} student sequences from {path}")
|
| 437 |
-
return sequences
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plrs/pipeline.py
DELETED
|
@@ -1,236 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
plrs.pipeline
|
| 3 |
-
=============
|
| 4 |
-
PLRSPipeline: the main entry point.
|
| 5 |
-
|
| 6 |
-
Orchestrates SAKT inference → DAG constraint validation → multi-objective ranking.
|
| 7 |
-
|
| 8 |
-
Usage
|
| 9 |
-
-----
|
| 10 |
-
from plrs import PLRSPipeline
|
| 11 |
-
from plrs.curriculum import load_dag
|
| 12 |
-
|
| 13 |
-
curriculum = load_dag("math_dag.json")
|
| 14 |
-
pipeline = PLRSPipeline(curriculum, model_path="sakt_model.pt")
|
| 15 |
-
|
| 16 |
-
# From raw interaction history
|
| 17 |
-
results = pipeline.recommend_from_history(
|
| 18 |
-
skill_seq=[12, 45, 3, 78],
|
| 19 |
-
correct_seq=[1, 0, 1, 1],
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
# From pre-computed mastery dict
|
| 23 |
-
results = pipeline.recommend_from_mastery(
|
| 24 |
-
mastery_scores={"algebra_basics": 0.85, "quadratic_equations": 0.42}
|
| 25 |
-
)
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
from __future__ import annotations
|
| 29 |
-
|
| 30 |
-
from pathlib import Path
|
| 31 |
-
from typing import Any
|
| 32 |
-
|
| 33 |
-
from plrs.constraints.dag import DAGConstraintLayer, MasteryVector
|
| 34 |
-
from plrs.curriculum.loader import CurriculumGraph
|
| 35 |
-
from plrs.ranking.ranker import MultiObjectiveRanker, RankedRecommendation
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class PLRSPipeline:
|
| 39 |
-
"""
|
| 40 |
-
End-to-end PLRS recommendation pipeline.
|
| 41 |
-
|
| 42 |
-
Parameters
|
| 43 |
-
----------
|
| 44 |
-
curriculum : CurriculumGraph
|
| 45 |
-
model_path : str or Path, optional
|
| 46 |
-
Path to a trained SAKT .pt file. If None, only mastery-dict mode is available.
|
| 47 |
-
threshold : float
|
| 48 |
-
Mastery threshold (default 0.70).
|
| 49 |
-
soft_threshold : float
|
| 50 |
-
Soft constraint threshold (default 0.50).
|
| 51 |
-
top_n : int
|
| 52 |
-
Number of top approved recommendations (default 5).
|
| 53 |
-
w_gap, w_readiness, w_downstream : float
|
| 54 |
-
Ranker objective weights.
|
| 55 |
-
device : str
|
| 56 |
-
PyTorch device for model inference (default "cpu").
|
| 57 |
-
"""
|
| 58 |
-
|
| 59 |
-
def __init__(
|
| 60 |
-
self,
|
| 61 |
-
curriculum: CurriculumGraph,
|
| 62 |
-
model_path: str | Path | None = None,
|
| 63 |
-
threshold: float = 0.70,
|
| 64 |
-
soft_threshold: float = 0.50,
|
| 65 |
-
top_n: int = 5,
|
| 66 |
-
w_gap: float = 0.4,
|
| 67 |
-
w_readiness: float = 0.4,
|
| 68 |
-
w_downstream: float = 0.2,
|
| 69 |
-
device: str = "cpu",
|
| 70 |
-
) -> None:
|
| 71 |
-
self.curriculum = curriculum
|
| 72 |
-
self.threshold = threshold
|
| 73 |
-
self.soft_threshold = soft_threshold
|
| 74 |
-
self.top_n = top_n
|
| 75 |
-
self.device = device
|
| 76 |
-
|
| 77 |
-
self.constraint_layer = DAGConstraintLayer(curriculum)
|
| 78 |
-
self.ranker = MultiObjectiveRanker(
|
| 79 |
-
curriculum,
|
| 80 |
-
w_gap=w_gap,
|
| 81 |
-
w_readiness=w_readiness,
|
| 82 |
-
w_downstream=w_downstream,
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
self._model = None
|
| 86 |
-
if model_path is not None:
|
| 87 |
-
self._load_model(model_path)
|
| 88 |
-
|
| 89 |
-
# ------------------------------------------------------------------ #
|
| 90 |
-
# Public API #
|
| 91 |
-
# ------------------------------------------------------------------ #
|
| 92 |
-
|
| 93 |
-
def recommend_from_mastery(
|
| 94 |
-
self,
|
| 95 |
-
mastery_scores: dict[str, float],
|
| 96 |
-
cascade: bool = False,
|
| 97 |
-
) -> dict[str, Any]:
|
| 98 |
-
"""
|
| 99 |
-
Generate recommendations from a pre-computed mastery dict.
|
| 100 |
-
|
| 101 |
-
Parameters
|
| 102 |
-
----------
|
| 103 |
-
mastery_scores : dict[str, float]
|
| 104 |
-
Mapping from topic_id → mastery probability [0, 1].
|
| 105 |
-
cascade : bool
|
| 106 |
-
If True, propagate mastery upward through prerequisites.
|
| 107 |
-
|
| 108 |
-
Returns
|
| 109 |
-
-------
|
| 110 |
-
dict with keys: approved, challenging, vetoed, stats, mastery_summary
|
| 111 |
-
"""
|
| 112 |
-
mastery = self._build_mastery_vector(mastery_scores)
|
| 113 |
-
if cascade:
|
| 114 |
-
mastery.cascade_up()
|
| 115 |
-
return self._run(mastery)
|
| 116 |
-
|
| 117 |
-
def recommend_from_history(
|
| 118 |
-
self,
|
| 119 |
-
skill_seq: list[int],
|
| 120 |
-
correct_seq: list[int],
|
| 121 |
-
skill_to_topic: dict[int, str] | None = None,
|
| 122 |
-
cascade: bool = False,
|
| 123 |
-
) -> dict[str, Any]:
|
| 124 |
-
"""
|
| 125 |
-
Generate recommendations from raw student interaction history.
|
| 126 |
-
|
| 127 |
-
Requires a loaded SAKT model (pass model_path to __init__).
|
| 128 |
-
|
| 129 |
-
Parameters
|
| 130 |
-
----------
|
| 131 |
-
skill_seq : list[int]
|
| 132 |
-
Sequence of skill IDs from the student's history.
|
| 133 |
-
correct_seq : list[int]
|
| 134 |
-
Corresponding correctness flags (1/0).
|
| 135 |
-
skill_to_topic : dict[int, str], optional
|
| 136 |
-
Mapping from SAKT skill_id → curriculum topic_id.
|
| 137 |
-
Required to map model output back to DAG nodes.
|
| 138 |
-
cascade : bool
|
| 139 |
-
If True, propagate mastery upward through prerequisites.
|
| 140 |
-
|
| 141 |
-
Returns
|
| 142 |
-
-------
|
| 143 |
-
dict with keys: approved, challenging, vetoed, stats, mastery_summary
|
| 144 |
-
"""
|
| 145 |
-
if self._model is None:
|
| 146 |
-
raise RuntimeError(
|
| 147 |
-
"No model loaded. Pass model_path to PLRSPipeline() to use history-based inference."
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
skill_probs = self._model.predict_mastery(skill_seq, correct_seq, device=self.device)
|
| 151 |
-
|
| 152 |
-
if skill_to_topic:
|
| 153 |
-
mastery_scores = {}
|
| 154 |
-
for skill_id, prob in skill_probs.items():
|
| 155 |
-
topic_id = skill_to_topic.get(skill_id)
|
| 156 |
-
if topic_id:
|
| 157 |
-
mastery_scores[topic_id] = max(mastery_scores.get(topic_id, 0.0), prob)
|
| 158 |
-
else:
|
| 159 |
-
# Without mapping, return raw skill probabilities (limited utility)
|
| 160 |
-
mastery_scores = {str(k): v for k, v in skill_probs.items()}
|
| 161 |
-
|
| 162 |
-
mastery = self._build_mastery_vector(mastery_scores)
|
| 163 |
-
if cascade:
|
| 164 |
-
mastery.cascade_up()
|
| 165 |
-
return self._run(mastery)
|
| 166 |
-
|
| 167 |
-
def what_if(self, topic_id: str) -> dict[str, Any]:
|
| 168 |
-
"""
|
| 169 |
-
What-if analysis: what unlocks if a student masters this topic?
|
| 170 |
-
|
| 171 |
-
Parameters
|
| 172 |
-
----------
|
| 173 |
-
topic_id : str
|
| 174 |
-
|
| 175 |
-
Returns
|
| 176 |
-
-------
|
| 177 |
-
dict with direct_unlocks, all_unlocks, blocked_by, total_unlocked
|
| 178 |
-
"""
|
| 179 |
-
graph = self.curriculum.graph
|
| 180 |
-
direct = self.curriculum.successors(topic_id)
|
| 181 |
-
all_unlocks = self.curriculum.descendants(topic_id)
|
| 182 |
-
blocked_by = self.curriculum.prerequisites(topic_id)
|
| 183 |
-
|
| 184 |
-
return {
|
| 185 |
-
"topic_id": topic_id,
|
| 186 |
-
"topic_label": self.curriculum.label(topic_id),
|
| 187 |
-
"direct_unlocks": [
|
| 188 |
-
{"id": n, "label": self.curriculum.label(n)} for n in direct
|
| 189 |
-
],
|
| 190 |
-
"all_unlocks": [
|
| 191 |
-
{"id": n, "label": self.curriculum.label(n)} for n in all_unlocks
|
| 192 |
-
],
|
| 193 |
-
"blocked_by": [
|
| 194 |
-
{"id": n, "label": self.curriculum.label(n)} for n in blocked_by
|
| 195 |
-
],
|
| 196 |
-
"total_unlocked": len(all_unlocks),
|
| 197 |
-
}
|
| 198 |
-
|
| 199 |
-
# ------------------------------------------------------------------ #
|
| 200 |
-
# Internal helpers #
|
| 201 |
-
# ------------------------------------------------------------------ #
|
| 202 |
-
|
| 203 |
-
def _build_mastery_vector(self, mastery_scores: dict[str, float]) -> MasteryVector:
|
| 204 |
-
mv = MasteryVector(self.curriculum, self.threshold, self.soft_threshold)
|
| 205 |
-
mv.update_batch(mastery_scores)
|
| 206 |
-
return mv
|
| 207 |
-
|
| 208 |
-
def _run(self, mastery: MasteryVector) -> dict[str, Any]:
|
| 209 |
-
constraint_results = self.constraint_layer.validate_all(mastery)
|
| 210 |
-
ranked = self.ranker.rank(constraint_results, mastery, top_n=self.top_n)
|
| 211 |
-
ranked["mastery_summary"] = mastery.summary()
|
| 212 |
-
|
| 213 |
-
# Serialise to plain dicts for API/JSON friendliness
|
| 214 |
-
for key in ("approved", "challenging", "vetoed"):
|
| 215 |
-
ranked[key] = [self._rec_to_dict(r) for r in ranked[key]]
|
| 216 |
-
|
| 217 |
-
return ranked
|
| 218 |
-
|
| 219 |
-
def _load_model(self, path: str | Path) -> None:
|
| 220 |
-
from plrs.model.sakt import SAKTModel
|
| 221 |
-
self._model = SAKTModel.load(path, device=self.device)
|
| 222 |
-
|
| 223 |
-
@staticmethod
|
| 224 |
-
def _rec_to_dict(rec: RankedRecommendation) -> dict[str, Any]:
|
| 225 |
-
return {
|
| 226 |
-
"topic_id": rec.topic_id,
|
| 227 |
-
"topic_label": rec.topic_label,
|
| 228 |
-
"status": rec.status,
|
| 229 |
-
"mastery": rec.mastery,
|
| 230 |
-
"score": rec.score,
|
| 231 |
-
"reasoning": rec.reasoning,
|
| 232 |
-
"prerequisites": rec.prerequisites,
|
| 233 |
-
"unmet_prerequisites": rec.unmet_prerequisites,
|
| 234 |
-
"downstream_count": rec.downstream_count,
|
| 235 |
-
"score_breakdown": rec.score_breakdown,
|
| 236 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plrs/ranking/__init__.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
from plrs.ranking.ranker import MultiObjectiveRanker, RankedRecommendation
|
| 2 |
-
|
| 3 |
-
__all__ = ["MultiObjectiveRanker", "RankedRecommendation"]
|
|
|
|
|
|
|
|
|
|
|
|
plrs/ranking/ranker.py
DELETED
|
@@ -1,189 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
plrs.ranking.ranker
|
| 3 |
-
===================
|
| 4 |
-
Multi-objective ranking function for approved/challenging topics.
|
| 5 |
-
|
| 6 |
-
Scoring signals:
|
| 7 |
-
1. Mastery gap — how close the student is to mastering this topic
|
| 8 |
-
2. Readiness — fraction of prerequisites met
|
| 9 |
-
3. Downstream value — how many future topics this unlocks (normalised)
|
| 10 |
-
|
| 11 |
-
Weights are configurable. Default: gap=0.4, readiness=0.4, downstream=0.2
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
from __future__ import annotations
|
| 15 |
-
|
| 16 |
-
from dataclasses import dataclass
|
| 17 |
-
|
| 18 |
-
import networkx as nx
|
| 19 |
-
|
| 20 |
-
from plrs.constraints.dag import ConstraintResult, MasteryVector
|
| 21 |
-
from plrs.curriculum.loader import CurriculumGraph
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
@dataclass
|
| 25 |
-
class RankedRecommendation:
|
| 26 |
-
topic_id: str
|
| 27 |
-
topic_label: str
|
| 28 |
-
status: str # "approved" | "challenging"
|
| 29 |
-
mastery: float
|
| 30 |
-
score: float
|
| 31 |
-
reasoning: str
|
| 32 |
-
prerequisites: list[str]
|
| 33 |
-
unmet_prerequisites: list[str]
|
| 34 |
-
downstream_count: int
|
| 35 |
-
score_breakdown: dict[str, float]
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class MultiObjectiveRanker:
|
| 39 |
-
"""
|
| 40 |
-
Ranks constraint-validated topics by a weighted combination of signals.
|
| 41 |
-
|
| 42 |
-
Parameters
|
| 43 |
-
----------
|
| 44 |
-
curriculum : CurriculumGraph
|
| 45 |
-
w_gap : float
|
| 46 |
-
Weight for mastery gap signal (default 0.4).
|
| 47 |
-
w_readiness : float
|
| 48 |
-
Weight for prerequisite readiness signal (default 0.4).
|
| 49 |
-
w_downstream : float
|
| 50 |
-
Weight for downstream unlock value (default 0.2).
|
| 51 |
-
"""
|
| 52 |
-
|
| 53 |
-
def __init__(
|
| 54 |
-
self,
|
| 55 |
-
curriculum: CurriculumGraph,
|
| 56 |
-
w_gap: float = 0.4,
|
| 57 |
-
w_readiness: float = 0.4,
|
| 58 |
-
w_downstream: float = 0.2,
|
| 59 |
-
) -> None:
|
| 60 |
-
self.curriculum = curriculum
|
| 61 |
-
self.w_gap = w_gap
|
| 62 |
-
self.w_readiness = w_readiness
|
| 63 |
-
self.w_downstream = w_downstream
|
| 64 |
-
|
| 65 |
-
# Pre-compute downstream counts (expensive on large graphs; cache it)
|
| 66 |
-
self._downstream_counts = self._compute_downstream_counts()
|
| 67 |
-
max_d = max(self._downstream_counts.values(), default=1)
|
| 68 |
-
self._downstream_norm = {
|
| 69 |
-
node: count / max(max_d, 1)
|
| 70 |
-
for node, count in self._downstream_counts.items()
|
| 71 |
-
}
|
| 72 |
-
|
| 73 |
-
def _compute_downstream_counts(self) -> dict[str, int]:
|
| 74 |
-
return {
|
| 75 |
-
node: len(nx.descendants(self.curriculum.graph, node))
|
| 76 |
-
for node in self.curriculum.nodes
|
| 77 |
-
}
|
| 78 |
-
|
| 79 |
-
def score(self, result: ConstraintResult, mastery: MasteryVector) -> float:
|
| 80 |
-
"""Compute composite score for a single topic."""
|
| 81 |
-
topic_id = result.topic_id
|
| 82 |
-
|
| 83 |
-
# 1. Mastery gap: student is close but not mastered → higher priority
|
| 84 |
-
gap = max(0.0, mastery.threshold - mastery.get(topic_id))
|
| 85 |
-
gap_score = gap / mastery.threshold # normalise to [0, 1]
|
| 86 |
-
|
| 87 |
-
# 2. Readiness: fraction of prerequisites above soft threshold
|
| 88 |
-
prereqs = self.curriculum.prerequisites(topic_id)
|
| 89 |
-
if prereqs:
|
| 90 |
-
readiness = sum(
|
| 91 |
-
1 for p in prereqs if mastery.get(p) >= mastery.soft_threshold
|
| 92 |
-
) / len(prereqs)
|
| 93 |
-
else:
|
| 94 |
-
readiness = 1.0
|
| 95 |
-
|
| 96 |
-
# 3. Downstream value
|
| 97 |
-
downstream = self._downstream_norm.get(topic_id, 0.0)
|
| 98 |
-
|
| 99 |
-
score = (
|
| 100 |
-
self.w_gap * gap_score
|
| 101 |
-
+ self.w_readiness * readiness
|
| 102 |
-
+ self.w_downstream * downstream
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
return round(score, 4)
|
| 106 |
-
|
| 107 |
-
def rank(
|
| 108 |
-
self,
|
| 109 |
-
results: list[ConstraintResult],
|
| 110 |
-
mastery: MasteryVector,
|
| 111 |
-
top_n: int = 5,
|
| 112 |
-
challenging_penalty: float = 0.8,
|
| 113 |
-
) -> dict[str, list[RankedRecommendation]]:
|
| 114 |
-
"""
|
| 115 |
-
Rank a list of constraint results into approved / challenging / vetoed.
|
| 116 |
-
|
| 117 |
-
Parameters
|
| 118 |
-
----------
|
| 119 |
-
results : list[ConstraintResult]
|
| 120 |
-
mastery : MasteryVector
|
| 121 |
-
top_n : int
|
| 122 |
-
Number of top approved recommendations to return.
|
| 123 |
-
challenging_penalty : float
|
| 124 |
-
Score multiplier applied to challenging topics (default 0.8).
|
| 125 |
-
|
| 126 |
-
Returns
|
| 127 |
-
-------
|
| 128 |
-
dict with keys: "approved", "challenging", "vetoed", "stats"
|
| 129 |
-
"""
|
| 130 |
-
approved: list[RankedRecommendation] = []
|
| 131 |
-
challenging: list[RankedRecommendation] = []
|
| 132 |
-
vetoed: list[RankedRecommendation] = []
|
| 133 |
-
|
| 134 |
-
for result in results:
|
| 135 |
-
# Skip already-mastered topics
|
| 136 |
-
if mastery.is_mastered(result.topic_id):
|
| 137 |
-
continue
|
| 138 |
-
|
| 139 |
-
base_score = self.score(result, mastery)
|
| 140 |
-
topic_id = result.topic_id
|
| 141 |
-
|
| 142 |
-
breakdown = {
|
| 143 |
-
"gap": round(
|
| 144 |
-
self.w_gap * max(0.0, mastery.threshold - mastery.get(topic_id)) / mastery.threshold, 4
|
| 145 |
-
),
|
| 146 |
-
"readiness": round(self.w_readiness * (
|
| 147 |
-
sum(1 for p in self.curriculum.prerequisites(topic_id)
|
| 148 |
-
if mastery.get(p) >= mastery.soft_threshold)
|
| 149 |
-
/ max(len(self.curriculum.prerequisites(topic_id)), 1)
|
| 150 |
-
), 4),
|
| 151 |
-
"downstream": round(self.w_downstream * self._downstream_norm.get(topic_id, 0.0), 4),
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
rec = RankedRecommendation(
|
| 155 |
-
topic_id=result.topic_id,
|
| 156 |
-
topic_label=result.topic_label,
|
| 157 |
-
status=result.status,
|
| 158 |
-
mastery=round(result.mastery, 3),
|
| 159 |
-
score=round(base_score * (challenging_penalty if result.status == "challenging" else 1.0), 4),
|
| 160 |
-
reasoning=result.reasoning,
|
| 161 |
-
prerequisites=result.prerequisites,
|
| 162 |
-
unmet_prerequisites=result.unmet_prerequisites,
|
| 163 |
-
downstream_count=self._downstream_counts.get(result.topic_id, 0),
|
| 164 |
-
score_breakdown=breakdown,
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
if result.status == "approved":
|
| 168 |
-
approved.append(rec)
|
| 169 |
-
elif result.status == "challenging":
|
| 170 |
-
challenging.append(rec)
|
| 171 |
-
else:
|
| 172 |
-
vetoed.append(rec)
|
| 173 |
-
|
| 174 |
-
approved.sort(key=lambda r: r.score, reverse=True)
|
| 175 |
-
challenging.sort(key=lambda r: r.score, reverse=True)
|
| 176 |
-
|
| 177 |
-
total = len(results)
|
| 178 |
-
return {
|
| 179 |
-
"approved": approved[:top_n],
|
| 180 |
-
"challenging": challenging[:3],
|
| 181 |
-
"vetoed": vetoed[:5],
|
| 182 |
-
"stats": {
|
| 183 |
-
"total_topics": total,
|
| 184 |
-
"approved_count": len(approved),
|
| 185 |
-
"challenging_count": len(challenging),
|
| 186 |
-
"vetoed_count": len(vetoed),
|
| 187 |
-
"prerequisite_violation_rate": round(len(vetoed) / max(total, 1), 3),
|
| 188 |
-
},
|
| 189 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
-
streamlit>=1.
|
| 2 |
-
torch>=2.
|
| 3 |
pandas>=2.0.0
|
| 4 |
numpy>=1.24.0
|
| 5 |
-
networkx>=3.
|
| 6 |
scikit-learn>=1.3.0
|
| 7 |
huggingface_hub>=0.20.0
|
| 8 |
-
fastapi>=0.110.0
|
| 9 |
-
pydantic>=2.0
|
|
|
|
| 1 |
+
streamlit>=1.32.0
|
| 2 |
+
torch>=2.9.0
|
| 3 |
pandas>=2.0.0
|
| 4 |
numpy>=1.24.0
|
| 5 |
+
networkx>=3.1
|
| 6 |
scikit-learn>=1.3.0
|
| 7 |
huggingface_hub>=0.20.0
|
|
|
|
|
|
sakt_decay_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:673a79d8b9cd8a6cb4bb1081817a806c23f15ff54708e5546750e564bfc728f0
|
| 3 |
+
size 171713
|
models/sakt_model.pt → sakt_model.pt
RENAMED
|
File without changes
|
sakt_vanilla_best.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9b93a62b1132e17dde47ba1667e881a20b9614eaf3caad4494642cf2fbce8c2b
|
| 3 |
+
size 171713
|
training_curves.png
ADDED
|
Git LFS Details
|