analyst-buddy / scripts /inference_smoke.py
hjerpe's picture
F006/F008: serve Qwen models + model switcher (vanilla-first)
656f91e verified
Raw
History Blame Contribute Delete
4.12 kB
"""Inference smoke test β€” does the trained model actually run the agentic loop?
The training smoke lives on Modal (`modal run training/modal_app.py --smoke`); this
is its INFERENCE counterpart, run on HF (a ZeroGPU Space) and locally. It loads a
model, drives ONE (or N) full DESCRIBE->SAMPLE->QUERY->ANSWER episode(s) through
the real SQLEnvironment via `evaluation.ModelPolicy`, and prints the transcript +
**per-question latency** β€” the ADR 0006 acceptance gate (target <= ~15s/question
on the chosen Space tier) you check BEFORE recording the demo.
The model-load + ModelPolicy + env loop here is exactly what the Gradio Space's
`app.py` runs inside `@spaces.GPU`, so validating it here de-risks the Space.
# local (Off the Grid) β€” a merged model dir or a Hub id
uv run python scripts/inference_smoke.py --model <hf-id-or-local-dir>
# measure latency over a few questions, pin weights for reproducibility
uv run python scripts/inference_smoke.py --model <id> --n 5 --revision <sha>
Heavy deps (torch/transformers) are imported lazily so `--help` works without them.
"""
from __future__ import annotations
import argparse
import time
def _run(args: argparse.Namespace) -> int:
from sql_env.evaluation import ModelPolicy
from sql_env.server.sql_environment import SQLEnvironment
from sql_env.training.data_loading import load_model_and_tokenizer
print(f"Loading {args.model} (revision={args.revision or 'HEAD'}) ...")
model, tokenizer = load_model_and_tokenizer(args.model, revision=args.revision)
model.eval()
env = SQLEnvironment(
questions_path=args.questions,
db_dir=args.db_dir,
tokenizer=tokenizer,
step_budget=args.step_budget,
)
policy = ModelPolicy(
model, tokenizer, enable_thinking=args.enable_thinking
)
latencies: list[float] = []
correct = 0
for ep in range(args.n):
obs = env.reset(seed=args.seed + ep)
t0 = time.perf_counter()
steps = 0
if ep == 0:
print(f"\n--- episode {ep} | Q: {obs.question[:100]}")
while not obs.done and steps < args.step_budget + 2:
action = policy.select_action(obs)
obs = env.step(action)
steps += 1
if ep == 0:
arg = (action.argument or "")[:80]
print(f" [{steps}] {action.action_type}: {arg}")
dt = time.perf_counter() - t0
latencies.append(dt)
is_correct = (obs.reward or 0.0) > 0.0
correct += int(is_correct)
print(
f"episode {ep}: {steps} steps, {dt:.1f}s, "
f"{'CORRECT' if is_correct else 'wrong'}"
)
avg = sum(latencies) / len(latencies)
print(
f"\nSMOKE RESULT: {correct}/{args.n} correct | "
f"avg {avg:.1f}s/question (max {max(latencies):.1f}s)"
)
# ADR 0006 latency gate.
if avg > args.latency_budget:
print(
f"WARNING: avg {avg:.1f}s exceeds the {args.latency_budget:.0f}s "
"demo-latency budget β€” quantize (4-bit / GGUF), use a bigger Space "
"tier, or demo the 0.6B (see ADR 0006 / the week-plan triage)."
)
return 1
print("Latency OK for the demo tier.")
return 0
def main() -> int:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--model", required=True, help="HF model id or local dir")
p.add_argument("--revision", default=None, help="Hub commit SHA to pin weights")
p.add_argument("--questions", default="data/questions/eval_n50.json")
p.add_argument("--db-dir", default="data/databases")
p.add_argument("--n", type=int, default=1, help="number of episodes")
p.add_argument("--seed", type=int, default=0)
p.add_argument("--step-budget", type=int, default=10)
p.add_argument("--enable-thinking", action="store_true")
p.add_argument(
"--latency-budget",
type=float,
default=15.0,
help="per-question seconds budget (ADR 0006 demo gate)",
)
return _run(p.parse_args())
if __name__ == "__main__":
raise SystemExit(main())