Vilyam888's picture
Upload folder using huggingface_hub
aa988a7 verified
from __future__ import annotations
import argparse
import json
import math
import sys
from pathlib import Path
_METRICS_DIR = Path(__file__).resolve().parent
if str(_METRICS_DIR) not in sys.path:
sys.path.insert(0, str(_METRICS_DIR))
from broken_code_generation import FILE_TRAINING, MODEL_ID, TRAINER_STATE # noqa: E402
from report_io import metrics_path, write_report # noqa: E402
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=f"Training metrics for {MODEL_ID} only."
)
parser.add_argument("--trainer_state", type=Path, default=TRAINER_STATE)
parser.add_argument("--output", type=Path, default=None)
return parser.parse_args()
def extract_metrics(state: dict) -> dict:
train_loss = eval_loss = eval_acc = None
eval_by_epoch = []
for entry in state.get("log_history", []):
if "eval_loss" in entry:
eval_by_epoch.append(
{
"epoch": entry.get("epoch"),
"eval_loss": entry.get("eval_loss"),
"eval_mean_token_accuracy": entry.get("eval_mean_token_accuracy"),
"perplexity": round(math.exp(entry["eval_loss"]), 4),
}
)
if "loss" in entry and "eval_loss" not in entry:
train_loss = entry["loss"]
for entry in reversed(state.get("log_history", [])):
if "eval_loss" in entry:
eval_loss = entry["eval_loss"]
eval_acc = entry.get("eval_mean_token_accuracy")
break
return {
"train_loss_final": train_loss,
"eval_loss_final": eval_loss,
"eval_mean_token_accuracy": eval_acc,
"perplexity_validation": round(math.exp(eval_loss), 4) if eval_loss else None,
"num_train_epochs": state.get("num_train_epochs"),
"global_step": state.get("global_step"),
"eval_by_epoch": eval_by_epoch,
}
def main() -> None:
args = parse_args()
output = args.output or metrics_path(FILE_TRAINING)
state = json.loads(args.trainer_state.read_text(encoding="utf-8"))
report = {
"metric_group": "training_perplexity",
"model": MODEL_ID,
"adapter_dir": str(TRAINER_STATE.parent.parent),
"source": str(args.trainer_state),
"metrics": extract_metrics(state),
}
write_report(output, report)
print(json.dumps(report["metrics"], ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()