File size: 2,533 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from contextlib import contextmanager
from functools import wraps

import torch
from megatron.training import get_args, global_vars, initialize, training

from swift.utils import JsonlWriter, is_master


@contextmanager
def patch_training_log():
    jsonl_writer = None
    origin_training_log = training.training_log

    @wraps(origin_training_log)
    def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale,
                     report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad, *_args, **kwargs):
        nonlocal jsonl_writer
        args = get_args()
        if is_master() and iteration % args.log_interval == 0:
            logging_path = os.path.join(args.save, 'logging.jsonl')
            logs = {}
            for k, v in loss_dict.items():
                if isinstance(v, torch.Tensor):
                    v = v.item()
                logs[k] = round(v, 8)
            for k in {'grad_norm', 'params_norm', 'learning_rate'}:
                v = locals()[k]
                if v is not None:
                    logs[k] = round(v, 8)
            logs['consumed_samples'] = args.consumed_train_samples
            logs['global_step/max_steps'] = f'{iteration}/{args.train_iters}'
            if jsonl_writer is None:
                jsonl_writer = JsonlWriter(logging_path, enable_async=True)
            jsonl_writer.append(logs)
        return origin_training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration,
                                   loss_scale, report_memory_flag, skipped_iter, grad_norm, params_norm,
                                   num_zeros_in_grad, *_args, **kwargs)

    training.training_log = training_log
    try:
        yield
    finally:
        training.training_log = origin_training_log


@contextmanager
def patch_megatron_data_collator(data_collator):
    origin_build_pretraining_data_loader = training.build_pretraining_data_loader

    def build_pretraining_data_loader(*_args, **kwargs):
        args = get_args()
        res = origin_build_pretraining_data_loader(*_args, **kwargs)
        if res is not None and args.dataloader_type != 'external':
            res.collate_fn = data_collator
        return res

    training.build_pretraining_data_loader = build_pretraining_data_loader
    try:
        yield
    finally:
        training.build_pretraining_data_loader = origin_build_pretraining_data_loader