Spaces:
Running on Zero
Running on Zero
| """ | |
| Recall — shared inference wrapper. OWNER: Nikolai (Module B) | |
| Everything that touches the model goes through `chat()`. Both content_pipeline | |
| and learning_engine import this and nothing else model-related. | |
| Default is STUB mode (RECALL_STUB=1) so `python app.py` runs with no GPU and no | |
| model download. Flip RECALL_STUB=0 once the real MiniCPM call works on the Space. | |
| Model is a one-env-var config flip (NAH-9). Set RECALL_MODEL to a known alias | |
| or any full HF id; default is the 8B. If the Space is too slow / OOM, swap to a | |
| smaller model with no code change: | |
| RECALL_MODEL=1b RECALL_STUB=0 python app.py # MiniCPM5-1B (fast fallback) | |
| RECALL_MODEL=4b RECALL_STUB=0 python app.py # MiniCPM3-4B (Tiny Titan badge) | |
| Aliases resolve via MODELS below; an unknown value is treated as a literal HF id. | |
| Load dtype/device default to bf16 + device_map="auto" (correct for the Space's | |
| CUDA GPU). For a local real-model smoke test on Apple Silicon, override them — | |
| bf16 on MPS produces garbage, so use CPU/float32: | |
| RECALL_STUB=0 RECALL_MODEL=1b RECALL_DTYPE=float32 RECALL_DEVICE=cpu python app.py | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| STUB = os.getenv("RECALL_STUB", "1") == "1" | |
| # Known models, keyed by short alias so swapping is a single env-var flip. | |
| MODELS = { | |
| "v46": "openbmb/MiniCPM-V-4.6", # default / primary — multimodal (text + image) | |
| "8b": "openbmb/MiniCPM4.1-8B", # legacy text-only (needs transformers<5.0) | |
| "1b": "openbmb/MiniCPM5-1B", # legacy fast fallback | |
| "4b": "openbmb/MiniCPM3-4B", # legacy mid fallback (Tiny Titan badge) | |
| } | |
| # Default is the multimodal MiniCPM-V 4.6 so the same model grades text AND reads | |
| # image-only / scanned PDFs. The legacy text aliases need transformers<5.0 and no | |
| # longer load against the pinned transformers 5.x — keep them only for reference. | |
| _requested = os.getenv("RECALL_MODEL", "v46") | |
| # Accept an alias ("v46") or a full HF id ("org/model") passed through verbatim. | |
| MODEL_ID = MODELS.get(_requested, _requested) | |
| def _is_vision_model(model_id: str) -> bool: | |
| """MiniCPM-V (vision) ids load via a different class + processor than the | |
| text-only MiniCPM models. Detect by the '-V' family marker.""" | |
| return "minicpm-v" in model_id.lower() | |
| VISION = _is_vision_model(MODEL_ID) | |
| _model = None | |
| _tokenizer = None | |
| _processor = None # MiniCPM-V uses an AutoProcessor (image+text) instead of a tokenizer | |
| def active_model() -> str: | |
| """The HF model id currently configured ('stub' when running stubbed).""" | |
| return "stub" if STUB else MODEL_ID | |
| # Load-time dtype/device, overridable for local dev (defaults are correct for | |
| # the Space's CUDA GPU). bf16 on Apple-Silicon MPS produces garbage output, so a | |
| # Mac real-model smoke test needs RECALL_DTYPE=float32 RECALL_DEVICE=cpu; unset, | |
| # behavior is unchanged (bf16 + device_map="auto"). | |
| _DTYPE_ALIASES = { | |
| "bfloat16": "bfloat16", "bf16": "bfloat16", | |
| "float16": "float16", "fp16": "float16", "half": "float16", | |
| "float32": "float32", "fp32": "float32", "float": "float32", | |
| } | |
| def _resolve_dtype_name() -> str: | |
| """Normalized torch dtype name from RECALL_DTYPE (default 'bfloat16'). | |
| Unknown values fall back to the default rather than erroring at load.""" | |
| return _DTYPE_ALIASES.get(os.getenv("RECALL_DTYPE", "bfloat16").lower(), "bfloat16") | |
| def _resolve_device_map(): | |
| """device_map for from_pretrained. Default 'auto' (accelerate places it); | |
| RECALL_DEVICE overrides, e.g. 'cpu' for stable local CPU inference.""" | |
| return os.getenv("RECALL_DEVICE") or "auto" | |
| def _load(): | |
| """Lazy-load the model once. Only called when STUB is off.""" | |
| global _model, _tokenizer | |
| if _model is not None: | |
| return | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| dtype = getattr(torch, _resolve_dtype_name()) | |
| device_map = _resolve_device_map() | |
| print(f"[recall] loading model: {MODEL_ID} (dtype={_resolve_dtype_name()}, " | |
| f"device_map={device_map})") | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=dtype, | |
| device_map=device_map, | |
| trust_remote_code=True, | |
| ) | |
| # ---- MiniCPM-V (multimodal) path ------------------------------------------- | |
| # NEEDS GPU VERIFICATION: the calls below mirror the official MiniCPM-V-4.6 demo | |
| # Space (openbmb/MiniCPM-V-4.6-Demo) but can't be exercised without a GPU + the | |
| # ~9B model. The stub and legacy text paths are unchanged and remain testable. | |
| def _maybe_gpu(fn): | |
| """Wrap with HF ZeroGPU's @spaces.GPU when available; otherwise a no-op. | |
| `spaces` ships only in the real-model deps and is effect-free off a ZeroGPU | |
| Space, so this is safe in stub/local environments where it isn't installed. | |
| Registering a @spaces.GPU function is ALSO what keeps a ZeroGPU Space healthy | |
| (a ZeroGPU Space with none flips to RUNTIME_ERROR — see server.py).""" | |
| try: | |
| import spaces | |
| except Exception: # noqa: BLE001 — not installed (stub/local): run un-wrapped | |
| return fn | |
| return spaces.GPU(duration=120)(fn) | |
| def _load_vision() -> None: | |
| """Lazy-load the MiniCPM-V model + processor once. Only called when STUB is | |
| off and the active model is a vision model.""" | |
| global _model, _processor | |
| if _model is not None: | |
| return | |
| import torch | |
| from transformers import AutoProcessor, MiniCPMV4_6ForConditionalGeneration | |
| dtype = getattr(torch, _resolve_dtype_name()) | |
| device_map = _resolve_device_map() | |
| print(f"[recall] loading vision model: {MODEL_ID} (dtype={_resolve_dtype_name()}, " | |
| f"device_map={device_map})") | |
| _processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| _model = MiniCPMV4_6ForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=dtype, | |
| attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| device_map=device_map, | |
| ).eval() | |
| def _to_vision_content(content): | |
| """Normalize a message's `content` to MiniCPM-V parts. Accepts a plain string | |
| (text-only) or a list mixing strings and PIL.Image objects (image+text).""" | |
| if isinstance(content, str): | |
| return [{"type": "text", "text": content}] | |
| parts = [] | |
| for item in content: | |
| if isinstance(item, str): | |
| parts.append({"type": "text", "text": item}) | |
| else: # a PIL.Image (or anything image-like the processor accepts) | |
| parts.append({"type": "image", "image": item}) | |
| return parts | |
| def _chat_vision(messages: list[dict], max_tokens: int) -> str: | |
| """MiniCPM-V 4.6 inference, mirroring the official demo's processor+generate | |
| call (non-streaming). enable_thinking=False keeps the tight token budget for | |
| the JSON answer instead of a <think> preamble.""" | |
| _load_vision() | |
| import torch | |
| msgs = [{"role": m["role"], "content": _to_vision_content(m["content"])} | |
| for m in messages] | |
| inputs = _processor.apply_chat_template( | |
| msgs, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| enable_thinking=False, | |
| processor_kwargs={ | |
| "downsample_mode": "16x", | |
| "max_slice_nums": 9, | |
| "use_image_id": True, | |
| }, | |
| ).to(_model.device) | |
| # MiniCPM-V wants floating inputs (e.g. pixel_values) in the model dtype. | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor) and torch.is_floating_point(v): | |
| inputs[k] = v.to(dtype=getattr(torch, _resolve_dtype_name())) | |
| with torch.no_grad(): | |
| # Greedy decoding (do_sample=False): every caller wants a strict JSON | |
| # object/array, and greedy is markedly more reliable at that than sampling | |
| # for MiniCPM-V — verified on GPU. enable_thinking is already False so the | |
| # tight token budget goes to the answer, not a <think> preamble. | |
| out = _model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, | |
| downsample_mode="16x", | |
| ) | |
| gen = out[0][inputs["input_ids"].shape[1]:] | |
| return _processor.tokenizer.decode(gen, skip_special_tokens=True).strip() | |
| def _render_prompt(messages: list[dict]) -> str: | |
| """Build the prompt string. MiniCPM4.1/MiniCPM5 are hybrid reasoning models; | |
| we pass enable_thinking=False so they answer directly instead of spending the | |
| (deliberately tight) token budget on a <think> preamble that would push the | |
| JSON answer past max_tokens — and slow the demo. Non-reasoning models (e.g. | |
| MiniCPM3-4B) ignore the unused template variable; templates that actively | |
| reject it fall back to a plain render. extract_json() still strips any | |
| <think> that leaks through, so this is an optimization, not a correctness | |
| dependency.""" | |
| try: | |
| return _tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| except Exception: # noqa: BLE001 — template can't take the flag; render plain | |
| return _tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| ) | |
| def chat(messages: list[dict], max_tokens: int = 512) -> str: | |
| """ | |
| messages: [{"role": "system"|"user"|"assistant", "content": str}, ...] | |
| Returns the assistant's text. | |
| `content` is normally a str. For the multimodal model it may also be a list | |
| mixing strings and PIL.Image objects (image+text) — e.g. for image-only PDFs. | |
| GPU work is wrapped with @spaces.GPU inside the vision path; that decorator is | |
| also what keeps a ZeroGPU Space healthy. Keep max_tokens tight — latency is | |
| the demo killer. | |
| """ | |
| if STUB: | |
| return _stub_reply(messages) | |
| if VISION: | |
| return _chat_vision(messages, max_tokens) | |
| _load() | |
| text = _render_prompt(messages) | |
| inputs = _tokenizer(text, return_tensors="pt").to(_model.device) | |
| out = _model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| ) | |
| gen = out[0][inputs["input_ids"].shape[1]:] | |
| return _tokenizer.decode(gen, skip_special_tokens=True).strip() | |
| # ---- JSON helper: model output is never trusted ---------------------------- | |
| _THINK_CLOSE = re.compile(r"</think\s*>", re.IGNORECASE) | |
| def _strip_think(text: str) -> str: | |
| """Drop a reasoning-model <think> preamble. MiniCPM4.1/MiniCPM5 are hybrid | |
| reasoning models that emit <think>…</think> before the actual answer; when | |
| the chat template pre-fills the opening tag only the closing </think> shows | |
| up in the reply. Either way the answer (the JSON we want) is whatever follows | |
| the LAST </think>, so anchoring there also defuses stray braces inside the | |
| reasoning that would otherwise mislead the JSON search below. A truncated, | |
| never-closed <think> leaves the text untouched -> extract_json returns None | |
| -> the caller's repair retry / safe default handles it.""" | |
| last = None | |
| for last in _THINK_CLOSE.finditer(text): | |
| pass | |
| return text[last.end():].strip() if last else text | |
| def _loads(s: str): | |
| """json.loads, but tolerant of models that over-escape their output. Seen | |
| with MiniCPM4.1-8B, which sometimes escapes JSON as if it were a string | |
| literal — quotes as `\\"` and newlines as `\\n` — e.g. | |
| `[\\n {\\"k\\": \\"v\\"}\\n]` instead of real JSON. If the straight parse | |
| fails and the text carries `\\"`, retry by (a) decoding it as a JSON string | |
| body, which undoes \\", \\n, \\t and unicode escapes in one shot, then | |
| parsing the result; and (b) a simpler quote-only un-escape as a backstop. | |
| Strictly additive: valid JSON parses on the first try and never reaches the | |
| fallbacks, so legitimately escaped quotes inside a string are untouched. | |
| Returns the parsed value or None.""" | |
| try: | |
| return json.loads(s) | |
| except Exception: | |
| pass | |
| if '\\"' in s: | |
| # (a) Treat the whole reply as an escaped string and decode it once. | |
| try: | |
| return json.loads(json.loads('"' + s + '"')) | |
| except Exception: | |
| pass | |
| # (b) Backstop: just collapse the escaped quotes. | |
| try: | |
| return json.loads(s.replace('\\"', '"')) | |
| except Exception: | |
| pass | |
| return None | |
| def _scan_json_values(text: str) -> list: | |
| """Walk the string and collect every top-level JSON value. Handles models | |
| that emit several values with no array wrapper — e.g. MiniCPM-V on image | |
| input returns `{...} {...} {...}` (space-separated objects, no brackets) — | |
| and ignores junk between/around them (stray quotes, prose, trailing `"}`).""" | |
| dec = json.JSONDecoder() | |
| out, i, n = [], 0, len(text) | |
| while i < n: | |
| if text[i] in "{[": | |
| try: | |
| val, end = dec.raw_decode(text, i) | |
| out.append(val) | |
| i = end | |
| continue | |
| except ValueError: | |
| pass | |
| i += 1 | |
| return out | |
| def _open_brackets(s: str) -> tuple[list, bool]: | |
| """The bracket closers still open at the end of `s`, plus whether `s` ends | |
| inside a string literal. String-aware, so braces inside quotes don't count.""" | |
| stack: list[str] = [] | |
| in_str = esc = False | |
| for ch in s: | |
| if in_str: | |
| if esc: | |
| esc = False | |
| elif ch == "\\": | |
| esc = True | |
| elif ch == '"': | |
| in_str = False | |
| continue | |
| if ch == '"': | |
| in_str = True | |
| elif ch == "{": | |
| stack.append("}") | |
| elif ch == "[": | |
| stack.append("]") | |
| elif ch in "}]" and stack: | |
| stack.pop() | |
| return stack, in_str | |
| def _repair_truncated(text: str): | |
| """Best-effort recovery of a JSON value cut off mid-stream by the model's | |
| token limit — a common cause of an otherwise-clean grade/deck failing to | |
| parse (e.g. a reasoning preamble eats the budget). Closes a dangling string | |
| and any still-open brackets; if that won't parse, walks back to each completed | |
| top-level element and retries. Returns the parsed value or None.""" | |
| starts = [p for p in (text.find("{"), text.find("[")) if p >= 0] | |
| if not starts: | |
| return None | |
| s = text[min(starts):] | |
| stack, in_str = _open_brackets(s) | |
| if not stack and not in_str: | |
| return None # nothing left open — not a truncation we can repair | |
| # Attempt 1: close the value as-is (a trailing complete pair survives). | |
| data = _loads(s + ('"' if in_str else "") + "".join(reversed(stack))) | |
| if data is not None: | |
| return data | |
| # Attempt 2..n: drop the trailing incomplete element. Top-level element | |
| # boundaries are commas seen at bracket-depth 1. | |
| depth = 0 | |
| in_str = esc = False | |
| boundaries: list[int] = [] | |
| for i, ch in enumerate(s): | |
| if in_str: | |
| if esc: | |
| esc = False | |
| elif ch == "\\": | |
| esc = True | |
| elif ch == '"': | |
| in_str = False | |
| continue | |
| if ch == '"': | |
| in_str = True | |
| elif ch in "{[": | |
| depth += 1 | |
| elif ch in "}]": | |
| depth -= 1 | |
| elif ch == "," and depth == 1: | |
| boundaries.append(i) | |
| for cut in reversed(boundaries): | |
| head = s[:cut] | |
| st, _ = _open_brackets(head) | |
| data = _loads(head + "".join(reversed(st))) | |
| if data is not None: | |
| return data | |
| return None | |
| def extract_json(text: str): | |
| """ | |
| Pull JSON out of a model reply. Returns the parsed object/array, a list when | |
| the model emitted several values without an array wrapper, or None. Callers | |
| must handle None (skip card / use fallback grade). | |
| """ | |
| text = _strip_think(text.strip()) | |
| # strip ```json fences if present | |
| text = re.sub(r"^```(?:json)?|```$", "", text, flags=re.MULTILINE).strip() | |
| data = _loads(text) | |
| if data is not None: | |
| return data | |
| # The whole text didn't parse — commonly because the model concatenated | |
| # multiple JSON values (objects and/or arrays). Collect them all and flatten | |
| # to a single list so callers expecting an array still work. | |
| values = _scan_json_values(text) | |
| if len(values) == 1: | |
| return values[0] | |
| if values: | |
| flat: list = [] | |
| for v in values: | |
| flat.extend(v) if isinstance(v, list) else flat.append(v) | |
| return flat | |
| # Last resort: a single object/array embedded in prose and/or over-escaped | |
| # (\" / \n) — the plain scan above can't read that, but _loads can. | |
| match = re.search(r"(\[.*\]|\{.*\})", text, re.DOTALL) | |
| if match: | |
| data = _loads(match.group(1)) | |
| if data is not None: | |
| return data | |
| # Last resort: the value was cut off by the token limit (unterminated string | |
| # / open brackets) — recover the largest valid prefix. | |
| return _repair_truncated(text) | |
| def _augment_last_user(messages: list[dict]) -> list[dict]: | |
| """A copy of `messages` with a terse 'JSON only' reminder appended to the | |
| FINAL user turn — used for the repair pass. | |
| Appending to the existing instruction (rather than injecting a standalone | |
| 'that was not valid JSON' user turn, the previous approach) changes the prompt | |
| enough to break a deterministic bad reply WITHOUT handing the model a meta | |
| message it would otherwise *grade as if it were the student's answer* — which | |
| produced real-looking but nonsensical grades like score 0 / "incorrect JSON | |
| syntax". We also drop the bad reply rather than echo it back, so the model | |
| doesn't anchor on its own malformed output.""" | |
| out = [dict(m) for m in messages] | |
| reminder = ("\n\nIMPORTANT: reply with ONLY the raw JSON value — no prose, no " | |
| "markdown fences, no commentary before or after it.") | |
| for m in reversed(out): | |
| if m.get("role") == "user": | |
| c = m.get("content") | |
| if isinstance(c, list): # multimodal content (images + text parts) | |
| m["content"] = list(c) + [reminder] | |
| else: | |
| m["content"] = f"{c}{reminder}" | |
| return out | |
| out.append({"role": "user", "content": reminder.strip()}) | |
| return out | |
| def chat_json(messages: list[dict], max_tokens: int = 256, retries: int = 1): | |
| """ | |
| Call the model and parse its reply as JSON, with up to `retries` repair | |
| passes. Model output is never trusted: if the first reply isn't valid JSON we | |
| re-ask with a terse "ONLY JSON" reminder folded into the request and try | |
| again. | |
| Returns the parsed object/array, or None if every attempt fails (callers | |
| must handle None with a safe default — never crash the study loop). | |
| """ | |
| convo = list(messages) | |
| for attempt in range(retries + 1): | |
| reply = chat(convo, max_tokens=max_tokens) | |
| data = extract_json(reply) | |
| if data is not None: | |
| return data | |
| if attempt < retries: | |
| # Repair pass: re-ask the SAME task with a format reminder folded into | |
| # the final user turn (see _augment_last_user for why we don't inject | |
| # a separate 'that was not valid JSON' turn). | |
| convo = _augment_last_user(messages) | |
| return None | |
| # ---- Stub replies so the app runs with no model ---------------------------- | |
| def _msg_text(content) -> str: | |
| """Text of a message's content, ignoring any images (content may be a str or | |
| a list mixing strings and PIL.Image objects).""" | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| return " ".join(p for p in content if isinstance(p, str)) | |
| return "" | |
| def _stub_reply(messages: list[dict]) -> str: | |
| """Cheap deterministic-ish replies keyed off the caller's intent tag.""" | |
| content = " ".join(_msg_text(m.get("content", "")) for m in messages).lower() | |
| if "generate" in content and "question" in content: | |
| return json.dumps([ | |
| {"question": "[stub] What is the main idea of the source text?", | |
| "answer": "The main concept described in the passage.", | |
| "topic": "Stub Topic", | |
| "difficulty": 1}, | |
| {"question": "[stub] How does the key concept apply in this context?", | |
| "answer": "It applies by connecting the described mechanism to the outcome.", | |
| "topic": "Stub Topic", | |
| "difficulty": 2}, | |
| {"question": "[stub] Compare and contrast the two ideas presented.", | |
| "answer": "They differ in scope but share the same underlying principle.", | |
| "topic": "Stub Topic", | |
| "difficulty": 3}, | |
| ]) | |
| if "grade" in content or "score" in content: | |
| return json.dumps({ | |
| "score": 4, | |
| "explanation": "[stub] Close — you captured the main idea but missed a detail.", | |
| "missed_concept": "the specific detail", | |
| }) | |
| if "follow" in content: | |
| return json.dumps([ | |
| {"question": "[stub follow-up] Can you restate the missed detail?", | |
| "answer": "The specific detail from the passage.", | |
| "topic": "Stub Topic"}, | |
| ]) | |
| return "[stub] model reply" | |