File size: 2,381 Bytes
feba2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Utilities for checkpointing evaluation-related states (i.e. evaluation results, etc.)

We save the evaluation results in a JSON file at the step-specific evaluation results directory.
"""

import json
import os
from typing import Any, Dict

from huggingface_hub import upload_folder
from lightning.fabric import Fabric
from lightning.fabric.utilities.rank_zero import rank_zero_only

from src.config import CheckpointingConfig
from src.training.utils.io import use_backoff


@rank_zero_only
@use_backoff()
def save_evaluation_results(
    checkpointing_config: CheckpointingConfig,
    checkpoint_step: int,
    fabric: Fabric,
    evaluation_results: Dict[str, Any],
) -> None:
    """Save evaluation results to disk and optionally to HuggingFace Hub.

    The evaluation results are saved in the following directory structure:
    {checkpointing_config.runs_dir}/
        └── {checkpointing_config.run_name}/
            └── {checkpointing_config.eval_results_dir}/
                └── step_{checkpoint_step}.json

    NOTE: this function is only called on rank 0 to avoid conflicts; assumes that the evaluation
    results are gathered on rank 0.

    Args:
        checkpointing_config: Configuration object containing checkpoint settings
        checkpoint_step: Current training checkpoint step (i.e. number of learning steps taken)
        fabric: Lightning Fabric instance
        evaluation_results: Dictionary containing evaluation metrics
    """

    run_dir = os.path.join(checkpointing_config.runs_dir, checkpointing_config.run_name)
    eval_results_dir = os.path.join(
        run_dir, checkpointing_config.evaluation.eval_results_dir
    )

    os.makedirs(eval_results_dir, exist_ok=True)

    curr_eval_results_path = os.path.join(
        eval_results_dir, f"step_{checkpoint_step}.json"
    )

    # save out as json
    with open(curr_eval_results_path, "w") as f:
        json.dump(evaluation_results, f)

    if checkpointing_config.save_to_hf:
        upload_folder(
            folder_path=eval_results_dir,
            path_in_repo=checkpointing_config.evaluation.eval_results_dir,
            repo_id=checkpointing_config.hf_checkpoint.repo_id,
            commit_message=f"Saving Evaluation Results -- Step {checkpoint_step}",
            revision=checkpointing_config.run_name,
            token=os.getenv("HF_TOKEN"),
        )