linalg-zero / linalg_zero /sft /tool_evaluation.py
atomwalk12's picture
initial commit
0dd6c2f
import logging
from typing import Any
from transformers.trainer_callback import (
TrainerCallback,
TrainerControl,
TrainerState,
)
from transformers.training_args import TrainingArguments
from trl.trainer.model_config import ModelConfig
from linalg_zero.sft.hub import push_to_hub_revision
logger = logging.getLogger(__name__)
class EvaluationState:
"""Tracks evaluation state."""
def __init__(self) -> None:
self.messages: list[dict[str, Any]] = []
self.sample: dict[str, Any] | None = None
self.strict_format_match = 0.0
self.partial_format_score = 0.0
self.tool_parse_success = False
self.generated_answer = None
self.early_stop_reason: str | None = None
class DummyConfig:
def __init__(self, **kwargs: Any) -> None:
for k, v in kwargs.items():
setattr(self, k, v)
class PushToHubRevisionCallback(TrainerCallback):
def __init__(self, model_config: ModelConfig) -> None:
self.model_config = model_config
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs: Any,
) -> None:
if state.is_world_process_zero:
global_step = state.global_step
# WARNING: if you use dataclasses.replace(args, ...) the accelerator dist state will be broken
# Also if you instantiate a new SFTConfig, the accelerator dist state will also be broken
dummy_config = DummyConfig(
hub_model_id=args.hub_model_id,
hub_model_revision=f"{args.hub_model_revision}-step-{global_step:09d}",
output_dir=f"{args.output_dir}/checkpoint-{global_step}",
system_prompt=args.system_prompt,
)
_ = push_to_hub_revision(dummy_config, extra_ignore_patterns=["*.pt"]) # don't push the optimizer states