Update README.md

#1
by Clementio - opened
.gitattributes CHANGED
@@ -1 +1,35 @@
1
- models/sakt_model.pt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,52 +1,9 @@
1
  ---
2
- title: PLRS Logic Engine
3
- emoji: 🧠
4
- colorFrom: blue
5
- colorTo: indigo
6
- sdk: streamlit
7
- sdk_version: 1.33.0
8
- app_file: app.py
9
- pinned: true
10
  license: mit
11
- tags:
12
- - education
13
- - knowledge-tracing
14
- - recommendation-system
15
- - pytorch
16
- - transformers
17
- ---
18
-
19
- # PLRS — Personalized Learning Recommendation System
20
-
21
- > Constraint-aware personalized learning recommendations powered by Self-Attentive Knowledge Tracing (SAKT) and DAG prerequisite constraints.
22
-
23
- ## What it does
24
-
25
- PLRS combines a SAKT transformer model with a curriculum knowledge graph to generate recommendations that are both **personalized** and **pedagogically sound**. Topics are classified into three tiers:
26
-
27
- - ✅ **Approved** — prerequisites met, ready to learn
28
- - ⚠️ **Challenging** — prerequisites partially met
29
- - ❌ **Vetoed** — prerequisites not met, blocked
30
-
31
- ## Key results
32
-
33
- | Metric | PLRS | Collaborative Filtering |
34
- |--------|------|------------------------|
35
- | Val AUC | **0.7692** | — |
36
- | Prerequisite Violation Rate | **0.0%** | 81.3% |
37
-
38
- ## Bundled curricula
39
-
40
- - **Nigerian Secondary School Mathematics** (38 topics, 45 edges, JSS3–SS2)
41
- - **CS Fundamentals / Digital Technologies** (31 topics, 39 edges)
42
-
43
- ## Architecture
44
-
45
- ```
46
- Student History → SAKT → Mastery Vector → DAG Constraint Layer → Ranker → Recommendations
47
- ```
48
-
49
- ## Links
50
-
51
- - 📦 GitHub: [clementina-tom/plrs](https://github.com/clementina-tom/plrs)
52
- - 📄 Paper/Report: Final Year Project, Computer Science
 
1
  ---
 
 
 
 
 
 
 
 
2
  license: mit
3
+ language:
4
+ - en
5
+ metrics:
6
+ - accuracy
7
+ base_model:
8
+ - google/timesfm-2.0-500m-pytorch
9
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,492 +1,343 @@
1
- """
2
- PLRS — Logic Engine
3
- HuggingFace Space entry point.
4
-
5
- Loads SAKT model weights from HF Hub (Clementio/PLRS).
6
- Bundles the plrs package inline (until PyPI release).
7
- """
8
-
9
- import json
10
- import sys
11
- from pathlib import Path
12
-
13
- import numpy as np
14
  import streamlit as st
15
  import torch
16
-
17
- ROOT = Path(__file__).resolve().parent
18
- sys.path.insert(0, str(ROOT))
19
-
20
- from plrs.curriculum.loader import load_dag
21
- from plrs.pipeline import PLRSPipeline
22
-
23
- # ── Page config ───────────────────────────────────────────────────────────────
24
- st.set_page_config(
25
- page_title="PLRS · Logic Engine",
26
- page_icon="🧠",
27
- layout="wide",
28
- initial_sidebar_state="expanded",
29
- )
30
-
31
- # ── Styling ───────────────────────────────────────────────────────────────────
32
- st.markdown("""
33
- <style>
34
- @import url('https://fonts.googleapis.com/css2?family=DM+Mono:wght@300;400;500&family=Syne:wght@400;600;700;800&display=swap');
35
-
36
- html, body, [class*="css"] {
37
- font-family: 'Syne', sans-serif;
38
- background-color: #0a0e1a;
39
- color: #c8d0e0;
40
- }
41
- #MainMenu, footer, header { visibility: hidden; }
42
- .block-container { padding: 1.5rem 2rem 2rem 2rem; max-width: 1400px; }
43
-
44
- [data-testid="stSidebar"] {
45
- background: #0d1221;
46
- border-right: 1px solid #1e2a40;
47
- }
48
- [data-testid="stSidebar"] .stMarkdown p {
49
- font-family: 'DM Mono', monospace;
50
- font-size: 0.75rem;
51
- color: #4a5568;
52
- letter-spacing: 0.08em;
53
- }
54
-
55
- .plrs-header {
56
- display: flex; align-items: baseline; gap: 1rem;
57
- padding-bottom: 1rem; border-bottom: 1px solid #1e2a40; margin-bottom: 1.5rem;
58
- }
59
- .plrs-title { font-size: 1.75rem; font-weight: 800; letter-spacing: -0.02em; color: #e8edf5; }
60
- .plrs-sub {
61
- font-family: 'DM Mono', monospace; font-size: 0.7rem; color: #3d8bcd;
62
- letter-spacing: 0.12em; text-transform: uppercase; padding: 2px 8px;
63
- border: 1px solid #1e3a5f; border-radius: 2px;
64
- }
65
-
66
- .stat-row { display: flex; gap: 0.75rem; margin-bottom: 1.5rem; }
67
- .stat-card {
68
- flex: 1; background: #0d1221; border: 1px solid #1e2a40;
69
- border-radius: 4px; padding: 0.9rem 1rem; position: relative; overflow: hidden;
70
- }
71
- .stat-card::before {
72
- content: ''; position: absolute; top: 0; left: 0; right: 0;
73
- height: 2px; background: var(--accent, #3d8bcd);
74
- }
75
- .stat-card.green::before { --accent: #22c55e; }
76
- .stat-card.amber::before { --accent: #f59e0b; }
77
- .stat-card.red::before { --accent: #ef4444; }
78
- .stat-card.blue::before { --accent: #3d8bcd; }
79
- .stat-label { font-family: 'DM Mono', monospace; font-size: 0.62rem; color: #4a5568; letter-spacing: 0.12em; text-transform: uppercase; margin-bottom: 0.25rem; }
80
- .stat-value { font-size: 1.6rem; font-weight: 700; color: #e8edf5; line-height: 1; }
81
- .stat-sub { font-family: 'DM Mono', monospace; font-size: 0.65rem; color: #4a5568; margin-top: 0.2rem; }
82
-
83
- .rec-card {
84
- background: #0d1221; border: 1px solid #1e2a40; border-radius: 4px;
85
- padding: 0.9rem 1rem; margin-bottom: 0.5rem;
86
- }
87
- .rec-card.approved { border-left: 3px solid #22c55e; }
88
- .rec-card.challenging { border-left: 3px solid #f59e0b; }
89
- .rec-card.vetoed { border-left: 3px solid #ef4444; opacity: 0.6; }
90
- .rec-title { font-size: 0.95rem; font-weight: 700; color: #e8edf5; margin-bottom: 0.15rem; }
91
- .rec-meta { font-family: 'DM Mono', monospace; font-size: 0.65rem; color: #4a5568; letter-spacing: 0.06em; }
92
- .rec-reason { font-size: 0.75rem; color: #8899aa; margin-top: 0.35rem; padding-top: 0.35rem; border-top: 1px solid #1e2a40; }
93
- .score-bar-wrap { background: #131a2e; border-radius: 2px; height: 3px; margin-top: 0.5rem; overflow: hidden; }
94
- .score-bar { height: 100%; border-radius: 2px; background: var(--bar-color, #3d8bcd); }
95
-
96
- .section-label {
97
- font-family: 'DM Mono', monospace; font-size: 0.65rem; letter-spacing: 0.14em;
98
- text-transform: uppercase; color: #4a5568; border-bottom: 1px solid #1e2a40;
99
- padding-bottom: 0.4rem; margin-bottom: 0.75rem; margin-top: 1.25rem;
100
- }
101
- .unlock-chip {
102
- display: inline-block; font-family: 'DM Mono', monospace; font-size: 0.65rem;
103
- background: #131a2e; border: 1px solid #1e3a5f; border-radius: 2px;
104
- padding: 2px 7px; margin: 2px 3px 2px 0; color: #3d8bcd;
105
- }
106
- .blocked-chip {
107
- display: inline-block; font-family: 'DM Mono', monospace; font-size: 0.65rem;
108
- background: #1a1010; border: 1px solid #3f1e1e; border-radius: 2px;
109
- padding: 2px 7px; margin: 2px 3px 2px 0; color: #ef4444;
110
- }
111
-
112
- .stTabs [data-baseweb="tab-list"] { gap: 0; border-bottom: 1px solid #1e2a40; background: transparent; }
113
- .stTabs [data-baseweb="tab"] { font-family: 'DM Mono', monospace; font-size: 0.7rem; letter-spacing: 0.08em; color: #4a5568; padding: 0.5rem 1.25rem; border-bottom: 2px solid transparent; }
114
- .stTabs [aria-selected="true"] { color: #3d8bcd; border-bottom-color: #3d8bcd; background: transparent; }
115
- </style>
116
- """, unsafe_allow_html=True)
117
-
118
-
119
- # ── Model + pipeline loading ──────────────────────────────────────────────────
120
-
121
- @st.cache_resource(show_spinner="Loading curriculum & model from HuggingFace...")
122
- def load_pipelines():
123
- from plrs.model.model_loader import load_model_from_hub
124
-
125
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
- maps = ROOT / "data" / "knowledge_maps"
127
-
128
- # Load model (tries decay, vanilla, then base)
129
- model, model_type = load_model_from_hub(device=str(device))
130
-
131
- pipelines = {}
132
- for domain, fname in [("math", "math_dag.json"), ("cs", "cs_dag.json")]:
133
- path = maps / fname
134
- if path.exists():
135
- curriculum = load_dag(path)
136
- pipeline = PLRSPipeline(curriculum)
137
- if model:
138
- pipeline._model = model
139
- pipelines[domain] = pipeline
140
-
141
- return pipelines, model is not None, model_type
142
-
143
 
144
  @st.cache_data
145
  def load_skill_encoder():
146
- import pandas as pd
147
- path = ROOT / "data" / "skill_encoder_v2.csv"
148
- if path.exists():
149
- return pd.read_csv(path)
150
- return None
151
-
152
- pipelines, has_model, model_type = load_pipelines()
153
- skill_encoder = load_skill_encoder()
154
-
155
- # ── Sidebar ───────────────────────────────────────────────────────────────────
156
- with st.sidebar:
157
- st.markdown("### 🧠 PLRS")
158
- st.markdown('<p style="font-family:\'DM Mono\',monospace;font-size:0.65rem;color:#4a5568;letter-spacing:0.1em;">LOGIC ENGINE v0.2.0</p>', unsafe_allow_html=True)
159
-
160
- if has_model:
161
- st.markdown(f'<p style="color:#22c55e;font-size:0.7rem;font-family:\'DM Mono\',monospace;">● {model_type} LOADED</p>', unsafe_allow_html=True)
162
- else:
163
- st.markdown('<p style="color:#f59e0b;font-size:0.7rem;font-family:\'DM Mono\',monospace;">● MANUAL MODE</p>', unsafe_allow_html=True)
164
-
165
- st.markdown("---")
166
-
167
- domain_label = st.selectbox("Curriculum", ["Nigerian SS Mathematics", "CS Fundamentals"])
168
- domain_key = "math" if "Mathematics" in domain_label else "cs"
169
- pipeline = pipelines[domain_key]
170
- curriculum = pipeline.curriculum
171
-
172
- st.markdown("---")
173
- threshold = st.slider("Mastery threshold", 0.50, 0.90, 0.70, 0.05)
174
- soft_threshold = st.slider("Challenging threshold", 0.20, 0.65, 0.50, 0.05)
175
- top_n = st.slider("Top N recommendations", 3, 10, 5)
176
-
177
- pipeline.threshold = threshold
178
- pipeline.soft_threshold = soft_threshold
179
- pipeline.top_n = top_n
180
-
181
- st.markdown("---")
182
- st.markdown(f'<p style="font-family:\'DM Mono\',monospace;font-size:0.65rem;color:#4a5568;">NODES: <span style="color:#e8edf5;">{curriculum.num_nodes}</span></p>', unsafe_allow_html=True)
183
- st.markdown(f'<p style="font-family:\'DM Mono\',monospace;font-size:0.65rem;color:#4a5568;">EDGES: <span style="color:#e8edf5;">{curriculum.num_edges}</span></p>', unsafe_allow_html=True)
184
- st.markdown(f'<p style="font-family:\'DM Mono\',monospace;font-size:0.65rem;color:#4a5568;">MODEL: <span style="color:#e8edf5;">{model_type}</span></p>', unsafe_allow_html=True)
185
- st.markdown(f'<p style="font-family:\'DM Mono\',monospace;font-size:0.65rem;color:#4a5568;">VIOLATION RATE: <span style="color:#22c55e;">0.0%</span></p>', unsafe_allow_html=True)
186
-
187
- st.markdown("---")
188
- st.markdown('<p style="font-family:\'DM Mono\',monospace;font-size:0.6rem;color:#2a3a50;">github.com/clementina-tom/plrs</p>', unsafe_allow_html=True)
189
-
190
-
191
- # ── Header ────────────────────────────────────────────────────────────────────
192
- st.markdown("""
193
- <div class="plrs-header">
194
- <span class="plrs-title">Logic Engine</span>
195
- <span class="plrs-sub">Personalized Learning · Constraint-Aware · SAKT + DAG</span>
196
- </div>
197
- """, unsafe_allow_html=True)
198
-
199
-
200
- # ── Tabs ──────────────────────────────────────────────────────────────────────
201
- tab1, tab2, tab3 = st.tabs(["RECOMMENDATIONS", "WHAT-IF SIMULATOR", "CURRICULUM MAP"])
202
-
203
- ACTIVITY_TO_DOMAIN = {
204
- "math": {
205
- "oucontent": "algebraic_expressions", "forumng": "statistics_basic",
206
- "homepage": "whole_numbers", "subpage": "plane_shapes",
207
- "resource": "indices", "url": "number_bases",
208
- "ouwiki": "proportion_variation", "glossary": "algebraic_factorization",
209
- "quiz": "quadratic_equations",
210
- },
211
- "cs": {
212
- "oucontent": "programming_concepts", "forumng": "ethics_technology",
213
- "homepage": "computer_basics", "subpage": "html_basics",
214
- "resource": "networking_fundamentals", "url": "internet_basics",
215
- "ouwiki": "cloud_basics", "glossary": "intro_databases",
216
- "quiz": "python_basics",
217
- },
218
- }
219
-
220
-
221
- # ══════════════════════════════════════════════════════════════════════════════
222
- # TAB 1 — RECOMMENDATIONS
223
- # ══════════════════════════════════════════════════════════════════════════════
224
- with tab1:
225
- col_left, col_right = st.columns([1, 1.4], gap="large")
226
-
227
- with col_left:
228
- st.markdown('<div class="section-label">Learner Profile</div>', unsafe_allow_html=True)
229
- mode = st.radio("Input mode", ["Manual sliders", "Simulate student"], horizontal=True, label_visibility="collapsed")
230
-
231
- mastery_scores = {}
232
-
233
- if mode == "Manual sliders":
234
- for node in curriculum.nodes:
235
- label = curriculum.label(node)
236
- level = curriculum.level(node)
237
- val = st.slider(
238
- f"{label}",
239
- 0.0, 1.0, 0.0, 0.05,
240
- key=f"mastery_{node}",
241
- help=f"Level: {level}"
242
- )
243
- mastery_scores[node] = val
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  else:
245
- seq_len = st.slider("Sequence length", 10, 200, 50)
246
- seed = st.number_input("Student seed", 1, 9999, 42)
247
  np.random.seed(int(seed))
248
-
249
- activity_types = list(ACTIVITY_TO_DOMAIN[domain_key].keys())
250
- activity_probs = [0.38, 0.20, 0.15, 0.10, 0.06, 0.04, 0.03, 0.02, 0.02]
251
- mapping = ACTIVITY_TO_DOMAIN[domain_key]
252
-
253
- # Use skill_encoder to simulate skills that actually exist in the mapping
254
- if skill_encoder is not None:
255
- available_skills = skill_encoder["skill_id"].tolist()
256
- sim_skills = np.random.choice(available_skills, seq_len).tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  else:
258
- n_skills = 5736
259
- sim_skills = np.random.randint(0, n_skills, seq_len).tolist()
260
-
261
- sim_corrects = np.random.randint(0, 2, seq_len).tolist()
262
-
263
- topic_scores: dict = {}
264
-
265
- # Map simulated skills back to topics using the CSV work you did
266
- for skill_id in sim_skills:
267
- # If we have the encoder, find the activity type
268
- if skill_encoder is not None:
269
- row = skill_encoder[skill_encoder["skill_id"] == skill_id]
270
- if not row.empty:
271
- act = row["activity_type"].values[0]
272
- topic_id = mapping.get(act)
273
- if topic_id and topic_id in curriculum.nodes:
274
- # Generate a mastery signal based on frequency/success
275
- score = topic_scores.get(topic_id, 0.1)
276
- # Every time they see this topic, increase mastery slightly
277
- topic_scores[topic_id] = min(1.0, score + np.random.random() * 0.2)
 
 
 
 
 
 
 
 
 
 
 
 
278
  else:
279
- # Fallback to the old simple mapping if CSV is missing
280
- act_idx = skill_id % 100
281
- cumulative = 0
282
- thresholds = [int(p * 100) for p in activity_probs]
283
- thresholds[-1] += 100 - sum(thresholds)
284
- act = activity_types[-1]
285
- for a, thresh in zip(activity_types, thresholds):
286
- cumulative += thresh
287
- if act_idx < cumulative:
288
- act = a
289
- break
290
- topic_id = mapping.get(act)
291
- if topic_id and topic_id in curriculum.nodes:
292
- topic_scores[topic_id] = 0.5 + np.random.random() * 0.4
293
-
294
- mastery_scores = {n: 0.0 for n in curriculum.nodes}
295
- mastery_scores.update(topic_scores)
296
- st.success(f"Simulated {seq_len} interactions {len(topic_scores)} topics mapped")
297
-
298
- if topic_scores:
299
- st.markdown('<div class="section-label">Mapped Mastery Signal</div>', unsafe_allow_html=True)
300
- for tid, score in sorted(topic_scores.items(), key=lambda x: -x[1]):
301
- pct = int(score * 100)
302
- color = "#22c55e" if score >= threshold else "#f59e0b" if score >= soft_threshold else "#ef4444"
303
- st.markdown(f"""
304
- <div style="margin-bottom:6px;">
305
- <div style="display:flex;justify-content:space-between;font-size:0.72rem;color:#8899aa;margin-bottom:2px;">
306
- <span>{curriculum.label(tid)}</span>
307
- <span style="font-family:'DM Mono',monospace;">{pct}%</span>
308
- </div>
309
- <div class="score-bar-wrap">
310
- <div class="score-bar" style="width:{pct}%;--bar-color:{color};"></div>
311
- </div>
312
- </div>
313
- """, unsafe_allow_html=True)
314
-
315
- run = st.button("⚡ Generate Recommendations", type="primary", use_container_width=True)
316
-
317
- with col_right:
318
- if run or mode == "Simulate student":
319
- # Enable cascading for simulation to ensure prerequisites are also "mastered"
320
- is_sim = (mode == "Simulate student")
321
- results = pipeline.recommend_from_mastery(mastery_scores, cascade=is_sim)
322
- summary = results["mastery_summary"]
323
- stats = results["stats"]
324
-
325
- mastery_pct = int(summary["mastery_rate"] * 100)
326
- vrate_pct = int(stats["prerequisite_violation_rate"] * 100)
327
-
328
- st.markdown(f"""
329
- <div class="stat-row">
330
- <div class="stat-card blue">
331
- <div class="stat-label">Mastered</div>
332
- <div class="stat-value">{summary['mastered']}<span style="font-size:0.9rem;color:#4a5568;">/{summary['total_topics']}</span></div>
333
- <div class="stat-sub">{mastery_pct}% rate</div>
334
- </div>
335
- <div class="stat-card green">
336
- <div class="stat-label">Approved</div>
337
- <div class="stat-value">{stats['approved_count']}</div>
338
- <div class="stat-sub">ready to learn</div>
339
- </div>
340
- <div class="stat-card amber">
341
- <div class="stat-label">Challenging</div>
342
- <div class="stat-value">{stats['challenging_count']}</div>
343
- <div class="stat-sub">partial prereqs</div>
344
- </div>
345
- <div class="stat-card red">
346
- <div class="stat-label">Violation rate</div>
347
- <div class="stat-value">{vrate_pct}<span style="font-size:0.9rem;color:#4a5568;">%</span></div>
348
- <div class="stat-sub">blocked topics</div>
349
- </div>
350
- </div>
351
- """, unsafe_allow_html=True)
352
-
353
- if results["approved"]:
354
- st.markdown('<div class="section-label">✅ Approved Recommendations</div>', unsafe_allow_html=True)
355
- for i, rec in enumerate(results["approved"]):
356
- score_pct = int(rec["score"] * 100)
357
- st.markdown(f"""
358
- <div class="rec-card approved">
359
- <div class="rec-title">{i+1}. {rec['topic_label']}</div>
360
- <div class="rec-meta">score: {rec['score']:.3f} &nbsp;·&nbsp; mastery: {int(rec['mastery']*100)}% &nbsp;·&nbsp; unlocks: {rec['downstream_count']}</div>
361
- <div class="rec-reason">{rec['reasoning']}</div>
362
- <div class="score-bar-wrap"><div class="score-bar" style="width:{score_pct}%;--bar-color:#22c55e;"></div></div>
363
- </div>
364
- """, unsafe_allow_html=True)
365
- else:
366
- st.info("No approved topics — lower the mastery threshold or set some mastery levels.")
367
-
368
- if results["challenging"]:
369
- st.markdown('<div class="section-label">⚠️ Challenging</div>', unsafe_allow_html=True)
370
- for rec in results["challenging"]:
371
- score_pct = int(rec["score"] * 100)
372
- unmet = ", ".join(rec["unmet_prerequisites"]) or "—"
373
- st.markdown(f"""
374
- <div class="rec-card challenging">
375
- <div class="rec-title">{rec['topic_label']}</div>
376
- <div class="rec-meta">score: {rec['score']:.3f} &nbsp;·&nbsp; strengthen: {unmet}</div>
377
- <div class="rec-reason">{rec['reasoning']}</div>
378
- <div class="score-bar-wrap"><div class="score-bar" style="width:{score_pct}%;--bar-color:#f59e0b;"></div></div>
379
- </div>
380
- """, unsafe_allow_html=True)
381
-
382
- if results["vetoed"]:
383
- with st.expander(f"❌ Vetoed topics ({stats['vetoed_count']} total — prerequisite check failed)"):
384
- for rec in results["vetoed"]:
385
- unmet = ", ".join(rec["unmet_prerequisites"]) or "—"
386
- st.markdown(f"""
387
- <div class="rec-card vetoed">
388
- <div class="rec-title">{rec['topic_label']}</div>
389
- <div class="rec-meta">blocked by: {unmet}</div>
390
- </div>
391
- """, unsafe_allow_html=True)
392
- else:
393
- st.markdown("""
394
- <div style="height:280px;display:flex;align-items:center;justify-content:center;
395
- border:1px dashed #1e2a40;border-radius:4px;color:#2a3a50;">
396
- <div style="text-align:center;">
397
- <div style="font-size:2rem;margin-bottom:0.5rem;">⚡</div>
398
- <div style="font-family:'DM Mono',monospace;font-size:0.7rem;letter-spacing:0.1em;">
399
- SET MASTERY LEVELS · THEN GENERATE
400
- </div>
401
- </div>
402
- </div>
403
- """, unsafe_allow_html=True)
404
-
405
-
406
- # ══════════════════════════════════════════════════════════════════════════════
407
- # TAB 2 — WHAT-IF SIMULATOR
408
- # ══════════════════════════════════════════════════════════════════════════════
409
- with tab2:
410
- st.markdown('<div class="section-label">Prerequisite Impact Simulator</div>', unsafe_allow_html=True)
411
- st.markdown('<p style="font-size:0.8rem;color:#8899aa;">Select any topic to see what it unlocks and what currently blocks it.</p>', unsafe_allow_html=True)
412
-
413
- node_options = {curriculum.label(n): n for n in curriculum.nodes}
414
- selected_label = st.selectbox("Select topic", list(node_options.keys()))
415
- selected_id = node_options[selected_label]
416
- wi = pipeline.what_if(selected_id)
417
-
418
- col_a, col_b = st.columns(2, gap="large")
419
-
420
- with col_a:
421
- st.markdown('<div class="section-label">🔓 What This Unlocks</div>', unsafe_allow_html=True)
422
- if wi["direct_unlocks"]:
423
- st.markdown("**Directly unlocks:**")
424
- st.markdown("".join(f'<span class="unlock-chip">{u["label"]}</span>' for u in wi["direct_unlocks"]), unsafe_allow_html=True)
425
- else:
426
- st.markdown('<span style="color:#4a5568;font-size:0.8rem;">Leaf node — no further topics.</span>', unsafe_allow_html=True)
427
-
428
- if wi["all_unlocks"]:
429
- st.markdown(f"**All downstream ({wi['total_unlocked']}):**")
430
- st.markdown("".join(f'<span class="unlock-chip">{u["label"]}</span>' for u in wi["all_unlocks"]), unsafe_allow_html=True)
431
-
432
- st.markdown(f"""
433
- <div class="stat-card blue" style="margin-top:1rem;max-width:180px;">
434
- <div class="stat-label">Total Unlocked</div>
435
- <div class="stat-value">{wi['total_unlocked']}</div>
436
- </div>
437
- """, unsafe_allow_html=True)
438
-
439
- with col_b:
440
- st.markdown('<div class="section-label">🔒 What Blocks This</div>', unsafe_allow_html=True)
441
- if wi["blocked_by"]:
442
- st.markdown("**Prerequisites:**")
443
- st.markdown("".join(f'<span class="blocked-chip">{b["label"]}</span>' for b in wi["blocked_by"]), unsafe_allow_html=True)
444
- else:
445
- st.markdown('<span style="color:#22c55e;font-size:0.8rem;font-family:\'DM Mono\',monospace;">Root topic — no prerequisites.</span>', unsafe_allow_html=True)
446
-
447
-
448
- # ══════════════════════════════════════════════════════════════════════════════
449
- # TAB 3 — CURRICULUM MAP
450
- # ══════════════════════════════════════════════════════════════════════════════
451
- with tab3:
452
- st.markdown('<div class="section-label">Curriculum Knowledge Graph</div>', unsafe_allow_html=True)
453
-
454
- col_info, col_table = st.columns([1, 2], gap="large")
455
-
456
- with col_info:
457
- roots = [n for n in curriculum.nodes if not curriculum.prerequisites(n)]
458
- leaves = [n for n in curriculum.nodes if not curriculum.successors(n)]
459
-
460
- st.markdown(f"""
461
- <div class="stat-card blue" style="margin-bottom:0.75rem;">
462
- <div class="stat-label">Domain</div>
463
- <div style="font-size:0.85rem;font-weight:700;color:#e8edf5;">{curriculum.domain}</div>
464
- </div>
465
- <div class="stat-card green" style="margin-bottom:0.75rem;">
466
- <div class="stat-label">Topics</div><div class="stat-value">{curriculum.num_nodes}</div>
467
- </div>
468
- <div class="stat-card amber">
469
- <div class="stat-label">Prerequisite Edges</div><div class="stat-value">{curriculum.num_edges}</div>
470
- </div>
471
- """, unsafe_allow_html=True)
472
-
473
- st.markdown('<div class="section-label">Root Topics</div>', unsafe_allow_html=True)
474
- st.markdown("".join(f'<span class="unlock-chip">{curriculum.label(r)}</span>' for r in roots), unsafe_allow_html=True)
475
-
476
- st.markdown('<div class="section-label">Leaf Topics</div>', unsafe_allow_html=True)
477
- st.markdown("".join(f'<span class="blocked-chip">{curriculum.label(l)}</span>' for l in leaves), unsafe_allow_html=True)
478
-
479
- with col_table:
480
- import pandas as pd
481
- st.markdown('<div class="section-label">All Topics</div>', unsafe_allow_html=True)
482
- rows = []
483
- for node in curriculum.nodes:
484
- rows.append({
485
- "Topic": curriculum.label(node),
486
- "Level": curriculum.level(node),
487
- "Prerequisites": len(curriculum.prerequisites(node)),
488
- "Unlocks (direct)": len(curriculum.successors(node)),
489
- "Total Downstream": len(curriculum.descendants(node)),
490
- })
491
- df = pd.DataFrame(rows).sort_values("Total Downstream", ascending=False)
492
- st.dataframe(df, use_container_width=True, height=480, hide_index=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
+ import torch.nn as nn
4
+ import json
5
+ import pandas as pd
6
+ import networkx as nx
7
+ import numpy as np
8
+ from huggingface_hub import hf_hub_download
9
+ from typing import Dict, List, Optional, Tuple
10
+
11
+ st.set_page_config(page_title='Logic Engine', page_icon='🧠', layout='wide')
12
+
13
+ HF_REPO = 'Clementio/PLRS'
14
+
15
+ @st.cache_resource
16
+ def load_model():
17
+ config_path = hf_hub_download(repo_id=HF_REPO, filename='config.json')
18
+ with open(config_path) as f:
19
+ config = json.load(f)
20
+ model_path = hf_hub_download(repo_id=HF_REPO, filename='sakt_model.pt')
21
+ class SAKT(nn.Module):
22
+ def __init__(self, num_skills, embed_dim, num_heads, num_layers, max_seq_len, dropout):
23
+ super(SAKT, self).__init__()
24
+ self.num_skills = num_skills
25
+ self.interaction_embed = nn.Embedding(num_skills * 2 + 1, embed_dim, padding_idx=0)
26
+ self.skill_embed = nn.Embedding(num_skills + 1, embed_dim, padding_idx=0)
27
+ self.pos_embed = nn.Embedding(max_seq_len + 1, embed_dim)
28
+ encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True, dim_feedforward=embed_dim * 4, norm_first=True)
29
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, enable_nested_tensor=False)
30
+ self.dropout = nn.Dropout(dropout)
31
+ self.output = nn.Linear(embed_dim, 1)
32
+ def forward(self, interactions, target_skills, mask, return_attention=False):
33
+ batch_size, seq_len = interactions.shape
34
+ positions = torch.arange(seq_len, device=interactions.device).unsqueeze(0).expand(batch_size, -1)
35
+ x = self.interaction_embed(interactions)
36
+ x = x + self.pos_embed(positions)
37
+ x = x * mask.unsqueeze(-1).float()
38
+ x = self.dropout(x)
39
+ causal_mask = torch.triu(torch.full((seq_len, seq_len), float('-inf')), diagonal=1)
40
+ x = self.transformer(x, mask=causal_mask, is_causal=False)
41
+ x = x * mask.unsqueeze(-1).float()
42
+ x = x + self.skill_embed(target_skills)
43
+ return self.output(x).squeeze(-1)
44
+ device = torch.device('cpu')
45
+ model = SAKT(num_skills=config['num_skills'], embed_dim=config['embed_dim'], num_heads=config['num_heads'], num_layers=config['num_layers'], max_seq_len=config['max_seq_len'], dropout=config['dropout'])
46
+ model.load_state_dict(torch.load(model_path, map_location=device))
47
+ model.eval()
48
+ return model, config, device
49
+
50
+ @st.cache_resource
51
+ def load_knowledge_maps():
52
+ def load_dag(path):
53
+ with open(path) as f:
54
+ data = json.load(f)
55
+ G = nx.DiGraph()
56
+ for node in data['nodes']:
57
+ G.add_node(node['id'], label=node['label'], level=node['level'], term=node['term'])
58
+ for edge in data['edges']:
59
+ G.add_edge(edge['from'], edge['to'])
60
+ return G
61
+ return load_dag('knowledge_maps/math_dag.json'), load_dag('knowledge_maps/cs_dag.json')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  @st.cache_data
64
  def load_skill_encoder():
