File size: 2,571 Bytes
aa988a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import annotations

import argparse
import json
import math
import sys
from pathlib import Path

_METRICS_DIR = Path(__file__).resolve().parent
if str(_METRICS_DIR) not in sys.path:
    sys.path.insert(0, str(_METRICS_DIR))

from broken_code_generation import FILE_TRAINING, MODEL_ID, TRAINER_STATE  # noqa: E402
from report_io import metrics_path, write_report  # noqa: E402


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=f"Training metrics for {MODEL_ID} only."
    )
    parser.add_argument("--trainer_state", type=Path, default=TRAINER_STATE)
    parser.add_argument("--output", type=Path, default=None)
    return parser.parse_args()


def extract_metrics(state: dict) -> dict:
    train_loss = eval_loss = eval_acc = None
    eval_by_epoch = []

    for entry in state.get("log_history", []):
        if "eval_loss" in entry:
            eval_by_epoch.append(
                {
                    "epoch": entry.get("epoch"),
                    "eval_loss": entry.get("eval_loss"),
                    "eval_mean_token_accuracy": entry.get("eval_mean_token_accuracy"),
                    "perplexity": round(math.exp(entry["eval_loss"]), 4),
                }
            )
        if "loss" in entry and "eval_loss" not in entry:
            train_loss = entry["loss"]

    for entry in reversed(state.get("log_history", [])):
        if "eval_loss" in entry:
            eval_loss = entry["eval_loss"]
            eval_acc = entry.get("eval_mean_token_accuracy")
            break

    return {
        "train_loss_final": train_loss,
        "eval_loss_final": eval_loss,
        "eval_mean_token_accuracy": eval_acc,
        "perplexity_validation": round(math.exp(eval_loss), 4) if eval_loss else None,
        "num_train_epochs": state.get("num_train_epochs"),
        "global_step": state.get("global_step"),
        "eval_by_epoch": eval_by_epoch,
    }


def main() -> None:
    args = parse_args()
    output = args.output or metrics_path(FILE_TRAINING)
    state = json.loads(args.trainer_state.read_text(encoding="utf-8"))
    report = {
        "metric_group": "training_perplexity",
        "model": MODEL_ID,
        "adapter_dir": str(TRAINER_STATE.parent.parent),
        "source": str(args.trainer_state),
        "metrics": extract_metrics(state),
    }
    write_report(output, report)
    print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()