study-partner / server.py
nz-nz's picture
Sync from GitHub via hub-sync
02c841f verified
Raw
History Blame Contribute Delete
17.2 kB
"""
Recall — custom frontend server (NAH-36).
Serves the polished `Recall.dc.html` design (frontend/index.html) and exposes a
thin JSON API over the EXISTING backend. The learning/content logic and the
`schema.py` data contract are treated as an API and are NOT modified here.
The Session dict lives server-side, keyed by a short id the client carries
around (mirrors the single-session gr.State model the Gradio app uses) — so the
reference answers in the deck never leave the server.
Run it (stub mode is on by default):
pip install -r requirements.txt
python server.py # http://127.0.0.1:7860
Flip RECALL_STUB=0 once the real model is wired:
RECALL_STUB=0 python server.py
The legacy Gradio UI is still mounted at /gradio.
"""
from __future__ import annotations
import os
import tempfile
import threading
import time
import uuid
from collections import OrderedDict
from pathlib import Path
import gradio as gr
from fastapi import File, Form, UploadFile
from fastapi.responses import FileResponse, JSONResponse, Response
from pydantic import BaseModel
from starlette.concurrency import run_in_threadpool
import content_pipeline as cp
import learning_engine as le
FRONTEND = Path(__file__).parent / "frontend"
# Single-process session store. Bounded so a public Space can't be OOM'd by a
# client looping /api/generate: sessions are evicted least-recently-used past
# MAX_SESSIONS and expire after SESSION_TTL_SECONDS of inactivity.
MAX_SESSIONS = int(os.getenv("RECALL_MAX_SESSIONS", "500"))
SESSION_TTL_SECONDS = int(os.getenv("RECALL_SESSION_TTL", str(2 * 60 * 60))) # 2h
# Caps on input size so a single request can't exhaust memory/disk.
MAX_UPLOAD_BYTES = int(os.getenv("RECALL_MAX_UPLOAD_MB", "10")) * 1024 * 1024
MAX_TEXT_CHARS = int(os.getenv("RECALL_MAX_TEXT_CHARS", "200000")) # ~50k tokens
# sid -> (session, last_access_epoch). OrderedDict gives O(1) LRU eviction.
SESSIONS: "OrderedDict[str, tuple[dict, float]]" = OrderedDict()
def _purge_expired(now: float | None = None) -> None:
now = time.time() if now is None else now
stale = [sid for sid, (_, ts) in SESSIONS.items()
if now - ts > SESSION_TTL_SECONDS]
for sid in stale:
SESSIONS.pop(sid, None)
def get_session(sid: str) -> dict | None:
"""Fetch a live session and mark it most-recently-used, or None if it's
unknown/expired (callers already return a friendly 'session expired')."""
_purge_expired()
entry = SESSIONS.get(sid)
if entry is None:
return None
session, _ = entry
SESSIONS[sid] = (session, time.time())
SESSIONS.move_to_end(sid)
return session
def put_session(sid: str, session: dict) -> None:
"""Store/refresh a session, evicting the least-recently-used past the cap."""
_purge_expired()
SESSIONS[sid] = (session, time.time())
SESSIONS.move_to_end(sid)
while len(SESSIONS) > MAX_SESSIONS:
SESSIONS.popitem(last=False)
# The photosynthesis notes the design's "sample" affordances load. Lets the
# Upload screen's sample chip work even with no real PDF on disk.
SAMPLE_NOTES = (
"Photosynthesis happens in the chloroplast. The light-dependent reactions "
"occur in the thylakoid membranes, where water is split, ATP and NADPH are "
"produced, and oxygen is released. The Calvin cycle takes place in the "
"stroma, where the enzyme RuBisCO fixes CO2 onto RuBP. Cellular respiration "
"occurs in the mitochondria; most ATP is made during oxidative "
"phosphorylation, as the electron transport chain pumps protons and oxygen "
"acts as the final electron acceptor, forming water."
)
# The "scanned-slides.pdf" sample renders a slide to an IMAGE (no text layer) so
# the chip demonstrates the real image-PDF OCR path end-to-end: the vision model
# reads the rendered page and writes a deck from it (image PDFs are now supported,
# so the old "image-only = error" copy no longer applies).
SCAN_SAMPLE_LINES = (
"Photosynthesis - Lecture 3",
"",
"Occurs in the chloroplast.",
"Light-dependent reactions happen in the thylakoid",
"membranes: water is split, ATP and NADPH are made,",
"and oxygen is released.",
"The Calvin cycle runs in the stroma, where RuBisCO",
"fixes CO2 onto RuBP to produce glucose.",
)
def _render_scan_sample():
"""A synthetic 'scanned slide' page image (text rasterized, no text layer) so
the sample chip shows the vision/OCR path with one click."""
from PIL import Image, ImageDraw, ImageFont
def _font(size):
for p in ("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", "DejaVuSans.ttf"):
try:
return ImageFont.truetype(p, size)
except Exception: # noqa: BLE001 — no TTF on this host; use the bitmap default
pass
# Pillow >=10.1 lets load_default scale; without it the 5px bitmap font is
# too small for the model to OCR reliably.
try:
return ImageFont.load_default(size=size)
except TypeError:
return ImageFont.load_default()
img = Image.new("RGB", (1024, 720), "white")
draw = ImageDraw.Draw(img)
title_f, body_f = _font(40), _font(30)
y = 70
for i, line in enumerate(SCAN_SAMPLE_LINES):
draw.text((70, y), line, fill=(15, 15, 15), font=title_f if i == 0 else body_f)
y += 64 if i == 0 else 52
return img
_SCAN_SAMPLE_PNG: bytes | None = None
def _scan_sample_png() -> bytes:
"""PNG bytes of the synthetic scanned slide, rendered once and cached. The
render is deterministic, so the preview thumbnail is byte-identical to the
image the vision/OCR path actually consumes — the preview can't drift from
what's OCR'd."""
global _SCAN_SAMPLE_PNG
if _SCAN_SAMPLE_PNG is None:
import io
buf = io.BytesIO()
_render_scan_sample().save(buf, format="PNG")
_SCAN_SAMPLE_PNG = buf.getvalue()
return _SCAN_SAMPLE_PNG
# `gradio.Server` extends FastAPI: we keep all our own routes (custom frontend +
# JSON API) AND get gradio's launcher, which owns port 7860 on a Hugging Face
# gradio-SDK Space and wires up the queue + ZeroGPU GPU allocation. A plain
# FastAPI + manual `uvicorn.run(7860)` collides with the Space's own gradio
# server ("address already in use") — `gradio.Server` is the supported way to run
# a custom frontend on a gradio Space.
app = gr.Server(title="Recall")
# A gradio.Server MUST register at least one `@app.api(...)` endpoint: launch()
# builds its internal Blocks from these, and on a Space an *empty* Blocks doesn't
# stay "running" (launch returns and the process exits → RUNTIME_ERROR). This is
# also the gradio-native handle (`gradio_client` / the JS Client) onto the same
# backend the custom frontend drives over plain JSON `/api/*` routes below.
@app.api(name="health")
def health() -> str:
"""Liveness probe + the Server's required gradio endpoint."""
return "ok"
# ---- serialization ---------------------------------------------------------
def _card_out(card: dict | None) -> dict | None:
"""The client never needs (or should see) the reference answer or the raw
source chunk — strip the card down to what the UI renders."""
if not card:
return None
return {
"id": card["id"],
"question": card["question"],
"topic": card["topic"],
"difficulty": card["difficulty"],
"parent_id": card.get("parent_id"),
}
def _view(session: dict) -> dict:
"""Display state the header / mastery bars / deck rail are built from."""
deck = session["deck"]
history = session["history"]
answered = len(history)
total = len(deck)
# Cards that have graduated out of the queue (mastered this session). Progress
# tracks this rather than `answered` so the bar climbs as cards are mastered
# and reaches total exactly when the queue drains — instead of pinning at
# total/total the moment you've seen every card once.
mastered = total - len(session["queue"])
stats: dict[str, dict] = {}
for h in history:
s = stats.setdefault(h["topic"], {"correct": 0, "total": 0})
s["total"] += 1
if h["grade"] >= 3:
s["correct"] += 1
return {
"total": total,
"answered": answered,
"posDisplay": min(mastered + 1, total) if total else 0,
"streak": session["streak"],
"topicStats": stats,
"rail": [
{"id": c["id"], "topic": c["topic"], "injected": bool(c.get("parent_id"))}
for c in deck
],
}
# ---- request models --------------------------------------------------------
class SidBody(BaseModel):
sid: str
class GradeBody(BaseModel):
sid: str
answer: str = ""
class RegenBody(BaseModel):
sid: str
direction: str # "harder" | "easier"
# ---- API -------------------------------------------------------------------
@app.post("/api/generate")
async def api_generate(
text: str = Form(""),
sample: str = Form(""),
file: UploadFile | None = File(None),
):
need_more = ("I need a little more to work with — paste a paragraph of "
"notes or pick a PDF, and I'll build your deck.")
source = ""
# Set when a scanned/image-only PDF is turned into a deck by the vision model
# (no text to chunk); `None` means "fall through to the text path below".
deck_from_images = None
if text and text.strip():
source = text.strip()[:MAX_TEXT_CHARS] # cap to bound chunking work
if len(source) < 40:
return JSONResponse({"error": need_more}, status_code=400)
elif sample == "bio":
source = SAMPLE_NOTES
elif sample == "scan":
# Render a slide to an image and run the OCR/vision path — demonstrates
# image-PDF support end-to-end (stub returns the canned deck).
deck_from_images = await run_in_threadpool(
cp.generate_deck_from_images, [_render_scan_sample()])
elif file is not None:
suffix = Path(file.filename or "upload").suffix or ".txt"
# Stream to a temp file in capped chunks so an oversized upload never
# gets fully buffered in memory.
too_large = False
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp_path = tmp.name
size = 0
while chunk := await file.read(1024 * 1024):
size += len(chunk)
if size > MAX_UPLOAD_BYTES:
too_large = True
break
tmp.write(chunk)
try:
if too_large:
return JSONResponse(
{"error": f"That file is too large (limit "
f"{MAX_UPLOAD_BYTES // (1024 * 1024)} MB). Upload a "
"smaller PDF or paste the notes instead."},
status_code=413,
)
# Extraction (PDF parsing) is blocking — keep it off the event loop.
source = await run_in_threadpool(cp.extract_text, tmp_path)
except cp.ExtractionError as e:
# A scanned/image-only PDF has no text to chunk — render its pages and
# let the multimodal model read them directly. Fall back to the
# original error if it's not image-only or no pages render.
if await run_in_threadpool(cp.is_image_only_pdf, tmp_path):
images = await run_in_threadpool(cp.render_pdf_images, tmp_path)
if images:
deck_from_images = await run_in_threadpool(
cp.generate_deck_from_images, images)
if deck_from_images is None:
return JSONResponse({"error": str(e)}, status_code=422)
finally:
os.unlink(tmp_path)
else:
return JSONResponse({"error": need_more}, status_code=400)
if deck_from_images is not None:
deck = deck_from_images
else:
try:
# Deck generation hits the model — run it in a worker thread so a slow
# call doesn't block every other request on the event loop.
deck = await run_in_threadpool(cp.generate_deck, source)
except Exception as e: # noqa: BLE001 — surface as friendly copy, never crash
return JSONResponse(
{"error": f"Couldn't build a deck from that ({type(e).__name__}). "
"Try different material."},
status_code=422,
)
if not deck:
return JSONResponse(
{"error": "Couldn't generate questions from that. Try different material."},
status_code=422,
)
session = le.init_session(deck)
card = le.next_card(session)
sid = uuid.uuid4().hex
put_session(sid, session)
return {"sid": sid, "card": _card_out(card), "view": _view(session)}
@app.post("/api/grade")
async def api_grade(body: GradeBody):
session = get_session(body.sid)
if session is None:
return JSONResponse({"error": "session expired"}, status_code=404)
# Grading + follow-up generation hit the model; run the whole study step in
# a worker thread so it doesn't block the event loop.
grade, fups = await run_in_threadpool(le.grade_and_adapt, session, body.answer or "")
if grade is None:
return {"done": True, "view": _view(session)}
injected_ids = [f["id"] for f in fups]
put_session(body.sid, session)
return {
"grade": {
"score": grade["score"],
"correct": grade["correct"],
"explanation": grade["explanation"],
"missed": grade["missed_concept"],
},
"injectedIds": injected_ids,
"view": _view(session),
}
@app.post("/api/regenerate")
async def api_regenerate(body: RegenBody):
session = get_session(body.sid)
if session is None:
return JSONResponse({"error": "session expired"}, status_code=404)
card = le.next_card(session)
if card is None:
return {"card": None, "view": _view(session)}
new = await run_in_threadpool(cp.regenerate, card, body.direction) # hits the model
session = le.replace_card(session, card["id"], new)
put_session(body.sid, session)
out = _card_out(new)
out["diffLabel"] = "harder" if body.direction == "harder" else "easier"
return {"card": out, "view": _view(session)}
@app.post("/api/next")
async def api_next(body: SidBody):
session = get_session(body.sid)
if session is None:
return JSONResponse({"error": "session expired"}, status_code=404)
card = le.next_card(session)
return {"card": _card_out(card), "view": _view(session)}
@app.post("/api/recap")
async def api_recap(body: SidBody):
session = get_session(body.sid)
if session is None:
return JSONResponse({"error": "session expired"}, status_code=404)
r = await run_in_threadpool(le.recap, session) # reflection line hits the model
return {"recap": r, "view": _view(session)}
@app.post("/api/restart")
async def api_restart(body: SidBody):
"""Study the same source deck again from a clean session (Recap → restart)."""
session = get_session(body.sid)
if session is None:
return JSONResponse({"error": "session expired"}, status_code=404)
# Rebuild from the original (non-injected) cards only.
base = [c for c in session["deck"] if not c.get("parent_id")]
fresh = le.init_session(base)
card = le.next_card(fresh)
put_session(body.sid, fresh)
return {"card": _card_out(card), "view": _view(fresh)}
# ---- frontend --------------------------------------------------------------
@app.get("/api/sample/scan.png")
async def sample_scan_png():
"""Preview of the scanned-slides sample — served from the same renderer the
vision/OCR path consumes, so the thumbnail stays honest about what's OCR'd."""
return Response(_scan_sample_png(), media_type="image/png")
@app.get("/")
async def index():
return FileResponse(FRONTEND / "index.html")
# HF runs `python server.py` (README app_file) and proxies port 7860. `gradio.Server`
# launches at MODULE TOP LEVEL; on a Space gradio binds 0.0.0.0:7860 automatically.
# We launch non-blocking and hold the main thread ourselves so the (daemon) uvicorn
# server keeps serving. (Locally, set GRADIO_SERVER_PORT to use a different port.)
#
# Two deploy gotchas, both learned the hard way:
# * gradio is pinned to 6.10.0. On 6.17.x a custom gradio.Server doesn't stay up
# under the Space runtime. 6.10.0 is the version gradio's own `Server` example
# ships, and it still allows huggingface-hub<1.0 (needed by the real model).
# * Stub mode registers no `@spaces.GPU` function, which makes a ZeroGPU Space's
# `spaces` reload server error out and flip the Space to RUNTIME_ERROR even
# though uvicorn serves fine. So the stub demo runs on CPU-basic hardware. Once
# the real model (with `@spaces.GPU`) is wired, switch hardware back to ZeroGPU.
# Keep the README sdk_version and requirements*.txt gradio pins in lockstep.
app.launch(show_error=True, prevent_thread_lock=True)
threading.Event().wait()