from __future__ import annotations import argparse import json import math import sys from pathlib import Path _METRICS_DIR = Path(__file__).resolve().parent if str(_METRICS_DIR) not in sys.path: sys.path.insert(0, str(_METRICS_DIR)) from broken_code_generation import FILE_TRAINING, MODEL_ID, TRAINER_STATE # noqa: E402 from report_io import metrics_path, write_report # noqa: E402 def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=f"Training metrics for {MODEL_ID} only." ) parser.add_argument("--trainer_state", type=Path, default=TRAINER_STATE) parser.add_argument("--output", type=Path, default=None) return parser.parse_args() def extract_metrics(state: dict) -> dict: train_loss = eval_loss = eval_acc = None eval_by_epoch = [] for entry in state.get("log_history", []): if "eval_loss" in entry: eval_by_epoch.append( { "epoch": entry.get("epoch"), "eval_loss": entry.get("eval_loss"), "eval_mean_token_accuracy": entry.get("eval_mean_token_accuracy"), "perplexity": round(math.exp(entry["eval_loss"]), 4), } ) if "loss" in entry and "eval_loss" not in entry: train_loss = entry["loss"] for entry in reversed(state.get("log_history", [])): if "eval_loss" in entry: eval_loss = entry["eval_loss"] eval_acc = entry.get("eval_mean_token_accuracy") break return { "train_loss_final": train_loss, "eval_loss_final": eval_loss, "eval_mean_token_accuracy": eval_acc, "perplexity_validation": round(math.exp(eval_loss), 4) if eval_loss else None, "num_train_epochs": state.get("num_train_epochs"), "global_step": state.get("global_step"), "eval_by_epoch": eval_by_epoch, } def main() -> None: args = parse_args() output = args.output or metrics_path(FILE_TRAINING) state = json.loads(args.trainer_state.read_text(encoding="utf-8")) report = { "metric_group": "training_perplexity", "model": MODEL_ID, "adapter_dir": str(TRAINER_STATE.parent.parent), "source": str(args.trainer_state), "metrics": extract_metrics(state), } write_report(output, report) print(json.dumps(report["metrics"], ensure_ascii=False, indent=2)) if __name__ == "__main__": main()