mindi-backup / scripts /run_component7_inference_benchmark.py
Mindigenous
Initial full project backup with Git LFS
53f0cc2
"""
Run Component 7 inference benchmark on the same 5 Python prompts.
Outputs before/after syntax-valid score.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict
import torch
import yaml
# Ensure imports work from project root.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.inference_engine.inference_engine import DecodingConfig, InferenceEngine # noqa: E402
from src.model_architecture.code_transformer import CodeTransformerLM, ModelConfig, get_model_presets # noqa: E402
from src.tokenizer.code_tokenizer import CodeTokenizer # noqa: E402
PROMPTS = [
"Write a Python function to check if a number is prime.",
"Write Python code to reverse a string without using slicing.",
"Create a Python function that returns Fibonacci numbers up to n.",
"Write Python code to count word frequency in a sentence.",
"Write a Python function to sort a list of dictionaries by a key.",
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run Component 7 inference benchmark.")
parser.add_argument("--config", default="configs/component7_inference_config.yaml")
return parser.parse_args()
def load_yaml(path: Path) -> Dict[str, Any]:
if not path.exists():
raise FileNotFoundError(f"Config not found: {path}")
data = yaml.safe_load(path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
raise ValueError("Invalid YAML config.")
return data
def build_model_config(path: Path) -> ModelConfig:
cfg = load_yaml(path)
preset = cfg.get("preset")
model_cfg = cfg.get("model", {})
if preset:
merged = get_model_presets()[preset].__dict__.copy()
merged.update(model_cfg)
return ModelConfig(**merged)
return ModelConfig(**model_cfg)
def main() -> None:
args = parse_args()
try:
cfg = load_yaml(Path(args.config))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != "cuda":
raise RuntimeError("CUDA is required for Component 7 benchmark.")
model_cfg = build_model_config(PROJECT_ROOT / cfg["model"]["model_config_path"])
model = CodeTransformerLM(model_cfg).to(device)
ckpt_path = PROJECT_ROOT / cfg["model"]["checkpoint_path"]
payload = torch.load(ckpt_path, map_location=device)
model.load_state_dict(payload["model_state"])
model.half()
tokenizer = CodeTokenizer.load(str(PROJECT_ROOT / cfg["model"]["tokenizer_dir"]))
dcfg = DecodingConfig(
max_new_tokens=int(cfg["inference"].get("max_new_tokens", 180)),
greedy_temperature=float(cfg["inference"].get("greedy_temperature", 0.0)),
retry2_temperature=float(cfg["inference"].get("retry2_temperature", 0.25)),
retry2_top_p=float(cfg["inference"].get("retry2_top_p", 0.85)),
retry3_temperature=float(cfg["inference"].get("retry3_temperature", 0.35)),
retry3_top_p=float(cfg["inference"].get("retry3_top_p", 0.90)),
max_retries=int(cfg["inference"].get("max_retries", 3)),
min_tokens_before_stop_check=int(cfg["inference"].get("min_tokens_before_stop_check", 24)),
)
engine = InferenceEngine(model=model, tokenizer=tokenizer, device=device)
rows = []
syntax_ok_count = 0
for p in PROMPTS:
res = engine.generate_with_retry(prompt=p, language=str(cfg["inference"].get("language", "python")), cfg=dcfg)
final = res["final"]
syntax_ok = bool(final["syntax_ok"])
syntax_ok_count += 1 if syntax_ok else 0
rows.append(
{
"prompt": p,
"final_code": final["code"],
"syntax_ok": syntax_ok,
"attempt_used": final["attempt"],
"generated_tokens": final["generated_tokens"],
"attempts": res["attempts"],
}
)
before_score = None
before_path = PROJECT_ROOT / "artifacts" / "evaluation" / "component6_eval_results.json"
if before_path.exists():
d = json.loads(before_path.read_text(encoding="utf-8"))
try:
before_score = sum(1 for x in d["checkpoints"][0]["generations"] if x["python_syntax_ok"])
except Exception:
before_score = None
out = {
"checkpoint": str(ckpt_path),
"step": int(payload.get("step", -1)),
"before_component6_syntax_ok_out_of_5": before_score,
"after_component7_syntax_ok_out_of_5": syntax_ok_count,
"prompts": rows,
}
out_path = PROJECT_ROOT / cfg["output"]["results_json"]
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
print("Component 7 inference benchmark completed.")
if before_score is not None:
print(f"Before (Component 6): {before_score}/5 syntax-valid")
print(f"After (Component 7): {syntax_ok_count}/5 syntax-valid")
print(f"Saved results: {out_path}")
except Exception as exc:
print("Component 7 benchmark failed.")
print(f"What went wrong: {exc}")
print("Fix suggestion: verify checkpoint and tokenizer paths.")
raise SystemExit(1)
if __name__ == "__main__":
main()