case-forge / core /infer.py
nextmarte's picture
Honest fallback note (GPU busy / quota vs offline)
37c38d2 verified
Raw
History Blame Contribute Delete
6.77 kB
"""Inference for the fine-tuned student — short request → full case+note JSON.
Runtime = **ZeroGPU**, in-Space (free). On the first GPU call we load base
Qwen3-4B + the published LoRA and `merge_and_unload()` it (folding the adapter into
the weights removes PEFT overhead → faster decode), then generate via transformers
under `@spaces.GPU`. `max_new_tokens` is capped so a full case fits ZeroGPU's ~120s
window. Locally (no `spaces`/CUDA) it falls back to a real sample so the UI works
offline. Heavy/offline work (corpus gen, training, merge) runs on Modal, not here.
Config (env):
CASE_FORGE_BASE base model id
CASE_FORGE_ADAPTER HF repo id of the published LoRA
CASE_FORGE_MAX_TOKENS generation cap (default 2800 — fits the ZeroGPU window)
CASE_FORGE_DEMO=1 force the demo sample (no model load)
"""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
_ROOT = Path(__file__).resolve().parent.parent # case-forge/
_MONOREPO = _ROOT.parent # build-small-hackathon/
for _p in (str(_ROOT), str(_MONOREPO)):
if _p not in sys.path:
sys.path.insert(0, _p)
from data.schema import validate_case # noqa: E402
from pipeline.prompts import Seed, build_minimal_prompt # noqa: E402
from shared import gpu # noqa: E402
BASE_MODEL = os.environ.get("CASE_FORGE_BASE", "Qwen/Qwen3-4B-Instruct-2507")
ADAPTER_REPO = os.environ.get(
"CASE_FORGE_ADAPTER", "build-small-hackathon/case-forge-qwen3-4b").strip()
MAX_NEW_TOKENS = int(os.environ.get("CASE_FORGE_MAX_TOKENS", "2800"))
FORCE_DEMO = os.environ.get("CASE_FORGE_DEMO", "").strip() in ("1", "true", "yes")
_MODEL = None
_TOK = None
def _has_cuda() -> bool:
try:
import torch
return torch.cuda.is_available()
except Exception:
return False
def _ensure_model() -> None:
"""Lazy-load base + LoRA and merge — runs inside the GPU-allocated context."""
global _MODEL, _TOK
if _MODEL is not None:
return
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL, torch_dtype=torch.bfloat16, device_map="cuda",
trust_remote_code=True,
)
if ADAPTER_REPO:
from peft import PeftModel
model = PeftModel.from_pretrained(model, ADAPTER_REPO)
model = model.merge_and_unload() # fold LoRA into base → faster generation
model.eval()
_MODEL, _TOK = model, tok
@gpu.gpu(duration=120)
def _generate_raw(messages: list[dict]) -> str:
"""Run the model on ZeroGPU and return the raw decoded completion."""
import torch
_ensure_model()
try:
text = _TOK.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=False,
)
except TypeError:
text = _TOK.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
inputs = _TOK(text, return_tensors="pt").to(_MODEL.device)
with torch.no_grad():
out = _MODEL.generate(
**inputs, max_new_tokens=MAX_NEW_TOKENS,
do_sample=True, temperature=0.7, top_p=0.95,
pad_token_id=_TOK.pad_token_id or _TOK.eos_token_id,
)
return _TOK.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
def _parse(raw: str) -> dict | None:
try:
return json.loads(raw[raw.find("{"): raw.rfind("}") + 1])
except Exception:
return None
# --- demo fallback -------------------------------------------------------
_DEMO_BANK: dict[str, dict] | None = None
def _demo_bank() -> dict[str, dict]:
"""One valid sample case per language from the local corpus — UI works offline."""
global _DEMO_BANK
if _DEMO_BANK is not None:
return _DEMO_BANK
bank: dict[str, dict] = {}
demo = _ROOT / "data" / "synthetic" / "demo_infer.json"
if demo.exists():
try:
obj = json.loads(demo.read_text(encoding="utf-8")).get("obj")
if obj:
bank[obj.get("language", "pt")] = obj
except Exception:
pass
for name in ("pairs_v3.jsonl", "pairs_v2.jsonl", "pairs.jsonl"):
if len(bank) >= 2:
break
path = _ROOT / "data" / "synthetic" / name
if not path.exists():
continue
for line in path.read_text(encoding="utf-8").splitlines():
if len(bank) >= 2:
break
try:
pair = json.loads(line).get("pair")
except Exception:
continue
if pair and pair.get("language") not in bank:
bank[pair["language"]] = pair
_DEMO_BANK = bank
return bank
def _demo_result(seed: Seed, reason: str = "offline") -> dict:
"""reason: 'offline' (no GPU/local) or 'busy' (Space GPU error / ZeroGPU quota)."""
bank = _demo_bank()
obj = bank.get(seed.language) or next(iter(bank.values()), None)
ok, errs, warns = validate_case(obj) if obj else (False, ["sem amostra de demo"], [])
return {"obj": obj, "valid": ok, "errors": errs, "warnings": warns,
"raw": "", "demo": True, "reason": reason}
# --- public API ----------------------------------------------------------
def generate(domain: str, topic: str, level: str = "MBA",
language: str = "pt", theory: str = "") -> dict:
"""Forge one case+note from a short request, on ZeroGPU.
Returns {obj, valid, errors, warnings, raw, demo}.
"""
seed = Seed(
domain=(domain or "administração").strip(),
topic=(topic or "").strip(),
level=(level or "MBA").strip(),
language=language if language in ("pt", "en") else "pt",
theory=[t.strip() for t in (theory or "").split(",") if t.strip()],
)
if FORCE_DEMO or not (gpu._HAS_SPACES or _has_cuda()):
return _demo_result(seed, "offline")
try:
raw = _generate_raw(build_minimal_prompt(seed))
except Exception as exc:
# On the Space this is typically a ZeroGPU quota / GPU-unavailable abort.
out = _demo_result(seed, "busy")
out["errors"] = [f"falha na geração: {exc}"] + out["errors"]
return out
obj = _parse(raw)
ok, errs, warns = validate_case(obj) if obj else (False, ["parse falhou"], [])
return {"obj": obj, "valid": ok, "errors": errs, "warnings": warns,
"raw": raw, "demo": False, "reason": None}
__all__ = ["generate", "Seed"]