65
+ return pd.read_csv('data/skill_encoder.csv')
66
+
67
+ class MasteryVector:
68
+ def __init__(self, graph, threshold=0.70):
69
+ self.graph = graph
70
+ self.threshold = threshold
71
+ self.mastery = {node: 0.0 for node in graph.nodes}
72
+ def update(self, topic_id, probability):
73
+ if topic_id in self.mastery: self.mastery[topic_id] = probability
74
+ def is_mastered(self, topic_id):
75
+ return self.mastery.get(topic_id, 0.0) >= self.threshold
76
+ def get_mastery(self, topic_id):
77
+ return self.mastery.get(topic_id, 0.0)
78
+ def get_mastery_summary(self):
79
+ mastered = [t for t in self.mastery if self.is_mastered(t)]
80
+ return {'total_topics': len(self.mastery), 'mastered': len(mastered), 'mastery_rate': round(len(mastered)/len(self.mastery), 3), 'mastered_topics': mastered}
81
+
82
+ class DAGConstraintLayer:
83
+ def __init__(self, graph, threshold=0.70, soft_threshold=0.50):
84
+ self.graph = graph
85
+ self.threshold = threshold
86
+ self.soft_threshold = soft_threshold # below full threshold but above this = challenging
87
+ def validate(self, topic_id, mastery_vector):
88
+ if topic_id not in self.graph.nodes: return 'vetoed', 'Topic not found.'
89
+ prerequisites = list(self.graph.predecessors(topic_id))
90
+ label = self.graph.nodes[topic_id].get('label', topic_id)
91
+ if not prerequisites: return 'approved', f'✅ Foundational topic — no prerequisites.'
92
+ hard_fails = []
93
+ soft_fails = []
94
+ for p in prerequisites:
95
+ m = mastery_vector.get_mastery(p)
96
+ plabel = self.graph.nodes[p].get('label', p)
97
+ if m < self.soft_threshold:
98
+ hard_fails.append((plabel, m))
99
+ elif m < self.threshold:
100
+ soft_fails.append((plabel, m))
101
+ if hard_fails:
102
+ gaps = ', '.join([f"{l} ({m:.0%} mastered, need {self.threshold:.0%})" for l,m in hard_fails])
103
+ return 'vetoed', f'❌ Prerequisites not met: {gaps}'
104
+ elif soft_fails:
105
+ gaps = ', '.join([f"{l} ({m:.0%} mastered, need {self.threshold:.0%})" for l,m in soft_fails])
106
+ return 'challenging', f'⚠️ Challenging — prerequisites nearly met: {gaps}. Proceed with caution.'
107
+ else:
108
+ prereq_labels = [self.graph.nodes[p].get('label',p) for p in prerequisites]
109
+ return 'approved', f'✅ Prerequisites mastered: {", ".join(prereq_labels)}'
110
+
111
+ class RankingFunction:
112
+ def __init__(self, graph, threshold=0.70, w_gap=0.40, w_ready=0.35, w_downstream=0.25):
113
+ self.graph=graph; self.threshold=threshold; self.w_gap=w_gap; self.w_ready=w_ready; self.w_downstream=w_downstream
114
+ scores = {n: len(nx.descendants(graph, n)) for n in graph.nodes}
115
+ mx = max(scores.values()) if scores else 1
116
+ self._downstream = {n: s/mx for n,s in scores.items()}
117
+ def score(self, topic_id, mastery_vector):
118
+ current = mastery_vector.get_mastery(topic_id)
119
+ gap = min(max(0.0, self.threshold-current)/self.threshold, 1.0)
120
+ prereqs = list(self.graph.predecessors(topic_id))
121
+ readiness = 1.0 if not prereqs else sum(1 for p in prereqs if mastery_vector.is_mastered(p))/len(prereqs)
122
+ downstream = self._downstream.get(topic_id, 0.0)
123
+ return round(self.w_gap*gap + self.w_ready*readiness + self.w_downstream*downstream, 3)
124
+
125
+ class LearningRecommendationPipeline:
126
+ def __init__(self, graph, threshold=0.70, soft_threshold=0.50, top_n=5):
127
+ self.graph=graph
128
+ self.constraint=DAGConstraintLayer(graph, threshold, soft_threshold)
129
+ self.ranker=RankingFunction(graph, threshold)
130
+ self.top_n=top_n
131
+ def run(self, mastery_vector):
132
+ approved, challenging, vetoed = [], [], []
133
+ for topic_id in self.graph.nodes:
134
+ status, reasoning = self.constraint.validate(topic_id, mastery_vector)
135
+ entry = {'topic_id': topic_id, 'topic_label': self.graph.nodes[topic_id].get('label', topic_id), 'mastery': round(mastery_vector.get_mastery(topic_id),3), 'reasoning': reasoning, 'status': status}
136
+ if status == 'approved' and not mastery_vector.is_mastered(topic_id):
137
+ entry['score'] = self.ranker.score(topic_id, mastery_vector)
138
+ approved.append(entry)
139
+ elif status == 'challenging' and not mastery_vector.is_mastered(topic_id):
140
+ entry['score'] = self.ranker.score(topic_id, mastery_vector) * 0.8 # slight penalty
141
+ challenging.append(entry)
142
+ elif status == 'vetoed':
143
+ vetoed.append(entry)
144
+ approved.sort(key=lambda x: x['score'], reverse=True)
145
+ challenging.sort(key=lambda x: x['score'], reverse=True)
146
+ return {'top_recommendations': approved[:self.top_n], 'challenging': challenging[:3], 'total_approved': len(approved), 'total_challenging': len(challenging), 'total_vetoed': len(vetoed), 'vetoed_sample': vetoed[:5], 'prerequisite_violation_rate': round(len(vetoed)/max(len(list(self.graph.nodes)),1),3)}
147
+
148
+ ACTIVITY_TO_MATH = {'oucontent':'algebraic_expressions','forumng':'statistics_basic','homepage':'whole_numbers','subpage':'plane_shapes','resource':'indices','url':'number_bases','ouwiki':'proportion_variation','glossary':'algebraic_factorization','quiz':'quadratic_equations'}
149
+ ACTIVITY_TO_CS = {'oucontent':'programming_concepts','forumng':'ethics_technology','homepage':'computer_basics','subpage':'html_basics','resource':'networking_fundamentals','url':'internet_basics','ouwiki':'cloud_basics','glossary':'intro_databases','quiz':'python_basics'}
150
+
151
+ def run_sakt_inference(model, config, skill_seq, correct_seq, device):
152
+ max_len=config['max_seq_len']; n_skills=config['num_skills']
153
+ if len(skill_seq)>max_len: skill_seq=skill_seq[-max_len:]; correct_seq=correct_seq[-max_len:]
154
+ interactions=[s+c*n_skills for s,c in zip(skill_seq[:-1],correct_seq[:-1])]
155
+ target_skills=skill_seq[1:]
156
+ seq_len=len(interactions); pad_len=max_len-seq_len
157
+ interactions=[0]*pad_len+interactions; target_skills=[0]*pad_len+target_skills; mask=[False]*pad_len+[True]*seq_len
158
+ with torch.no_grad():
159
+ logits=model(torch.LongTensor([interactions]).to(device),torch.LongTensor([target_skills]).to(device),torch.BoolTensor([mask]).to(device))
160
+ probs=torch.sigmoid(logits).squeeze(0)
161
+ mastery={}; real_probs=probs[torch.BoolTensor(mask)].cpu().numpy(); real_skills=target_skills[pad_len:]
162
+ for skill_id,prob in zip(real_skills,real_probs): mastery[int(skill_id)]=float(prob)
163
+ return mastery
164
+
165
+ def build_mastery_vector(skill_probs, graph, skill_encoder_df, domain, threshold, soft_threshold):
166
+ mv=MasteryVector(graph, threshold); mapping=ACTIVITY_TO_MATH if domain=='math' else ACTIVITY_TO_CS
167
+ topic_scores={}
168
+ for skill_id,prob in skill_probs.items():
169
+ row=skill_encoder_df[skill_encoder_df['skill_id']==skill_id]
170
+ if row.empty: continue
171
+ act=row['activity_type'].values[0] if 'activity_type' in row.columns else None
172
+ topic_id=mapping.get(act) if act else None
173
+ if topic_id: topic_scores[topic_id]=max(topic_scores.get(topic_id,0.0),prob)
174
+ for topic_id,score in topic_scores.items(): mv.update(topic_id,score)
175
+ return mv
176
+
177
+ def what_if_analysis(topic_id, graph):
178
+ unlocks = list(nx.descendants(graph, topic_id))
179
+ direct_unlocks = list(graph.successors(topic_id))
180
+ blocked_by = list(graph.predecessors(topic_id))
181
+ unlock_labels = [graph.nodes[n].get('label',n) for n in direct_unlocks]
182
+ all_unlock_labels = [graph.nodes[n].get('label',n) for n in unlocks]
183
+ blocked_labels = [graph.nodes[n].get('label',n) for n in blocked_by]
184
+ return {'direct_unlocks': unlock_labels, 'all_unlocks': all_unlock_labels, 'blocked_by': blocked_labels, 'total_unlocked': len(unlocks)}
185
+
186
+ def get_attention_weights(model, config, skill_seq, correct_seq, device):
187
+ max_len=config['max_seq_len']; n_skills=config['num_skills']
188
+ if len(skill_seq)>max_len: skill_seq=skill_seq[-max_len:]; correct_seq=correct_seq[-max_len:]
189
+ interactions=[s+c*n_skills for s,c in zip(skill_seq[:-1],correct_seq[:-1])]
190
+ target_skills=skill_seq[1:]
191
+ seq_len=len(interactions); pad_len=max_len-seq_len
192
+ interactions=[0]*pad_len+interactions; target_skills=[0]*pad_len+target_skills; mask_list=[False]*pad_len+[True]*seq_len
193
+ interactions_t=torch.LongTensor([interactions]); target_t=torch.LongTensor([target_skills]); mask_t=torch.BoolTensor([mask_list])
194
+ attention_weights = []
195
+ def hook_fn(module, input, output):
196
+ if hasattr(module, 'self_attn'):
197
+ pass
198
+ with torch.no_grad():
199
+ positions=torch.arange(max_len).unsqueeze(0)
200
+ x=model.interaction_embed(interactions_t)+model.pos_embed(positions)
201
+ x=x*mask_t.unsqueeze(-1).float()
202
+ real_mask=mask_t.squeeze(0)
203
+ real_skills=target_skills[pad_len:]
204
+ real_probs=torch.sigmoid(model(interactions_t,target_t,mask_t)).squeeze(0)[real_mask].numpy()
205
+ return real_skills[-10:], real_probs[-10:], seq_len
206
+
207
+ def main():
208
+ model, config, device = load_model()
209
+ math_graph, cs_graph = load_knowledge_maps()
210
+ skill_encoder = load_skill_encoder()
211
+ st.title('🧠 Logic Engine')
212
+ st.subheader('Domain-Agnostic Constraint-Aware Learning Recommender')
213
+ st.markdown('---')
214
+ st.sidebar.title('⚙️ Configuration')
215
+ domain = st.sidebar.selectbox('Select Domain', ['Mathematics', 'CS Fundamentals'])
216
+ threshold = st.sidebar.slider('Mastery Threshold', 0.50, 0.90, 0.70, 0.05, help='Minimum mastery to consider a topic fully mastered')
217
+ soft_threshold = st.sidebar.slider('Challenging Threshold', 0.30, 0.70, 0.50, 0.05, help='Topics above this but below mastery threshold are marked Challenging')
218
+ top_n = st.sidebar.slider('Top N Recommendations', 3, 10, 5)
219
+ graph = math_graph if domain=='Mathematics' else cs_graph
220
+ domain_key = 'math' if domain=='Mathematics' else 'cs'
221
+ pipeline = LearningRecommendationPipeline(graph, threshold, soft_threshold, top_n)
222
+ st.sidebar.markdown('---')
223
+ st.sidebar.markdown('**About**')
224
+ st.sidebar.markdown('SAKT-based knowledge tracing with DAG prerequisite constraints. Three-tier recommendations: ✅ Approved, ⚠️ Challenging, ❌ Vetoed.')
225
+ tab1, tab2, tab3, tab4 = st.tabs(['🎯 Recommendations','🔍 What-If Simulator','🗺️ Knowledge Map','📊 Diagnostics'])
226
+
227
+ with tab1:
228
+ st.header('Learner Profile')
229
+ mode = st.radio('Input Mode', ['Manual Mastery Input','Simulate Student Sequence'], horizontal=True)
230
+ mastery_vector = MasteryVector(graph, threshold)
231
+ if mode=='Manual Mastery Input':
232
+ st.markdown('Set your current mastery level for each topic:')
233
+ cols=st.columns(2); nodes=list(graph.nodes)
234
+ for i,node in enumerate(nodes):
235
+ label=graph.nodes[node].get('label',node); level=graph.nodes[node].get('level','')
236
+ val=cols[i%2].slider(f'{label} ({level})',0.0,1.0,0.0,0.05,key=f'mastery_{node}')
237
+ mastery_vector.update(node,val)
238
  else:
