| import argparse |
| import ast |
| import json |
| import subprocess |
| import sys |
|
|
|
|
| DEFAULT_TEST_PROMPTS = [
|
| "Fix this Python code: def add(a,b) return a+b",
|
| "Explain what this code does: for i in range(3): print(i)",
|
| "Write Python code for linear regression and explain it.",
|
| "Debug this snippet: if x = 5: print(x)",
|
| ]
|
|
|
|
|
| def run_inference(python_exec, model_path, base_model, prompt, max_new_tokens, allow_downloads): |
| cmd = [
|
| python_exec,
|
| "infer_local.py",
|
| "--model-path",
|
| model_path,
|
| "--base-model",
|
| base_model,
|
| "--prompt",
|
| prompt,
|
| "--max-new-tokens", |
| str(max_new_tokens), |
| ] |
| if allow_downloads: |
| cmd.append("--allow-downloads") |
| result = subprocess.run(cmd, check=False, capture_output=True, text=True)
|
| if result.returncode != 0:
|
| return None, f"inference failed: {result.stderr.strip()}"
|
|
|
| stdout = result.stdout.strip()
|
| try:
|
| payload = json.loads(stdout)
|
| return payload, None
|
| except json.JSONDecodeError as exc:
|
|
|
| merged = f"{result.stdout}\n{result.stderr}"
|
| start = merged.find("{")
|
| end = merged.rfind("}")
|
| if start != -1 and end != -1 and end > start:
|
| candidate = merged[start : end + 1]
|
| try:
|
| payload = json.loads(candidate)
|
| return payload, None
|
| except json.JSONDecodeError:
|
| pass
|
| return None, f"invalid json output: {exc}: {stdout[:300]}"
|
|
|
|
|
| def safe_float(value): |
| try:
|
| return float(value)
|
| except (TypeError, ValueError):
|
| return 0.0 |
|
|
|
|
| def prompt_expects_code(prompt): |
| prompt_l = prompt.lower() |
| markers = ( |
| "fix", |
| "debug", |
| "repair", |
| "write", |
| "create", |
| "generate", |
| "implement", |
| "function", |
| "code", |
| "snippet", |
| "python", |
| "multiply", |
| "multiplication", |
| "product", |
| "add", |
| "addition", |
| "sum", |
| "subtract", |
| "subtraction", |
| "difference", |
| "divide", |
| "division", |
| "quotient", |
| ) |
| return any(marker in prompt_l for marker in markers) |
|
|
|
|
| def code_is_valid_for_prompt(prompt, code): |
| code = str(code or "").strip() |
| if not code: |
| return False |
| if not prompt_expects_code(prompt): |
| return True |
| python_like = any( |
| marker in code |
| for marker in ("def ", "import ", "class ", "print(", "return ", "for ", "if ") |
| ) |
| if not python_like: |
| return False |
| try: |
| ast.parse(code) |
| return True |
| except SyntaxError: |
| return False |
|
|
|
|
| def score_payload(prompt, payload): |
| required_keys = {
|
| "code",
|
| "explanation",
|
| "confidence",
|
| "important_tokens",
|
| "relevancy_score",
|
| "hallucination",
|
| "hallucination_check_reason",
|
| "latency_ms",
|
| }
|
| has_all_keys = required_keys.issubset(payload.keys())
|
| code_ok = code_is_valid_for_prompt(prompt, payload.get("code", "")) |
| explanation_ok = bool(str(payload.get("explanation", "")).strip())
|
| confidence = safe_float(payload.get("confidence", 0.0))
|
| relevancy = safe_float(payload.get("relevancy_score", 0.0))
|
| hallucination = bool(payload.get("hallucination", False))
|
|
|
| return {
|
| "schema_ok": has_all_keys,
|
| "content_ok": code_ok and explanation_ok,
|
| "confidence": confidence,
|
| "relevancy": relevancy,
|
| "hallucination": hallucination,
|
| }
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--model-path", type=str, default="model")
|
| parser.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-Coder-0.5B-Instruct")
|
| parser.add_argument("--max-new-tokens", type=int, default=320)
|
| parser.add_argument("--strict-min-confidence", type=float, default=0.6) |
| parser.add_argument("--strict-min-relevancy", type=float, default=0.25) |
| parser.add_argument("--prompt", action="append", default=[]) |
| parser.add_argument( |
| "--allow-downloads", |
| action="store_true", |
| help="Allow infer_local.py to download missing model files from Hugging Face.", |
| ) |
| args = parser.parse_args()
|
|
|
| prompts = args.prompt if args.prompt else DEFAULT_TEST_PROMPTS
|
| results = []
|
| passed = 0
|
|
|
| for prompt in prompts:
|
| payload, error = run_inference(
|
| python_exec=sys.executable,
|
| model_path=args.model_path,
|
| base_model=args.base_model,
|
| prompt=prompt, |
| max_new_tokens=args.max_new_tokens, |
| allow_downloads=args.allow_downloads, |
| ) |
| if error:
|
| results.append({"prompt": prompt, "error": error, "pass": False})
|
| continue
|
|
|
| metrics = score_payload(prompt, payload) |
| is_pass = (
|
| metrics["schema_ok"]
|
| and metrics["content_ok"]
|
| and metrics["confidence"] >= args.strict_min_confidence
|
| and metrics["relevancy"] >= args.strict_min_relevancy
|
| and not metrics["hallucination"]
|
| )
|
| if is_pass:
|
| passed += 1
|
|
|
| results.append(
|
| {
|
| "prompt": prompt,
|
| "pass": is_pass,
|
| "metrics": metrics,
|
| }
|
| )
|
|
|
| accuracy = passed / len(prompts) if prompts else 0.0
|
| summary = {
|
| "total_tests": len(prompts),
|
| "passed_tests": passed,
|
| "accuracy": round(accuracy, 4),
|
| "thresholds": {
|
| "min_confidence": args.strict_min_confidence,
|
| "min_relevancy": args.strict_min_relevancy,
|
| "hallucination_must_be_false": True,
|
| },
|
| "results": results,
|
| }
|
| print(json.dumps(summary, indent=2, ensure_ascii=False))
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|