File size: 4,108 Bytes
ccefec1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# -*- coding: utf-8 -*-
import json
import logging
import os
import sys
import time
from transformers.trainer_callback import (ExportableState, TrainerCallback,
TrainerControl, TrainerState)
from transformers.training_args import TrainingArguments
def get_logger(name: str = None) -> logging.Logger:
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
logger = logging.getLogger(name)
if 'RANK' in os.environ and int(os.environ['RANK']) == 0:
logger.setLevel(logging.INFO)
logger.addHandler(handler)
return logger
logger = get_logger(__name__)
LOG_FILE_NAME = "trainer_log.jsonl"
class LogCallback(TrainerCallback, ExportableState):
def __init__(self, start_time: float = None, elapsed_time: float = None):
self.start_time = time.time() if start_time is None else start_time
self.elapsed_time = 0 if elapsed_time is None else elapsed_time
self.last_time = self.start_time
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs
):
r"""
Event called at the beginning of training.
"""
if state.is_local_process_zero:
if not args.resume_from_checkpoint:
self.start_time = time.time()
self.elapsed_time = 0
else:
self.start_time = state.stateful_callbacks['LogCallback']['start_time']
self.elapsed_time = state.stateful_callbacks['LogCallback']['elapsed_time']
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
self.last_time = time.time()
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs,
**kwargs
):
if args.save_on_each_node:
if not state.is_local_process_zero:
return
else:
if not state.is_world_process_zero:
return
self.elapsed_time += time.time() - self.last_time
self.last_time = time.time()
if 'num_input_tokens_seen' in logs:
logs['num_tokens'] = logs.pop('num_input_tokens_seen')
state.log_history[-1].pop('num_input_tokens_seen')
throughput = logs['num_tokens'] / args.world_size / self.elapsed_time
state.log_history[-1]['throughput'] = logs['throughput'] = throughput
state.stateful_callbacks["LogCallback"] = self.state()
logs = dict(
current_steps=state.global_step,
total_steps=state.max_steps,
loss=state.log_history[-1].get("loss", None),
eval_loss=state.log_history[-1].get("eval_loss", None),
predict_loss=state.log_history[-1].get("predict_loss", None),
learning_rate=state.log_history[-1].get("learning_rate", None),
epoch=state.log_history[-1].get("epoch", None),
percentage=round(state.global_step / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(logs) + "\n")
def state(self) -> dict:
return {
'start_time': self.start_time,
'elapsed_time': self.elapsed_time
}
@classmethod
def from_state(cls, state):
return cls(state['start_time'], state['elapsed_time'])
|