""" Recall — shared inference wrapper. OWNER: Nikolai (Module B) Everything that touches the model goes through `chat()`. Both content_pipeline and learning_engine import this and nothing else model-related. Default is STUB mode (RECALL_STUB=1) so `python app.py` runs with no GPU and no model download. Flip RECALL_STUB=0 once the real MiniCPM call works on the Space. Model is a one-env-var config flip (NAH-9). Set RECALL_MODEL to a known alias or any full HF id; default is the 8B. If the Space is too slow / OOM, swap to a smaller model with no code change: RECALL_MODEL=1b RECALL_STUB=0 python app.py # MiniCPM5-1B (fast fallback) RECALL_MODEL=4b RECALL_STUB=0 python app.py # MiniCPM3-4B (Tiny Titan badge) Aliases resolve via MODELS below; an unknown value is treated as a literal HF id. Load dtype/device default to bf16 + device_map="auto" (correct for the Space's CUDA GPU). For a local real-model smoke test on Apple Silicon, override them — bf16 on MPS produces garbage, so use CPU/float32: RECALL_STUB=0 RECALL_MODEL=1b RECALL_DTYPE=float32 RECALL_DEVICE=cpu python app.py """ from __future__ import annotations import json import os import re STUB = os.getenv("RECALL_STUB", "1") == "1" # Known models, keyed by short alias so swapping is a single env-var flip. MODELS = { "v46": "openbmb/MiniCPM-V-4.6", # default / primary — multimodal (text + image) "8b": "openbmb/MiniCPM4.1-8B", # legacy text-only (needs transformers<5.0) "1b": "openbmb/MiniCPM5-1B", # legacy fast fallback "4b": "openbmb/MiniCPM3-4B", # legacy mid fallback (Tiny Titan badge) } # Default is the multimodal MiniCPM-V 4.6 so the same model grades text AND reads # image-only / scanned PDFs. The legacy text aliases need transformers<5.0 and no # longer load against the pinned transformers 5.x — keep them only for reference. _requested = os.getenv("RECALL_MODEL", "v46") # Accept an alias ("v46") or a full HF id ("org/model") passed through verbatim. MODEL_ID = MODELS.get(_requested, _requested) def _is_vision_model(model_id: str) -> bool: """MiniCPM-V (vision) ids load via a different class + processor than the text-only MiniCPM models. Detect by the '-V' family marker.""" return "minicpm-v" in model_id.lower() VISION = _is_vision_model(MODEL_ID) _model = None _tokenizer = None _processor = None # MiniCPM-V uses an AutoProcessor (image+text) instead of a tokenizer def active_model() -> str: """The HF model id currently configured ('stub' when running stubbed).""" return "stub" if STUB else MODEL_ID # Load-time dtype/device, overridable for local dev (defaults are correct for # the Space's CUDA GPU). bf16 on Apple-Silicon MPS produces garbage output, so a # Mac real-model smoke test needs RECALL_DTYPE=float32 RECALL_DEVICE=cpu; unset, # behavior is unchanged (bf16 + device_map="auto"). _DTYPE_ALIASES = { "bfloat16": "bfloat16", "bf16": "bfloat16", "float16": "float16", "fp16": "float16", "half": "float16", "float32": "float32", "fp32": "float32", "float": "float32", } def _resolve_dtype_name() -> str: """Normalized torch dtype name from RECALL_DTYPE (default 'bfloat16'). Unknown values fall back to the default rather than erroring at load.""" return _DTYPE_ALIASES.get(os.getenv("RECALL_DTYPE", "bfloat16").lower(), "bfloat16") def _resolve_device_map(): """device_map for from_pretrained. Default 'auto' (accelerate places it); RECALL_DEVICE overrides, e.g. 'cpu' for stable local CPU inference.""" return os.getenv("RECALL_DEVICE") or "auto" def _load(): """Lazy-load the model once. Only called when STUB is off.""" global _model, _tokenizer if _model is not None: return import torch from transformers import AutoModelForCausalLM, AutoTokenizer dtype = getattr(torch, _resolve_dtype_name()) device_map = _resolve_device_map() print(f"[recall] loading model: {MODEL_ID} (dtype={_resolve_dtype_name()}, " f"device_map={device_map})") _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=dtype, device_map=device_map, trust_remote_code=True, ) # ---- MiniCPM-V (multimodal) path ------------------------------------------- # NEEDS GPU VERIFICATION: the calls below mirror the official MiniCPM-V-4.6 demo # Space (openbmb/MiniCPM-V-4.6-Demo) but can't be exercised without a GPU + the # ~9B model. The stub and legacy text paths are unchanged and remain testable. def _maybe_gpu(fn): """Wrap with HF ZeroGPU's @spaces.GPU when available; otherwise a no-op. `spaces` ships only in the real-model deps and is effect-free off a ZeroGPU Space, so this is safe in stub/local environments where it isn't installed. Registering a @spaces.GPU function is ALSO what keeps a ZeroGPU Space healthy (a ZeroGPU Space with none flips to RUNTIME_ERROR — see server.py).""" try: import spaces except Exception: # noqa: BLE001 — not installed (stub/local): run un-wrapped return fn return spaces.GPU(duration=120)(fn) def _load_vision() -> None: """Lazy-load the MiniCPM-V model + processor once. Only called when STUB is off and the active model is a vision model.""" global _model, _processor if _model is not None: return import torch from transformers import AutoProcessor, MiniCPMV4_6ForConditionalGeneration dtype = getattr(torch, _resolve_dtype_name()) device_map = _resolve_device_map() print(f"[recall] loading vision model: {MODEL_ID} (dtype={_resolve_dtype_name()}, " f"device_map={device_map})") _processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) _model = MiniCPMV4_6ForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=dtype, attn_implementation="sdpa", trust_remote_code=True, device_map=device_map, ).eval() def _to_vision_content(content): """Normalize a message's `content` to MiniCPM-V parts. Accepts a plain string (text-only) or a list mixing strings and PIL.Image objects (image+text).""" if isinstance(content, str): return [{"type": "text", "text": content}] parts = [] for item in content: if isinstance(item, str): parts.append({"type": "text", "text": item}) else: # a PIL.Image (or anything image-like the processor accepts) parts.append({"type": "image", "image": item}) return parts @_maybe_gpu def _chat_vision(messages: list[dict], max_tokens: int) -> str: """MiniCPM-V 4.6 inference, mirroring the official demo's processor+generate call (non-streaming). enable_thinking=False keeps the tight token budget for the JSON answer instead of a preamble.""" _load_vision() import torch msgs = [{"role": m["role"], "content": _to_vision_content(m["content"])} for m in messages] inputs = _processor.apply_chat_template( msgs, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", enable_thinking=False, processor_kwargs={ "downsample_mode": "16x", "max_slice_nums": 9, "use_image_id": True, }, ).to(_model.device) # MiniCPM-V wants floating inputs (e.g. pixel_values) in the model dtype. for k, v in inputs.items(): if isinstance(v, torch.Tensor) and torch.is_floating_point(v): inputs[k] = v.to(dtype=getattr(torch, _resolve_dtype_name())) with torch.no_grad(): # Greedy decoding (do_sample=False): every caller wants a strict JSON # object/array, and greedy is markedly more reliable at that than sampling # for MiniCPM-V — verified on GPU. enable_thinking is already False so the # tight token budget goes to the answer, not a preamble. out = _model.generate( **inputs, max_new_tokens=max_tokens, do_sample=False, downsample_mode="16x", ) gen = out[0][inputs["input_ids"].shape[1]:] return _processor.tokenizer.decode(gen, skip_special_tokens=True).strip() def _render_prompt(messages: list[dict]) -> str: """Build the prompt string. MiniCPM4.1/MiniCPM5 are hybrid reasoning models; we pass enable_thinking=False so they answer directly instead of spending the (deliberately tight) token budget on a preamble that would push the JSON answer past max_tokens — and slow the demo. Non-reasoning models (e.g. MiniCPM3-4B) ignore the unused template variable; templates that actively reject it fall back to a plain render. extract_json() still strips any that leaks through, so this is an optimization, not a correctness dependency.""" try: return _tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) except Exception: # noqa: BLE001 — template can't take the flag; render plain return _tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) def chat(messages: list[dict], max_tokens: int = 512) -> str: """ messages: [{"role": "system"|"user"|"assistant", "content": str}, ...] Returns the assistant's text. `content` is normally a str. For the multimodal model it may also be a list mixing strings and PIL.Image objects (image+text) — e.g. for image-only PDFs. GPU work is wrapped with @spaces.GPU inside the vision path; that decorator is also what keeps a ZeroGPU Space healthy. Keep max_tokens tight — latency is the demo killer. """ if STUB: return _stub_reply(messages) if VISION: return _chat_vision(messages, max_tokens) _load() text = _render_prompt(messages) inputs = _tokenizer(text, return_tensors="pt").to(_model.device) out = _model.generate( **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=0.7, top_p=0.9, ) gen = out[0][inputs["input_ids"].shape[1]:] return _tokenizer.decode(gen, skip_special_tokens=True).strip() # ---- JSON helper: model output is never trusted ---------------------------- _THINK_CLOSE = re.compile(r"", re.IGNORECASE) def _strip_think(text: str) -> str: """Drop a reasoning-model preamble. MiniCPM4.1/MiniCPM5 are hybrid reasoning models that emit before the actual answer; when the chat template pre-fills the opening tag only the closing shows up in the reply. Either way the answer (the JSON we want) is whatever follows the LAST , so anchoring there also defuses stray braces inside the reasoning that would otherwise mislead the JSON search below. A truncated, never-closed leaves the text untouched -> extract_json returns None -> the caller's repair retry / safe default handles it.""" last = None for last in _THINK_CLOSE.finditer(text): pass return text[last.end():].strip() if last else text def _loads(s: str): """json.loads, but tolerant of models that over-escape their output. Seen with MiniCPM4.1-8B, which sometimes escapes JSON as if it were a string literal — quotes as `\\"` and newlines as `\\n` — e.g. `[\\n {\\"k\\": \\"v\\"}\\n]` instead of real JSON. If the straight parse fails and the text carries `\\"`, retry by (a) decoding it as a JSON string body, which undoes \\", \\n, \\t and unicode escapes in one shot, then parsing the result; and (b) a simpler quote-only un-escape as a backstop. Strictly additive: valid JSON parses on the first try and never reaches the fallbacks, so legitimately escaped quotes inside a string are untouched. Returns the parsed value or None.""" try: return json.loads(s) except Exception: pass if '\\"' in s: # (a) Treat the whole reply as an escaped string and decode it once. try: return json.loads(json.loads('"' + s + '"')) except Exception: pass # (b) Backstop: just collapse the escaped quotes. try: return json.loads(s.replace('\\"', '"')) except Exception: pass return None def _scan_json_values(text: str) -> list: """Walk the string and collect every top-level JSON value. Handles models that emit several values with no array wrapper — e.g. MiniCPM-V on image input returns `{...} {...} {...}` (space-separated objects, no brackets) — and ignores junk between/around them (stray quotes, prose, trailing `"}`).""" dec = json.JSONDecoder() out, i, n = [], 0, len(text) while i < n: if text[i] in "{[": try: val, end = dec.raw_decode(text, i) out.append(val) i = end continue except ValueError: pass i += 1 return out def _open_brackets(s: str) -> tuple[list, bool]: """The bracket closers still open at the end of `s`, plus whether `s` ends inside a string literal. String-aware, so braces inside quotes don't count.""" stack: list[str] = [] in_str = esc = False for ch in s: if in_str: if esc: esc = False elif ch == "\\": esc = True elif ch == '"': in_str = False continue if ch == '"': in_str = True elif ch == "{": stack.append("}") elif ch == "[": stack.append("]") elif ch in "}]" and stack: stack.pop() return stack, in_str def _repair_truncated(text: str): """Best-effort recovery of a JSON value cut off mid-stream by the model's token limit — a common cause of an otherwise-clean grade/deck failing to parse (e.g. a reasoning preamble eats the budget). Closes a dangling string and any still-open brackets; if that won't parse, walks back to each completed top-level element and retries. Returns the parsed value or None.""" starts = [p for p in (text.find("{"), text.find("[")) if p >= 0] if not starts: return None s = text[min(starts):] stack, in_str = _open_brackets(s) if not stack and not in_str: return None # nothing left open — not a truncation we can repair # Attempt 1: close the value as-is (a trailing complete pair survives). data = _loads(s + ('"' if in_str else "") + "".join(reversed(stack))) if data is not None: return data # Attempt 2..n: drop the trailing incomplete element. Top-level element # boundaries are commas seen at bracket-depth 1. depth = 0 in_str = esc = False boundaries: list[int] = [] for i, ch in enumerate(s): if in_str: if esc: esc = False elif ch == "\\": esc = True elif ch == '"': in_str = False continue if ch == '"': in_str = True elif ch in "{[": depth += 1 elif ch in "}]": depth -= 1 elif ch == "," and depth == 1: boundaries.append(i) for cut in reversed(boundaries): head = s[:cut] st, _ = _open_brackets(head) data = _loads(head + "".join(reversed(st))) if data is not None: return data return None def extract_json(text: str): """ Pull JSON out of a model reply. Returns the parsed object/array, a list when the model emitted several values without an array wrapper, or None. Callers must handle None (skip card / use fallback grade). """ text = _strip_think(text.strip()) # strip ```json fences if present text = re.sub(r"^```(?:json)?|```$", "", text, flags=re.MULTILINE).strip() data = _loads(text) if data is not None: return data # The whole text didn't parse — commonly because the model concatenated # multiple JSON values (objects and/or arrays). Collect them all and flatten # to a single list so callers expecting an array still work. values = _scan_json_values(text) if len(values) == 1: return values[0] if values: flat: list = [] for v in values: flat.extend(v) if isinstance(v, list) else flat.append(v) return flat # Last resort: a single object/array embedded in prose and/or over-escaped # (\" / \n) — the plain scan above can't read that, but _loads can. match = re.search(r"(\[.*\]|\{.*\})", text, re.DOTALL) if match: data = _loads(match.group(1)) if data is not None: return data # Last resort: the value was cut off by the token limit (unterminated string # / open brackets) — recover the largest valid prefix. return _repair_truncated(text) def _augment_last_user(messages: list[dict]) -> list[dict]: """A copy of `messages` with a terse 'JSON only' reminder appended to the FINAL user turn — used for the repair pass. Appending to the existing instruction (rather than injecting a standalone 'that was not valid JSON' user turn, the previous approach) changes the prompt enough to break a deterministic bad reply WITHOUT handing the model a meta message it would otherwise *grade as if it were the student's answer* — which produced real-looking but nonsensical grades like score 0 / "incorrect JSON syntax". We also drop the bad reply rather than echo it back, so the model doesn't anchor on its own malformed output.""" out = [dict(m) for m in messages] reminder = ("\n\nIMPORTANT: reply with ONLY the raw JSON value — no prose, no " "markdown fences, no commentary before or after it.") for m in reversed(out): if m.get("role") == "user": c = m.get("content") if isinstance(c, list): # multimodal content (images + text parts) m["content"] = list(c) + [reminder] else: m["content"] = f"{c}{reminder}" return out out.append({"role": "user", "content": reminder.strip()}) return out def chat_json(messages: list[dict], max_tokens: int = 256, retries: int = 1): """ Call the model and parse its reply as JSON, with up to `retries` repair passes. Model output is never trusted: if the first reply isn't valid JSON we re-ask with a terse "ONLY JSON" reminder folded into the request and try again. Returns the parsed object/array, or None if every attempt fails (callers must handle None with a safe default — never crash the study loop). """ convo = list(messages) for attempt in range(retries + 1): reply = chat(convo, max_tokens=max_tokens) data = extract_json(reply) if data is not None: return data if attempt < retries: # Repair pass: re-ask the SAME task with a format reminder folded into # the final user turn (see _augment_last_user for why we don't inject # a separate 'that was not valid JSON' turn). convo = _augment_last_user(messages) return None # ---- Stub replies so the app runs with no model ---------------------------- def _msg_text(content) -> str: """Text of a message's content, ignoring any images (content may be a str or a list mixing strings and PIL.Image objects).""" if isinstance(content, str): return content if isinstance(content, list): return " ".join(p for p in content if isinstance(p, str)) return "" def _stub_reply(messages: list[dict]) -> str: """Cheap deterministic-ish replies keyed off the caller's intent tag.""" content = " ".join(_msg_text(m.get("content", "")) for m in messages).lower() if "generate" in content and "question" in content: return json.dumps([ {"question": "[stub] What is the main idea of the source text?", "answer": "The main concept described in the passage.", "topic": "Stub Topic", "difficulty": 1}, {"question": "[stub] How does the key concept apply in this context?", "answer": "It applies by connecting the described mechanism to the outcome.", "topic": "Stub Topic", "difficulty": 2}, {"question": "[stub] Compare and contrast the two ideas presented.", "answer": "They differ in scope but share the same underlying principle.", "topic": "Stub Topic", "difficulty": 3}, ]) if "grade" in content or "score" in content: return json.dumps({ "score": 4, "explanation": "[stub] Close — you captured the main idea but missed a detail.", "missed_concept": "the specific detail", }) if "follow" in content: return json.dumps([ {"question": "[stub follow-up] Can you restate the missed detail?", "answer": "The specific detail from the passage.", "topic": "Stub Topic"}, ]) return "[stub] model reply"