study-partner / llm.py
nz-nz's picture
Sync from GitHub via hub-sync
ae15cb7 verified
Raw
History Blame Contribute Delete
21.4 kB
"""
Recall — shared inference wrapper. OWNER: Nikolai (Module B)
Everything that touches the model goes through `chat()`. Both content_pipeline
and learning_engine import this and nothing else model-related.
Default is STUB mode (RECALL_STUB=1) so `python app.py` runs with no GPU and no
model download. Flip RECALL_STUB=0 once the real MiniCPM call works on the Space.
Model is a one-env-var config flip (NAH-9). Set RECALL_MODEL to a known alias
or any full HF id; default is the 8B. If the Space is too slow / OOM, swap to a
smaller model with no code change:
RECALL_MODEL=1b RECALL_STUB=0 python app.py # MiniCPM5-1B (fast fallback)
RECALL_MODEL=4b RECALL_STUB=0 python app.py # MiniCPM3-4B (Tiny Titan badge)
Aliases resolve via MODELS below; an unknown value is treated as a literal HF id.
Load dtype/device default to bf16 + device_map="auto" (correct for the Space's
CUDA GPU). For a local real-model smoke test on Apple Silicon, override them —
bf16 on MPS produces garbage, so use CPU/float32:
RECALL_STUB=0 RECALL_MODEL=1b RECALL_DTYPE=float32 RECALL_DEVICE=cpu python app.py
"""
from __future__ import annotations
import json
import os
import re
STUB = os.getenv("RECALL_STUB", "1") == "1"
# Known models, keyed by short alias so swapping is a single env-var flip.
MODELS = {
"v46": "openbmb/MiniCPM-V-4.6", # default / primary — multimodal (text + image)
"8b": "openbmb/MiniCPM4.1-8B", # legacy text-only (needs transformers<5.0)
"1b": "openbmb/MiniCPM5-1B", # legacy fast fallback
"4b": "openbmb/MiniCPM3-4B", # legacy mid fallback (Tiny Titan badge)
}
# Default is the multimodal MiniCPM-V 4.6 so the same model grades text AND reads
# image-only / scanned PDFs. The legacy text aliases need transformers<5.0 and no
# longer load against the pinned transformers 5.x — keep them only for reference.
_requested = os.getenv("RECALL_MODEL", "v46")
# Accept an alias ("v46") or a full HF id ("org/model") passed through verbatim.
MODEL_ID = MODELS.get(_requested, _requested)
def _is_vision_model(model_id: str) -> bool:
"""MiniCPM-V (vision) ids load via a different class + processor than the
text-only MiniCPM models. Detect by the '-V' family marker."""
return "minicpm-v" in model_id.lower()
VISION = _is_vision_model(MODEL_ID)
_model = None
_tokenizer = None
_processor = None # MiniCPM-V uses an AutoProcessor (image+text) instead of a tokenizer
def active_model() -> str:
"""The HF model id currently configured ('stub' when running stubbed)."""
return "stub" if STUB else MODEL_ID
# Load-time dtype/device, overridable for local dev (defaults are correct for
# the Space's CUDA GPU). bf16 on Apple-Silicon MPS produces garbage output, so a
# Mac real-model smoke test needs RECALL_DTYPE=float32 RECALL_DEVICE=cpu; unset,
# behavior is unchanged (bf16 + device_map="auto").
_DTYPE_ALIASES = {
"bfloat16": "bfloat16", "bf16": "bfloat16",
"float16": "float16", "fp16": "float16", "half": "float16",
"float32": "float32", "fp32": "float32", "float": "float32",
}
def _resolve_dtype_name() -> str:
"""Normalized torch dtype name from RECALL_DTYPE (default 'bfloat16').
Unknown values fall back to the default rather than erroring at load."""
return _DTYPE_ALIASES.get(os.getenv("RECALL_DTYPE", "bfloat16").lower(), "bfloat16")
def _resolve_device_map():
"""device_map for from_pretrained. Default 'auto' (accelerate places it);
RECALL_DEVICE overrides, e.g. 'cpu' for stable local CPU inference."""
return os.getenv("RECALL_DEVICE") or "auto"
def _load():
"""Lazy-load the model once. Only called when STUB is off."""
global _model, _tokenizer
if _model is not None:
return
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
dtype = getattr(torch, _resolve_dtype_name())
device_map = _resolve_device_map()
print(f"[recall] loading model: {MODEL_ID} (dtype={_resolve_dtype_name()}, "
f"device_map={device_map})")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
device_map=device_map,
trust_remote_code=True,
)
# ---- MiniCPM-V (multimodal) path -------------------------------------------
# NEEDS GPU VERIFICATION: the calls below mirror the official MiniCPM-V-4.6 demo
# Space (openbmb/MiniCPM-V-4.6-Demo) but can't be exercised without a GPU + the
# ~9B model. The stub and legacy text paths are unchanged and remain testable.
def _maybe_gpu(fn):
"""Wrap with HF ZeroGPU's @spaces.GPU when available; otherwise a no-op.
`spaces` ships only in the real-model deps and is effect-free off a ZeroGPU
Space, so this is safe in stub/local environments where it isn't installed.
Registering a @spaces.GPU function is ALSO what keeps a ZeroGPU Space healthy
(a ZeroGPU Space with none flips to RUNTIME_ERROR — see server.py)."""
try:
import spaces
except Exception: # noqa: BLE001 — not installed (stub/local): run un-wrapped
return fn
return spaces.GPU(duration=120)(fn)
def _load_vision() -> None:
"""Lazy-load the MiniCPM-V model + processor once. Only called when STUB is
off and the active model is a vision model."""
global _model, _processor
if _model is not None:
return
import torch
from transformers import AutoProcessor, MiniCPMV4_6ForConditionalGeneration
dtype = getattr(torch, _resolve_dtype_name())
device_map = _resolve_device_map()
print(f"[recall] loading vision model: {MODEL_ID} (dtype={_resolve_dtype_name()}, "
f"device_map={device_map})")
_processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
_model = MiniCPMV4_6ForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
attn_implementation="sdpa",
trust_remote_code=True,
device_map=device_map,
).eval()
def _to_vision_content(content):
"""Normalize a message's `content` to MiniCPM-V parts. Accepts a plain string
(text-only) or a list mixing strings and PIL.Image objects (image+text)."""
if isinstance(content, str):
return [{"type": "text", "text": content}]
parts = []
for item in content:
if isinstance(item, str):
parts.append({"type": "text", "text": item})
else: # a PIL.Image (or anything image-like the processor accepts)
parts.append({"type": "image", "image": item})
return parts
@_maybe_gpu
def _chat_vision(messages: list[dict], max_tokens: int) -> str:
"""MiniCPM-V 4.6 inference, mirroring the official demo's processor+generate
call (non-streaming). enable_thinking=False keeps the tight token budget for
the JSON answer instead of a <think> preamble."""
_load_vision()
import torch
msgs = [{"role": m["role"], "content": _to_vision_content(m["content"])}
for m in messages]
inputs = _processor.apply_chat_template(
msgs,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
enable_thinking=False,
processor_kwargs={
"downsample_mode": "16x",
"max_slice_nums": 9,
"use_image_id": True,
},
).to(_model.device)
# MiniCPM-V wants floating inputs (e.g. pixel_values) in the model dtype.
for k, v in inputs.items():
if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
inputs[k] = v.to(dtype=getattr(torch, _resolve_dtype_name()))
with torch.no_grad():
# Greedy decoding (do_sample=False): every caller wants a strict JSON
# object/array, and greedy is markedly more reliable at that than sampling
# for MiniCPM-V — verified on GPU. enable_thinking is already False so the
# tight token budget goes to the answer, not a <think> preamble.
out = _model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
downsample_mode="16x",
)
gen = out[0][inputs["input_ids"].shape[1]:]
return _processor.tokenizer.decode(gen, skip_special_tokens=True).strip()
def _render_prompt(messages: list[dict]) -> str:
"""Build the prompt string. MiniCPM4.1/MiniCPM5 are hybrid reasoning models;
we pass enable_thinking=False so they answer directly instead of spending the
(deliberately tight) token budget on a <think> preamble that would push the
JSON answer past max_tokens — and slow the demo. Non-reasoning models (e.g.
MiniCPM3-4B) ignore the unused template variable; templates that actively
reject it fall back to a plain render. extract_json() still strips any
<think> that leaks through, so this is an optimization, not a correctness
dependency."""
try:
return _tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=False,
)
except Exception: # noqa: BLE001 — template can't take the flag; render plain
return _tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
def chat(messages: list[dict], max_tokens: int = 512) -> str:
"""
messages: [{"role": "system"|"user"|"assistant", "content": str}, ...]
Returns the assistant's text.
`content` is normally a str. For the multimodal model it may also be a list
mixing strings and PIL.Image objects (image+text) — e.g. for image-only PDFs.
GPU work is wrapped with @spaces.GPU inside the vision path; that decorator is
also what keeps a ZeroGPU Space healthy. Keep max_tokens tight — latency is
the demo killer.
"""
if STUB:
return _stub_reply(messages)
if VISION:
return _chat_vision(messages, max_tokens)
_load()
text = _render_prompt(messages)
inputs = _tokenizer(text, return_tensors="pt").to(_model.device)
out = _model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
gen = out[0][inputs["input_ids"].shape[1]:]
return _tokenizer.decode(gen, skip_special_tokens=True).strip()
# ---- JSON helper: model output is never trusted ----------------------------
_THINK_CLOSE = re.compile(r"</think\s*>", re.IGNORECASE)
def _strip_think(text: str) -> str:
"""Drop a reasoning-model <think> preamble. MiniCPM4.1/MiniCPM5 are hybrid
reasoning models that emit <think>…</think> before the actual answer; when
the chat template pre-fills the opening tag only the closing </think> shows
up in the reply. Either way the answer (the JSON we want) is whatever follows
the LAST </think>, so anchoring there also defuses stray braces inside the
reasoning that would otherwise mislead the JSON search below. A truncated,
never-closed <think> leaves the text untouched -> extract_json returns None
-> the caller's repair retry / safe default handles it."""
last = None
for last in _THINK_CLOSE.finditer(text):
pass
return text[last.end():].strip() if last else text
def _loads(s: str):
"""json.loads, but tolerant of models that over-escape their output. Seen
with MiniCPM4.1-8B, which sometimes escapes JSON as if it were a string
literal — quotes as `\\"` and newlines as `\\n` — e.g.
`[\\n {\\"k\\": \\"v\\"}\\n]` instead of real JSON. If the straight parse
fails and the text carries `\\"`, retry by (a) decoding it as a JSON string
body, which undoes \\", \\n, \\t and unicode escapes in one shot, then
parsing the result; and (b) a simpler quote-only un-escape as a backstop.
Strictly additive: valid JSON parses on the first try and never reaches the
fallbacks, so legitimately escaped quotes inside a string are untouched.
Returns the parsed value or None."""
try:
return json.loads(s)
except Exception:
pass
if '\\"' in s:
# (a) Treat the whole reply as an escaped string and decode it once.
try:
return json.loads(json.loads('"' + s + '"'))
except Exception:
pass
# (b) Backstop: just collapse the escaped quotes.
try:
return json.loads(s.replace('\\"', '"'))
except Exception:
pass
return None
def _scan_json_values(text: str) -> list:
"""Walk the string and collect every top-level JSON value. Handles models
that emit several values with no array wrapper — e.g. MiniCPM-V on image
input returns `{...} {...} {...}` (space-separated objects, no brackets) —
and ignores junk between/around them (stray quotes, prose, trailing `"}`)."""
dec = json.JSONDecoder()
out, i, n = [], 0, len(text)
while i < n:
if text[i] in "{[":
try:
val, end = dec.raw_decode(text, i)
out.append(val)
i = end
continue
except ValueError:
pass
i += 1
return out
def _open_brackets(s: str) -> tuple[list, bool]:
"""The bracket closers still open at the end of `s`, plus whether `s` ends
inside a string literal. String-aware, so braces inside quotes don't count."""
stack: list[str] = []
in_str = esc = False
for ch in s:
if in_str:
if esc:
esc = False
elif ch == "\\":
esc = True
elif ch == '"':
in_str = False
continue
if ch == '"':
in_str = True
elif ch == "{":
stack.append("}")
elif ch == "[":
stack.append("]")
elif ch in "}]" and stack:
stack.pop()
return stack, in_str
def _repair_truncated(text: str):
"""Best-effort recovery of a JSON value cut off mid-stream by the model's
token limit — a common cause of an otherwise-clean grade/deck failing to
parse (e.g. a reasoning preamble eats the budget). Closes a dangling string
and any still-open brackets; if that won't parse, walks back to each completed
top-level element and retries. Returns the parsed value or None."""
starts = [p for p in (text.find("{"), text.find("[")) if p >= 0]
if not starts:
return None
s = text[min(starts):]
stack, in_str = _open_brackets(s)
if not stack and not in_str:
return None # nothing left open — not a truncation we can repair
# Attempt 1: close the value as-is (a trailing complete pair survives).
data = _loads(s + ('"' if in_str else "") + "".join(reversed(stack)))
if data is not None:
return data
# Attempt 2..n: drop the trailing incomplete element. Top-level element
# boundaries are commas seen at bracket-depth 1.
depth = 0
in_str = esc = False
boundaries: list[int] = []
for i, ch in enumerate(s):
if in_str:
if esc:
esc = False
elif ch == "\\":
esc = True
elif ch == '"':
in_str = False
continue
if ch == '"':
in_str = True
elif ch in "{[":
depth += 1
elif ch in "}]":
depth -= 1
elif ch == "," and depth == 1:
boundaries.append(i)
for cut in reversed(boundaries):
head = s[:cut]
st, _ = _open_brackets(head)
data = _loads(head + "".join(reversed(st)))
if data is not None:
return data
return None
def extract_json(text: str):
"""
Pull JSON out of a model reply. Returns the parsed object/array, a list when
the model emitted several values without an array wrapper, or None. Callers
must handle None (skip card / use fallback grade).
"""
text = _strip_think(text.strip())
# strip ```json fences if present
text = re.sub(r"^```(?:json)?|```$", "", text, flags=re.MULTILINE).strip()
data = _loads(text)
if data is not None:
return data
# The whole text didn't parse — commonly because the model concatenated
# multiple JSON values (objects and/or arrays). Collect them all and flatten
# to a single list so callers expecting an array still work.
values = _scan_json_values(text)
if len(values) == 1:
return values[0]
if values:
flat: list = []
for v in values:
flat.extend(v) if isinstance(v, list) else flat.append(v)
return flat
# Last resort: a single object/array embedded in prose and/or over-escaped
# (\" / \n) — the plain scan above can't read that, but _loads can.
match = re.search(r"(\[.*\]|\{.*\})", text, re.DOTALL)
if match:
data = _loads(match.group(1))
if data is not None:
return data
# Last resort: the value was cut off by the token limit (unterminated string
# / open brackets) — recover the largest valid prefix.
return _repair_truncated(text)
def _augment_last_user(messages: list[dict]) -> list[dict]:
"""A copy of `messages` with a terse 'JSON only' reminder appended to the
FINAL user turn — used for the repair pass.
Appending to the existing instruction (rather than injecting a standalone
'that was not valid JSON' user turn, the previous approach) changes the prompt
enough to break a deterministic bad reply WITHOUT handing the model a meta
message it would otherwise *grade as if it were the student's answer* — which
produced real-looking but nonsensical grades like score 0 / "incorrect JSON
syntax". We also drop the bad reply rather than echo it back, so the model
doesn't anchor on its own malformed output."""
out = [dict(m) for m in messages]
reminder = ("\n\nIMPORTANT: reply with ONLY the raw JSON value — no prose, no "
"markdown fences, no commentary before or after it.")
for m in reversed(out):
if m.get("role") == "user":
c = m.get("content")
if isinstance(c, list): # multimodal content (images + text parts)
m["content"] = list(c) + [reminder]
else:
m["content"] = f"{c}{reminder}"
return out
out.append({"role": "user", "content": reminder.strip()})
return out
def chat_json(messages: list[dict], max_tokens: int = 256, retries: int = 1):
"""
Call the model and parse its reply as JSON, with up to `retries` repair
passes. Model output is never trusted: if the first reply isn't valid JSON we
re-ask with a terse "ONLY JSON" reminder folded into the request and try
again.
Returns the parsed object/array, or None if every attempt fails (callers
must handle None with a safe default — never crash the study loop).
"""
convo = list(messages)
for attempt in range(retries + 1):
reply = chat(convo, max_tokens=max_tokens)
data = extract_json(reply)
if data is not None:
return data
if attempt < retries:
# Repair pass: re-ask the SAME task with a format reminder folded into
# the final user turn (see _augment_last_user for why we don't inject
# a separate 'that was not valid JSON' turn).
convo = _augment_last_user(messages)
return None
# ---- Stub replies so the app runs with no model ----------------------------
def _msg_text(content) -> str:
"""Text of a message's content, ignoring any images (content may be a str or
a list mixing strings and PIL.Image objects)."""
if isinstance(content, str):
return content
if isinstance(content, list):
return " ".join(p for p in content if isinstance(p, str))
return ""
def _stub_reply(messages: list[dict]) -> str:
"""Cheap deterministic-ish replies keyed off the caller's intent tag."""
content = " ".join(_msg_text(m.get("content", "")) for m in messages).lower()
if "generate" in content and "question" in content:
return json.dumps([
{"question": "[stub] What is the main idea of the source text?",
"answer": "The main concept described in the passage.",
"topic": "Stub Topic",
"difficulty": 1},
{"question": "[stub] How does the key concept apply in this context?",
"answer": "It applies by connecting the described mechanism to the outcome.",
"topic": "Stub Topic",
"difficulty": 2},
{"question": "[stub] Compare and contrast the two ideas presented.",
"answer": "They differ in scope but share the same underlying principle.",
"topic": "Stub Topic",
"difficulty": 3},
])
if "grade" in content or "score" in content:
return json.dumps({
"score": 4,
"explanation": "[stub] Close — you captured the main idea but missed a detail.",
"missed_concept": "the specific detail",
})
if "follow" in content:
return json.dumps([
{"question": "[stub follow-up] Can you restate the missed detail?",
"answer": "The specific detail from the passage.",
"topic": "Stub Topic"},
])
return "[stub] model reply"