""" Component 6: Evaluation system. - Computes validation loss for selected checkpoints. - Generates code for 5 simple Python prompts. - Performs syntax validity checks. - Saves results JSON. """ from __future__ import annotations import argparse import json import math import sys from pathlib import Path from typing import Any, Dict, List import torch import yaml from torch.utils.data import DataLoader # Ensure src imports work. PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from src.evaluation_system.code_eval import python_syntax_ok, restore_code_from_structured, save_json # noqa: E402 from src.model_architecture.code_transformer import CodeTransformerLM, ModelConfig, get_model_presets # noqa: E402 from src.tokenizer.code_tokenizer import CodeTokenizer # noqa: E402 from src.training_pipeline.tokenized_dataset import CausalCollator, TokenizedJsonlDataset # noqa: E402 PROMPTS = [ "Write a Python function to check if a number is prime.", "Write Python code to reverse a string without using slicing.", "Create a Python function that returns Fibonacci numbers up to n.", "Write Python code to count word frequency in a sentence.", "Write a Python function to sort a list of dictionaries by a key.", ] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run Component 6 evaluation.") parser.add_argument("--config", default="configs/component6_evaluation_config.yaml") return parser.parse_args() def load_yaml(path: Path) -> Dict[str, Any]: if not path.exists(): raise FileNotFoundError(f"Config not found: {path}") data = yaml.safe_load(path.read_text(encoding="utf-8")) if not isinstance(data, dict): raise ValueError("Invalid YAML config.") return data def build_model_config(model_cfg_path: Path) -> ModelConfig: cfg = load_yaml(model_cfg_path) preset = cfg.get("preset") model_cfg = cfg.get("model", {}) if preset: presets = get_model_presets() if preset not in presets: raise ValueError(f"Unknown preset: {preset}") merged = presets[preset].__dict__.copy() merged.update(model_cfg) return ModelConfig(**merged) return ModelConfig(**model_cfg) @torch.no_grad() def eval_val_loss(model: CodeTransformerLM, val_loader: DataLoader, device: torch.device, max_batches: int = 50) -> float: model.eval() losses = [] for i, (input_ids, labels) in enumerate(val_loader): if i >= max_batches: break input_ids = input_ids.to(device) labels = labels.to(device) with torch.amp.autocast("cuda", enabled=(device.type == "cuda"), dtype=torch.float16): out = model(input_ids=input_ids, labels=labels) losses.append(float(out["loss"].item())) model.train() if not losses: return 1e9 return sum(losses) / len(losses) @torch.no_grad() def generate_code( model: CodeTransformerLM, tokenizer: CodeTokenizer, prompt: str, device: torch.device, max_new_tokens: int, temperature: float, top_p: float, ) -> str: model.eval() prompt_text = tokenizer.format_training_sample(prompt=prompt, code="", language="python") # Remove trailing empty code marker noise. prompt_text = prompt_text.replace(" ", "").strip() ids = tokenizer.encode(prompt_text) eos_id = tokenizer.special_token_ids.get("", None) # Remove trailing EOS from prompt so generation continues naturally. if eos_id is not None and len(ids) > 1 and ids[-1] == int(eos_id): ids = ids[:-1] input_ids = torch.tensor([ids], dtype=torch.long, device=device) for _ in range(max_new_tokens): out = model(input_ids=input_ids) logits = out["logits"][:, -1, :] if temperature <= 0: next_id = torch.argmax(logits, dim=-1, keepdim=True) else: logits = logits / temperature probs = torch.softmax(logits, dim=-1) # Top-p (nucleus) sampling. sorted_probs, sorted_idx = torch.sort(probs, descending=True) cumulative = torch.cumsum(sorted_probs, dim=-1) cutoff = cumulative > top_p cutoff[..., 1:] = cutoff[..., :-1].clone() cutoff[..., 0] = False sorted_probs[cutoff] = 0.0 sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True) sampled = torch.multinomial(sorted_probs, num_samples=1) next_id = sorted_idx.gather(-1, sampled) input_ids = torch.cat([input_ids, next_id], dim=1) if eos_id is not None and int(next_id.item()) == int(eos_id): break decoded = tokenizer.decode(input_ids[0].tolist()) code = restore_code_from_structured(decoded) return code def main() -> None: args = parse_args() try: cfg = load_yaml(Path(args.config)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device.type != "cuda": raise RuntimeError("CUDA is required for this evaluation run.") model_cfg = build_model_config(Path(cfg["model"]["model_config_path"])) model_cfg.max_seq_len = int(cfg["inference"]["max_seq_len"]) tokenizer = CodeTokenizer.load(str(PROJECT_ROOT / "artifacts" / "tokenizer" / "code_tokenizer_v1")) val_ds = TokenizedJsonlDataset( path=str(PROJECT_ROOT / cfg["data"]["tokenized_jsonl_path"]), split="val", val_ratio=float(cfg["data"].get("val_ratio", 0.02)), split_seed=int(cfg["data"].get("split_seed", 17)), ) val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, collate_fn=CausalCollator(pad_token_id=0, max_seq_len=model_cfg.max_seq_len), ) ckpt_results: List[Dict[str, Any]] = [] for ckpt_rel in cfg["model"]["checkpoint_paths"]: ckpt_path = PROJECT_ROOT / ckpt_rel if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") model = CodeTransformerLM(model_cfg).to(device) payload = torch.load(ckpt_path, map_location=device) model.load_state_dict(payload["model_state"]) model.half() val_loss = eval_val_loss(model, val_loader, device=device, max_batches=50) generations = [] for p in PROMPTS: code = generate_code( model=model, tokenizer=tokenizer, prompt=p, device=device, max_new_tokens=int(cfg["inference"].get("max_new_tokens", 160)), temperature=float(cfg["inference"].get("temperature", 0.8)), top_p=float(cfg["inference"].get("top_p", 0.9)), ) generations.append( { "prompt": p, "generated_code": code, "python_syntax_ok": python_syntax_ok(code), } ) ckpt_results.append( { "checkpoint": str(ckpt_path), "step": int(payload.get("step", -1)), "best_val_in_checkpoint": float(payload.get("best_val", math.nan)), "eval_val_loss_now": float(val_loss), "generations": generations, } ) # Basic fit flags from checkpoint trend. fit_flag = "healthy" if ckpt_results and ckpt_results[-1]["eval_val_loss_now"] > 1.5: fit_flag = "underfitting" out = { "fit_flag": fit_flag, "checkpoints": ckpt_results, "recommended_prompts": PROMPTS, } out_path = str(PROJECT_ROOT / cfg["output"]["results_json"]) save_json(out_path, out) print("Component 6 evaluation completed.") print(f"Saved results: {out_path}") print(f"Fit flag: {fit_flag}") for row in ckpt_results: print(f"Checkpoint step={row['step']} val_loss={row['eval_val_loss_now']:.4f}") ok_count = sum(1 for g in row["generations"] if g["python_syntax_ok"]) print(f"Python syntax valid in generated samples: {ok_count}/5") except Exception as exc: print("Component 6 evaluation failed.") print(f"What went wrong: {exc}") print("Fix suggestion: verify checkpoint path and tokenizer path.") raise SystemExit(1) if __name__ == "__main__": main()