| import json |
| import logging |
| import os |
| import signal |
| import sys |
| import time |
| from concurrent.futures import ThreadPoolExecutor |
| from datetime import timedelta |
| from typing import TYPE_CHECKING, Any, Dict, Optional |
|
|
| import transformers |
| from transformers import TrainerCallback |
| from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length |
|
|
| from .constants import TRAINER_LOG |
| from .logging import LoggerHandler, get_logger |
| from .misc import fix_valuehead_checkpoint |
|
|
|
|
| if TYPE_CHECKING: |
| from transformers import TrainerControl, TrainerState, TrainingArguments |
|
|
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class FixValueHeadModelCallback(TrainerCallback): |
| def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| r""" |
| Event called after a checkpoint save. |
| """ |
| if args.should_save: |
| fix_valuehead_checkpoint( |
| model=kwargs.pop("model"), |
| output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), |
| safe_serialization=args.save_safetensors, |
| ) |
|
|
|
|
| class LogCallback(TrainerCallback): |
| def __init__(self, output_dir: str) -> None: |
| r""" |
| Initializes a callback for logging training and evaluation status. |
| """ |
| """ Progress """ |
| self.start_time = 0 |
| self.cur_steps = 0 |
| self.max_steps = 0 |
| self.elapsed_time = "" |
| self.remaining_time = "" |
| self.thread_pool: Optional["ThreadPoolExecutor"] = None |
| """ Status """ |
| self.aborted = False |
| self.do_train = False |
| """ Web UI """ |
| self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] |
| if self.webui_mode: |
| signal.signal(signal.SIGABRT, self._set_abort) |
| self.logger_handler = LoggerHandler(output_dir) |
| logging.root.addHandler(self.logger_handler) |
| transformers.logging.add_handler(self.logger_handler) |
|
|
| def _set_abort(self, signum, frame) -> None: |
| self.aborted = True |
|
|
| def _reset(self, max_steps: int = 0) -> None: |
| self.start_time = time.time() |
| self.cur_steps = 0 |
| self.max_steps = max_steps |
| self.elapsed_time = "" |
| self.remaining_time = "" |
|
|
| def _timing(self, cur_steps: int) -> None: |
| cur_time = time.time() |
| elapsed_time = cur_time - self.start_time |
| avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 |
| remaining_time = (self.max_steps - cur_steps) * avg_time_per_step |
| self.cur_steps = cur_steps |
| self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) |
| self.remaining_time = str(timedelta(seconds=int(remaining_time))) |
|
|
| def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None: |
| with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: |
| f.write(json.dumps(logs) + "\n") |
|
|
| def _create_thread_pool(self, output_dir: str) -> None: |
| os.makedirs(output_dir, exist_ok=True) |
| self.thread_pool = ThreadPoolExecutor(max_workers=1) |
|
|
| def _close_thread_pool(self) -> None: |
| if self.thread_pool is not None: |
| self.thread_pool.shutdown(wait=True) |
| self.thread_pool = None |
|
|
| def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| r""" |
| Event called at the end of the initialization of the `Trainer`. |
| """ |
| if ( |
| args.should_save |
| and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) |
| and args.overwrite_output_dir |
| ): |
| logger.warning("Previous trainer log in this folder will be deleted.") |
| os.remove(os.path.join(args.output_dir, TRAINER_LOG)) |
|
|
| def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| r""" |
| Event called at the beginning of training. |
| """ |
| if args.should_save: |
| self.do_train = True |
| self._reset(max_steps=state.max_steps) |
| self._create_thread_pool(output_dir=args.output_dir) |
|
|
| def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| r""" |
| Event called at the end of training. |
| """ |
| self._close_thread_pool() |
|
|
| def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| r""" |
| Event called at the end of an substep during gradient accumulation. |
| """ |
| if self.aborted: |
| control.should_epoch_stop = True |
| control.should_training_stop = True |
|
|
| def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| r""" |
| Event called at the end of a training step. |
| """ |
| if self.aborted: |
| control.should_epoch_stop = True |
| control.should_training_stop = True |
|
|
| def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| r""" |
| Event called after an evaluation phase. |
| """ |
| if not self.do_train: |
| self._close_thread_pool() |
|
|
| def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| r""" |
| Event called after a successful prediction. |
| """ |
| if not self.do_train: |
| self._close_thread_pool() |
|
|
| def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
| r""" |
| Event called after logging the last logs. |
| """ |
| if not args.should_save: |
| return |
|
|
| self._timing(cur_steps=state.global_step) |
| logs = dict( |
| current_steps=self.cur_steps, |
| total_steps=self.max_steps, |
| loss=state.log_history[-1].get("loss", None), |
| eval_loss=state.log_history[-1].get("eval_loss", None), |
| predict_loss=state.log_history[-1].get("predict_loss", None), |
| reward=state.log_history[-1].get("reward", None), |
| accuracy=state.log_history[-1].get("rewards/accuracies", None), |
| learning_rate=state.log_history[-1].get("learning_rate", None), |
| epoch=state.log_history[-1].get("epoch", None), |
| percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, |
| elapsed_time=self.elapsed_time, |
| remaining_time=self.remaining_time, |
| ) |
| logs = {k: v for k, v in logs.items() if v is not None} |
| if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): |
| logger.info( |
| "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( |
| logs["loss"], logs["learning_rate"], logs["epoch"] |
| ) |
| ) |
|
|
| if self.thread_pool is not None: |
| self.thread_pool.submit(self._write_log, args.output_dir, logs) |
|
|
| def on_prediction_step( |
| self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs |
| ): |
| r""" |
| Event called after a prediction step. |
| """ |
| if self.do_train: |
| return |
|
|
| if self.aborted: |
| sys.exit(0) |
|
|
| if not args.should_save: |
| return |
|
|
| eval_dataloader = kwargs.pop("eval_dataloader", None) |
| if has_length(eval_dataloader): |
| if self.max_steps == 0: |
| self._reset(max_steps=len(eval_dataloader)) |
| self._create_thread_pool(output_dir=args.output_dir) |
|
|
| self._timing(cur_steps=self.cur_steps + 1) |
| if self.cur_steps % 5 == 0 and self.thread_pool is not None: |
| logs = dict( |
| current_steps=self.cur_steps, |
| total_steps=self.max_steps, |
| percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, |
| elapsed_time=self.elapsed_time, |
| remaining_time=self.remaining_time, |
| ) |
| self.thread_pool.submit(self._write_log, args.output_dir, logs) |
|
|