mindi-backup / scripts /run_component6_evaluation.py
Mindigenous
Initial full project backup with Git LFS
53f0cc2
"""
Component 6: Evaluation system.
- Computes validation loss for selected checkpoints.
- Generates code for 5 simple Python prompts.
- Performs syntax validity checks.
- Saves results JSON.
"""
from __future__ import annotations
import argparse
import json
import math
import sys
from pathlib import Path
from typing import Any, Dict, List
import torch
import yaml
from torch.utils.data import DataLoader
# Ensure src imports work.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.evaluation_system.code_eval import python_syntax_ok, restore_code_from_structured, save_json # noqa: E402
from src.model_architecture.code_transformer import CodeTransformerLM, ModelConfig, get_model_presets # noqa: E402
from src.tokenizer.code_tokenizer import CodeTokenizer # noqa: E402
from src.training_pipeline.tokenized_dataset import CausalCollator, TokenizedJsonlDataset # noqa: E402
PROMPTS = [
"Write a Python function to check if a number is prime.",
"Write Python code to reverse a string without using slicing.",
"Create a Python function that returns Fibonacci numbers up to n.",
"Write Python code to count word frequency in a sentence.",
"Write a Python function to sort a list of dictionaries by a key.",
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run Component 6 evaluation.")
parser.add_argument("--config", default="configs/component6_evaluation_config.yaml")
return parser.parse_args()
def load_yaml(path: Path) -> Dict[str, Any]:
if not path.exists():
raise FileNotFoundError(f"Config not found: {path}")
data = yaml.safe_load(path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
raise ValueError("Invalid YAML config.")
return data
def build_model_config(model_cfg_path: Path) -> ModelConfig:
cfg = load_yaml(model_cfg_path)
preset = cfg.get("preset")
model_cfg = cfg.get("model", {})
if preset:
presets = get_model_presets()
if preset not in presets:
raise ValueError(f"Unknown preset: {preset}")
merged = presets[preset].__dict__.copy()
merged.update(model_cfg)
return ModelConfig(**merged)
return ModelConfig(**model_cfg)
@torch.no_grad()
def eval_val_loss(model: CodeTransformerLM, val_loader: DataLoader, device: torch.device, max_batches: int = 50) -> float:
model.eval()
losses = []
for i, (input_ids, labels) in enumerate(val_loader):
if i >= max_batches:
break
input_ids = input_ids.to(device)
labels = labels.to(device)
with torch.amp.autocast("cuda", enabled=(device.type == "cuda"), dtype=torch.float16):
out = model(input_ids=input_ids, labels=labels)
losses.append(float(out["loss"].item()))
model.train()
if not losses:
return 1e9
return sum(losses) / len(losses)
@torch.no_grad()
def generate_code(
model: CodeTransformerLM,
tokenizer: CodeTokenizer,
prompt: str,
device: torch.device,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> str:
model.eval()
prompt_text = tokenizer.format_training_sample(prompt=prompt, code="", language="python")
# Remove trailing empty code marker noise.
prompt_text = prompt_text.replace(" <NL>", "").strip()
ids = tokenizer.encode(prompt_text)
eos_id = tokenizer.special_token_ids.get("<EOS>", None)
# Remove trailing EOS from prompt so generation continues naturally.
if eos_id is not None and len(ids) > 1 and ids[-1] == int(eos_id):
ids = ids[:-1]
input_ids = torch.tensor([ids], dtype=torch.long, device=device)
for _ in range(max_new_tokens):
out = model(input_ids=input_ids)
logits = out["logits"][:, -1, :]
if temperature <= 0:
next_id = torch.argmax(logits, dim=-1, keepdim=True)
else:
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
# Top-p (nucleus) sampling.
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=-1)
cutoff = cumulative > top_p
cutoff[..., 1:] = cutoff[..., :-1].clone()
cutoff[..., 0] = False
sorted_probs[cutoff] = 0.0
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
sampled = torch.multinomial(sorted_probs, num_samples=1)
next_id = sorted_idx.gather(-1, sampled)
input_ids = torch.cat([input_ids, next_id], dim=1)
if eos_id is not None and int(next_id.item()) == int(eos_id):
break
decoded = tokenizer.decode(input_ids[0].tolist())
code = restore_code_from_structured(decoded)
return code
def main() -> None:
args = parse_args()
try:
cfg = load_yaml(Path(args.config))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != "cuda":
raise RuntimeError("CUDA is required for this evaluation run.")
model_cfg = build_model_config(Path(cfg["model"]["model_config_path"]))
model_cfg.max_seq_len = int(cfg["inference"]["max_seq_len"])
tokenizer = CodeTokenizer.load(str(PROJECT_ROOT / "artifacts" / "tokenizer" / "code_tokenizer_v1"))
val_ds = TokenizedJsonlDataset(
path=str(PROJECT_ROOT / cfg["data"]["tokenized_jsonl_path"]),
split="val",
val_ratio=float(cfg["data"].get("val_ratio", 0.02)),
split_seed=int(cfg["data"].get("split_seed", 17)),
)
val_loader = DataLoader(
val_ds,
batch_size=1,
shuffle=False,
collate_fn=CausalCollator(pad_token_id=0, max_seq_len=model_cfg.max_seq_len),
)
ckpt_results: List[Dict[str, Any]] = []
for ckpt_rel in cfg["model"]["checkpoint_paths"]:
ckpt_path = PROJECT_ROOT / ckpt_rel
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
model = CodeTransformerLM(model_cfg).to(device)
payload = torch.load(ckpt_path, map_location=device)
model.load_state_dict(payload["model_state"])
model.half()
val_loss = eval_val_loss(model, val_loader, device=device, max_batches=50)
generations = []
for p in PROMPTS:
code = generate_code(
model=model,
tokenizer=tokenizer,
prompt=p,
device=device,
max_new_tokens=int(cfg["inference"].get("max_new_tokens", 160)),
temperature=float(cfg["inference"].get("temperature", 0.8)),
top_p=float(cfg["inference"].get("top_p", 0.9)),
)
generations.append(
{
"prompt": p,
"generated_code": code,
"python_syntax_ok": python_syntax_ok(code),
}
)
ckpt_results.append(
{
"checkpoint": str(ckpt_path),
"step": int(payload.get("step", -1)),
"best_val_in_checkpoint": float(payload.get("best_val", math.nan)),
"eval_val_loss_now": float(val_loss),
"generations": generations,
}
)
# Basic fit flags from checkpoint trend.
fit_flag = "healthy"
if ckpt_results and ckpt_results[-1]["eval_val_loss_now"] > 1.5:
fit_flag = "underfitting"
out = {
"fit_flag": fit_flag,
"checkpoints": ckpt_results,
"recommended_prompts": PROMPTS,
}
out_path = str(PROJECT_ROOT / cfg["output"]["results_json"])
save_json(out_path, out)
print("Component 6 evaluation completed.")
print(f"Saved results: {out_path}")
print(f"Fit flag: {fit_flag}")
for row in ckpt_results:
print(f"Checkpoint step={row['step']} val_loss={row['eval_val_loss_now']:.4f}")
ok_count = sum(1 for g in row["generations"] if g["python_syntax_ok"])
print(f"Python syntax valid in generated samples: {ok_count}/5")
except Exception as exc:
print("Component 6 evaluation failed.")
print(f"What went wrong: {exc}")
print("Fix suggestion: verify checkpoint path and tokenizer path.")
raise SystemExit(1)
if __name__ == "__main__":
main()