| """ |
| Run Component 7 inference benchmark on the same 5 Python prompts. |
| Outputs before/after syntax-valid score. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
| from typing import Any, Dict |
|
|
| import torch |
| import yaml |
|
|
| |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from src.inference_engine.inference_engine import DecodingConfig, InferenceEngine |
| from src.model_architecture.code_transformer import CodeTransformerLM, ModelConfig, get_model_presets |
| from src.tokenizer.code_tokenizer import CodeTokenizer |
|
|
| PROMPTS = [ |
| "Write a Python function to check if a number is prime.", |
| "Write Python code to reverse a string without using slicing.", |
| "Create a Python function that returns Fibonacci numbers up to n.", |
| "Write Python code to count word frequency in a sentence.", |
| "Write a Python function to sort a list of dictionaries by a key.", |
| ] |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run Component 7 inference benchmark.") |
| parser.add_argument("--config", default="configs/component7_inference_config.yaml") |
| return parser.parse_args() |
|
|
|
|
| def load_yaml(path: Path) -> Dict[str, Any]: |
| if not path.exists(): |
| raise FileNotFoundError(f"Config not found: {path}") |
| data = yaml.safe_load(path.read_text(encoding="utf-8")) |
| if not isinstance(data, dict): |
| raise ValueError("Invalid YAML config.") |
| return data |
|
|
|
|
| def build_model_config(path: Path) -> ModelConfig: |
| cfg = load_yaml(path) |
| preset = cfg.get("preset") |
| model_cfg = cfg.get("model", {}) |
| if preset: |
| merged = get_model_presets()[preset].__dict__.copy() |
| merged.update(model_cfg) |
| return ModelConfig(**merged) |
| return ModelConfig(**model_cfg) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| try: |
| cfg = load_yaml(Path(args.config)) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if device.type != "cuda": |
| raise RuntimeError("CUDA is required for Component 7 benchmark.") |
|
|
| model_cfg = build_model_config(PROJECT_ROOT / cfg["model"]["model_config_path"]) |
| model = CodeTransformerLM(model_cfg).to(device) |
|
|
| ckpt_path = PROJECT_ROOT / cfg["model"]["checkpoint_path"] |
| payload = torch.load(ckpt_path, map_location=device) |
| model.load_state_dict(payload["model_state"]) |
| model.half() |
|
|
| tokenizer = CodeTokenizer.load(str(PROJECT_ROOT / cfg["model"]["tokenizer_dir"])) |
|
|
| dcfg = DecodingConfig( |
| max_new_tokens=int(cfg["inference"].get("max_new_tokens", 180)), |
| greedy_temperature=float(cfg["inference"].get("greedy_temperature", 0.0)), |
| retry2_temperature=float(cfg["inference"].get("retry2_temperature", 0.25)), |
| retry2_top_p=float(cfg["inference"].get("retry2_top_p", 0.85)), |
| retry3_temperature=float(cfg["inference"].get("retry3_temperature", 0.35)), |
| retry3_top_p=float(cfg["inference"].get("retry3_top_p", 0.90)), |
| max_retries=int(cfg["inference"].get("max_retries", 3)), |
| min_tokens_before_stop_check=int(cfg["inference"].get("min_tokens_before_stop_check", 24)), |
| ) |
|
|
| engine = InferenceEngine(model=model, tokenizer=tokenizer, device=device) |
|
|
| rows = [] |
| syntax_ok_count = 0 |
| for p in PROMPTS: |
| res = engine.generate_with_retry(prompt=p, language=str(cfg["inference"].get("language", "python")), cfg=dcfg) |
| final = res["final"] |
| syntax_ok = bool(final["syntax_ok"]) |
| syntax_ok_count += 1 if syntax_ok else 0 |
| rows.append( |
| { |
| "prompt": p, |
| "final_code": final["code"], |
| "syntax_ok": syntax_ok, |
| "attempt_used": final["attempt"], |
| "generated_tokens": final["generated_tokens"], |
| "attempts": res["attempts"], |
| } |
| ) |
|
|
| before_score = None |
| before_path = PROJECT_ROOT / "artifacts" / "evaluation" / "component6_eval_results.json" |
| if before_path.exists(): |
| d = json.loads(before_path.read_text(encoding="utf-8")) |
| try: |
| before_score = sum(1 for x in d["checkpoints"][0]["generations"] if x["python_syntax_ok"]) |
| except Exception: |
| before_score = None |
|
|
| out = { |
| "checkpoint": str(ckpt_path), |
| "step": int(payload.get("step", -1)), |
| "before_component6_syntax_ok_out_of_5": before_score, |
| "after_component7_syntax_ok_out_of_5": syntax_ok_count, |
| "prompts": rows, |
| } |
|
|
| out_path = PROJECT_ROOT / cfg["output"]["results_json"] |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8") |
|
|
| print("Component 7 inference benchmark completed.") |
| if before_score is not None: |
| print(f"Before (Component 6): {before_score}/5 syntax-valid") |
| print(f"After (Component 7): {syntax_ok_count}/5 syntax-valid") |
| print(f"Saved results: {out_path}") |
|
|
| except Exception as exc: |
| print("Component 7 benchmark failed.") |
| print(f"What went wrong: {exc}") |
| print("Fix suggestion: verify checkpoint and tokenizer paths.") |
| raise SystemExit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|