Spaces:
Running on Zero
Running on Zero
| """Inference for the fine-tuned student — short request → full case+note JSON. | |
| Runtime = **ZeroGPU**, in-Space (free). On the first GPU call we load base | |
| Qwen3-4B + the published LoRA and `merge_and_unload()` it (folding the adapter into | |
| the weights removes PEFT overhead → faster decode), then generate via transformers | |
| under `@spaces.GPU`. `max_new_tokens` is capped so a full case fits ZeroGPU's ~120s | |
| window. Locally (no `spaces`/CUDA) it falls back to a real sample so the UI works | |
| offline. Heavy/offline work (corpus gen, training, merge) runs on Modal, not here. | |
| Config (env): | |
| CASE_FORGE_BASE base model id | |
| CASE_FORGE_ADAPTER HF repo id of the published LoRA | |
| CASE_FORGE_MAX_TOKENS generation cap (default 2800 — fits the ZeroGPU window) | |
| CASE_FORGE_DEMO=1 force the demo sample (no model load) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| _ROOT = Path(__file__).resolve().parent.parent # case-forge/ | |
| _MONOREPO = _ROOT.parent # build-small-hackathon/ | |
| for _p in (str(_ROOT), str(_MONOREPO)): | |
| if _p not in sys.path: | |
| sys.path.insert(0, _p) | |
| from data.schema import validate_case # noqa: E402 | |
| from pipeline.prompts import Seed, build_minimal_prompt # noqa: E402 | |
| from shared import gpu # noqa: E402 | |
| BASE_MODEL = os.environ.get("CASE_FORGE_BASE", "Qwen/Qwen3-4B-Instruct-2507") | |
| ADAPTER_REPO = os.environ.get( | |
| "CASE_FORGE_ADAPTER", "build-small-hackathon/case-forge-qwen3-4b").strip() | |
| MAX_NEW_TOKENS = int(os.environ.get("CASE_FORGE_MAX_TOKENS", "2800")) | |
| FORCE_DEMO = os.environ.get("CASE_FORGE_DEMO", "").strip() in ("1", "true", "yes") | |
| _MODEL = None | |
| _TOK = None | |
| def _has_cuda() -> bool: | |
| try: | |
| import torch | |
| return torch.cuda.is_available() | |
| except Exception: | |
| return False | |
| def _ensure_model() -> None: | |
| """Lazy-load base + LoRA and merge — runs inside the GPU-allocated context.""" | |
| global _MODEL, _TOK | |
| if _MODEL is not None: | |
| return | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| tok = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, torch_dtype=torch.bfloat16, device_map="cuda", | |
| trust_remote_code=True, | |
| ) | |
| if ADAPTER_REPO: | |
| from peft import PeftModel | |
| model = PeftModel.from_pretrained(model, ADAPTER_REPO) | |
| model = model.merge_and_unload() # fold LoRA into base → faster generation | |
| model.eval() | |
| _MODEL, _TOK = model, tok | |
| def _generate_raw(messages: list[dict]) -> str: | |
| """Run the model on ZeroGPU and return the raw decoded completion.""" | |
| import torch | |
| _ensure_model() | |
| try: | |
| text = _TOK.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| except TypeError: | |
| text = _TOK.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| ) | |
| inputs = _TOK(text, return_tensors="pt").to(_MODEL.device) | |
| with torch.no_grad(): | |
| out = _MODEL.generate( | |
| **inputs, max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=True, temperature=0.7, top_p=0.95, | |
| pad_token_id=_TOK.pad_token_id or _TOK.eos_token_id, | |
| ) | |
| return _TOK.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| def _parse(raw: str) -> dict | None: | |
| try: | |
| return json.loads(raw[raw.find("{"): raw.rfind("}") + 1]) | |
| except Exception: | |
| return None | |
| # --- demo fallback ------------------------------------------------------- | |
| _DEMO_BANK: dict[str, dict] | None = None | |
| def _demo_bank() -> dict[str, dict]: | |
| """One valid sample case per language from the local corpus — UI works offline.""" | |
| global _DEMO_BANK | |
| if _DEMO_BANK is not None: | |
| return _DEMO_BANK | |
| bank: dict[str, dict] = {} | |
| demo = _ROOT / "data" / "synthetic" / "demo_infer.json" | |
| if demo.exists(): | |
| try: | |
| obj = json.loads(demo.read_text(encoding="utf-8")).get("obj") | |
| if obj: | |
| bank[obj.get("language", "pt")] = obj | |
| except Exception: | |
| pass | |
| for name in ("pairs_v3.jsonl", "pairs_v2.jsonl", "pairs.jsonl"): | |
| if len(bank) >= 2: | |
| break | |
| path = _ROOT / "data" / "synthetic" / name | |
| if not path.exists(): | |
| continue | |
| for line in path.read_text(encoding="utf-8").splitlines(): | |
| if len(bank) >= 2: | |
| break | |
| try: | |
| pair = json.loads(line).get("pair") | |
| except Exception: | |
| continue | |
| if pair and pair.get("language") not in bank: | |
| bank[pair["language"]] = pair | |
| _DEMO_BANK = bank | |
| return bank | |
| def _demo_result(seed: Seed, reason: str = "offline") -> dict: | |
| """reason: 'offline' (no GPU/local) or 'busy' (Space GPU error / ZeroGPU quota).""" | |
| bank = _demo_bank() | |
| obj = bank.get(seed.language) or next(iter(bank.values()), None) | |
| ok, errs, warns = validate_case(obj) if obj else (False, ["sem amostra de demo"], []) | |
| return {"obj": obj, "valid": ok, "errors": errs, "warnings": warns, | |
| "raw": "", "demo": True, "reason": reason} | |
| # --- public API ---------------------------------------------------------- | |
| def generate(domain: str, topic: str, level: str = "MBA", | |
| language: str = "pt", theory: str = "") -> dict: | |
| """Forge one case+note from a short request, on ZeroGPU. | |
| Returns {obj, valid, errors, warnings, raw, demo}. | |
| """ | |
| seed = Seed( | |
| domain=(domain or "administração").strip(), | |
| topic=(topic or "").strip(), | |
| level=(level or "MBA").strip(), | |
| language=language if language in ("pt", "en") else "pt", | |
| theory=[t.strip() for t in (theory or "").split(",") if t.strip()], | |
| ) | |
| if FORCE_DEMO or not (gpu._HAS_SPACES or _has_cuda()): | |
| return _demo_result(seed, "offline") | |
| try: | |
| raw = _generate_raw(build_minimal_prompt(seed)) | |
| except Exception as exc: | |
| # On the Space this is typically a ZeroGPU quota / GPU-unavailable abort. | |
| out = _demo_result(seed, "busy") | |
| out["errors"] = [f"falha na geração: {exc}"] + out["errors"] | |
| return out | |
| obj = _parse(raw) | |
| ok, errs, warns = validate_case(obj) if obj else (False, ["parse falhou"], []) | |
| return {"obj": obj, "valid": ok, "errors": errs, "warnings": warns, | |
| "raw": raw, "demo": False, "reason": None} | |
| __all__ = ["generate", "Seed"] | |