# Copyright (c) Alibaba, Inc. and its affiliates. import os from contextlib import contextmanager from functools import wraps import torch from megatron.training import get_args, global_vars, initialize, training from swift.utils import JsonlWriter, is_master @contextmanager def patch_training_log(): jsonl_writer = None origin_training_log = training.training_log @wraps(origin_training_log) def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad, *_args, **kwargs): nonlocal jsonl_writer args = get_args() if is_master() and iteration % args.log_interval == 0: logging_path = os.path.join(args.save, 'logging.jsonl') logs = {} for k, v in loss_dict.items(): if isinstance(v, torch.Tensor): v = v.item() logs[k] = round(v, 8) for k in {'grad_norm', 'params_norm', 'learning_rate'}: v = locals()[k] if v is not None: logs[k] = round(v, 8) logs['consumed_samples'] = args.consumed_train_samples logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}' if jsonl_writer is None: jsonl_writer = JsonlWriter(logging_path, enable_async=True) jsonl_writer.append(logs) return origin_training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad, *_args, **kwargs) training.training_log = training_log try: yield finally: training.training_log = origin_training_log @contextmanager def patch_megatron_data_collator(data_collator): origin_build_pretraining_data_loader = training.build_pretraining_data_loader def build_pretraining_data_loader(*_args, **kwargs): args = get_args() res = origin_build_pretraining_data_loader(*_args, **kwargs) if res is not None and args.dataloader_type != 'external': res.collate_fn = data_collator return res training.build_pretraining_data_loader = build_pretraining_data_loader try: yield finally: training.build_pretraining_data_loader = origin_build_pretraining_data_loader