"""Inference for the fine-tuned student — short request → full case+note JSON. Runtime = **ZeroGPU**, in-Space (free). On the first GPU call we load base Qwen3-4B + the published LoRA and `merge_and_unload()` it (folding the adapter into the weights removes PEFT overhead → faster decode), then generate via transformers under `@spaces.GPU`. `max_new_tokens` is capped so a full case fits ZeroGPU's ~120s window. Locally (no `spaces`/CUDA) it falls back to a real sample so the UI works offline. Heavy/offline work (corpus gen, training, merge) runs on Modal, not here. Config (env): CASE_FORGE_BASE base model id CASE_FORGE_ADAPTER HF repo id of the published LoRA CASE_FORGE_MAX_TOKENS generation cap (default 2800 — fits the ZeroGPU window) CASE_FORGE_DEMO=1 force the demo sample (no model load) """ from __future__ import annotations import json import os import sys from pathlib import Path os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") _ROOT = Path(__file__).resolve().parent.parent # case-forge/ _MONOREPO = _ROOT.parent # build-small-hackathon/ for _p in (str(_ROOT), str(_MONOREPO)): if _p not in sys.path: sys.path.insert(0, _p) from data.schema import validate_case # noqa: E402 from pipeline.prompts import Seed, build_minimal_prompt # noqa: E402 from shared import gpu # noqa: E402 BASE_MODEL = os.environ.get("CASE_FORGE_BASE", "Qwen/Qwen3-4B-Instruct-2507") ADAPTER_REPO = os.environ.get( "CASE_FORGE_ADAPTER", "build-small-hackathon/case-forge-qwen3-4b").strip() MAX_NEW_TOKENS = int(os.environ.get("CASE_FORGE_MAX_TOKENS", "2800")) FORCE_DEMO = os.environ.get("CASE_FORGE_DEMO", "").strip() in ("1", "true", "yes") _MODEL = None _TOK = None def _has_cuda() -> bool: try: import torch return torch.cuda.is_available() except Exception: return False def _ensure_model() -> None: """Lazy-load base + LoRA and merge — runs inside the GPU-allocated context.""" global _MODEL, _TOK if _MODEL is not None: return import torch from transformers import AutoModelForCausalLM, AutoTokenizer tok = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True, ) if ADAPTER_REPO: from peft import PeftModel model = PeftModel.from_pretrained(model, ADAPTER_REPO) model = model.merge_and_unload() # fold LoRA into base → faster generation model.eval() _MODEL, _TOK = model, tok @gpu.gpu(duration=120) def _generate_raw(messages: list[dict]) -> str: """Run the model on ZeroGPU and return the raw decoded completion.""" import torch _ensure_model() try: text = _TOK.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) except TypeError: text = _TOK.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = _TOK(text, return_tensors="pt").to(_MODEL.device) with torch.no_grad(): out = _MODEL.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=0.7, top_p=0.95, pad_token_id=_TOK.pad_token_id or _TOK.eos_token_id, ) return _TOK.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) def _parse(raw: str) -> dict | None: try: return json.loads(raw[raw.find("{"): raw.rfind("}") + 1]) except Exception: return None # --- demo fallback ------------------------------------------------------- _DEMO_BANK: dict[str, dict] | None = None def _demo_bank() -> dict[str, dict]: """One valid sample case per language from the local corpus — UI works offline.""" global _DEMO_BANK if _DEMO_BANK is not None: return _DEMO_BANK bank: dict[str, dict] = {} demo = _ROOT / "data" / "synthetic" / "demo_infer.json" if demo.exists(): try: obj = json.loads(demo.read_text(encoding="utf-8")).get("obj") if obj: bank[obj.get("language", "pt")] = obj except Exception: pass for name in ("pairs_v3.jsonl", "pairs_v2.jsonl", "pairs.jsonl"): if len(bank) >= 2: break path = _ROOT / "data" / "synthetic" / name if not path.exists(): continue for line in path.read_text(encoding="utf-8").splitlines(): if len(bank) >= 2: break try: pair = json.loads(line).get("pair") except Exception: continue if pair and pair.get("language") not in bank: bank[pair["language"]] = pair _DEMO_BANK = bank return bank def _demo_result(seed: Seed, reason: str = "offline") -> dict: """reason: 'offline' (no GPU/local) or 'busy' (Space GPU error / ZeroGPU quota).""" bank = _demo_bank() obj = bank.get(seed.language) or next(iter(bank.values()), None) ok, errs, warns = validate_case(obj) if obj else (False, ["sem amostra de demo"], []) return {"obj": obj, "valid": ok, "errors": errs, "warnings": warns, "raw": "", "demo": True, "reason": reason} # --- public API ---------------------------------------------------------- def generate(domain: str, topic: str, level: str = "MBA", language: str = "pt", theory: str = "") -> dict: """Forge one case+note from a short request, on ZeroGPU. Returns {obj, valid, errors, warnings, raw, demo}. """ seed = Seed( domain=(domain or "administração").strip(), topic=(topic or "").strip(), level=(level or "MBA").strip(), language=language if language in ("pt", "en") else "pt", theory=[t.strip() for t in (theory or "").split(",") if t.strip()], ) if FORCE_DEMO or not (gpu._HAS_SPACES or _has_cuda()): return _demo_result(seed, "offline") try: raw = _generate_raw(build_minimal_prompt(seed)) except Exception as exc: # On the Space this is typically a ZeroGPU quota / GPU-unavailable abort. out = _demo_result(seed, "busy") out["errors"] = [f"falha na geração: {exc}"] + out["errors"] return out obj = _parse(raw) ok, errs, warns = validate_case(obj) if obj else (False, ["parse falhou"], []) return {"obj": obj, "valid": ok, "errors": errs, "warnings": warns, "raw": raw, "demo": False, "reason": None} __all__ = ["generate", "Seed"]