Spaces:
Running on Zero
Running on Zero
File size: 1,756 Bytes
0dd6c2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | from transformers.trainer_callback import (
EarlyStoppingCallback,
TrainerCallback,
)
from trl import ModelConfig, ScriptArguments
from trl.data_utils import DatasetDict
from linalg_zero.config.data import SFTRunConfig
from linalg_zero.sft.tool_calling_accuracy import ToolCallingAccuracyCallback
from linalg_zero.sft.tool_evaluation import PushToHubRevisionCallback
CALLBACKS = {
"push_to_hub_revision": PushToHubRevisionCallback,
"tool_calling_accuracy": ToolCallingAccuracyCallback,
"early_stopping": EarlyStoppingCallback,
}
def get_callbacks(
train_config: SFTRunConfig, model_config: ModelConfig, script_args: ScriptArguments, dataset: DatasetDict
) -> list[TrainerCallback]:
callbacks = []
for callback_name in train_config.callbacks:
if callback_name not in CALLBACKS:
raise ValueError(f"Callback {callback_name} not found in CALLBACKS.")
# Different callbacks have different constructor signatures
if callback_name == "tool_calling_accuracy":
callbacks.append(
CALLBACKS[callback_name](
model_name=model_config.model_name_or_path,
dataset_name=script_args.dataset_name,
eval_dataset=dataset[script_args.dataset_test_split],
)
)
elif callback_name == "early_stopping":
patience = train_config.early_stopping_patience
threshold = train_config.early_stopping_threshold
callbacks.append(
CALLBACKS[callback_name](early_stopping_patience=patience, early_stopping_threshold=threshold)
)
else:
callbacks.append(CALLBACKS[callback_name](model_config))
return callbacks
|