bbb / ms-swift /swift /trainers /callback.py
Student0809's picture
Add files using upload-large-folder tool
2742ed8 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import time
from tqdm import tqdm
from transformers import trainer
from transformers.trainer_callback import (DefaultFlowCallback, PrinterCallback, ProgressCallback, TrainerControl,
TrainerState)
from transformers.trainer_utils import IntervalStrategy, has_length, speed_metrics
from swift.utils import append_to_jsonl, is_pai_training_job, use_torchacc
from ..utils.utils import format_time
from .arguments import TrainingArguments
def add_train_message(logs, state, start_time) -> None:
logs['global_step/max_steps'] = f'{state.global_step}/{state.max_steps}'
train_percentage = state.global_step / state.max_steps if state.max_steps else 0.
logs['percentage'] = f'{train_percentage * 100:.2f}%'
elapsed = time.time() - start_time
logs['elapsed_time'] = format_time(elapsed)
if train_percentage != 0:
logs['remaining_time'] = format_time(elapsed / train_percentage - elapsed)
for k, v in logs.items():
if isinstance(v, float):
logs[k] = round(logs[k], 8)
class ProgressCallbackNew(ProgressCallback):
def on_train_begin(self, args, state, control, **kwargs):
if state.is_world_process_zero:
self.training_bar = tqdm(desc='Train', total=state.max_steps, dynamic_ncols=True)
self.current_step = 0
self.start_time = time.time()
if use_torchacc():
self.warmup_start_time = 0
self.warmup_metric = None
self.metric_warmup_step = int(args.metric_warmup_step
* state.max_steps) if args.metric_warmup_step < 1 else args.metric_warmup_step
def on_prediction_step(self, args, state: TrainerState, control, eval_dataloader=None, **kwargs):
if state.is_world_process_zero and has_length(eval_dataloader):
if self.prediction_bar is None:
if self.training_bar is not None:
self.training_bar.fp.write('\n')
self.prediction_bar = tqdm(
desc='Val', total=len(eval_dataloader), leave=True, dynamic_ncols=True, position=0)
self.prediction_bar.update()
def on_log(self, args: TrainingArguments, state: TrainerState, control, logs=None, **kwargs):
if use_torchacc():
if state.global_step >= self.metric_warmup_step and self.warmup_start_time == 0:
self.warmup_start_time = time.time()
self.metric_warmup_step = state.global_step
if state.max_steps == state.global_step and self.warmup_metric is None:
num_steps = state.max_steps - self.metric_warmup_step
num_total_samples = args.train_dataset_sample
num_after_warmup_samples = int(num_total_samples / state.max_steps * num_steps)
self.warmup_metric = speed_metrics('warmup_train', self.warmup_start_time, num_after_warmup_samples,
num_steps)
self.warmup_metric['num_total_samples'] = num_total_samples
self.warmup_metric['num_after_warmup_samples'] = num_after_warmup_samples
if 'train_samples_per_second' in logs:
logs.update(self.warmup_metric)
state.log_history[-1] = logs
add_train_message(logs, state, self.start_time)
if not is_pai_training_job() and state.is_world_process_zero:
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, logs)
super().on_log(args, state, control, logs, **kwargs)
if state.is_world_process_zero and self.training_bar is not None:
self.training_bar.refresh()
class DefaultFlowCallbackNew(DefaultFlowCallback):
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
control = super().on_step_end(args, state, control, **kwargs)
# save the last ckpt
evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy
if state.global_step == state.max_steps:
if evaluation_strategy != IntervalStrategy.NO:
control.should_evaluate = True
if args.save_strategy != IntervalStrategy.NO:
control.should_save = True
return control
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
control = super().on_epoch_end(args, state, control, **kwargs)
evaluation_strategy = args.eval_strategy if hasattr(args, 'eval_strategy') else args.evaluation_strategy
if args.max_epochs is not None and args.max_epochs <= math.ceil(state.epoch):
if evaluation_strategy != IntervalStrategy.NO:
control.should_evaluate = True
if args.save_strategy != IntervalStrategy.NO:
control.should_save = True
control.should_training_stop = True
return control
class PrinterCallbackNew(PrinterCallback):
def on_train_begin(self, args, state, control, **kwargs):
self.start_time = time.time()
return super().on_train_begin(args, state, control, **kwargs)
def on_log(self, args, state, control, logs=None, **kwargs):
add_train_message(logs, state, self.start_time)
if not is_pai_training_job() and state.is_world_process_zero:
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
append_to_jsonl(jsonl_path, logs)
_ = logs.pop('total_flos', None)
if state.is_world_process_zero:
print(logs, flush=True)
# monkey patching
trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew
trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew]
trainer.PrinterCallback = PrinterCallbackNew