|
|
""" |
|
|
Utilities for checkpointing evaluation-related states (i.e. evaluation results, etc.) |
|
|
|
|
|
We save the evaluation results in a JSON file at the step-specific evaluation results directory. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
from typing import Any, Dict |
|
|
|
|
|
from huggingface_hub import upload_folder |
|
|
from lightning.fabric import Fabric |
|
|
from lightning.fabric.utilities.rank_zero import rank_zero_only |
|
|
|
|
|
from src.config import CheckpointingConfig |
|
|
from src.training.utils.io import use_backoff |
|
|
|
|
|
|
|
|
@rank_zero_only |
|
|
@use_backoff() |
|
|
def save_evaluation_results( |
|
|
checkpointing_config: CheckpointingConfig, |
|
|
checkpoint_step: int, |
|
|
fabric: Fabric, |
|
|
evaluation_results: Dict[str, Any], |
|
|
) -> None: |
|
|
"""Save evaluation results to disk and optionally to HuggingFace Hub. |
|
|
|
|
|
The evaluation results are saved in the following directory structure: |
|
|
{checkpointing_config.runs_dir}/ |
|
|
βββ {checkpointing_config.run_name}/ |
|
|
βββ {checkpointing_config.eval_results_dir}/ |
|
|
βββ step_{checkpoint_step}.json |
|
|
|
|
|
NOTE: this function is only called on rank 0 to avoid conflicts; assumes that the evaluation |
|
|
results are gathered on rank 0. |
|
|
|
|
|
Args: |
|
|
checkpointing_config: Configuration object containing checkpoint settings |
|
|
checkpoint_step: Current training checkpoint step (i.e. number of learning steps taken) |
|
|
fabric: Lightning Fabric instance |
|
|
evaluation_results: Dictionary containing evaluation metrics |
|
|
""" |
|
|
|
|
|
run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name) |
|
|
eval_results_dir = os.path.join( |
|
|
run_dir, checkpointing_config.evaluation.eval_results_dir |
|
|
) |
|
|
|
|
|
os.makedirs(eval_results_dir, exist_ok=True) |
|
|
|
|
|
curr_eval_results_path = os.path.join( |
|
|
eval_results_dir, f"step_{checkpoint_step}.json" |
|
|
) |
|
|
|
|
|
|
|
|
with open(curr_eval_results_path, "w") as f: |
|
|
json.dump(evaluation_results, f) |
|
|
|
|
|
if checkpointing_config.save_to_hf: |
|
|
upload_folder( |
|
|
folder_path=eval_results_dir, |
|
|
path_in_repo=checkpointing_config.evaluation.eval_results_dir, |
|
|
repo_id=checkpointing_config.hf_checkpoint.repo_id, |
|
|
commit_message=f"Saving Evaluation Results -- Step {checkpoint_step}", |
|
|
revision=checkpointing_config.run_name, |
|
|
token=os.getenv("HF_TOKEN"), |
|
|
) |
|
|
|