Spaces:
Running on Zero
Running on Zero
| """Inference smoke test β does the trained model actually run the agentic loop? | |
| The training smoke lives on Modal (`modal run training/modal_app.py --smoke`); this | |
| is its INFERENCE counterpart, run on HF (a ZeroGPU Space) and locally. It loads a | |
| model, drives ONE (or N) full DESCRIBE->SAMPLE->QUERY->ANSWER episode(s) through | |
| the real SQLEnvironment via `evaluation.ModelPolicy`, and prints the transcript + | |
| **per-question latency** β the ADR 0006 acceptance gate (target <= ~15s/question | |
| on the chosen Space tier) you check BEFORE recording the demo. | |
| The model-load + ModelPolicy + env loop here is exactly what the Gradio Space's | |
| `app.py` runs inside `@spaces.GPU`, so validating it here de-risks the Space. | |
| # local (Off the Grid) β a merged model dir or a Hub id | |
| uv run python scripts/inference_smoke.py --model <hf-id-or-local-dir> | |
| # measure latency over a few questions, pin weights for reproducibility | |
| uv run python scripts/inference_smoke.py --model <id> --n 5 --revision <sha> | |
| Heavy deps (torch/transformers) are imported lazily so `--help` works without them. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import time | |
| def _run(args: argparse.Namespace) -> int: | |
| from sql_env.evaluation import ModelPolicy | |
| from sql_env.server.sql_environment import SQLEnvironment | |
| from sql_env.training.data_loading import load_model_and_tokenizer | |
| print(f"Loading {args.model} (revision={args.revision or 'HEAD'}) ...") | |
| model, tokenizer = load_model_and_tokenizer(args.model, revision=args.revision) | |
| model.eval() | |
| env = SQLEnvironment( | |
| questions_path=args.questions, | |
| db_dir=args.db_dir, | |
| tokenizer=tokenizer, | |
| step_budget=args.step_budget, | |
| ) | |
| policy = ModelPolicy( | |
| model, tokenizer, enable_thinking=args.enable_thinking | |
| ) | |
| latencies: list[float] = [] | |
| correct = 0 | |
| for ep in range(args.n): | |
| obs = env.reset(seed=args.seed + ep) | |
| t0 = time.perf_counter() | |
| steps = 0 | |
| if ep == 0: | |
| print(f"\n--- episode {ep} | Q: {obs.question[:100]}") | |
| while not obs.done and steps < args.step_budget + 2: | |
| action = policy.select_action(obs) | |
| obs = env.step(action) | |
| steps += 1 | |
| if ep == 0: | |
| arg = (action.argument or "")[:80] | |
| print(f" [{steps}] {action.action_type}: {arg}") | |
| dt = time.perf_counter() - t0 | |
| latencies.append(dt) | |
| is_correct = (obs.reward or 0.0) > 0.0 | |
| correct += int(is_correct) | |
| print( | |
| f"episode {ep}: {steps} steps, {dt:.1f}s, " | |
| f"{'CORRECT' if is_correct else 'wrong'}" | |
| ) | |
| avg = sum(latencies) / len(latencies) | |
| print( | |
| f"\nSMOKE RESULT: {correct}/{args.n} correct | " | |
| f"avg {avg:.1f}s/question (max {max(latencies):.1f}s)" | |
| ) | |
| # ADR 0006 latency gate. | |
| if avg > args.latency_budget: | |
| print( | |
| f"WARNING: avg {avg:.1f}s exceeds the {args.latency_budget:.0f}s " | |
| "demo-latency budget β quantize (4-bit / GGUF), use a bigger Space " | |
| "tier, or demo the 0.6B (see ADR 0006 / the week-plan triage)." | |
| ) | |
| return 1 | |
| print("Latency OK for the demo tier.") | |
| return 0 | |
| def main() -> int: | |
| p = argparse.ArgumentParser(description=__doc__) | |
| p.add_argument("--model", required=True, help="HF model id or local dir") | |
| p.add_argument("--revision", default=None, help="Hub commit SHA to pin weights") | |
| p.add_argument("--questions", default="data/questions/eval_n50.json") | |
| p.add_argument("--db-dir", default="data/databases") | |
| p.add_argument("--n", type=int, default=1, help="number of episodes") | |
| p.add_argument("--seed", type=int, default=0) | |
| p.add_argument("--step-budget", type=int, default=10) | |
| p.add_argument("--enable-thinking", action="store_true") | |
| p.add_argument( | |
| "--latency-budget", | |
| type=float, | |
| default=15.0, | |
| help="per-question seconds budget (ADR 0006 demo gate)", | |
| ) | |
| return _run(p.parse_args()) | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |