File size: 11,828 Bytes
f31d2d4
 
 
 
 
 
 
 
 
 
 
 
 
 
40c272a
 
f31d2d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab8e02f
 
 
 
f31d2d4
ab8e02f
 
 
 
 
 
 
 
 
 
 
f31d2d4
ab8e02f
 
 
 
 
f31d2d4
 
 
 
 
40c272a
f31d2d4
 
 
 
 
 
 
 
 
 
 
560c197
 
 
 
f31d2d4
 
 
 
560c197
 
 
 
 
 
 
 
 
f31d2d4
560c197
 
 
 
f31d2d4
 
 
560c197
 
 
 
 
 
 
 
 
 
f31d2d4
 
40c272a
 
f31d2d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40c272a
f31d2d4
 
 
 
 
 
 
 
40c272a
1b69513
 
 
 
 
 
f31d2d4
 
1b69513
f31d2d4
 
 
 
 
1b69513
 
f31d2d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40c272a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd6592b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f31d2d4
 
 
 
 
 
 
40c272a
 
 
 
f31d2d4
40c272a
f31d2d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab8e02f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f31d2d4
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""
Recall — Module B: Learning Engine.  OWNER: Nikolai

The brain: scheduling (SM-2-lite), grading, adaptation, follow-up generation,
and the recap. Runs in STUB mode out of the box. Public signatures are fixed —
app.py depends on them.
"""
from __future__ import annotations

import llm
from schema import (
    Card, GradeResult, Session, new_card, new_card_state, new_grade, validate_card,
)

# STUB is owned by llm (single source of truth) and read dynamically as
# `llm.STUB` so every module agrees and runtime/reload changes are honored.


# ---- Session lifecycle -----------------------------------------------------

def init_session(deck: list[Card]) -> Session:
    states = {c["id"]: new_card_state(c["id"]) for c in deck}
    return Session(
        deck=list(deck),
        states=states,
        queue=[c["id"] for c in deck],
        history=[],
        streak=0,
    )


WEAK_TOPIC_THRESHOLD = 3.0   # avg grade below this = a topic the user is weak on
WEAK_LOOKAHEAD = 4           # how far down the queue we'll reach to surface a weak card


def next_card(session: Session) -> Card | None:
    """
    Return the next card to study. Among the next few due cards we bias toward
    the user's weakest topic (lowest average grade so far) — so once the model
    sees you're shaky on a topic, that topic comes back sooner. With no history
    yet this is a no-op and we serve the queue in order.

    The chosen card is rotated to the front of the queue so `apply_result`'s
    "pop the front" contract still holds.
    """
    queue = session["queue"]
    if not queue:
        return None

    idx = _weak_biased_index(session)
    if idx > 0:
        queue.insert(0, queue.pop(idx))   # bring the weak-topic card to the front
    return _find(session, queue[0])


# ---- Grading ---------------------------------------------------------------

def grade_answer(card: Card, user_answer: str) -> GradeResult:
    if llm.STUB:
        # Trivial heuristic so the stub demo "feels" responsive.
        ans = (user_answer or "").strip().lower()
        ref = card["answer"].strip().lower()
        overlap = len(set(ans.split()) & set(ref.split()))
        score = 5 if overlap >= 2 else (3 if overlap == 1 else 1)
        expl = ("Correct — you hit the key idea." if score >= 3
                else f"Not quite. Expected something like: {card['answer']}")
        return new_grade(score, expl, missed_concept=card["topic"])

    messages = [
        {"role": "system", "content":
         "You grade a student's answer against a reference answer. "
         "Return ONLY a JSON object with keys: "
         "score (integer 0-5), explanation (string for the student), "
         "missed_concept (short string naming what they got wrong, or \"\")."},
        {"role": "user", "content":
         f"Question: {card['question']}\nReference answer: {card['answer']}\n"
         f"Student answer: {user_answer}\nGrade it."},
    ]
    # Parser + one repair retry; safe default if the model never returns JSON.
    data = llm.chat_json(messages, max_tokens=256)
    if not _valid_grade(data):
        return new_grade(
            2,
            "Couldn't grade automatically — compare your answer to the "
            f"reference: {card['answer']}",
            card["topic"],
        )
    return new_grade(
        int(data["score"]),
        str(data.get("explanation", "")).strip()
        or f"Reference answer: {card['answer']}",
        str(data.get("missed_concept") or card["topic"]).strip(),
    )


def _valid_grade(data) -> bool:
    """A grade is usable only if it carries a numeric, in-range score."""
    if not isinstance(data, dict) or "score" not in data:
        return False
    try:
        return 0 <= int(data["score"]) <= 5
    except (TypeError, ValueError):
        return False


# ---- Adaptation: SM-2-lite -------------------------------------------------

def apply_result(session: Session, card: Card, grade: GradeResult,
                 user_answer: str = "") -> Session:
    st = session["states"][card["id"]]
    st["reps"] += 1
    st["last_grade"] = grade["score"]

    # remove this card from the front of the queue
    if session["queue"] and session["queue"][0] == card["id"]:
        session["queue"].pop(0)

    if grade["correct"]:
        st["ease"] = min(3.0, st["ease"] + 0.1)
        st["interval"] = max(2, int(st["interval"] * st["ease"]))
        session["streak"] += 1
        _insert_at(session, card["id"], st["interval"])  # comes back later
    else:
        st["lapses"] += 1
        st["ease"] = max(1.3, st["ease"] - 0.2)
        st["interval"] = 1
        session["streak"] = 0
        _insert_at(session, card["id"], 2)               # comes back soon

    session["history"].append({
        "card_id": card["id"],
        "user_answer": user_answer,
        "grade": grade["score"],
        "topic": card["topic"],
    })
    return session


def generate_followups(card: Card, grade: GradeResult, n: int = 2) -> list[Card]:
    """The money feature: new cards drilling exactly what was missed."""
    if llm.STUB:
        # Two canned drills so the demo shows the design's "+2 new questions"
        # adaptive moment. The real path below returns up to `n`.
        prompts = [
            f"[follow-up] In your own words, what's the key idea behind: {card['question']}",
            f"[follow-up] Restate: {card['question']}",
        ]
        return [
            new_card(
                p,
                card["answer"],
                topic=card["topic"],
                source_chunk=card["source_chunk"],
                difficulty=max(1, card["difficulty"] - 1),
                parent_id=card["id"],
            )
            for p in prompts[:n]
        ]

    messages = [
        {"role": "system", "content":
         "The student missed a concept. Generate follow-up quiz questions that "
         "drill it. Return ONLY a JSON array with keys: question, answer, topic."},
        {"role": "user", "content":
         f"Original question: {card['question']}\n"
         f"Missed concept: {grade['missed_concept']}\n"
         f"Source: {card['source_chunk']}\nGenerate {n} simpler follow-ups."},
    ]
    data = llm.extract_json(llm.chat(messages, max_tokens=400))
    out: list[Card] = []
    if isinstance(data, list):
        for item in data[:n]:
            if not isinstance(item, dict):
                continue
            c = new_card(
                str(item.get("question", "")).strip(),
                str(item.get("answer", "")).strip(),
                topic=str(item.get("topic", card["topic"])).strip() or card["topic"],
                source_chunk=card["source_chunk"],
                difficulty=max(1, card["difficulty"] - 1),
                parent_id=card["id"],
            )
            if validate_card(c):
                out.append(c)
    return out


def add_followups(session: Session, cards: list[Card]) -> Session:
    """Register generated follow-ups into the deck + queue (near-term)."""
    for c in cards:
        session["deck"].append(c)
        session["states"][c["id"]] = new_card_state(c["id"])
        _insert_at(session, c["id"], 1)
    return session


def grade_and_adapt(session: Session, user_answer: str) -> tuple[GradeResult | None, list[Card]]:
    """One full study step: grade the current card, apply the result, and on a
    miss generate + enqueue follow-ups. Returns (grade, injected_cards), with
    grade None only when the queue is empty.

    This is the canonical study-loop sequence. Both the Gradio app and the JSON
    server call it instead of re-implementing the next_card → grade → apply →
    follow-up dance, so the loop can never drift between the two frontends.
    """
    card = next_card(session)
    if card is None:
        return None, []
    grade = grade_answer(card, user_answer or "")
    apply_result(session, card, grade, user_answer=user_answer or "")
    injected: list[Card] = []
    if not grade["correct"]:
        fups = generate_followups(card, grade)
        if fups:
            add_followups(session, fups)
            injected = fups
    return grade, injected


def replace_card(session: Session, old_id: str, new: Card) -> Session:
    """Swap a card in place (used by the difficulty toggle, NAH-32).

    Replaces the deck entry, resets its CardState (it's effectively a new
    question), and rewrites every queue occurrence so the queue's
    "pop the front" contract still holds.
    """
    session["deck"] = [new if c["id"] == old_id else c for c in session["deck"]]
    session["states"].pop(old_id, None)
    session["states"][new["id"]] = new_card_state(new["id"])
    session["queue"] = [new["id"] if cid == old_id else cid
                        for cid in session["queue"]]
    return session


# ---- Recap -----------------------------------------------------------------

def recap(session: Session) -> dict:
    grades_by_topic: dict[str, list[int]] = {}
    for h in session["history"]:
        grades_by_topic.setdefault(h["topic"], []).append(h["grade"])

    # Same threshold the scheduler uses to decide what to resurface, so a topic
    # the recap calls "weak" is exactly one next_card brings back sooner.
    mastered = [t for t, g in grades_by_topic.items() if _avg(g) >= WEAK_TOPIC_THRESHOLD]
    weak = [t for t, g in grades_by_topic.items() if _avg(g) < WEAK_TOPIC_THRESHOLD]

    if llm.STUB:
        reflection = ("Solid start. You're strong on "
                      f"{', '.join(mastered) or 'nothing yet'}; "
                      f"{', '.join(weak) or 'no weak spots'} could use another pass.")
    else:
        msg = [
            {"role": "system", "content":
             "Write one encouraging sentence reflecting on a study session."},
            {"role": "user", "content":
             f"Mastered: {mastered}. Weak: {weak}. Streak: {session['streak']}."},
        ]
        reflection = llm.chat(msg, max_tokens=80)

    return {
        "mastered": mastered,
        "weak_topics": weak,
        "reflection": reflection,
        "streak": session["streak"],
        "answered": len(session["history"]),
    }


# ---- helpers ---------------------------------------------------------------

def _find(session: Session, card_id: str) -> Card | None:
    return next((c for c in session["deck"] if c["id"] == card_id), None)


def _topic_averages(session: Session) -> dict[str, float]:
    """Average grade per topic across answered history (empty until first answer)."""
    grades: dict[str, list[int]] = {}
    for h in session["history"]:
        grades.setdefault(h["topic"], []).append(h["grade"])
    return {t: _avg(g) for t, g in grades.items()}


def _weak_biased_index(session: Session) -> int:
    """
    Index into the queue of the card to serve next. Looks at the next
    WEAK_LOOKAHEAD cards and picks the one whose topic has the lowest average
    grade, as long as that topic is actually weak (avg < threshold). Returns 0
    (keep normal order) when nothing in reach is weak or there's no history yet.
    """
    queue = session["queue"]
    averages = _topic_averages(session)
    if not averages:
        return 0

    best_idx, best_avg = 0, None
    for i, card_id in enumerate(queue[:WEAK_LOOKAHEAD]):
        card = _find(session, card_id)
        if card is None:
            continue
        avg = averages.get(card["topic"])
        if avg is None or avg >= WEAK_TOPIC_THRESHOLD:
            continue
        if best_avg is None or avg < best_avg:
            best_idx, best_avg = i, avg
    return best_idx


def _insert_at(session: Session, card_id: str, pos: int) -> None:
    pos = max(0, min(pos, len(session["queue"])))
    session["queue"].insert(pos, card_id)


def _avg(xs: list[int]) -> float:
    return sum(xs) / len(xs) if xs else 0.0