239
+ seq_length=st.slider('Sequence Length',10,200,50)
240
+ seed=st.number_input('Student Seed',1,1000,42,1)
241
  np.random.seed(int(seed))
242
+ sim_skills=np.random.randint(0,config['num_skills'],seq_length).tolist()
243
+ sim_corrects=np.random.randint(0,2,seq_length).tolist()
244
+ skill_probs=run_sakt_inference(model,config,sim_skills,sim_corrects,device)
245
+ mastery_vector=build_mastery_vector(skill_probs,graph,skill_encoder,domain_key,threshold,soft_threshold)
246
+ st.success(f'SAKT inference complete — {len(skill_probs)} skill predictions generated')
247
+ real_skills, real_probs, seq_len = get_attention_weights(model,config,sim_skills,sim_corrects,device)
248
+ if len(real_probs) > 0:
249
+ st.markdown('**📈 SAKT Mastery Signal (last 10 interactions):**')
250
+ attn_df = pd.DataFrame({'Interaction Step': [f'Step {seq_len-len(real_probs)+i+1}' for i in range(len(real_probs))], 'Predicted Mastery': [round(float(p),3) for p in real_probs]})
251
+ st.bar_chart(attn_df.set_index('Interaction Step'))
252
+ if st.button('🚀 Generate Recommendations', type='primary'):
253
+ output=pipeline.run(mastery_vector)
254
+ summary=mastery_vector.get_mastery_summary()
255
+ col1,col2,col3,col4,col5=st.columns(5)
256
+ col1.metric('Topics Mastered',f"{summary['mastered']} / {summary['total_topics']}")
257
+ col2.metric('Mastery Rate',f"{summary['mastery_rate']:.1%}")
258
+ col3.metric('✅ Approved',output['total_approved'])
259
+ col4.metric('⚠️ Challenging',output['total_challenging'])
260
+ col5.metric('Violation Rate',f"{output['prerequisite_violation_rate']:.1%}")
261
+ st.markdown('---')
262
+ st.subheader(f'✅ Top {top_n} Approved Recommendations')
263
+ if not output['top_recommendations']: st.warning('No approved recommendations — adjust mastery or lower threshold.')
264
  else:
265
+ for i,rec in enumerate(output['top_recommendations'],1):
266
+ with st.expander(f"{i}. {rec['topic_label']} — Score: {rec['score']} | Mastery: {rec['mastery']:.1%}", expanded=(i<=3)):
267
+ st.markdown(f"**Reasoning:** {rec['reasoning']}")
268
+ st.progress(rec['mastery'])
269
+ if output['challenging']:
270
+ st.markdown('---')
271
+ st.subheader('⚠️ Challenging Topics (proceed with caution)')
272
+ for rec in output['challenging']:
273
+ with st.expander(f"{rec['topic_label']} | Mastery: {rec['mastery']:.1%}"):
274
+ st.markdown(f"**Reasoning:** {rec['reasoning']}")
275
+ st.progress(rec['mastery'])
276
+ if output['vetoed_sample']:
277
+ st.markdown('---'); st.subheader('❌ Sample Vetoed Topics')
278
+ for rec in output['vetoed_sample']:
279
+ with st.expander(f"✗ {rec['topic_label']}"):
280
+ st.markdown(f"**Reason:** {rec['reasoning']}")
281
+
282
+ with tab2:
283
+ st.header('🔍 What-If Prerequisite Simulator')
284
+ st.markdown('Explore how mastering a topic unlocks future learning paths — or what is blocking you from starting it.')
285
+ nodes_list = list(graph.nodes)
286
+ labels_list = [graph.nodes[n].get('label',n) for n in nodes_list]
287
+ selected_label = st.selectbox('Select a topic to analyse:', labels_list)
288
+ selected_node = nodes_list[labels_list.index(selected_label)]
289
+ if st.button('🔍 Analyse Topic', type='primary'):
290
+ result = what_if_analysis(selected_node, graph)
291
+ col1, col2 = st.columns(2)
292
+ with col1:
293
+ st.subheader('🔓 If you master this topic...')
294
+ if result['direct_unlocks']:
295
+ st.markdown(f"**Directly unlocks {len(result['direct_unlocks'])} topic(s):**")
296
+ for t in result['direct_unlocks']: st.markdown(f' → {t}')
297
  else:
298
+ st.info('This is a terminal topic it does not unlock further topics in this map.')
299
+ if result['all_unlocks']:
300
+ st.markdown(f"**Total topics eventually unlocked: {result['total_unlocked']}**")
301
+ with col2:
302
+ st.subheader('🔒 To start this topic you need...')
303
+ if result['blocked_by']:
304
+ st.markdown('**Prerequisites required:**')
305
+ for t in result['blocked_by']: st.markdown(f' ✓ {t}')
306
+ else:
307
+ st.success('This is a foundational topic — no prerequisites needed. You can start it now!')
308
+ if result['all_unlocks']:
309
+ st.markdown('---')
310
+ st.markdown('**Full learning path unlocked:**')
311
+ st.markdown(' → '.join([selected_label] + result['all_unlocks'][:8]) + ('...' if len(result['all_unlocks'])>8 else ''))
312
+
313
+ with tab3:
314
+ st.header(f'{domain} Knowledge Map')
315
+ st.markdown(f"**{graph.number_of_nodes()} topics** | **{graph.number_of_edges()} prerequisite relationships**")
316
+ rows=[]
317
+ for node in graph.nodes:
318
+ label=graph.nodes[node].get('label',node); level=graph.nodes[node].get('level',''); term=graph.nodes[node].get('term','')
319
+ prereqs=[graph.nodes[p].get('label',p) for p in graph.predecessors(node)]
320
+ rows.append({'Topic':label,'Level':level,'Term':term,'Prerequisites':', '.join(prereqs) if prereqs else 'None (Foundational)'})
321
+ st.dataframe(pd.DataFrame(rows), use_container_width=True, hide_index=True)
322
+ longest=nx.dag_longest_path(graph)
323
+ st.markdown('**Longest prerequisite chain:**')
324
+ st.markdown(' → '.join([graph.nodes[n].get('label',n) for n in longest]))
325
+
326
+ with tab4:
327
+ st.header('System Diagnostics')
328
+ col1,col2=st.columns(2)
329
+ with col1: st.subheader('Model Configuration'); st.json(config)
330
+ with col2:
331
+ st.subheader('DAG Statistics')
332
+ st.json({'domain':domain,'nodes':graph.number_of_nodes(),'edges':graph.number_of_edges(),'is_valid_dag':nx.is_directed_acyclic_graph(graph),'longest_path':len(nx.dag_longest_path(graph))})
333
+ st.subheader('Constraint Layer')
334
+ st.markdown(f'**Mastery threshold:** {threshold:.0%} topics above this are considered mastered')
335
+ st.markdown(f'**Challenging threshold:** {soft_threshold:.0%} — topics between this and mastery threshold are marked ⚠️ Challenging')
336
+ st.markdown('**Hard veto:** topics with prerequisites below challenging threshold are fully blocked')
337
+ st.subheader('Domain Switching')
338
+ dcol1,dcol2=st.columns(2)
339
+ with dcol1: st.metric('Math DAG',f'{math_graph.number_of_nodes()} topics')
340
+ with dcol2: st.metric('CS DAG',f'{cs_graph.number_of_nodes()} topics')
341
+
342
+ if __name__ == '__main__':
343
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data/skill_encoder.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/skill_encoder_v2.csv DELETED
The diff for this file is too large to render. See raw diff
 
