recall / learning_engine.py
arturogp3's picture
Sync from GitHub via hub-sync
40c272a verified
Raw
History Blame Contribute Delete
11.8 kB
"""
Recall — Module B: Learning Engine. OWNER: Nikolai
The brain: scheduling (SM-2-lite), grading, adaptation, follow-up generation,
and the recap. Runs in STUB mode out of the box. Public signatures are fixed —
app.py depends on them.
"""
from __future__ import annotations
import llm
from schema import (
Card, GradeResult, Session, new_card, new_card_state, new_grade, validate_card,
)
# STUB is owned by llm (single source of truth) and read dynamically as
# `llm.STUB` so every module agrees and runtime/reload changes are honored.
# ---- Session lifecycle -----------------------------------------------------
def init_session(deck: list[Card]) -> Session:
states = {c["id"]: new_card_state(c["id"]) for c in deck}
return Session(
deck=list(deck),
states=states,
queue=[c["id"] for c in deck],
history=[],
streak=0,
)
WEAK_TOPIC_THRESHOLD = 3.0 # avg grade below this = a topic the user is weak on
WEAK_LOOKAHEAD = 4 # how far down the queue we'll reach to surface a weak card
def next_card(session: Session) -> Card | None:
"""
Return the next card to study. Among the next few due cards we bias toward
the user's weakest topic (lowest average grade so far) — so once the model
sees you're shaky on a topic, that topic comes back sooner. With no history
yet this is a no-op and we serve the queue in order.
The chosen card is rotated to the front of the queue so `apply_result`'s
"pop the front" contract still holds.
"""
queue = session["queue"]
if not queue:
return None
idx = _weak_biased_index(session)
if idx > 0:
queue.insert(0, queue.pop(idx)) # bring the weak-topic card to the front
return _find(session, queue[0])
# ---- Grading ---------------------------------------------------------------
def grade_answer(card: Card, user_answer: str) -> GradeResult:
if llm.STUB:
# Trivial heuristic so the stub demo "feels" responsive.
ans = (user_answer or "").strip().lower()
ref = card["answer"].strip().lower()
overlap = len(set(ans.split()) & set(ref.split()))
score = 5 if overlap >= 2 else (3 if overlap == 1 else 1)
expl = ("Correct — you hit the key idea." if score >= 3
else f"Not quite. Expected something like: {card['answer']}")
return new_grade(score, expl, missed_concept=card["topic"])
messages = [
{"role": "system", "content":
"You grade a student's answer against a reference answer. "
"Return ONLY a JSON object with keys: "
"score (integer 0-5), explanation (string for the student), "
"missed_concept (short string naming what they got wrong, or \"\")."},
{"role": "user", "content":
f"Question: {card['question']}\nReference answer: {card['answer']}\n"
f"Student answer: {user_answer}\nGrade it."},
]
# Parser + one repair retry; safe default if the model never returns JSON.
data = llm.chat_json(messages, max_tokens=256)
if not _valid_grade(data):
return new_grade(
2,
"Couldn't grade automatically — compare your answer to the "
f"reference: {card['answer']}",
card["topic"],
)
return new_grade(
int(data["score"]),
str(data.get("explanation", "")).strip()
or f"Reference answer: {card['answer']}",
str(data.get("missed_concept") or card["topic"]).strip(),
)
def _valid_grade(data) -> bool:
"""A grade is usable only if it carries a numeric, in-range score."""
if not isinstance(data, dict) or "score" not in data:
return False
try:
return 0 <= int(data["score"]) <= 5
except (TypeError, ValueError):
return False
# ---- Adaptation: SM-2-lite -------------------------------------------------
def apply_result(session: Session, card: Card, grade: GradeResult,
user_answer: str = "") -> Session:
st = session["states"][card["id"]]
st["reps"] += 1
st["last_grade"] = grade["score"]
# remove this card from the front of the queue
if session["queue"] and session["queue"][0] == card["id"]:
session["queue"].pop(0)
if grade["correct"]:
st["ease"] = min(3.0, st["ease"] + 0.1)
st["interval"] = max(2, int(st["interval"] * st["ease"]))
session["streak"] += 1
_insert_at(session, card["id"], st["interval"]) # comes back later
else:
st["lapses"] += 1
st["ease"] = max(1.3, st["ease"] - 0.2)
st["interval"] = 1
session["streak"] = 0
_insert_at(session, card["id"], 2) # comes back soon
session["history"].append({
"card_id": card["id"],
"user_answer": user_answer,
"grade": grade["score"],
"topic": card["topic"],
})
return session
def generate_followups(card: Card, grade: GradeResult, n: int = 2) -> list[Card]:
"""The money feature: new cards drilling exactly what was missed."""
if llm.STUB:
# Two canned drills so the demo shows the design's "+2 new questions"
# adaptive moment. The real path below returns up to `n`.
prompts = [
f"[follow-up] In your own words, what's the key idea behind: {card['question']}",
f"[follow-up] Restate: {card['question']}",
]
return [
new_card(
p,
card["answer"],
topic=card["topic"],
source_chunk=card["source_chunk"],
difficulty=max(1, card["difficulty"] - 1),
parent_id=card["id"],
)
for p in prompts[:n]
]
messages = [
{"role": "system", "content":
"The student missed a concept. Generate follow-up quiz questions that "
"drill it. Return ONLY a JSON array with keys: question, answer, topic."},
{"role": "user", "content":
f"Original question: {card['question']}\n"
f"Missed concept: {grade['missed_concept']}\n"
f"Source: {card['source_chunk']}\nGenerate {n} simpler follow-ups."},
]
data = llm.extract_json(llm.chat(messages, max_tokens=400))
out: list[Card] = []
if isinstance(data, list):
for item in data[:n]:
if not isinstance(item, dict):
continue
c = new_card(
str(item.get("question", "")).strip(),
str(item.get("answer", "")).strip(),
topic=str(item.get("topic", card["topic"])).strip() or card["topic"],
source_chunk=card["source_chunk"],
difficulty=max(1, card["difficulty"] - 1),
parent_id=card["id"],
)
if validate_card(c):
out.append(c)
return out
def add_followups(session: Session, cards: list[Card]) -> Session:
"""Register generated follow-ups into the deck + queue (near-term)."""
for c in cards:
session["deck"].append(c)
session["states"][c["id"]] = new_card_state(c["id"])
_insert_at(session, c["id"], 1)
return session
def grade_and_adapt(session: Session, user_answer: str) -> tuple[GradeResult | None, list[Card]]:
"""One full study step: grade the current card, apply the result, and on a
miss generate + enqueue follow-ups. Returns (grade, injected_cards), with
grade None only when the queue is empty.
This is the canonical study-loop sequence. Both the Gradio app and the JSON
server call it instead of re-implementing the next_card → grade → apply →
follow-up dance, so the loop can never drift between the two frontends.
"""
card = next_card(session)
if card is None:
return None, []
grade = grade_answer(card, user_answer or "")
apply_result(session, card, grade, user_answer=user_answer or "")
injected: list[Card] = []
if not grade["correct"]:
fups = generate_followups(card, grade)
if fups:
add_followups(session, fups)
injected = fups
return grade, injected
def replace_card(session: Session, old_id: str, new: Card) -> Session:
"""Swap a card in place (used by the difficulty toggle, NAH-32).
Replaces the deck entry, resets its CardState (it's effectively a new
question), and rewrites every queue occurrence so the queue's
"pop the front" contract still holds.
"""
session["deck"] = [new if c["id"] == old_id else c for c in session["deck"]]
session["states"].pop(old_id, None)
session["states"][new["id"]] = new_card_state(new["id"])
session["queue"] = [new["id"] if cid == old_id else cid
for cid in session["queue"]]
return session
# ---- Recap -----------------------------------------------------------------
def recap(session: Session) -> dict:
grades_by_topic: dict[str, list[int]] = {}
for h in session["history"]:
grades_by_topic.setdefault(h["topic"], []).append(h["grade"])
# Same threshold the scheduler uses to decide what to resurface, so a topic
# the recap calls "weak" is exactly one next_card brings back sooner.
mastered = [t for t, g in grades_by_topic.items() if _avg(g) >= WEAK_TOPIC_THRESHOLD]
weak = [t for t, g in grades_by_topic.items() if _avg(g) < WEAK_TOPIC_THRESHOLD]
if llm.STUB:
reflection = ("Solid start. You're strong on "
f"{', '.join(mastered) or 'nothing yet'}; "
f"{', '.join(weak) or 'no weak spots'} could use another pass.")
else:
msg = [
{"role": "system", "content":
"Write one encouraging sentence reflecting on a study session."},
{"role": "user", "content":
f"Mastered: {mastered}. Weak: {weak}. Streak: {session['streak']}."},
]
reflection = llm.chat(msg, max_tokens=80)
return {
"mastered": mastered,
"weak_topics": weak,
"reflection": reflection,
"streak": session["streak"],
"answered": len(session["history"]),
}
# ---- helpers ---------------------------------------------------------------
def _find(session: Session, card_id: str) -> Card | None:
return next((c for c in session["deck"] if c["id"] == card_id), None)
def _topic_averages(session: Session) -> dict[str, float]:
"""Average grade per topic across answered history (empty until first answer)."""
grades: dict[str, list[int]] = {}
for h in session["history"]:
grades.setdefault(h["topic"], []).append(h["grade"])
return {t: _avg(g) for t, g in grades.items()}
def _weak_biased_index(session: Session) -> int:
"""
Index into the queue of the card to serve next. Looks at the next
WEAK_LOOKAHEAD cards and picks the one whose topic has the lowest average
grade, as long as that topic is actually weak (avg < threshold). Returns 0
(keep normal order) when nothing in reach is weak or there's no history yet.
"""
queue = session["queue"]
averages = _topic_averages(session)
if not averages:
return 0
best_idx, best_avg = 0, None
for i, card_id in enumerate(queue[:WEAK_LOOKAHEAD]):
card = _find(session, card_id)
if card is None:
continue
avg = averages.get(card["topic"])
if avg is None or avg >= WEAK_TOPIC_THRESHOLD:
continue
if best_avg is None or avg < best_avg:
best_idx, best_avg = i, avg
return best_idx
def _insert_at(session: Session, card_id: str, pos: int) -> None:
pos = max(0, min(pos, len(session["queue"])))
session["queue"].insert(pos, card_id)
def _avg(xs: list[int]) -> float:
return sum(xs) / len(xs) if xs else 0.0