import argparse import ast import json import subprocess import sys DEFAULT_TEST_PROMPTS = [ "Fix this Python code: def add(a,b) return a+b", "Explain what this code does: for i in range(3): print(i)", "Write Python code for linear regression and explain it.", "Debug this snippet: if x = 5: print(x)", ] def run_inference(python_exec, model_path, base_model, prompt, max_new_tokens, allow_downloads): cmd = [ python_exec, "infer_local.py", "--model-path", model_path, "--base-model", base_model, "--prompt", prompt, "--max-new-tokens", str(max_new_tokens), ] if allow_downloads: cmd.append("--allow-downloads") result = subprocess.run(cmd, check=False, capture_output=True, text=True) if result.returncode != 0: return None, f"inference failed: {result.stderr.strip()}" stdout = result.stdout.strip() try: payload = json.loads(stdout) return payload, None except json.JSONDecodeError as exc: # Some libraries may emit informational logs before/after JSON. merged = f"{result.stdout}\n{result.stderr}" start = merged.find("{") end = merged.rfind("}") if start != -1 and end != -1 and end > start: candidate = merged[start : end + 1] try: payload = json.loads(candidate) return payload, None except json.JSONDecodeError: pass return None, f"invalid json output: {exc}: {stdout[:300]}" def safe_float(value): try: return float(value) except (TypeError, ValueError): return 0.0 def prompt_expects_code(prompt): prompt_l = prompt.lower() markers = ( "fix", "debug", "repair", "write", "create", "generate", "implement", "function", "code", "snippet", "python", "multiply", "multiplication", "product", "add", "addition", "sum", "subtract", "subtraction", "difference", "divide", "division", "quotient", ) return any(marker in prompt_l for marker in markers) def code_is_valid_for_prompt(prompt, code): code = str(code or "").strip() if not code: return False if not prompt_expects_code(prompt): return True python_like = any( marker in code for marker in ("def ", "import ", "class ", "print(", "return ", "for ", "if ") ) if not python_like: return False try: ast.parse(code) return True except SyntaxError: return False def score_payload(prompt, payload): required_keys = { "code", "explanation", "confidence", "important_tokens", "relevancy_score", "hallucination", "hallucination_check_reason", "latency_ms", } has_all_keys = required_keys.issubset(payload.keys()) code_ok = code_is_valid_for_prompt(prompt, payload.get("code", "")) explanation_ok = bool(str(payload.get("explanation", "")).strip()) confidence = safe_float(payload.get("confidence", 0.0)) relevancy = safe_float(payload.get("relevancy_score", 0.0)) hallucination = bool(payload.get("hallucination", False)) return { "schema_ok": has_all_keys, "content_ok": code_ok and explanation_ok, "confidence": confidence, "relevancy": relevancy, "hallucination": hallucination, } def main(): parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="model") parser.add_argument("--base-model", type=str, default="Qwen/Qwen2.5-Coder-0.5B-Instruct") parser.add_argument("--max-new-tokens", type=int, default=320) parser.add_argument("--strict-min-confidence", type=float, default=0.6) parser.add_argument("--strict-min-relevancy", type=float, default=0.25) parser.add_argument("--prompt", action="append", default=[]) parser.add_argument( "--allow-downloads", action="store_true", help="Allow infer_local.py to download missing model files from Hugging Face.", ) args = parser.parse_args() prompts = args.prompt if args.prompt else DEFAULT_TEST_PROMPTS results = [] passed = 0 for prompt in prompts: payload, error = run_inference( python_exec=sys.executable, model_path=args.model_path, base_model=args.base_model, prompt=prompt, max_new_tokens=args.max_new_tokens, allow_downloads=args.allow_downloads, ) if error: results.append({"prompt": prompt, "error": error, "pass": False}) continue metrics = score_payload(prompt, payload) is_pass = ( metrics["schema_ok"] and metrics["content_ok"] and metrics["confidence"] >= args.strict_min_confidence and metrics["relevancy"] >= args.strict_min_relevancy and not metrics["hallucination"] ) if is_pass: passed += 1 results.append( { "prompt": prompt, "pass": is_pass, "metrics": metrics, } ) accuracy = passed / len(prompts) if prompts else 0.0 summary = { "total_tests": len(prompts), "passed_tests": passed, "accuracy": round(accuracy, 4), "thresholds": { "min_confidence": args.strict_min_confidence, "min_relevancy": args.strict_min_relevancy, "hallucination_must_be_false": True, }, "results": results, } print(json.dumps(summary, indent=2, ensure_ascii=False)) if __name__ == "__main__": main()