index.html DELETED
@@ -1,609 +0,0 @@
1
- <!DOCTYPE html>
2
- <html lang="en">
3
- <head>
4
- <meta charset="UTF-8" />
5
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
- <title>PLRS — Personalized Learning Recommendation System</title>
7
- <meta name="description" content="Constraint-aware personalized learning recommendations. Plug in your curriculum, get intelligent recommendations out." />
8
- <link rel="preconnect" href="https://fonts.googleapis.com" />
9
- <link href="https://fonts.googleapis.com/css2?family=DM+Mono:ital,wght@0,300;0,400;0,500;1,300&family=Syne:wght@400;600;700;800&display=swap" rel="stylesheet" />
10
-
11
- <style>
12
- *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
13
-
14
- :root {
15
- --bg: #080c18;
16
- --bg2: #0d1221;
17
- --bg3: #131a2e;
18
- --border: #1e2a40;
19
- --border2: #1e3a5f;
20
- --text: #c8d0e0;
21
- --text-dim: #4a5568;
22
- --text-hi: #e8edf5;
23
- --blue: #3d8bcd;
24
- --green: #22c55e;
25
- --amber: #f59e0b;
26
- --red: #ef4444;
27
- --mono: 'DM Mono', monospace;
28
- --sans: 'Syne', sans-serif;
29
- }
30
-
31
- html { scroll-behavior: smooth; }
32
-
33
- body {
34
- background: var(--bg);
35
- color: var(--text);
36
- font-family: var(--sans);
37
- line-height: 1.6;
38
- overflow-x: hidden;
39
- }
40
-
41
- /* ── Noise overlay ── */
42
- body::before {
43
- content: '';
44
- position: fixed; inset: 0;
45
- background-image: url("data:image/svg+xml,%3Csvg viewBox='0 0 256 256' xmlns='http://www.w3.org/2000/svg'%3E%3Cfilter id='n'%3E%3CfeTurbulence type='fractalNoise' baseFrequency='0.9' numOctaves='4' stitchTiles='stitch'/%3E%3C/filter%3E%3Crect width='100%25' height='100%25' filter='url(%23n)' opacity='0.03'/%3E%3C/svg%3E");
46
- pointer-events: none;
47
- z-index: 0;
48
- opacity: 0.4;
49
- }
50
-
51
- /* ── Nav ── */
52
- nav {
53
- position: fixed; top: 0; left: 0; right: 0;
54
- display: flex; align-items: center; justify-content: space-between;
55
- padding: 1rem 2.5rem;
56
- background: rgba(8, 12, 24, 0.85);
57
- backdrop-filter: blur(12px);
58
- border-bottom: 1px solid var(--border);
59
- z-index: 100;
60
- }
61
- .nav-logo {
62
- font-weight: 800; font-size: 1.1rem; color: var(--text-hi);
63
- letter-spacing: -0.02em; text-decoration: none;
64
- }
65
- .nav-logo span { color: var(--blue); }
66
- .nav-links { display: flex; gap: 2rem; align-items: center; }
67
- .nav-links a {
68
- font-family: var(--mono); font-size: 0.7rem; letter-spacing: 0.1em;
69
- color: var(--text-dim); text-decoration: none; text-transform: uppercase;
70
- transition: color 0.2s;
71
- }
72
- .nav-links a:hover { color: var(--blue); }
73
- .btn {
74
- display: inline-flex; align-items: center; gap: 0.5rem;
75
- padding: 0.5rem 1.1rem; border-radius: 3px; font-family: var(--mono);
76
- font-size: 0.7rem; letter-spacing: 0.08em; text-decoration: none;
77
- transition: all 0.2s; cursor: pointer; border: none;
78
- }
79
- .btn-primary {
80
- background: var(--blue); color: #fff;
81
- }
82
- .btn-primary:hover { background: #4d9bdd; }
83
- .btn-outline {
84
- background: transparent; color: var(--blue);
85
- border: 1px solid var(--border2);
86
- }
87
- .btn-outline:hover { border-color: var(--blue); background: rgba(61,139,205,0.07); }
88
-
89
- /* ── Hero ── */
90
- .hero {
91
- min-height: 100vh;
92
- display: flex; flex-direction: column; justify-content: center;
93
- padding: 8rem 2.5rem 5rem;
94
- max-width: 1100px; margin: 0 auto;
95
- position: relative;
96
- }
97
- .hero-eyebrow {
98
- font-family: var(--mono); font-size: 0.7rem; letter-spacing: 0.18em;
99
- color: var(--blue); text-transform: uppercase; margin-bottom: 1.5rem;
100
- display: flex; align-items: center; gap: 0.75rem;
101
- }
102
- .hero-eyebrow::before {
103
- content: ''; display: block; width: 2rem; height: 1px; background: var(--blue);
104
- }
105
- .hero h1 {
106
- font-size: clamp(2.8rem, 6vw, 5rem);
107
- font-weight: 800; line-height: 1.05;
108
- letter-spacing: -0.03em; color: var(--text-hi);
109
- margin-bottom: 1.5rem;
110
- }
111
- .hero h1 em {
112
- font-style: normal; color: var(--blue);
113
- }
114
- .hero-sub {
115
- font-size: 1.1rem; color: var(--text-dim);
116
- max-width: 560px; margin-bottom: 2.5rem;
117
- line-height: 1.7;
118
- }
119
- .hero-ctas { display: flex; gap: 0.75rem; flex-wrap: wrap; margin-bottom: 4rem; }
120
- .btn-hero {
121
- padding: 0.75rem 1.5rem; font-size: 0.8rem;
122
- }
123
-
124
- /* ── Stat strip ── */
125
- .stat-strip {
126
- display: flex; gap: 2.5rem; flex-wrap: wrap;
127
- border-top: 1px solid var(--border);
128
- padding-top: 2rem;
129
- }
130
- .stat-item {}
131
- .stat-num {
132
- font-size: 2rem; font-weight: 800; color: var(--text-hi);
133
- line-height: 1;
134
- }
135
- .stat-num span { color: var(--green); }
136
- .stat-label {
137
- font-family: var(--mono); font-size: 0.65rem; letter-spacing: 0.1em;
138
- color: var(--text-dim); text-transform: uppercase; margin-top: 0.2rem;
139
- }
140
-
141
- /* ── Grid background decoration ── */
142
- .hero-grid {
143
- position: absolute; top: 0; right: -5%; bottom: 0; width: 50%;
144
- background-image:
145
- linear-gradient(var(--border) 1px, transparent 1px),
146
- linear-gradient(90deg, var(--border) 1px, transparent 1px);
147
- background-size: 40px 40px;
148
- mask-image: linear-gradient(to left, rgba(0,0,0,0.15), transparent 70%);
149
- pointer-events: none;
150
- }
151
-
152
- /* ── Section ── */
153
- section {
154
- max-width: 1100px; margin: 0 auto;
155
- padding: 5rem 2.5rem;
156
- }
157
- .section-label {
158
- font-family: var(--mono); font-size: 0.65rem; letter-spacing: 0.18em;
159
- color: var(--blue); text-transform: uppercase;
160
- display: flex; align-items: center; gap: 0.75rem;
161
- margin-bottom: 1rem;
162
- }
163
- .section-label::before {
164
- content: ''; display: block; width: 1.5rem; height: 1px; background: var(--blue);
165
- }
166
- .section-title {
167
- font-size: clamp(1.8rem, 3.5vw, 2.5rem);
168
- font-weight: 800; letter-spacing: -0.02em; color: var(--text-hi);
169
- margin-bottom: 1rem;
170
- }
171
- .section-body {
172
- color: var(--text-dim); font-size: 0.95rem; max-width: 600px;
173
- line-height: 1.8; margin-bottom: 2.5rem;
174
- }
175
-
176
- /* ── Architecture flow ── */
177
- .arch-flow {
178
- display: flex; align-items: center; flex-wrap: wrap;
179
- gap: 0; margin: 2.5rem 0;
180
- }
181
- .arch-node {
182
- background: var(--bg2); border: 1px solid var(--border);
183
- border-radius: 4px; padding: 0.7rem 1rem;
184
- font-family: var(--mono); font-size: 0.72rem; color: var(--text);
185
- letter-spacing: 0.04em; position: relative;
186
- }
187
- .arch-node.highlight { border-color: var(--blue); color: var(--blue); }
188
- .arch-arrow {
189
- font-family: var(--mono); color: var(--border2); padding: 0 0.4rem;
190
- font-size: 0.9rem;
191
- }
192
-
193
- /* ── Three-tier cards ── */
194
- .tier-grid { display: grid; grid-template-columns: repeat(3, 1fr); gap: 1rem; margin-top: 2rem; }
195
- .tier-card {
196
- background: var(--bg2); border: 1px solid var(--border);
197
- border-radius: 4px; padding: 1.5rem;
198
- position: relative; overflow: hidden;
199
- }
200
- .tier-card::before {
201
- content: ''; position: absolute; top: 0; left: 0; right: 0; height: 2px;
202
- background: var(--accent);
203
- }
204
- .tier-card.green { --accent: var(--green); }
205
- .tier-card.amber { --accent: var(--amber); }
206
- .tier-card.red { --accent: var(--red); }
207
- .tier-icon { font-size: 1.5rem; margin-bottom: 0.75rem; }
208
- .tier-name {
209
- font-weight: 700; font-size: 1rem; color: var(--text-hi);
210
- margin-bottom: 0.35rem;
211
- }
212
- .tier-desc { font-size: 0.8rem; color: var(--text-dim); line-height: 1.6; }
213
-
214
- /* ── Results table ── */
215
- .results-table {
216
- width: 100%; border-collapse: collapse;
217
- font-family: var(--mono); font-size: 0.78rem;
218
- margin-top: 2rem;
219
- }
220
- .results-table th {
221
- text-align: left; padding: 0.6rem 1rem;
222
- color: var(--text-dim); letter-spacing: 0.1em; text-transform: uppercase;
223
- font-size: 0.65rem; border-bottom: 1px solid var(--border);
224
- }
225
- .results-table td {
226
- padding: 0.75rem 1rem; border-bottom: 1px solid var(--border);
227
- color: var(--text);
228
- }
229
- .results-table tr:last-child td { border-bottom: none; }
230
- .results-table tr.highlight-row td { color: var(--text-hi); }
231
- .badge-green {
232
- background: rgba(34,197,94,0.1); color: var(--green);
233
- border: 1px solid rgba(34,197,94,0.3);
234
- padding: 1px 7px; border-radius: 2px; font-size: 0.65rem;
235
- }
236
- .badge-red {
237
- background: rgba(239,68,68,0.1); color: var(--red);
238
- border: 1px solid rgba(239,68,68,0.3);
239
- padding: 1px 7px; border-radius: 2px; font-size: 0.65rem;
240
- }
241
-
242
- /* ── Code block ── */
243
- .code-wrap {
244
- background: var(--bg2); border: 1px solid var(--border);
245
- border-radius: 4px; overflow: hidden; margin-top: 2rem;
246
- }
247
- .code-header {
248
- display: flex; align-items: center; justify-content: space-between;
249
- padding: 0.6rem 1rem; border-bottom: 1px solid var(--border);
250
- background: var(--bg3);
251
- }
252
- .code-dots { display: flex; gap: 5px; }
253
- .code-dots span {
254
- width: 10px; height: 10px; border-radius: 50%;
255
- background: var(--border2);
256
- }
257
- .code-lang {
258
- font-family: var(--mono); font-size: 0.62rem;
259
- color: var(--text-dim); letter-spacing: 0.1em;
260
- }
261
- pre {
262
- padding: 1.5rem;
263
- font-family: var(--mono); font-size: 0.78rem;
264
- line-height: 1.7; color: var(--text);
265
- overflow-x: auto;
266
- }
267
- .cm { color: #4a5568; } /* comment */
268
- .ck { color: #3d8bcd; } /* keyword */
269
- .cs { color: #22c55e; } /* string */
270
- .cn { color: #f59e0b; } /* number / name */
271
- .cf { color: #c084fc; } /* function */
272
-
273
- /* ── Feature grid ── */
274
- .feature-grid { display: grid; grid-template-columns: repeat(2, 1fr); gap: 1px; background: var(--border); margin-top: 2rem; border: 1px solid var(--border); border-radius: 4px; overflow: hidden; }
275
- .feature-cell {
276
- background: var(--bg); padding: 1.5rem;
277
- }
278
- .feature-icon { font-size: 1.2rem; margin-bottom: 0.75rem; }
279
- .feature-title { font-weight: 700; color: var(--text-hi); margin-bottom: 0.35rem; font-size: 0.9rem; }
280
- .feature-desc { font-size: 0.78rem; color: var(--text-dim); line-height: 1.6; }
281
-
282
- /* ── CTA section ── */
283
- .cta-section {
284
- background: var(--bg2);
285
- border-top: 1px solid var(--border);
286
- border-bottom: 1px solid var(--border);
287
- padding: 5rem 2.5rem;
288
- text-align: center;
289
- }
290
- .cta-inner { max-width: 600px; margin: 0 auto; }
291
- .cta-title { font-size: 2.2rem; font-weight: 800; letter-spacing: -0.02em; color: var(--text-hi); margin-bottom: 1rem; }
292
- .cta-sub { color: var(--text-dim); margin-bottom: 2rem; line-height: 1.7; }
293
- .cta-btns { display: flex; gap: 0.75rem; justify-content: center; flex-wrap: wrap; }
294
-
295
- /* ── Footer ── */
296
- footer {
297
- border-top: 1px solid var(--border);
298
- padding: 2rem 2.5rem;
299
- display: flex; justify-content: space-between; align-items: center;
300
- flex-wrap: wrap; gap: 1rem;
301
- max-width: 100%;
302
- }
303
- .footer-left { font-family: var(--mono); font-size: 0.65rem; color: var(--text-dim); }
304
- .footer-links { display: flex; gap: 1.5rem; }
305
- .footer-links a { font-family: var(--mono); font-size: 0.65rem; color: var(--text-dim); text-decoration: none; }
306
- .footer-links a:hover { color: var(--blue); }
307
-
308
- /* ── Animations ── */
309
- @keyframes fadeUp {
310
- from { opacity: 0; transform: translateY(20px); }
311
- to { opacity: 1; transform: translateY(0); }
312
- }
313
- .hero-eyebrow { animation: fadeUp 0.5s ease 0.1s both; }
314
- .hero h1 { animation: fadeUp 0.5s ease 0.2s both; }
315
- .hero-sub { animation: fadeUp 0.5s ease 0.3s both; }
316
- .hero-ctas { animation: fadeUp 0.5s ease 0.4s both; }
317
- .stat-strip { animation: fadeUp 0.5s ease 0.5s both; }
318
-
319
- /* ── Responsive ── */
320
- @media (max-width: 768px) {
321
- nav { padding: 0.75rem 1.25rem; }
322
- .nav-links .btn { display: none; }
323
- .hero { padding: 7rem 1.25rem 4rem; }
324
- .tier-grid { grid-template-columns: 1fr; }
325
- .feature-grid { grid-template-columns: 1fr; }
326
- section { padding: 3rem 1.25rem; }
327
- .arch-flow { gap: 0.25rem; }
328
- }
329
- </style>
330
- </head>
331
- <body>
332
-
333
- <!-- ── Nav ── -->
334
- <nav>
335
- <a href="#" class="nav-logo">PL<span>RS</span></a>
336
- <div class="nav-links">
337
- <a href="#how-it-works">How it works</a>
338
- <a href="#results">Results</a>
339
- <a href="#quickstart">Quickstart</a>
340
- <a href="https://github.com/clementina-tom/plrs" target="_blank">GitHub</a>
341
- <a href="https://huggingface.co/spaces/Clementio/PLRS" class="btn btn-primary btn-hero" target="_blank">Live Demo →</a>
342
- </div>
343
- </nav>
344
-
345
- <!-- ── Hero ── -->
346
- <div class="hero">
347
- <div class="hero-grid"></div>
348
-
349
- <div class="hero-eyebrow">Knowledge Tracing · Constraint-Aware · Open Source</div>
350
-
351
- <h1>Recommendations that<br/><em>respect</em> how learning works.</h1>
352
-
353
- <p class="hero-sub">
354
- PLRS combines Self-Attentive Knowledge Tracing with a DAG prerequisite constraint layer
355
- to generate personalized learning recommendations that are pedagogically sound —
356
- not just statistically optimal.
357
- </p>
358
-
359
- <div class="hero-ctas">
360
- <a href="https://huggingface.co/spaces/Clementio/PLRS" target="_blank" class="btn btn-primary btn-hero">
361
- Try the live demo
362
- </a>
363
- <a href="https://github.com/clementina-tom/plrs" target="_blank" class="btn btn-outline btn-hero">
364
- View on GitHub
365
- </a>
366
- <a href="#quickstart" class="btn btn-outline btn-hero">
367
- Quickstart
368
- </a>
369
- </div>
370
-
371
- <div class="stat-strip">
372
- <div class="stat-item">
373
- <div class="stat-num"><span>0.0</span>%</div>
374
- <div class="stat-label">Prerequisite violation rate</div>
375
- </div>
376
- <div class="stat-item">
377
- <div class="stat-num">0.7692</div>
378
- <div class="stat-label">SAKT Val AUC (OULAD)</div>
379
- </div>
380
- <div class="stat-item">
381
- <div class="stat-num">69</div>
382
- <div class="stat-label">Curriculum topics (2 domains)</div>
383
- </div>
384
- <div class="stat-item">
385
- <div class="stat-num">52</div>
386
- <div class="stat-label">Tests passing</div>
387
- </div>
388
- </div>
389
- </div>
390
-
391
- <!-- ── How it works ── -->
392
- <section id="how-it-works">
393
- <div class="section-label">Architecture</div>
394
- <h2 class="section-title">Three layers. One guarantee.</h2>
395
- <p class="section-body">
396
- Standard recommendation systems optimise for engagement or accuracy —
397
- they will happily recommend Calculus to a student who hasn't mastered Algebra.
398
- PLRS adds a constraint layer that makes this <em>structurally impossible</em>.
399
- </p>
400
-
401
- <div class="arch-flow">
402
- <div class="arch-node">Student History</div>
403
- <div class="arch-arrow">→</div>
404
- <div class="arch-node highlight">SAKT Model</div>
405
- <div class="arch-arrow">→</div>
406
- <div class="arch-node">Mastery Vector</div>
407
- <div class="arch-arrow">→</div>
408
- <div class="arch-node highlight">DAG Constraints</div>
409
- <div class="arch-arrow">→</div>
410
- <div class="arch-node">Multi-Objective Ranker</div>
411
- <div class="arch-arrow">→</div>
412
- <div class="arch-node highlight">Recommendations</div>
413
- </div>
414
-
415
- <div class="tier-grid">
416
- <div class="tier-card green">
417
- <div class="tier-icon">✅</div>
418
- <div class="tier-name">Approved</div>
419
- <div class="tier-desc">All prerequisites met above the mastery threshold. Student is ready to learn this topic now.</div>
420
- </div>
421
- <div class="tier-card amber">
422
- <div class="tier-icon">⚠️</div>
423
- <div class="tier-name">Challenging</div>
424
- <div class="tier-desc">Prerequisites partially met — above the soft threshold but below full mastery. Proceed with awareness.</div>
425
- </div>
426
- <div class="tier-card red">
427
- <div class="tier-icon">❌</div>
428
- <div class="tier-name">Vetoed</div>
429
- <div class="tier-desc">One or more prerequisites not met. Structurally blocked until foundations are solid.</div>
430
- </div>
431
- </div>
432
- </section>
433
-
434
- <!-- ── Results ── -->
435
- <section id="results" style="border-top: 1px solid var(--border);">
436
- <div class="section-label">Evaluation</div>
437
- <h2 class="section-title">0% violation rate. Not a tuning choice.</h2>
438
- <p class="section-body">
439
- Evaluated on the Open University Learning Analytics Dataset (OULAD) with
440
- Nigerian secondary school curriculum knowledge maps. The 0% violation rate
441
- is a structural guarantee from the DAG constraint layer — not a hyperparameter.
442
- </p>
443
-
444
- <table class="results-table">
445
- <thead>
446
- <tr>
447
- <th>Model</th>
448
- <th>Val AUC</th>
449
- <th>Prerequisite Violation Rate</th>
450
- <th>Coverage</th>
451
- </tr>
452
- </thead>
453
- <tbody>
454
- <tr class="highlight-row">
455
- <td><strong>PLRS (SAKT + DAG)</strong></td>
456
- <td><strong>0.7692</strong></td>
457
- <td><span class="badge-green">0.0%</span></td>
458
- <td>Full curriculum</td>
459
- </tr>
460
- <tr>
461
- <td>Collaborative Filtering</td>
462
- <td>—</td>
463
- <td><span class="badge-red">81.3%</span></td>
464
- <td>Partial</td>
465
- </tr>
466
- <tr>
467
- <td>Matrix Factorization</td>
468
- <td>—</td>
469
- <td><span class="badge-red">83.7%</span></td>
470
- <td>Partial</td>
471
- </tr>
472
- <tr>
473
- <td>BKT (baseline)</td>
474
- <td>~0.67</td>
475
- <td><span class="badge-red">No constraint layer</span></td>
476
- <td>Partial</td>
477
- </tr>
478
- </tbody>
479
- </table>
480
- </section>
481
-
482
- <!-- ── Quickstart ── -->
483
- <section id="quickstart" style="border-top: 1px solid var(--border);">
484
- <div class="section-label">Quickstart</div>
485
- <h2 class="section-title">Plug in your curriculum.</h2>
486
- <p class="section-body">
487
- PLRS is curriculum-agnostic. Define your knowledge graph in a simple JSON format
488
- and get recommendations immediately. No retraining required for new domains.
489
- </p>
490
-
491
- <div class="code-wrap">
492
- <div class="code-header">
493
- <div class="code-dots"><span></span><span></span><span></span></div>
494
- <div class="code-lang">PYTHON</div>
495
- </div>
496
- <pre><span class="ck">from</span> plrs <span class="ck">import</span> PLRSPipeline
497
- <span class="ck">from</span> plrs.curriculum <span class="ck">import</span> load_dag
498
-
499
- <span class="cm"># Load your curriculum (JSON knowledge graph)</span>
500
- curriculum = <span class="cf">load_dag</span>(<span class="cs">"math_dag.json"</span>)
501
-
502
- <span class="cm"># Create pipeline — no model needed for mastery-dict mode</span>
503
- pipeline = <span class="cf">PLRSPipeline</span>(curriculum)
504
-
505
- <span class="cm"># Get recommendations from student mastery scores</span>
506
- results = pipeline.<span class="cf">recommend_from_mastery</span>({
507
- <span class="cs">"whole_numbers"</span>: <span class="cn">0.90</span>,
508
- <span class="cs">"algebraic_expressions"</span>: <span class="cn">0.75</span>,
509
- <span class="cs">"quadratic_equations"</span>: <span class="cn">0.40</span>,
510
- })
511
-
512
- <span class="ck">for</span> rec <span class="ck">in</span> results[<span class="cs">"approved"</span>]:
513
- <span class="cf">print</span>(<span class="cs">f"✅ {rec['topic_label']} (score={rec['score']})"</span>)
514
- <span class="cf">print</span>(<span class="cs">f" {rec['reasoning']}"</span>)
515
-
516
- <span class="cm"># What-if: what does mastering this topic unlock?</span>
517
- wi = pipeline.<span class="cf">what_if</span>(<span class="cs">"algebraic_expressions"</span>)
518
- <span class="cf">print</span>(<span class="cs">f"Unlocks {wi['total_unlocked']} downstream topics"</span>)</pre>
519
- </div>
520
-
521
- <div class="code-wrap" style="margin-top: 1rem;">
522
- <div class="code-header">
523
- <div class="code-dots"><span></span><span></span><span></span></div>
524
- <div class="code-lang">REST API</div>
525
- </div>
526
- <pre><span class="cm"># Start the server</span>
527
- $ python scripts/serve.py
528
- <span class="cm"># → http://127.0.0.1:8000/docs</span>
529
-
530
- <span class="cm"># Get recommendations</span>
531
- $ curl -X POST http://localhost:<span class="cn">8000</span>/recommend \
532
- -H <span class="cs">"Content-Type: application/json"</span> \
533
- -d <span class="cs">'{"domain":"math","mastery_scores":{"whole_numbers":0.9}}'</span></pre>
534
- </div>
535
- </section>
536
-
537
- <!-- ── Features ── -->
538
- <section style="border-top: 1px solid var(--border);">
539
- <div class="section-label">Features</div>
540
- <h2 class="section-title">Built for real deployment.</h2>
541
-
542
- <div class="feature-grid">
543
- <div class="feature-cell">
544
- <div class="feature-icon">🔌</div>
545
- <div class="feature-title">Curriculum-agnostic</div>
546
- <div class="feature-desc">Define any knowledge graph in a simple JSON format. Ships with Nigerian secondary school Maths and CS Fundamentals (NERDC JSS3–SS2).</div>
547
- </div>
548
- <div class="feature-cell">
549
- <div class="feature-icon">⚡</div>
550
- <div class="feature-title">FastAPI REST backend</div>
551
- <div class="feature-desc">Production-ready API with <code>/recommend</code>, <code>/what-if</code>, and <code>/curriculum</code> endpoints. Auto-generated OpenAPI docs.</div>
552
- </div>
553
- <div class="feature-cell">
554
- <div class="feature-icon">🧠</div>
555
- <div class="feature-title">SAKT + Forgetting Curve</div>
556
- <div class="feature-desc">Self-Attentive Knowledge Tracing with optional Ebbinghaus decay attention — older interactions contribute less to current mastery estimates.</div>
557
- </div>
558
- <div class="feature-cell">
559
- <div class="feature-icon">🔍</div>
560
- <div class="feature-title">What-If Simulator</div>
561
- <div class="feature-desc">"If I master Trigonometry now, what unlocks?" — live DAG traversal shows direct and transitive downstream topics.</div>
562
- </div>
563
- <div class="feature-cell">
564
- <div class="feature-icon">📦</div>
565
- <div class="feature-title">PyPI-ready package</div>
566
- <div class="feature-desc"><code>pip install plrs</code> — modular architecture with clean public API. Full type annotations throughout.</div>
567
- </div>
568
- <div class="feature-cell">
569
- <div class="feature-icon">🧪</div>
570
- <div class="feature-title">52 tests, CI on 3 Python versions</div>
571
- <div class="feature-desc">Unit tests, API integration tests, and evaluator tests. GitHub Actions runs on Python 3.10, 3.11, and 3.12.</div>
572
- </div>
573
- </div>
574
- </section>
575
-
576
- <!-- ── CTA ── -->
577
- <div class="cta-section">
578
- <div class="cta-inner">
579
- <div class="cta-title">Try it now — no setup required.</div>
580
- <p class="cta-sub">
581
- The live demo runs the full pipeline in your browser.
582
- Adjust mastery sliders, simulate student sequences, explore the curriculum graph.
583
- </p>
584
- <div class="cta-btns">
585
- <a href="https://huggingface.co/spaces/Clementio/PLRS" target="_blank" class="btn btn-primary btn-hero">
586
- Open live demo →
587
- </a>
588
- <a href="https://github.com/clementina-tom/plrs" target="_blank" class="btn btn-outline btn-hero">
589
- Star on GitHub
590
- </a>
591
- </div>
592
- </div>
593
- </div>
594
-
595
- <!-- ── Footer ── */
596
- <footer>
597
- <div class="footer-left">
598
- PLRS — Personalized Learning Recommendation System<br/>
599
- MIT License · Built by <a href="https://github.com/clementina-tom" style="color:var(--blue);text-decoration:none;">Clementina Tom</a>
600
- </div>
601
- <div class="footer-links">
602
- <a href="https://github.com/clementina-tom/plrs" target="_blank">GitHub</a>
603
- <a href="https://huggingface.co/spaces/Clementio/PLRS" target="_blank">HuggingFace</a>
604
- <a href="https://huggingface.co/spaces/Clementio/PLRS" target="_blank">Live Demo</a>
605
- </div>
606
- </footer>
607
-
608
- </body>
609
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{data/knowledge_maps → knowledge_maps}/cs_dag.json RENAMED
File without changes
{data/knowledge_maps → knowledge_maps}/math_dag.json RENAMED
File without changes
plrs/__init__.py DELETED
@@ -1,30 +0,0 @@
1
- """
2
- PLRS — Personalized Learning Recommendation System
3
- ====================================================
4
- Constraint-aware personalized learning recommendations.
5
- Plug in your curriculum DAG, get intelligent recommendations out.
6
-
7
- Quick start:
8
- from plrs import PLRSPipeline
9
- from plrs.curriculum import load_dag
10
-
11
- graph = load_dag("my_curriculum.json")
12
- pipeline = PLRSPipeline(graph)
13
- results = pipeline.recommend(student_history)
14
- """
15
-
16
- from plrs.pipeline import PLRSPipeline
17
- from plrs.model.sakt import SAKTModel
18
- from plrs.constraints.dag import DAGConstraintLayer
19
- from plrs.ranking.ranker import MultiObjectiveRanker
20
- from plrs.curriculum.loader import load_dag, CurriculumGraph
21
-
22
- __version__ = "0.1.0"
23
- __all__ = [
24
- "PLRSPipeline",
25
- "SAKTModel",
26
- "DAGConstraintLayer",
27
- "MultiObjectiveRanker",
28
- "load_dag",
29
- "CurriculumGraph",
30
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
plrs/constraints/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from plrs.constraints.dag import DAGConstraintLayer, MasteryVector, ConstraintResult
2
-
3
- __all__ = ["DAGConstraintLayer", "MasteryVector", "ConstraintResult"]
 
 
 
 
plrs/constraints/dag.py DELETED
@@ -1,201 +0,0 @@
1
- """
2
- plrs.constraints.dag
3
- ====================
4
- DAG-based prerequisite constraint layer.
5
-
6
- Three-tier classification:
7
- - approved : prerequisites met, topic is ready
8
- - challenging : prerequisites partially met (above soft threshold)
9
- - vetoed : prerequisites not met, topic is blocked
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- from dataclasses import dataclass, field
15
- from typing import Literal
16
-
17
- from plrs.curriculum.loader import CurriculumGraph
18
-
19
- Status = Literal["approved", "challenging", "vetoed"]
20
-
21
-
22
- class MasteryVector:
23
- """
24
- Holds a student's estimated mastery probability per topic.
25
-
26
- Parameters
27
- ----------
28
- curriculum : CurriculumGraph
29
- threshold : float
30
- Mastery threshold — above this, a topic is considered mastered (default 0.70).
31
- soft_threshold : float
32
- Soft threshold — above this but below threshold, a topic is "challenging" (default 0.50).
33
- """
34
-
35
- def __init__(
36
- self,
37
- curriculum: CurriculumGraph,
38
- threshold: float = 0.70,
39
- soft_threshold: float = 0.50,
40
- ) -> None:
41
- self.curriculum = curriculum
42
- self.threshold = threshold
43
- self.soft_threshold = soft_threshold
44
- self._mastery: dict[str, float] = {node: 0.0 for node in curriculum.nodes}
45
-
46
- # ------------------------------------------------------------------ #
47
- # Mutations #
48
- # ------------------------------------------------------------------ #
49
-
50
- def update(self, topic_id: str, probability: float) -> None:
51
- """Set mastery probability for a topic (clamped to [0, 1])."""
52
- if topic_id in self._mastery:
53
- self._mastery[topic_id] = max(0.0, min(1.0, probability))
54
-
55
- def update_batch(self, updates: dict[str, float]) -> None:
56
- """Update multiple topics at once."""
57
- for topic_id, prob in updates.items():
58
- self.update(topic_id, prob)
59
-
60
- def cascade_up(self) -> None:
61
- """
62
- Propagate mastery scores upward through the DAG.
63
-
64
- If a student has high mastery on a topic, infer that their
65
- prerequisites are also likely mastered.
66
- """
67
- changed = True
68
- while changed:
69
- changed = False
70
- for node in self.curriculum.nodes:
71
- node_mastery = self.get(node)
72
- if node_mastery < 0.40:
73
- continue
74
- # For each prerequisite of this node
75
- for prereq in self.curriculum.prerequisites(node):
76
- prereq_mastery = self.get(prereq)
77
- # Infer prerequisite mastery as at least 85% of descendant mastery
78
- inferred = min(node_mastery * 0.85, 0.95)
79
- if inferred > prereq_mastery:
80
- self.update(prereq, inferred)
81
- changed = True
82
-
83
- # ------------------------------------------------------------------ #
84
- # Queries #
85
- # ------------------------------------------------------------------ #
86
-
87
- def get(self, topic_id: str) -> float:
88
- return self._mastery.get(topic_id, 0.0)
89
-
90
- def is_mastered(self, topic_id: str) -> bool:
91
- return self.get(topic_id) >= self.threshold
92
-
93
- def is_partial(self, topic_id: str) -> bool:
94
- """Between soft_threshold and threshold — partially mastered."""
95
- v = self.get(topic_id)
96
- return self.soft_threshold <= v < self.threshold
97
-
98
- def summary(self) -> dict:
99
- mastered = [t for t in self._mastery if self.is_mastered(t)]
100
- partial = [t for t in self._mastery if self.is_partial(t)]
101
- return {
102
- "total_topics": len(self._mastery),
103
- "mastered": len(mastered),
104
- "partial": len(partial),
105
- "not_started": len(self._mastery) - len(mastered) - len(partial),
106
- "mastery_rate": round(len(mastered) / max(len(self._mastery), 1), 3),
107
- "mastered_topics": mastered,
108
- }
109
-
110
- def to_dict(self) -> dict[str, float]:
111
- return dict(self._mastery)
112
-
113
- def __repr__(self) -> str:
114
- s = self.summary()
115
- return (
116
- f"MasteryVector(mastered={s['mastered']}/{s['total_topics']}, "
117
- f"rate={s['mastery_rate']:.1%})"
118
- )
119
-
120
-
121
- @dataclass
122
- class ConstraintResult:
123
- topic_id: str
124
- topic_label: str
125
- status: Status
126
- mastery: float
127
- reasoning: str
128
- score: float = 0.0
129
- prerequisites: list[str] = field(default_factory=list)
130
- unmet_prerequisites: list[str] = field(default_factory=list)
131
-
132
-
133
- class DAGConstraintLayer:
134
- """
135
- Validates topic recommendations against curriculum prerequisite structure.
136
-
137
- Uses three-tier soft constraint logic:
138
- - mastery >= threshold on ALL prerequisites → approved
139
- - mastery >= soft_threshold on ALL prereqs → challenging
140
- - any prerequisite below soft_threshold → vetoed
141
- """
142
-
143
- def __init__(self, curriculum: CurriculumGraph) -> None:
144
- self.curriculum = curriculum
145
-
146
- def validate(
147
- self,
148
- topic_id: str,
149
- mastery: MasteryVector,
150
- ) -> ConstraintResult:
151
- label = self.curriculum.label(topic_id)
152
- prereqs = self.curriculum.prerequisites(topic_id)
153
- topic_mastery = mastery.get(topic_id)
154
-
155
- if not prereqs:
156
- return ConstraintResult(
157
- topic_id=topic_id,
158
- topic_label=label,
159
- status="approved",
160
- mastery=topic_mastery,
161
- reasoning="No prerequisites required.",
162
- prerequisites=[],
163
- unmet_prerequisites=[],
164
- )
165
-
166
- prereq_labels = [self.curriculum.label(p) for p in prereqs]
167
- unmet_hard = [p for p in prereqs if not mastery.is_mastered(p)]
168
- unmet_soft = [p for p in prereqs if mastery.get(p) < mastery.soft_threshold]
169
-
170
- if not unmet_soft:
171
- # All prereqs above soft threshold — at least challenging
172
- if not unmet_hard:
173
- status = "approved"
174
- reasoning = f"All {len(prereqs)} prerequisite(s) met."
175
- else:
176
- status = "challenging"
177
- unmet_labels = [self.curriculum.label(p) for p in unmet_hard]
178
- reasoning = (
179
- f"Prerequisite(s) partially met. "
180
- f"Strengthen: {', '.join(unmet_labels)}."
181
- )
182
- else:
183
- status = "vetoed"
184
- unmet_labels = [self.curriculum.label(p) for p in unmet_soft]
185
- reasoning = (
186
- f"Blocked. Master first: {', '.join(unmet_labels)}."
187
- )
188
-
189
- return ConstraintResult(
190
- topic_id=topic_id,
191
- topic_label=label,
192
- status=status,
193
- mastery=topic_mastery,
194
- reasoning=reasoning,
195
- prerequisites=prereq_labels,
196
- unmet_prerequisites=[self.curriculum.label(p) for p in (unmet_hard if status == "challenging" else unmet_soft)],
197
- )
198
-
199
- def validate_all(self, mastery: MasteryVector) -> list[ConstraintResult]:
200
- """Validate every topic in the curriculum."""
201
- return [self.validate(node, mastery) for node in self.curriculum.nodes]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
plrs/curriculum/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from plrs.curriculum.loader import load_dag, CurriculumGraph
2
-
3
- __all__ = ["load_dag", "CurriculumGraph"]
 
 
 
 
plrs/curriculum/loader.py DELETED
@@ -1,144 +0,0 @@
1
- """
2
- plrs.curriculum.loader
3
- ======================
4
- Load and validate curriculum knowledge graphs from JSON.
5
-
6
- The JSON schema is deliberately simple so educators can author their own:
7
-
8
- {
9
- "domain": "Mathematics",
10
- "nodes": [
11
- {"id": "algebra_basics", "label": "Algebra Basics", "level": "JSS3"},
12
- {"id": "quadratic_equations", "label": "Quadratic Equations", "level": "SS1"}
13
- ],
14
- "edges": [
15
- {"from": "algebra_basics", "to": "quadratic_equations"}
16
- ]
17
- }
18
- """
19
-
20
- from __future__ import annotations
21
-
22
- import json
23
- from dataclasses import dataclass, field
24
- from pathlib import Path
25
- from typing import Any
26
-
27
- import networkx as nx
28
-
29
-
30
- @dataclass
31
- class CurriculumGraph:
32
- """Thin wrapper around a NetworkX DiGraph with domain metadata."""
33
-
34
- domain: str
35
- graph: nx.DiGraph
36
- meta: dict[str, Any] = field(default_factory=dict)
37
-
38
- # ------------------------------------------------------------------ #
39
- # Properties #
40
- # ------------------------------------------------------------------ #
41
-
42
- @property
43
- def nodes(self) -> list[str]:
44
- return list(self.graph.nodes)
45
-
46
- @property
47
- def num_nodes(self) -> int:
48
- return self.graph.number_of_nodes()
49
-
50
- @property
51
- def num_edges(self) -> int:
52
- return self.graph.number_of_edges()
53
-
54
- def label(self, node_id: str) -> str:
55
- return self.graph.nodes[node_id].get("label", node_id)
56
-
57
- def level(self, node_id: str) -> str:
58
- return self.graph.nodes[node_id].get("level", "")
59
-
60
- def prerequisites(self, node_id: str) -> list[str]:
61
- return list(self.graph.predecessors(node_id))
62
-
63
- def successors(self, node_id: str) -> list[str]:
64
- return list(self.graph.successors(node_id))
65
-
66
- def descendants(self, node_id: str) -> list[str]:
67
- return list(nx.descendants(self.graph, node_id))
68
-
69
- def validate(self) -> list[str]:
70
- """Return a list of validation warnings (empty = all good)."""
71
- warnings: list[str] = []
72
- if not nx.is_directed_acyclic_graph(self.graph):
73
- warnings.append("Graph contains cycles — prerequisite checking will be unreliable.")
74
- isolates = list(nx.isolates(self.graph))
75
- if isolates:
76
- warnings.append(f"{len(isolates)} isolated nodes (no edges): {isolates[:5]}")
77
- return warnings
78
-
79
- def __repr__(self) -> str:
80
- return (
81
- f"CurriculumGraph(domain={self.domain!r}, "
82
- f"nodes={self.num_nodes}, edges={self.num_edges})"
83
- )
84
-
85
-
86
- def load_dag(path: str | Path) -> CurriculumGraph:
87
- """
88
- Load a curriculum DAG from a JSON file.
89
-
90
- Parameters
91
- ----------
92
- path : str or Path
93
- Path to the curriculum JSON file.
94
-
95
- Returns
96
- -------
97
- CurriculumGraph
98
-
99
- Raises
100
- ------
101
- FileNotFoundError
102
- If the file does not exist.
103
- ValueError
104
- If the JSON schema is invalid.
105
- """
106
- path = Path(path)
107
- if not path.exists():
108
- raise FileNotFoundError(f"Curriculum file not found: {path}")
109
-
110
- with open(path) as f:
111
- data = json.load(f)
112
-
113
- _validate_schema(data, path)
114
-
115
- domain = data.get("domain", path.stem)
116
- meta = {k: v for k, v in data.items() if k not in ("nodes", "edges", "domain")}
117
-
118
- G = nx.DiGraph()
119
- for node in data["nodes"]:
120
- G.add_node(node["id"], **{k: v for k, v in node.items() if k != "id"})
121
- for edge in data["edges"]:
122
- G.add_edge(edge["from"], edge["to"])
123
-
124
- curriculum = CurriculumGraph(domain=domain, graph=G, meta=meta)
125
-
126
- warnings = curriculum.validate()
127
- for w in warnings:
128
- import warnings as _w
129
- _w.warn(f"[PLRS] {w}", stacklevel=2)
130
-
131
- return curriculum
132
-
133
-
134
- def _validate_schema(data: dict, path: Path) -> None:
135
- if "nodes" not in data:
136
- raise ValueError(f"{path}: Missing required key 'nodes'")
137
- if "edges" not in data:
138
- raise ValueError(f"{path}: Missing required key 'edges'")
139
- for i, node in enumerate(data["nodes"]):
140
- if "id" not in node:
141
- raise ValueError(f"{path}: Node at index {i} missing required key 'id'")
142
- for i, edge in enumerate(data["edges"]):
143
- if "from" not in edge or "to" not in edge:
144
- raise ValueError(f"{path}: Edge at index {i} missing 'from' or 'to'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
plrs/model/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from plrs.model.sakt import SAKTModel
2
- from plrs.model.sakt_decay import SAKTWithDecay
3
- from plrs.model.trainer import SAKTTrainer, TrainerConfig, load_sequences_from_csv
4
-
5
- __all__ = ["SAKTModel", "SAKTWithDecay", "SAKTTrainer", "TrainerConfig", "load_sequences_from_csv"]
 
 
 
 
 
 
plrs/model/evaluator.py DELETED
@@ -1,374 +0,0 @@
1
- """
2
- plrs.model.evaluator
3
- ====================
4
- Evaluation suite for PLRS.
5
-
6
- Metrics:
7
- - Knowledge Tracing: AUC-ROC, Accuracy, Binary Cross-Entropy
8
- - Recommendation: Prerequisite Violation Rate, Coverage, Diversity
9
- - Baselines: Random, Popularity, BKT (Bayesian Knowledge Tracing)
10
-
11
- Usage:
12
- from plrs.model.evaluator import PLRSEvaluator
13
- evaluator = PLRSEvaluator(pipeline, curriculum)
14
- report = evaluator.evaluate(test_sequences, skill_to_topic)
15
- """
16
-
17
- from __future__ import annotations
18
-
19
- import time
20
- from dataclasses import dataclass, field
21
- from typing import Any
22
-
23
- import numpy as np
24
-
25
- try:
26
- from sklearn.metrics import roc_auc_score, accuracy_score, log_loss
27
- HAS_SKLEARN = True
28
- except ImportError:
29
- HAS_SKLEARN = False
30
-
31
-
32
- # ── Baseline models ───────────────────────────────────────────────────────────
33
-
34
- class RandomBaseline:
35
- """Predicts 0.5 for every interaction."""
36
- def predict(self, skill_seq, correct_seq):
37
- return {i: 0.5 for i in range(len(skill_seq))}
38
-
39
- def recommend(self, curriculum, n=5):
40
- import random
41
- return random.sample(curriculum.nodes, min(n, len(curriculum.nodes)))
42
-
43
-
44
- class PopularityBaseline:
45
- """Recommends the most-seen skills; predicts by global correctness rate."""
46
-
47
- def __init__(self):
48
- self.skill_correct: dict[int, list[float]] = {}
49
- self.topic_count: dict[str, int] = {}
50
-
51
- def fit(self, sequences, skill_to_topic=None):
52
- for skill_seq, correct_seq in sequences:
53
- for skill, correct in zip(skill_seq, correct_seq):
54
- self.skill_correct.setdefault(skill, []).append(float(correct))
55
- if skill_to_topic:
56
- topic = skill_to_topic.get(skill)
57
- if topic:
58
- self.topic_count[topic] = self.topic_count.get(topic, 0) + 1
59
-
60
- def predict_prob(self, skill_id: int) -> float:
61
- history = self.skill_correct.get(skill_id, [])
62
- return float(np.mean(history)) if history else 0.5
63
-
64
- def recommend(self, curriculum, n=5):
65
- if not self.topic_count:
66
- return curriculum.nodes[:n]
67
- sorted_topics = sorted(self.topic_count, key=self.topic_count.get, reverse=True)
68
- return [t for t in sorted_topics if t in curriculum.nodes][:n]
69
-
70
-
71
- class BKTBaseline:
72
- """
73
- Bayesian Knowledge Tracing (per-skill).
74
- Simple 4-parameter model: p_init, p_transit, p_slip, p_guess.
75
- """
76
-
77
- def __init__(self, p_init=0.3, p_transit=0.1, p_slip=0.1, p_guess=0.2):
78
- self.p_init = p_init
79
- self.p_transit = p_transit
80
- self.p_slip = p_slip
81
- self.p_guess = p_guess
82
- self._mastery: dict[int, float] = {}
83
-
84
- def _update(self, skill: int, correct: int) -> float:
85
- p = self._mastery.get(skill, self.p_init)
86
- # Bayes update
87
- if correct:
88
- num = p * (1 - self.p_slip)
89
- den = num + (1 - p) * self.p_guess
90
- else:
91
- num = p * self.p_slip
92
- den = num + (1 - p) * (1 - self.p_guess)
93
- p_post = num / max(den, 1e-9)
94
- # Learning
95
- p_post = p_post + (1 - p_post) * self.p_transit
96
- self._mastery[skill] = p_post
97
- return p_post
98
-
99
- def predict_sequence(self, skill_seq: list[int], correct_seq: list[int]) -> list[float]:
100
- self._mastery = {}
101
- probs = []
102
- for skill, correct in zip(skill_seq[:-1], correct_seq[:-1]):
103
- self._update(skill, correct)
104
- next_skill = skill_seq[len(probs) + 1]
105
- probs.append(self._mastery.get(next_skill, self.p_init))
106
- return probs
107
-
108
- def get_mastery(self) -> dict[int, float]:
109
- return dict(self._mastery)
110
-
111
-
112
- # ── Result dataclasses ────────────────────────────────────────────────────────
113
-
114
- @dataclass
115
- class KTMetrics:
116
- """Knowledge tracing evaluation metrics."""
117
- model_name: str
118
- auc: float
119
- accuracy: float
120
- log_loss: float
121
- n_samples: int
122
- elapsed_s: float
123
-
124
-
125
- @dataclass
126
- class RecommendMetrics:
127
- """Recommendation quality metrics."""
128
- violation_rate: float # fraction of recommendations that violate prerequisites
129
- coverage: float # fraction of curriculum covered by recommendations
130
- avg_downstream: float # avg topics unlocked by recommendations
131
- mastery_rate: float # avg student mastery in test set
132
-
133
-
134
- @dataclass
135
- class EvaluationReport:
136
- """Full evaluation report."""
137
- kt_metrics: list[KTMetrics]
138
- rec_metrics: RecommendMetrics | None
139
- config: dict[str, Any]
140
- timestamp: str
141
-
142
- def print(self) -> None:
143
- print("\n" + "=" * 62)
144
- print(" PLRS EVALUATION REPORT")
145
- print("=" * 62)
146
-
147
- print(f"\n{'Model':<22} {'AUC':>8} {'Accuracy':>10} {'Log Loss':>10} {'Samples':>8}")
148
- print("-" * 62)
149
- for m in self.kt_metrics:
150
- print(f"{m.model_name:<22} {m.auc:>8.4f} {m.accuracy:>10.4f} {m.log_loss:>10.4f} {m.n_samples:>8,}")
151
-
152
- if self.rec_metrics:
153
- r = self.rec_metrics
154
- print(f"\n{'Recommendation Metrics':}")
155
- print(f" Prerequisite violation rate : {r.violation_rate:.1%}")
156
- print(f" Curriculum coverage : {r.coverage:.1%}")
157
- print(f" Avg downstream unlocked : {r.avg_downstream:.1f}")
158
- print(f" Avg student mastery rate : {r.mastery_rate:.1%}")
159
-
160
- print("=" * 62 + "\n")
161
-
162
- def to_dict(self) -> dict:
163
- return {
164
- "kt_metrics": [
165
- {
166
- "model": m.model_name,
167
- "auc": round(m.auc, 6),
168
- "accuracy": round(m.accuracy, 6),
169
- "log_loss": round(m.log_loss, 6),
170
- "n_samples": m.n_samples,
171
- "elapsed_s": round(m.elapsed_s, 3),
172
- }
173
- for m in self.kt_metrics
174
- ],
175
- "rec_metrics": {
176
- "violation_rate": round(self.rec_metrics.violation_rate, 6),
177
- "coverage": round(self.rec_metrics.coverage, 6),
178
- "avg_downstream": round(self.rec_metrics.avg_downstream, 3),
179
- "mastery_rate": round(self.rec_metrics.mastery_rate, 6),
180
- } if self.rec_metrics else None,
181
- "config": self.config,
182
- "timestamp": self.timestamp,
183
- }
184
-
185
-
186
- # ── Main evaluator ────────────────────────────────────────────────────────────
187
-
188
- class PLRSEvaluator:
189
- """
190
- Evaluate PLRS against baselines on held-out student sequences.
191
-
192
- Parameters
193
- ----------
194
- pipeline : PLRSPipeline
195
- A loaded pipeline (with or without SAKT model).
196
- """
197
-
198
- def __init__(self, pipeline) -> None:
199
- self.pipeline = pipeline
200
- self.curriculum = pipeline.curriculum
201
-
202
- def evaluate(
203
- self,
204
- test_sequences: list[tuple[list[int], list[int]]],
205
- skill_to_topic: dict[int, str] | None = None,
206
- train_sequences: list[tuple[list[int], list[int]]] | None = None,
207
- include_baselines: bool = True,
208
- ) -> EvaluationReport:
209
- """
210
- Run full evaluation.
211
-
212
- Parameters
213
- ----------
214
- test_sequences : list of (skill_seq, correct_seq)
215
- skill_to_topic : dict mapping skill_id → curriculum topic_id
216
- train_sequences : used to fit popularity baseline
217
- include_baselines : whether to evaluate BKT and popularity baselines
218
-
219
- Returns
220
- -------
221
- EvaluationReport
222
- """
223
- import datetime
224
-
225
- kt_metrics: list[KTMetrics] = []
226
-
227
- # ── SAKT evaluation ──────────────────────────────────────────
228
- if self.pipeline._model is not None:
229
- kt_metrics.append(
230
- self._eval_sakt(test_sequences)
231
- )
232
-
233
- # ── Baselines ────────────────────────────────────────────────
234
- if include_baselines:
235
- kt_metrics.append(self._eval_random(test_sequences))
236
- kt_metrics.append(self._eval_bkt(test_sequences))
237
-
238
- pop = PopularityBaseline()
239
- pop.fit(train_sequences or test_sequences, skill_to_topic)
240
- kt_metrics.append(self._eval_popularity(test_sequences, pop))
241
-
242
- # ── Recommendation metrics ───────────────────────────────────
243
- rec_metrics = self._eval_recommendations(test_sequences, skill_to_topic)
244
-
245
- return EvaluationReport(
246
- kt_metrics=kt_metrics,
247
- rec_metrics=rec_metrics,
248
- config={
249
- "threshold": self.pipeline.threshold,
250
- "soft_threshold": self.pipeline.soft_threshold,
251
- "top_n": self.pipeline.top_n,
252
- "n_test_students": len(test_sequences),
253
- },
254
- timestamp=datetime.datetime.now().isoformat(),
255
- )
256
-
257
- # ── KT evaluation helpers ─────────────────────────────────────────────────
258
-
259
- def _eval_sakt(self, sequences) -> KTMetrics:
260
- t0 = time.time()
261
- all_probs, all_labels = [], []
262
-
263
- for skill_seq, correct_seq in sequences:
264
- if len(skill_seq) < 2:
265
- continue
266
- probs = self.pipeline._model.predict_mastery(skill_seq, correct_seq)
267
- for skill_id, prob in probs.items():
268
- if skill_id < len(correct_seq):
269
- all_probs.append(prob)
270
- all_labels.append(float(correct_seq[skill_id]))
271
-
272
- return self._compute_kt_metrics("SAKT", all_probs, all_labels, time.time() - t0)
273
-
274
- def _eval_random(self, sequences) -> KTMetrics:
275
- t0 = time.time()
276
- all_probs, all_labels = [], []
277
- for skill_seq, correct_seq in sequences:
278
- for correct in correct_seq[1:]:
279
- all_probs.append(0.5)
280
- all_labels.append(float(correct))
281
- return self._compute_kt_metrics("Random (baseline)", all_probs, all_labels, time.time() - t0)
282
-
283
- def _eval_bkt(self, sequences) -> KTMetrics:
284
- t0 = time.time()
285
- all_probs, all_labels = [], []
286
- bkt = BKTBaseline()
287
- for skill_seq, correct_seq in sequences:
288
- if len(skill_seq) < 2:
289
- continue
290
- probs = bkt.predict_sequence(skill_seq, correct_seq)
291
- labels = [float(c) for c in correct_seq[1:len(probs) + 1]]
292
- all_probs.extend(probs)
293
- all_labels.extend(labels)
294
- return self._compute_kt_metrics("BKT (baseline)", all_probs, all_labels, time.time() - t0)
295
-
296
- def _eval_popularity(self, sequences, pop: PopularityBaseline) -> KTMetrics:
297
- t0 = time.time()
298
- all_probs, all_labels = [], []
299
- for skill_seq, correct_seq in sequences:
300
- for skill, correct in zip(skill_seq[1:], correct_seq[1:]):
301
- all_probs.append(pop.predict_prob(skill))
302
- all_labels.append(float(correct))
303
- return self._compute_kt_metrics("Popularity (baseline)", all_probs, all_labels, time.time() - t0)
304
-
305
- @staticmethod
306
- def _compute_kt_metrics(name, probs, labels, elapsed) -> KTMetrics:
307
- probs_arr = np.nan_to_num(np.array(probs), nan=0.5)
308
- labels_arr = np.nan_to_num(np.array(labels), nan=0.0)
309
- n = len(probs_arr)
310
-
311
- if HAS_SKLEARN and n > 0 and len(np.unique(labels_arr)) > 1:
312
- auc = float(roc_auc_score(labels_arr, probs_arr))
313
- acc = float(accuracy_score(labels_arr, (probs_arr >= 0.5).astype(int)))
314
- loss = float(log_loss(labels_arr, np.clip(probs_arr, 1e-7, 1 - 1e-7)))
315
- else:
316
- auc = 0.5
317
- acc = float(((probs_arr >= 0.5) == labels_arr).mean()) if n > 0 else 0.0
318
- loss = float(-np.mean(
319
- labels_arr * np.log(probs_arr + 1e-7) +
320
- (1 - labels_arr) * np.log(1 - probs_arr + 1e-7)
321
- )) if n > 0 else 0.0
322
-
323
- return KTMetrics(
324
- model_name=name, auc=auc, accuracy=acc,
325
- log_loss=loss, n_samples=n, elapsed_s=elapsed,
326
- )
327
-
328
- # ── Recommendation evaluation ─────────────────────────────────────────────
329
-
330
- def _eval_recommendations(
331
- self,
332
- sequences,
333
- skill_to_topic,
334
- ) -> RecommendMetrics:
335
- violation_rates, coverages, downstreams, mastery_rates = [], [], [], []
336
-
337
- for skill_seq, correct_seq in sequences:
338
- # Build mastery from sequence
339
- if skill_to_topic:
340
- topic_scores: dict[str, float] = {}
341
- for skill, correct in zip(skill_seq, correct_seq):
342
- topic = skill_to_topic.get(skill)
343
- if topic and topic in self.curriculum.nodes:
344
- topic_scores[topic] = max(topic_scores.get(topic, 0.0), float(correct))
345
- mastery_scores = {n: 0.0 for n in self.curriculum.nodes}
346
- mastery_scores.update(topic_scores)
347
- else:
348
- mastery_scores = {n: 0.0 for n in self.curriculum.nodes}
349
-
350
- results = self.pipeline.recommend_from_mastery(mastery_scores)
351
- stats = results["stats"]
352
- summary = results["mastery_summary"]
353
-
354
- violation_rates.append(stats["prerequisite_violation_rate"])
355
- mastery_rates.append(summary["mastery_rate"])
356
-
357
- # Coverage: fraction of curriculum represented in approved+challenging
358
- rec_topics = set(
359
- r["topic_id"] for r in results["approved"] + results["challenging"]
360
- )
361
- coverages.append(len(rec_topics) / max(self.curriculum.num_nodes, 1))
362
-
363
- # Avg downstream unlock value
364
- if results["approved"]:
365
- downstreams.append(
366
- np.mean([r["downstream_count"] for r in results["approved"]])
367
- )
368
-
369
- return RecommendMetrics(
370
- violation_rate=float(np.mean(violation_rates)) if violation_rates else 0.0,
371
- coverage=float(np.mean(coverages)) if coverages else 0.0,
372
- avg_downstream=float(np.mean(downstreams)) if downstreams else 0.0,
373
- mastery_rate=float(np.mean(mastery_rates)) if mastery_rates else 0.0,
374
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
plrs/model/model_loader.py DELETED
@@ -1,116 +0,0 @@
1
- """
2
- HF Space model loader — updated for SAKTWithDecay (v0.2.0 weights).
3
-
4
- Drop this file into your HF Space as `model_loader.py` and call
5
- `load_model_from_hub()` in app.py instead of the old loading logic.
6
-
7
- The v0.2.0 weights (sakt_decay_best.pt) are saved with our new format:
8
- {
9
- "state_dict": {...},
10
- "model_type": "SAKTWithDecay",
11
- "config": {"num_skills": 20, "embed_dim": 64, ...}
12
- }
13
-
14
- Falls back gracefully to mastery-dict mode if weights can't be loaded.
15
- """
16
-
17
- from __future__ import annotations
18
-
19
- import json
20
- from pathlib import Path
21
-
22
- import torch
23
-
24
- HF_REPO = "Clementio/PLRS"
25
-
26
-
27
- def load_model_from_hub(device: str = "cpu"):
28
- """
29
- Load SAKT model weights from HuggingFace Hub.
30
-
31
- Tries files in priority order:
32
- 1. sakt_decay_best.pt (v0.2.0 — decay attention)
33
- 2. sakt_vanilla_best.pt (v0.2.0 — vanilla transformer)
34
- 3. sakt_model.pt (v0.1.0 — synthetic baseline)
35
-
36
- Returns (model, model_type_str) or (None, "unavailable").
37
- """
38
- try:
39
- from huggingface_hub import hf_hub_download
40
- except ImportError:
41
- return None, "huggingface_hub not installed"
42
-
43
- for filename, model_type in [
44
- ("models/sakt_decay_best.pt", "SAKTWithDecay"),
45
- ("models/sakt_vanilla_best.pt", "SAKTModel"),
46
- ("models/sakt_model.pt", "SAKTModel"),
47
- ]:
48
- try:
49
- path = hf_hub_download(repo_id=HF_REPO, filename=filename)
50
- model = _load_weights(path, model_type, device)
51
- if model is not None:
52
- return model, model_type
53
- except Exception:
54
- continue
55
-
56
- return None, "unavailable"
57
-
58
-
59
- def _load_weights(path: str, preferred_type: str, device: str):
60
- """Load model weights from a .pt file, handling both old and new formats."""
61
- try:
62
- payload = torch.load(path, map_location=device, weights_only=False)
63
- except Exception:
64
- return None
65
-
66
- # ── New format (v0.2.0): {"state_dict": ..., "model_type": ..., "config": ...}
67
- if isinstance(payload, dict) and "state_dict" in payload:
68
- cfg = payload.get("config", {})
69
- model_type = payload.get("model_type", preferred_type)
70
-
71
- if model_type == "SAKTWithDecay":
72
- from plrs.model.sakt_decay import SAKTWithDecay
73
- model = SAKTWithDecay(
74
- num_skills=cfg.get("num_skills", 5737),
75
- embed_dim=cfg.get("embed_dim", 64),
76
- num_heads=cfg.get("num_heads", 8),
77
- dropout=cfg.get("dropout", 0.2),
78
- max_seq_len=cfg.get("max_seq_len", 100),
79
- decay_init=cfg.get("decay_init", 1.0),
80
- )
81
- else:
82
- from plrs.model.sakt import SAKTModel
83
- model = SAKTModel(
84
- num_skills=cfg.get("num_skills", 5737),
85
- embed_dim=cfg.get("embed_dim", 64),
86
- num_heads=cfg.get("num_heads", 8),
87
- dropout=cfg.get("dropout", 0.2),
88
- max_seq_len=cfg.get("max_seq_len", 100),
89
- )
90
-
91
- try:
92
- model.load_state_dict(payload["state_dict"], strict=False)
93
- model.eval()
94
- model.to(device)
95
- return model
96
- except Exception:
97
- return None
98
-
99
- # ── Old format (v0.1.0 FYP): raw state_dict + separate config.json
100
- try:
101
- config_path = Path(path).parent / "config.json"
102
- if config_path.exists():
103
- config = json.loads(config_path.read_text())
104
- else:
105
- config = {"num_skills": 5736, "embed_dim": 64}
106
-
107
- from plrs.model.sakt import SAKTModel
108
- model = SAKTModel(
109
- num_skills=config.get("num_skills", 5736),
110
- embed_dim=config.get("embed_dim", 64),
111
- )
112
- model.load_state_dict(payload, strict=False)
113
- model.eval()
114
- return model
115
- except Exception:
116
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
plrs/model/sakt.py DELETED
@@ -1,219 +0,0 @@
1
- """
2
- plrs.model.sakt
3
- ===============
4
- Self-Attentive Knowledge Tracing (SAKT) model.
5
-
6
- Architecture: transformer-style attention over student interaction sequences.
7
- Each interaction is encoded as (skill_id + correctness * n_skills).
8
-
9
- Reference: Pandey & Karypis, 2019 — "A Self-Attentive model for Knowledge Tracing"
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- from pathlib import Path
15
- from typing import Any
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
-
21
- class SAKTModel(nn.Module):
22
- """
23
- SAKT: Self-Attentive Knowledge Tracing.
24
-
25
- Parameters
26
- ----------
27
- num_skills : int
28
- Total number of unique skills in the dataset.
29
- embed_dim : int
30
- Embedding dimension for interactions and positions.
31
- num_heads : int
32
- Number of attention heads.
33
- dropout : float
34
- Dropout rate.
35
- max_seq_len : int
36
- Maximum interaction sequence length.
37
- """
38
-
39
- def __init__(
40
- self,
41
- num_skills: int,
42
- embed_dim: int = 64,
43
- num_heads: int = 8,
44
- dropout: float = 0.2,
45
- max_seq_len: int = 100,
46
- ) -> None:
47
- super().__init__()
48
- self.num_skills = num_skills
49
- self.embed_dim = embed_dim
50
- self.max_seq_len = max_seq_len
51
-
52
- # Interaction embedding: (skill, correct) → dense vector
53
- self.interaction_embed = nn.Embedding(2 * num_skills + 2, embed_dim, padding_idx=0) # +2: shift+1 means max index = 2*n+1
54
- # Positional embedding
55
- self.pos_embed = nn.Embedding(max_seq_len, embed_dim)
56
-
57
- # Self-attention layer
58
- self.self_attn = nn.MultiheadAttention(
59
- embed_dim=embed_dim,
60
- num_heads=num_heads,
61
- dropout=dropout,
62
- batch_first=True,
63
- )
64
-
65
- self.layer_norm1 = nn.LayerNorm(embed_dim)
66
- self.layer_norm2 = nn.LayerNorm(embed_dim)
67
-
68
- self.ffn = nn.Sequential(
69
- nn.Linear(embed_dim, embed_dim * 2),
70
- nn.ReLU(),
71
- nn.Dropout(dropout),
72
- nn.Linear(embed_dim * 2, embed_dim),
73
- )
74
-
75
- # Skill query embedding for target prediction
76
- self.skill_embed = nn.Embedding(num_skills + 1, embed_dim, padding_idx=0)
77
-
78
- self.output_layer = nn.Linear(embed_dim * 2, 1)
79
- self.dropout = nn.Dropout(dropout)
80
-
81
- def forward(
82
- self,
83
- interactions: torch.Tensor, # (batch, seq_len)
84
- target_skills: torch.Tensor, # (batch, seq_len)
85
- mask: torch.Tensor, # (batch, seq_len) bool — True = real token
86
- ) -> torch.Tensor:
87
- """
88
- Forward pass.
89
-
90
- Returns
91
- -------
92
- torch.Tensor of shape (batch, seq_len) — logits per position.
93
- """
94
- batch_size, seq_len = interactions.shape
95
- positions = torch.arange(seq_len, device=interactions.device).unsqueeze(0)
96
-
97
- x = self.interaction_embed(interactions) + self.pos_embed(positions)
98
- x = self.dropout(x)
99
-
100
- # Causal mask — bool upper-triangular (MHA handles conversion internally)
101
- causal_mask = torch.triu(
102
- torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool),
103
- diagonal=1,
104
- )
105
-
106
- # Key padding mask: True = ignore (PyTorch MHA convention)
107
- key_padding_mask = ~mask # (batch, seq_len) bool
108
-
109
- x_attn, _ = self.self_attn(
110
- query=x,
111
- key=x,
112
- value=x,
113
- attn_mask=causal_mask,
114
- key_padding_mask=key_padding_mask,
115
- )
116
- # Replace any NaN in attention output (from fully-masked rows) with 0
117
- x_attn = torch.nan_to_num(x_attn, nan=0.0)
118
- x = self.layer_norm1(x + x_attn)
119
- x = self.layer_norm2(x + self.ffn(x))
120
-
121
- # Concatenate with target skill embedding for final prediction
122
- skill_x = self.skill_embed(target_skills)
123
- out = self.output_layer(torch.cat([x, skill_x], dim=-1)).squeeze(-1)
124
-
125
- return out # (batch, seq_len) logits
126
-
127
- # ------------------------------------------------------------------ #
128
- # Inference helpers #
129
- # ------------------------------------------------------------------ #
130
-
131
- @torch.no_grad()
132
- def predict_mastery(
133
- self,
134
- skill_seq: list[int],
135
- correct_seq: list[int],
136
- device: torch.device | str = "cpu",
137
- ) -> dict[int, float]:
138
- """
139
- Run inference on a student's interaction history.
140
-
141
- Parameters
142
- ----------
143
- skill_seq : list[int]
144
- Sequence of skill IDs the student interacted with.
145
- correct_seq : list[int]
146
- Corresponding correctness (1 = correct, 0 = incorrect).
147
- device : str or torch.device
148
-
149
- Returns
150
- -------
151
- dict[int, float]
152
- Mapping from skill_id → predicted mastery probability.
153
- """
154
- if len(skill_seq) < 2:
155
- return {}
156
-
157
- if len(skill_seq) > self.max_seq_len:
158
- skill_seq = skill_seq[-self.max_seq_len:]
159
- correct_seq = correct_seq[-self.max_seq_len:]
160
-
161
- interactions = [s + c * self.num_skills + 1 for s, c in zip(skill_seq[:-1], correct_seq[:-1])] # +1: reserve 0 for padding
162
- target_skills = skill_seq[1:]
163
-
164
- seq_len = len(interactions)
165
- pad_len = self.max_seq_len - seq_len
166
-
167
- interactions_padded = [0] * pad_len + interactions
168
- target_padded = [0] * pad_len + target_skills
169
- mask = [False] * pad_len + [True] * seq_len
170
-
171
- interactions_t = torch.LongTensor([interactions_padded]).to(device)
172
- target_t = torch.LongTensor([target_padded]).to(device)
173
- mask_t = torch.BoolTensor([mask]).to(device)
174
-
175
- self.eval()
176
- self.to(device)
177
-
178
- logits = self(interactions_t, target_t, mask_t)
179
- probs = torch.sigmoid(logits).squeeze(0)
180
-
181
- real_probs = probs[torch.BoolTensor(mask)].cpu().numpy()
182
- mastery = {
183
- int(skill_id): float(prob)
184
- for skill_id, prob in zip(target_skills, real_probs)
185
- }
186
- return mastery
187
-
188
- # ------------------------------------------------------------------ #
189
- # Serialisation #
190
- # ------------------------------------------------------------------ #
191
-
192
- def save(self, path: str | Path, config: dict[str, Any] | None = None) -> None:
193
- """Save model weights and config to a .pt file."""
194
- payload = {
195
- "state_dict": self.state_dict(),
196
- "config": config or {
197
- "num_skills": self.num_skills,
198
- "embed_dim": self.embed_dim,
199
- "max_seq_len": self.max_seq_len,
200
- },
201
- }
202
- torch.save(payload, path)
203
-
204
- @classmethod
205
- def load(cls, path: str | Path, device: str | torch.device = "cpu") -> "SAKTModel":
206
- """Load a saved SAKT model."""
207
- payload = torch.load(path, map_location=device, weights_only=False)
208
- config = payload["config"]
209
- model = cls(
210
- num_skills=config["num_skills"],
211
- embed_dim=config.get("embed_dim", 64),
212
- num_heads=config.get("num_heads", 8),
213
- dropout=config.get("dropout", 0.2),
214
- max_seq_len=config.get("max_seq_len", 100),
215
- )
216
- model.load_state_dict(payload["state_dict"])
217
- model.to(device)
218
- model.eval()
219
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
plrs/model/sakt_decay.py DELETED
@@ -1,253 +0,0 @@
1
- """
2
- plrs.model.sakt_decay
3
- =====================
4
- SAKT with Ebbinghaus Forgetting Curve Decay.
5
-
6
- Extends the base SAKT model by applying exponential temporal decay to
7
- attention weights, reflecting that older interactions contribute less to
8
- current mastery estimates.
9
-
10
- The decay function follows the Ebbinghaus retention curve:
11
- R(t) = exp(-t / decay_rate)
12
-
13
- Where t is the time gap between interaction j and the current position i,
14
- measured in interaction steps (or elapsed time if timestamps are available).
15
-
16
- This typically improves val AUC by 0.01–0.02 over vanilla SAKT.
17
- """
18
-
19
- from __future__ import annotations
20
-
21
- import math
22
- from pathlib import Path
23
- from typing import Any
24
-
25
- import torch
26
- import torch.nn as nn
27
- import torch.nn.functional as F
28
-
29
-
30
- class DecayAttention(nn.Module):
31
- """
32
- Multi-head attention with Ebbinghaus forgetting curve decay.
33
-
34
- Applies position-based temporal decay to attention logits before softmax:
35
- attention_logits[i, j] -= decay_rate_learned * log(1 + |i - j|)
36
-
37
- The decay rate is a learned scalar per head, initialised from a prior.
38
- """
39
-
40
- def __init__(
41
- self,
42
- embed_dim: int,
43
- num_heads: int,
44
- dropout: float = 0.2,
45
- decay_init: float = 1.0,
46
- ) -> None:
47
- super().__init__()
48
- self.embed_dim = embed_dim
49
- self.num_heads = num_heads
50
- self.head_dim = embed_dim // num_heads
51
- assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
52
-
53
- self.q_proj = nn.Linear(embed_dim, embed_dim)
54
- self.k_proj = nn.Linear(embed_dim, embed_dim)
55
- self.v_proj = nn.Linear(embed_dim, embed_dim)
56
- self.out_proj = nn.Linear(embed_dim, embed_dim)
57
- self.dropout = nn.Dropout(dropout)
58
-
59
- # Learned decay rate per head — initialised to decay_init
60
- # Constrained positive via softplus during forward
61
- self.decay_logit = nn.Parameter(
62
- torch.full((num_heads,), math.log(math.exp(decay_init) - 1))
63
- )
64
-
65
- def forward(
66
- self,
67
- x: torch.Tensor, # (batch, seq_len, embed_dim)
68
- causal_mask: torch.Tensor, # (seq_len, seq_len) bool — True = block
69
- key_padding_mask: torch.Tensor, # (batch, seq_len) bool — True = pad
70
- ) -> torch.Tensor:
71
- B, L, D = x.shape
72
- H, Hd = self.num_heads, self.head_dim
73
-
74
- Q = self.q_proj(x).view(B, L, H, Hd).transpose(1, 2) # (B, H, L, Hd)
75
- K = self.k_proj(x).view(B, L, H, Hd).transpose(1, 2)
76
- V = self.v_proj(x).view(B, L, H, Hd).transpose(1, 2)
77
-
78
- # Scaled dot-product attention scores
79
- scale = math.sqrt(self.head_dim)
80
- scores = torch.matmul(Q, K.transpose(-2, -1)) / scale # (B, H, L, L)
81
-
82
- # ── Ebbinghaus decay ──────────────────────────────────────── #
83
- # Build temporal distance matrix: dist[i, j] = |i - j|
84
- positions = torch.arange(L, device=x.device)
85
- dist = (positions.unsqueeze(0) - positions.unsqueeze(1)).abs().float() # (L, L)
86
-
87
- # decay = softplus(decay_logit) ensures strictly positive rates
88
- decay_rate = F.softplus(self.decay_logit) # (H,)
89
-
90
- # Decay penalty: rate_h * log(1 + dist) — shape (H, L, L)
91
- decay_penalty = decay_rate.view(H, 1, 1) * torch.log1p(dist).unsqueeze(0)
92
- scores = scores - decay_penalty.unsqueeze(0) # broadcast over batch
93
- # ─────────────────────────────────────────────────────────── #
94
-
95
- # Apply causal mask
96
- scores = scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), -1e9)
97
-
98
- # Apply padding mask
99
- if key_padding_mask is not None:
100
- scores = scores.masked_fill(
101
- key_padding_mask.unsqueeze(1).unsqueeze(2), -1e9
102
- )
103
-
104
- attn = F.softmax(scores, dim=-1)
105
- attn = self.dropout(attn)
106
-
107
- out = torch.matmul(attn, V) # (B, H, L, Hd)
108
- out = out.transpose(1, 2).contiguous().view(B, L, D) # (B, L, D)
109
- return self.out_proj(out)
110
-
111
-
112
- class SAKTWithDecay(nn.Module):
113
- """
114
- SAKT + Ebbinghaus Forgetting Curve Decay.
115
-
116
- Drop-in replacement for SAKTModel with improved AUC through
117
- temporal decay attention. All other architecture details are identical.
118
-
119
- Parameters
120
- ----------
121
- num_skills : int
122
- embed_dim : int
123
- num_heads : int
124
- dropout : float
125
- max_seq_len : int
126
- decay_init : float
127
- Initial decay rate (higher = faster forgetting). Default 1.0.
128
- """
129
-
130
- def __init__(
131
- self,
132
- num_skills: int,
133
- embed_dim: int = 64,
134
- num_heads: int = 8,
135
- dropout: float = 0.2,
136
- max_seq_len: int = 100,
137
- decay_init: float = 1.0,
138
- ) -> None:
139
- super().__init__()
140
- self.num_skills = num_skills
141
- self.embed_dim = embed_dim
142
- self.max_seq_len = max_seq_len
143
-
144
- self.interaction_embed = nn.Embedding(2 * num_skills + 2, embed_dim, padding_idx=0) # +2: shift+1 means max index = 2*n+1
145
- self.pos_embed = nn.Embedding(max_seq_len, embed_dim)
146
-
147
- # Decay-aware attention replaces nn.MultiheadAttention
148
- self.decay_attn = DecayAttention(embed_dim, num_heads, dropout, decay_init)
149
-
150
- self.layer_norm1 = nn.LayerNorm(embed_dim)
151
- self.layer_norm2 = nn.LayerNorm(embed_dim)
152
- self.ffn = nn.Sequential(
153
- nn.Linear(embed_dim, embed_dim * 2),
154
- nn.ReLU(),
155
- nn.Dropout(dropout),
156
- nn.Linear(embed_dim * 2, embed_dim),
157
- )
158
-
159
- self.skill_embed = nn.Embedding(num_skills + 1, embed_dim, padding_idx=0)
160
- self.output_layer = nn.Linear(embed_dim * 2, 1)
161
- self.dropout = nn.Dropout(dropout)
162
-
163
- def forward(
164
- self,
165
- interactions: torch.Tensor,
166
- target_skills: torch.Tensor,
167
- mask: torch.Tensor,
168
- ) -> torch.Tensor:
169
- B, L = interactions.shape
170
- positions = torch.arange(L, device=interactions.device).unsqueeze(0)
171
-
172
- x = self.interaction_embed(interactions) + self.pos_embed(positions)
173
- x = self.dropout(x)
174
-
175
- causal_mask = torch.triu(
176
- torch.ones(L, L, device=x.device, dtype=torch.bool), diagonal=1
177
- )
178
- key_padding_mask = ~mask # True = ignore
179
-
180
- x_attn = self.decay_attn(x, causal_mask, key_padding_mask)
181
- x = self.layer_norm1(x + x_attn)
182
- x = self.layer_norm2(x + self.ffn(x))
183
-
184
- skill_x = self.skill_embed(target_skills)
185
- out = self.output_layer(torch.cat([x, skill_x], dim=-1)).squeeze(-1)
186
- return out
187
-
188
- @torch.no_grad()
189
- def predict_mastery(
190
- self,
191
- skill_seq: list[int],
192
- correct_seq: list[int],
193
- device: torch.device | str = "cpu",
194
- ) -> dict[int, float]:
195
- """Same interface as SAKTModel.predict_mastery."""
196
- if len(skill_seq) < 2:
197
- return {}
198
-
199
- if len(skill_seq) > self.max_seq_len:
200
- skill_seq = skill_seq[-self.max_seq_len:]
201
- correct_seq = correct_seq[-self.max_seq_len:]
202
-
203
- interactions = [s + c * self.num_skills + 1 for s, c in zip(skill_seq[:-1], correct_seq[:-1])] # +1: reserve 0 for padding
204
- target_skills = skill_seq[1:]
205
- seq_len = len(interactions)
206
- pad_len = self.max_seq_len - seq_len
207
-
208
- interactions_padded = [0] * pad_len + interactions
209
- target_padded = [0] * pad_len + target_skills
210
- mask_list = [False] * pad_len + [True] * seq_len
211
-
212
- self.eval()
213
- self.to(device)
214
-
215
- logits = self(
216
- torch.LongTensor([interactions_padded]).to(device),
217
- torch.LongTensor([target_padded]).to(device),
218
- torch.BoolTensor([mask_list]).to(device),
219
- )
220
- probs = torch.sigmoid(logits).squeeze(0)
221
- real_probs = probs[torch.BoolTensor(mask_list)].cpu().numpy()
222
-
223
- return {int(sid): float(p) for sid, p in zip(target_skills, real_probs)}
224
-
225
- def save(self, path: str | Path, config: dict[str, Any] | None = None) -> None:
226
- payload = {
227
- "state_dict": self.state_dict(),
228
- "model_type": "SAKTWithDecay",
229
- "config": config or {
230
- "num_skills": self.num_skills,
231
- "embed_dim": self.embed_dim,
232
- "max_seq_len": self.max_seq_len,
233
- "model_type": "SAKTWithDecay",
234
- },
235
- }
236
- torch.save(payload, path)
237
-
238
- @classmethod
239
- def load(cls, path: str | Path, device: str | torch.device = "cpu") -> "SAKTWithDecay":
240
- payload = torch.load(path, map_location=device, weights_only=False)
241
- cfg = payload["config"]
242
- model = cls(
243
- num_skills=cfg["num_skills"],
244
- embed_dim=cfg.get("embed_dim", 64),
245
- num_heads=cfg.get("num_heads", 8),
246
- dropout=cfg.get("dropout", 0.2),
247
- max_seq_len=cfg.get("max_seq_len", 100),
248
- decay_init=cfg.get("decay_init", 1.0),
249
- )
250
- model.load_state_dict(payload["state_dict"])
251
- model.to(device)
252
- model.eval()
253
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
plrs/model/trainer.py DELETED
@@ -1,437 +0,0 @@
1
- """
2
- plrs.model.trainer
3
- ==================
4
- Training loop for the SAKT knowledge tracing model.
5
-
6
- Handles:
7
- - Dataset preparation from raw interaction logs
8
- - Train / validation split
9
- - Training with early stopping
10
- - Checkpoint saving (best val AUC)
11
- - Metrics: AUC, accuracy, loss
12
-
13
- Expected input format (CSV or DataFrame):
14
- student_id | skill_id | correct | timestamp (optional)
15
- """
16
-
17
- from __future__ import annotations
18
-
19
- import time
20
- from dataclasses import dataclass, field
21
- from pathlib import Path
22
- from typing import Iterator
23
-
24
- import numpy as np
25
- import torch
26
- import torch.nn as nn
27
- from torch.utils.data import DataLoader, Dataset
28
-
29
- try:
30
- from sklearn.metrics import roc_auc_score
31
- HAS_SKLEARN = True
32
- except ImportError:
33
- HAS_SKLEARN = False
34
-
35
-
36
- # ------------------------------------------------------------------ #
37
- # Dataset #
38
- # ------------------------------------------------------------------ #
39
-
40
- class KTDataset(Dataset):
41
- """
42
- Knowledge Tracing dataset.
43
-
44
- Each sample is one student's full interaction sequence, windowed to
45
- max_seq_len. Long sequences are split into multiple windows.
46
-
47
- Parameters
48
- ----------
49
- sequences : list of (skill_seq, correct_seq)
50
- Each element is a tuple of parallel lists.
51
- max_seq_len : int
52
- n_skills : int
53
- """
54
-
55
- def __init__(
56
- self,
57
- sequences: list[tuple[list[int], list[int]]],
58
- max_seq_len: int = 100,
59
- n_skills: int = 5736,
60
- ) -> None:
61
- self.max_seq_len = max_seq_len
62
- self.n_skills = n_skills
63
- self.samples: list[tuple[list[int], list[int]]] = []
64
-
65
- for skill_seq, correct_seq in sequences:
66
- # Window long sequences
67
- for start in range(0, max(1, len(skill_seq) - 1), max_seq_len // 2):
68
- end = start + max_seq_len + 1
69
- s = skill_seq[start:end]
70
- c = correct_seq[start:end]
71
- if len(s) >= 2:
72
- self.samples.append((s, c))
73
-
74
- def __len__(self) -> int:
75
- return len(self.samples)
76
-
77
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
78
- skill_seq, correct_seq = self.samples[idx]
79
-
80
- if len(skill_seq) > self.max_seq_len + 1:
81
- skill_seq = skill_seq[-self.max_seq_len - 1:]
82
- correct_seq = correct_seq[-self.max_seq_len - 1:]
83
-
84
- interactions = [s + c * self.n_skills + 1 for s, c in zip(skill_seq[:-1], correct_seq[:-1])] # +1: reserve 0 for padding
85
- target_skills = skill_seq[1:]
86
- target_correct = correct_seq[1:]
87
-
88
- seq_len = len(interactions)
89
- pad_len = self.max_seq_len - seq_len
90
-
91
- interactions_padded = [0] * pad_len + interactions
92
- target_padded = [0] * pad_len + target_skills
93
- correct_padded = [0] * pad_len + target_correct
94
- mask = [False] * pad_len + [True] * seq_len
95
-
96
- return {
97
- "interactions": torch.LongTensor(interactions_padded),
98
- "target_skills": torch.LongTensor(target_padded),
99
- "target_correct": torch.FloatTensor(correct_padded),
100
- "mask": torch.BoolTensor(mask),
101
- }
102
-
103
-
104
- def collate_fn(batch: list[dict]) -> dict[str, torch.Tensor]:
105
- return {k: torch.stack([b[k] for b in batch]) for k in batch[0]}
106
-
107
-
108
- # ------------------------------------------------------------------ #
109
- # Trainer config #
110
- # ------------------------------------------------------------------ #
111
-
112
- @dataclass
113
- class TrainerConfig:
114
- # Model
115
- num_skills: int = 5736
116
- embed_dim: int = 64
117
- num_heads: int = 8
118
- dropout: float = 0.2
119
- max_seq_len: int = 100
120
-
121
- # Training
122
- epochs: int = 50
123
- batch_size: int = 64
124
- lr: float = 1e-3
125
- weight_decay: float = 1e-5
126
- val_split: float = 0.1
127
-
128
- # Early stopping
129
- patience: int = 5
130
- min_delta: float = 1e-4
131
-
132
- # Output
133
- output_dir: str = "checkpoints"
134
- run_name: str = "sakt_run"
135
-
136
- # Device
137
- device: str = "auto" # "auto" | "cpu" | "cuda" | "mps"
138
-
139
-
140
- # ------------------------------------------------------------------ #
141
- # Trainer #
142
- # ------------------------------------------------------------------ #
143
-
144
- @dataclass
145
- class EpochMetrics:
146
- epoch: int
147
- train_loss: float
148
- val_loss: float
149
- val_auc: float
150
- val_acc: float
151
- elapsed: float
152
-
153
-
154
- class SAKTTrainer:
155
- """
156
- Trainer for the SAKT knowledge tracing model.
157
-
158
- Parameters
159
- ----------
160
- config : TrainerConfig
161
- """
162
-
163
- def __init__(self, config: TrainerConfig) -> None:
164
- self.config = config
165
- self.device = self._resolve_device(config.device)
166
- self.output_dir = Path(config.output_dir)
167
- self.output_dir.mkdir(parents=True, exist_ok=True)
168
-
169
- # ---------------------------------------------------------------- #
170
- # Public API #
171
- # ---------------------------------------------------------------- #
172
-
173
- def fit(
174
- self,
175
- sequences: list[tuple[list[int], list[int]]],
176
- val_sequences: list[tuple[list[int], list[int]]] | None = None,
177
- ) -> list[EpochMetrics]:
178
- """
179
- Train the SAKT model on interaction sequences.
180
-
181
- Parameters
182
- ----------
183
- sequences : list of (skill_seq, correct_seq)
184
- Training data. Each element is a student's full history.
185
- val_sequences : list of (skill_seq, correct_seq), optional
186
- If None, val_split fraction of sequences is held out.
187
-
188
- Returns
189
- -------
190
- list[EpochMetrics] — training history
191
- """
192
- from plrs.model.sakt import SAKTModel
193
-
194
- cfg = self.config
195
-
196
- # Split if no explicit val set
197
- if val_sequences is None:
198
- n_val = max(1, int(len(sequences) * cfg.val_split))
199
- idx = np.random.permutation(len(sequences))
200
- val_sequences = [sequences[i] for i in idx[:n_val]]
201
- train_sequences = [sequences[i] for i in idx[n_val:]]
202
- else:
203
- train_sequences = sequences
204
-
205
- print(f"Training samples : {len(train_sequences)} students")
206
- print(f"Validation samples: {len(val_sequences)} students")
207
- print(f"Device: {self.device}")
208
-
209
- train_ds = KTDataset(train_sequences, cfg.max_seq_len, cfg.num_skills)
210
- val_ds = KTDataset(val_sequences, cfg.max_seq_len, cfg.num_skills)
211
-
212
- train_loader = DataLoader(
213
- train_ds, batch_size=cfg.batch_size, shuffle=True,
214
- collate_fn=collate_fn, num_workers=0,
215
- )
216
- val_loader = DataLoader(
217
- val_ds, batch_size=cfg.batch_size * 2, shuffle=False,
218
- collate_fn=collate_fn, num_workers=0,
219
- )
220
-
221
- model = SAKTModel(
222
- num_skills=cfg.num_skills,
223
- embed_dim=cfg.embed_dim,
224
- num_heads=cfg.num_heads,
225
- dropout=cfg.dropout,
226
- max_seq_len=cfg.max_seq_len,
227
- ).to(self.device)
228
-
229
- print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
230
-
231
- optimizer = torch.optim.Adam(
232
- model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
233
- )
234
-
235
- # Zero out NaN gradients that arise from softmax backward over fully-padded rows.
236
- # This is a known issue with nn.MultiheadAttention + bool key_padding_mask.
237
- # The hook is safe: it only zeroes truly NaN gradients, never valid ones.
238
- def _zero_nan_grad(grad: torch.Tensor) -> torch.Tensor:
239
- return torch.nan_to_num(grad, nan=0.0)
240
- for p in model.parameters():
241
- p.register_hook(_zero_nan_grad)
242
- scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
243
- optimizer, mode="max", patience=2, factor=0.5
244
- )
245
- criterion = nn.BCEWithLogitsLoss()
246
-
247
- history: list[EpochMetrics] = []
248
- best_auc = 0.0
249
- patience_counter = 0
250
- best_path = self.output_dir / f"{cfg.run_name}_best.pt"
251
-
252
- print(f"\n{'Epoch':>6} {'Train Loss':>11} {'Val Loss':>9} {'Val AUC':>9} {'Val Acc':>9} {'Time':>7}")
253
- print("-" * 58)
254
-
255
- for epoch in range(1, cfg.epochs + 1):
256
- t0 = time.time()
257
-
258
- train_loss = self._train_epoch(model, train_loader, optimizer, criterion)
259
- val_loss, val_auc, val_acc = self._val_epoch(model, val_loader, criterion)
260
-
261
- scheduler.step(val_auc)
262
- elapsed = time.time() - t0
263
-
264
- metrics = EpochMetrics(
265
- epoch=epoch,
266
- train_loss=train_loss,
267
- val_loss=val_loss,
268
- val_auc=val_auc,
269
- val_acc=val_acc,
270
- elapsed=elapsed,
271
- )
272
- history.append(metrics)
273
-
274
- print(
275
- f"{epoch:>6} {train_loss:>11.4f} {val_loss:>9.4f} "
276
- f"{val_auc:>9.4f} {val_acc:>9.4f} {elapsed:>6.1f}s"
277
- )
278
-
279
- # Save best
280
- if val_auc > best_auc + cfg.min_delta:
281
- best_auc = val_auc
282
- patience_counter = 0
283
- model.save(best_path, config=self._model_config())
284
- print(f" ✅ New best AUC: {best_auc:.4f} → saved to {best_path}")
285
- else:
286
- patience_counter += 1
287
- if patience_counter >= cfg.patience:
288
- print(f"\nEarly stopping at epoch {epoch} (patience={cfg.patience})")
289
- break
290
-
291
- print(f"\nTraining complete. Best val AUC: {best_auc:.4f}")
292
- print(f"Best model: {best_path}")
293
- return history
294
-
295
- # ---------------------------------------------------------------- #
296
- # Internal #
297
- # ---------------------------------------------------------------- #
298
-
299
- def _train_epoch(self, model, loader, optimizer, criterion) -> float:
300
- model.train()
301
- total_loss = 0.0
302
-
303
- for batch in loader:
304
- interactions = batch["interactions"].to(self.device)
305
- target_skills = batch["target_skills"].to(self.device)
306
- target_correct = batch["target_correct"].to(self.device)
307
- mask = batch["mask"].to(self.device)
308
-
309
- optimizer.zero_grad()
310
- logits = model(interactions, target_skills, mask)
311
-
312
- # Only compute loss on real (non-padded) positions
313
- real_logits = logits[mask]
314
- real_targets = target_correct[mask]
315
-
316
- loss = criterion(real_logits, real_targets)
317
- loss.backward()
318
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
319
- optimizer.step()
320
-
321
- total_loss += loss.item()
322
-
323
- return total_loss / max(len(loader), 1)
324
-
325
- @torch.no_grad()
326
- def _val_epoch(self, model, loader, criterion) -> tuple[float, float, float]:
327
- model.eval()
328
- total_loss = 0.0
329
- all_probs: list[float] = []
330
- all_labels: list[float] = []
331
-
332
- for batch in loader:
333
- interactions = batch["interactions"].to(self.device)
334
- target_skills = batch["target_skills"].to(self.device)
335
- target_correct = batch["target_correct"].to(self.device)
336
- mask = batch["mask"].to(self.device)
337
-
338
- logits = model(interactions, target_skills, mask)
339
- real_logits = logits[mask]
340
- real_targets = target_correct[mask]
341
-
342
- loss = criterion(real_logits, real_targets)
343
- total_loss += loss.item()
344
-
345
- probs = torch.sigmoid(real_logits).cpu().numpy()
346
- labels = real_targets.cpu().numpy()
347
- all_probs.extend(probs.tolist())
348
- all_labels.extend(labels.tolist())
349
-
350
- avg_loss = total_loss / max(len(loader), 1)
351
- all_probs_arr = np.array(all_probs)
352
- all_labels_arr = np.array(all_labels)
353
-
354
- # Guard against NaN (can occur with very small val sets)
355
- all_probs_arr = np.nan_to_num(all_probs_arr, nan=0.5)
356
- all_labels_arr = np.nan_to_num(all_labels_arr, nan=0.0)
357
-
358
- if HAS_SKLEARN and len(np.unique(all_labels_arr)) > 1:
359
- auc = float(roc_auc_score(all_labels_arr, all_probs_arr))
360
- else:
361
- auc = 0.5 # fallback (single class or no sklearn)
362
-
363
- acc = float(((all_probs_arr >= 0.5) == all_labels_arr).mean())
364
- return avg_loss, auc, acc
365
-
366
- def _model_config(self) -> dict:
367
- cfg = self.config
368
- return {
369
- "num_skills": cfg.num_skills,
370
- "embed_dim": cfg.embed_dim,
371
- "num_heads": cfg.num_heads,
372
- "dropout": cfg.dropout,
373
- "max_seq_len": cfg.max_seq_len,
374
- }
375
-
376
- @staticmethod
377
- def _resolve_device(device: str) -> torch.device:
378
- if device == "auto":
379
- if torch.cuda.is_available():
380
- return torch.device("cuda")
381
- if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
382
- return torch.device("mps")
383
- return torch.device("cpu")
384
- return torch.device(device)
385
-
386
-
387
- # ------------------------------------------------------------------ #
388
- # Utilities #
389
- # ------------------------------------------------------------------ #
390
-
391
- def load_sequences_from_csv(
392
- path: str | Path,
393
- student_col: str = "student_id",
394
- skill_col: str = "skill_id",
395
- correct_col: str = "correct",
396
- timestamp_col: str | None = "timestamp",
397
- min_seq_len: int = 5,
398
- ) -> list[tuple[list[int], list[int]]]:
399
- """
400
- Load student interaction sequences from a CSV file.
401
-
402
- Parameters
403
- ----------
404
- path : str or Path
405
- CSV with columns: student_id, skill_id, correct, [timestamp]
406
- student_col, skill_col, correct_col : str
407
- Column names.
408
- timestamp_col : str or None
409
- If provided, sort interactions by this column within each student.
410
- min_seq_len : int
411
- Drop students with fewer than this many interactions.
412
-
413
- Returns
414
- -------
415
- list of (skill_seq, correct_seq) tuples
416
- """
417
- import pandas as pd
418
-
419
- df = pd.read_csv(path)
420
-
421
- required = [student_col, skill_col, correct_col]
422
- missing = [c for c in required if c not in df.columns]
423
- if missing:
424
- raise ValueError(f"Missing columns in CSV: {missing}. Found: {df.columns.tolist()}")
425
-
426
- if timestamp_col and timestamp_col in df.columns:
427
- df = df.sort_values([student_col, timestamp_col])
428
-
429
- sequences = []
430
- for _, group in df.groupby(student_col):
431
- skills = group[skill_col].astype(int).tolist()
432
- corrects = group[correct_col].astype(int).tolist()
433
- if len(skills) >= min_seq_len:
434
- sequences.append((skills, corrects))
435
-
436
- print(f"Loaded {len(sequences)} student sequences from {path}")
437
- return sequences
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
plrs/pipeline.py DELETED
@@ -1,236 +0,0 @@
1
- """
2
- plrs.pipeline
3
- =============
4
- PLRSPipeline: the main entry point.
5
-
6
- Orchestrates SAKT inference → DAG constraint validation → multi-objective ranking.
7
-
8
- Usage
9
- -----
10
- from plrs import PLRSPipeline
11
- from plrs.curriculum import load_dag
12
-
13
- curriculum = load_dag("math_dag.json")
14
- pipeline = PLRSPipeline(curriculum, model_path="sakt_model.pt")
15
-
16
- # From raw interaction history
17
- results = pipeline.recommend_from_history(
18
- skill_seq=[12, 45, 3, 78],
19
- correct_seq=[1, 0, 1, 1],
20
- )
21
-
22
- # From pre-computed mastery dict
23
- results = pipeline.recommend_from_mastery(
24
- mastery_scores={"algebra_basics": 0.85, "quadratic_equations": 0.42}
25
- )
26
- """
27
-
28
- from __future__ import annotations
29
-
30
- from pathlib import Path
31
- from typing import Any
32
-
33
- from plrs.constraints.dag import DAGConstraintLayer, MasteryVector
34
- from plrs.curriculum.loader import CurriculumGraph
35
- from plrs.ranking.ranker import MultiObjectiveRanker, RankedRecommendation
36
-
37
-
38
- class PLRSPipeline:
39
- """
40
- End-to-end PLRS recommendation pipeline.
41
-
42
- Parameters
43
- ----------
44
- curriculum : CurriculumGraph
45
- model_path : str or Path, optional
46
- Path to a trained SAKT .pt file. If None, only mastery-dict mode is available.
47
- threshold : float
48
- Mastery threshold (default 0.70).
49
- soft_threshold : float
50
- Soft constraint threshold (default 0.50).
51
- top_n : int
52
- Number of top approved recommendations (default 5).
53
- w_gap, w_readiness, w_downstream : float
54
- Ranker objective weights.
55
- device : str
56
- PyTorch device for model inference (default "cpu").
57
- """
58
-
59
- def __init__(
60
- self,
61
- curriculum: CurriculumGraph,
62
- model_path: str | Path | None = None,
63
- threshold: float = 0.70,
64
- soft_threshold: float = 0.50,
65
- top_n: int = 5,
66
- w_gap: float = 0.4,
67
- w_readiness: float = 0.4,
68
- w_downstream: float = 0.2,
69
- device: str = "cpu",
70
- ) -> None:
71
- self.curriculum = curriculum
72
- self.threshold = threshold
73
- self.soft_threshold = soft_threshold
74
- self.top_n = top_n
75
- self.device = device
76
-
77
- self.constraint_layer = DAGConstraintLayer(curriculum)
78
- self.ranker = MultiObjectiveRanker(
79
- curriculum,
80
- w_gap=w_gap,
81
- w_readiness=w_readiness,
82
- w_downstream=w_downstream,
83
- )
84
-
85
- self._model = None
86
- if model_path is not None:
87
- self._load_model(model_path)
88
-
89
- # ------------------------------------------------------------------ #
90
- # Public API #
91
- # ------------------------------------------------------------------ #
92
-
93
- def recommend_from_mastery(
94
- self,
95
- mastery_scores: dict[str, float],
96
- cascade: bool = False,
97
- ) -> dict[str, Any]:
98
- """
99
- Generate recommendations from a pre-computed mastery dict.
100
-
101
- Parameters
102
- ----------
103
- mastery_scores : dict[str, float]
104
- Mapping from topic_id → mastery probability [0, 1].
105
- cascade : bool
106
- If True, propagate mastery upward through prerequisites.
107
-
108
- Returns
109
- -------
110
- dict with keys: approved, challenging, vetoed, stats, mastery_summary
111
- """
112
- mastery = self._build_mastery_vector(mastery_scores)
113
- if cascade:
114
- mastery.cascade_up()
115
- return self._run(mastery)
116
-
117
- def recommend_from_history(
118
- self,
119
- skill_seq: list[int],
120
- correct_seq: list[int],
121
- skill_to_topic: dict[int, str] | None = None,
122
- cascade: bool = False,
123
- ) -> dict[str, Any]:
124
- """
125
- Generate recommendations from raw student interaction history.
126
-
127
- Requires a loaded SAKT model (pass model_path to __init__).
128
-
129
- Parameters
130
- ----------
131
- skill_seq : list[int]
132
- Sequence of skill IDs from the student's history.
133
- correct_seq : list[int]
134
- Corresponding correctness flags (1/0).
135
- skill_to_topic : dict[int, str], optional
136
- Mapping from SAKT skill_id → curriculum topic_id.
137
- Required to map model output back to DAG nodes.
138
- cascade : bool
139
- If True, propagate mastery upward through prerequisites.
140
-
141
- Returns
142
- -------
143
- dict with keys: approved, challenging, vetoed, stats, mastery_summary
144
- """
145
- if self._model is None:
146
- raise RuntimeError(
147
- "No model loaded. Pass model_path to PLRSPipeline() to use history-based inference."
148
- )
149
-
150
- skill_probs = self._model.predict_mastery(skill_seq, correct_seq, device=self.device)
151
-
152
- if skill_to_topic:
153
- mastery_scores = {}
154
- for skill_id, prob in skill_probs.items():
155
- topic_id = skill_to_topic.get(skill_id)
156
- if topic_id:
157
- mastery_scores[topic_id] = max(mastery_scores.get(topic_id, 0.0), prob)
158
- else:
159
- # Without mapping, return raw skill probabilities (limited utility)
160
- mastery_scores = {str(k): v for k, v in skill_probs.items()}
161
-
162
- mastery = self._build_mastery_vector(mastery_scores)
163
- if cascade:
164
- mastery.cascade_up()
165
- return self._run(mastery)
166
-
167
- def what_if(self, topic_id: str) -> dict[str, Any]:
168
- """
169
- What-if analysis: what unlocks if a student masters this topic?
170
-
171
- Parameters
172
- ----------
173
- topic_id : str
174
-
175
- Returns
176
- -------
177
- dict with direct_unlocks, all_unlocks, blocked_by, total_unlocked
178
- """
179
- graph = self.curriculum.graph
180
- direct = self.curriculum.successors(topic_id)
181
- all_unlocks = self.curriculum.descendants(topic_id)
182
- blocked_by = self.curriculum.prerequisites(topic_id)
183
-
184
- return {
185
- "topic_id": topic_id,
186
- "topic_label": self.curriculum.label(topic_id),
187
- "direct_unlocks": [
188
- {"id": n, "label": self.curriculum.label(n)} for n in direct
189
- ],
190
- "all_unlocks": [
191
- {"id": n, "label": self.curriculum.label(n)} for n in all_unlocks
192
- ],
193
- "blocked_by": [
194
- {"id": n, "label": self.curriculum.label(n)} for n in blocked_by
195
- ],
196
- "total_unlocked": len(all_unlocks),
197
- }
198
-
199
- # ------------------------------------------------------------------ #
200
- # Internal helpers #
201
- # ------------------------------------------------------------------ #
202
-
203
- def _build_mastery_vector(self, mastery_scores: dict[str, float]) -> MasteryVector:
204
- mv = MasteryVector(self.curriculum, self.threshold, self.soft_threshold)
205
- mv.update_batch(mastery_scores)
206
- return mv
207
-
208
- def _run(self, mastery: MasteryVector) -> dict[str, Any]:
209
- constraint_results = self.constraint_layer.validate_all(mastery)
210
- ranked = self.ranker.rank(constraint_results, mastery, top_n=self.top_n)
211
- ranked["mastery_summary"] = mastery.summary()
212
-
213
- # Serialise to plain dicts for API/JSON friendliness
214
- for key in ("approved", "challenging", "vetoed"):
215
- ranked[key] = [self._rec_to_dict(r) for r in ranked[key]]
216
-
217
- return ranked
218
-
219
- def _load_model(self, path: str | Path) -> None:
220
- from plrs.model.sakt import SAKTModel
221
- self._model = SAKTModel.load(path, device=self.device)
222
-
223
- @staticmethod
224
- def _rec_to_dict(rec: RankedRecommendation) -> dict[str, Any]:
225
- return {
226
- "topic_id": rec.topic_id,
227
- "topic_label": rec.topic_label,
228
- "status": rec.status,
229
- "mastery": rec.mastery,
230
- "score": rec.score,
231
- "reasoning": rec.reasoning,
232
- "prerequisites": rec.prerequisites,
233
- "unmet_prerequisites": rec.unmet_prerequisites,
234
- "downstream_count": rec.downstream_count,
235
- "score_breakdown": rec.score_breakdown,
236
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
plrs/ranking/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from plrs.ranking.ranker import MultiObjectiveRanker, RankedRecommendation
2
-
3
- __all__ = ["MultiObjectiveRanker", "RankedRecommendation"]
 
 
 
 
plrs/ranking/ranker.py DELETED
@@ -1,189 +0,0 @@
1
- """
2
- plrs.ranking.ranker
3
- ===================
4
- Multi-objective ranking function for approved/challenging topics.
5
-
6
- Scoring signals:
7
- 1. Mastery gap — how close the student is to mastering this topic
8
- 2. Readiness — fraction of prerequisites met
9
- 3. Downstream value — how many future topics this unlocks (normalised)
10
-
11
- Weights are configurable. Default: gap=0.4, readiness=0.4, downstream=0.2
12
- """
13
-
14
- from __future__ import annotations
15
-
16
- from dataclasses import dataclass
17
-
18
- import networkx as nx
19
-
20
- from plrs.constraints.dag import ConstraintResult, MasteryVector
21
- from plrs.curriculum.loader import CurriculumGraph
22
-
23
-
24
- @dataclass
25
- class RankedRecommendation:
26
- topic_id: str
27
- topic_label: str
28
- status: str # "approved" | "challenging"
29
- mastery: float
30
- score: float
31
- reasoning: str
32
- prerequisites: list[str]
33
- unmet_prerequisites: list[str]
34
- downstream_count: int
35
- score_breakdown: dict[str, float]
36
-
37
-
38
- class MultiObjectiveRanker:
39
- """
40
- Ranks constraint-validated topics by a weighted combination of signals.
41
-
42
- Parameters
43
- ----------
44
- curriculum : CurriculumGraph
45
- w_gap : float
46
- Weight for mastery gap signal (default 0.4).
47
- w_readiness : float
48
- Weight for prerequisite readiness signal (default 0.4).
49
- w_downstream : float
50
- Weight for downstream unlock value (default 0.2).
51
- """
52
-
53
- def __init__(
54
- self,
55
- curriculum: CurriculumGraph,
56
- w_gap: float = 0.4,
57
- w_readiness: float = 0.4,
58
- w_downstream: float = 0.2,
59
- ) -> None:
60
- self.curriculum = curriculum
61
- self.w_gap = w_gap
62
- self.w_readiness = w_readiness
63
- self.w_downstream = w_downstream
64
-
65
- # Pre-compute downstream counts (expensive on large graphs; cache it)
66
- self._downstream_counts = self._compute_downstream_counts()
67
- max_d = max(self._downstream_counts.values(), default=1)
68
- self._downstream_norm = {
69
- node: count / max(max_d, 1)
70
- for node, count in self._downstream_counts.items()
71
- }
72
-
73
- def _compute_downstream_counts(self) -> dict[str, int]:
74
- return {
75
- node: len(nx.descendants(self.curriculum.graph, node))
76
- for node in self.curriculum.nodes
77
- }
78
-
79
- def score(self, result: ConstraintResult, mastery: MasteryVector) -> float:
80
- """Compute composite score for a single topic."""
81
- topic_id = result.topic_id
82
-
83
- # 1. Mastery gap: student is close but not mastered → higher priority
84
- gap = max(0.0, mastery.threshold - mastery.get(topic_id))
85
- gap_score = gap / mastery.threshold # normalise to [0, 1]
86
-
87
- # 2. Readiness: fraction of prerequisites above soft threshold
88
- prereqs = self.curriculum.prerequisites(topic_id)
89
- if prereqs:
90
- readiness = sum(
91
- 1 for p in prereqs if mastery.get(p) >= mastery.soft_threshold
92
- ) / len(prereqs)
93
- else:
94
- readiness = 1.0
95
-
96
- # 3. Downstream value
97
- downstream = self._downstream_norm.get(topic_id, 0.0)
98
-
99
- score = (
100
- self.w_gap * gap_score
101
- + self.w_readiness * readiness
102
- + self.w_downstream * downstream
103
- )
104
-
105
- return round(score, 4)
106
-
107
- def rank(
108
- self,
109
- results: list[ConstraintResult],
110
- mastery: MasteryVector,
111
- top_n: int = 5,
112
- challenging_penalty: float = 0.8,
113
- ) -> dict[str, list[RankedRecommendation]]:
114
- """
115
- Rank a list of constraint results into approved / challenging / vetoed.
116
-
117
- Parameters
118
- ----------
119
- results : list[ConstraintResult]
120
- mastery : MasteryVector
121
- top_n : int
122
- Number of top approved recommendations to return.
123
- challenging_penalty : float
124
- Score multiplier applied to challenging topics (default 0.8).
125
-
126
- Returns
127
- -------
128
- dict with keys: "approved", "challenging", "vetoed", "stats"
129
- """
130
- approved: list[RankedRecommendation] = []
131
- challenging: list[RankedRecommendation] = []
132
- vetoed: list[RankedRecommendation] = []
133
-
134
- for result in results:
135
- # Skip already-mastered topics
136
- if mastery.is_mastered(result.topic_id):
137
- continue
138
-
139
- base_score = self.score(result, mastery)
140
- topic_id = result.topic_id
141
-
142
- breakdown = {
143
- "gap": round(
144
- self.w_gap * max(0.0, mastery.threshold - mastery.get(topic_id)) / mastery.threshold, 4
145
- ),
146
- "readiness": round(self.w_readiness * (
147
- sum(1 for p in self.curriculum.prerequisites(topic_id)
148
- if mastery.get(p) >= mastery.soft_threshold)
149
- / max(len(self.curriculum.prerequisites(topic_id)), 1)
150
- ), 4),
151
- "downstream": round(self.w_downstream * self._downstream_norm.get(topic_id, 0.0), 4),
152
- }
153
-
154
- rec = RankedRecommendation(
155
- topic_id=result.topic_id,
156
- topic_label=result.topic_label,
157
- status=result.status,
158
- mastery=round(result.mastery, 3),
159
- score=round(base_score * (challenging_penalty if result.status == "challenging" else 1.0), 4),
160
- reasoning=result.reasoning,
161
- prerequisites=result.prerequisites,
162
- unmet_prerequisites=result.unmet_prerequisites,
163
- downstream_count=self._downstream_counts.get(result.topic_id, 0),
164
- score_breakdown=breakdown,
165
- )
166
-
167
- if result.status == "approved":
168
- approved.append(rec)
169
- elif result.status == "challenging":
170
- challenging.append(rec)
171
- else:
172
- vetoed.append(rec)
173
-
174
- approved.sort(key=lambda r: r.score, reverse=True)
175
- challenging.sort(key=lambda r: r.score, reverse=True)
176
-
177
- total = len(results)
178
- return {
179
- "approved": approved[:top_n],
180
- "challenging": challenging[:3],
181
- "vetoed": vetoed[:5],
182
- "stats": {
183
- "total_topics": total,
184
- "approved_count": len(approved),
185
- "challenging_count": len(challenging),
186
- "vetoed_count": len(vetoed),
187
- "prerequisite_violation_rate": round(len(vetoed) / max(total, 1), 3),
188
- },
189
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,9 +1,7 @@
1
- streamlit>=1.33.0
2
- torch>=2.0.0
3
  pandas>=2.0.0
4
  numpy>=1.24.0
5
- networkx>=3.0
6
  scikit-learn>=1.3.0
7
  huggingface_hub>=0.20.0
8
- fastapi>=0.110.0
9
- pydantic>=2.0
 
1
+ streamlit>=1.32.0
2
+ torch>=2.9.0
3
  pandas>=2.0.0
4
  numpy>=1.24.0
5
+ networkx>=3.1
6
  scikit-learn>=1.3.0
7
  huggingface_hub>=0.20.0
 
 
models/sakt_model.pt → sakt_model.pt RENAMED
File without changes