""" Run Component 7 inference benchmark on the same 5 Python prompts. Outputs before/after syntax-valid score. """ from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Any, Dict import torch import yaml # Ensure imports work from project root. PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from src.inference_engine.inference_engine import DecodingConfig, InferenceEngine # noqa: E402 from src.model_architecture.code_transformer import CodeTransformerLM, ModelConfig, get_model_presets # noqa: E402 from src.tokenizer.code_tokenizer import CodeTokenizer # noqa: E402 PROMPTS = [ "Write a Python function to check if a number is prime.", "Write Python code to reverse a string without using slicing.", "Create a Python function that returns Fibonacci numbers up to n.", "Write Python code to count word frequency in a sentence.", "Write a Python function to sort a list of dictionaries by a key.", ] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run Component 7 inference benchmark.") parser.add_argument("--config", default="configs/component7_inference_config.yaml") return parser.parse_args() def load_yaml(path: Path) -> Dict[str, Any]: if not path.exists(): raise FileNotFoundError(f"Config not found: {path}") data = yaml.safe_load(path.read_text(encoding="utf-8")) if not isinstance(data, dict): raise ValueError("Invalid YAML config.") return data def build_model_config(path: Path) -> ModelConfig: cfg = load_yaml(path) preset = cfg.get("preset") model_cfg = cfg.get("model", {}) if preset: merged = get_model_presets()[preset].__dict__.copy() merged.update(model_cfg) return ModelConfig(**merged) return ModelConfig(**model_cfg) def main() -> None: args = parse_args() try: cfg = load_yaml(Path(args.config)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device.type != "cuda": raise RuntimeError("CUDA is required for Component 7 benchmark.") model_cfg = build_model_config(PROJECT_ROOT / cfg["model"]["model_config_path"]) model = CodeTransformerLM(model_cfg).to(device) ckpt_path = PROJECT_ROOT / cfg["model"]["checkpoint_path"] payload = torch.load(ckpt_path, map_location=device) model.load_state_dict(payload["model_state"]) model.half() tokenizer = CodeTokenizer.load(str(PROJECT_ROOT / cfg["model"]["tokenizer_dir"])) dcfg = DecodingConfig( max_new_tokens=int(cfg["inference"].get("max_new_tokens", 180)), greedy_temperature=float(cfg["inference"].get("greedy_temperature", 0.0)), retry2_temperature=float(cfg["inference"].get("retry2_temperature", 0.25)), retry2_top_p=float(cfg["inference"].get("retry2_top_p", 0.85)), retry3_temperature=float(cfg["inference"].get("retry3_temperature", 0.35)), retry3_top_p=float(cfg["inference"].get("retry3_top_p", 0.90)), max_retries=int(cfg["inference"].get("max_retries", 3)), min_tokens_before_stop_check=int(cfg["inference"].get("min_tokens_before_stop_check", 24)), ) engine = InferenceEngine(model=model, tokenizer=tokenizer, device=device) rows = [] syntax_ok_count = 0 for p in PROMPTS: res = engine.generate_with_retry(prompt=p, language=str(cfg["inference"].get("language", "python")), cfg=dcfg) final = res["final"] syntax_ok = bool(final["syntax_ok"]) syntax_ok_count += 1 if syntax_ok else 0 rows.append( { "prompt": p, "final_code": final["code"], "syntax_ok": syntax_ok, "attempt_used": final["attempt"], "generated_tokens": final["generated_tokens"], "attempts": res["attempts"], } ) before_score = None before_path = PROJECT_ROOT / "artifacts" / "evaluation" / "component6_eval_results.json" if before_path.exists(): d = json.loads(before_path.read_text(encoding="utf-8")) try: before_score = sum(1 for x in d["checkpoints"][0]["generations"] if x["python_syntax_ok"]) except Exception: before_score = None out = { "checkpoint": str(ckpt_path), "step": int(payload.get("step", -1)), "before_component6_syntax_ok_out_of_5": before_score, "after_component7_syntax_ok_out_of_5": syntax_ok_count, "prompts": rows, } out_path = PROJECT_ROOT / cfg["output"]["results_json"] out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8") print("Component 7 inference benchmark completed.") if before_score is not None: print(f"Before (Component 6): {before_score}/5 syntax-valid") print(f"After (Component 7): {syntax_ok_count}/5 syntax-valid") print(f"Saved results: {out_path}") except Exception as exc: print("Component 7 benchmark failed.") print(f"What went wrong: {exc}") print("Fix suggestion: verify checkpoint and tokenizer paths.") raise SystemExit(1) if __name__ == "__main__": main()