File size: 2,399 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
from contextlib import contextmanager

from swift.llm import git_clone_github
from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run

logger = get_logger()


def _patch_transformer_engine():
    try:
        from transformer_engine.pytorch.attention import FusedRoPEFunc
    except ImportError:
        try:
            import transformer_engine
            transformer_engine.pytorch.attention.FusedRoPEFunc = (
                transformer_engine.pytorch.dot_product_attention.rope.FusedRoPEFunc)
        except (ImportError, AttributeError):
            pass


def new_cyclic_iter(iter):
    from megatron.training import get_args
    args = get_args()
    max_epochs = args.max_epochs
    i = 0
    while True:
        if getattr(args, 'is_training', False):
            if max_epochs and i >= max_epochs:
                logger.info(f'Training of {i} epochs has been completed, the training has finished.')
                break
            logger.info(f'The training of Epoch {i} starts...')
        for x in iter:
            yield x
        i += 1


@contextmanager
def _training_context():
    from megatron.training import get_args
    args = get_args()
    args.is_training = True
    try:
        yield
    finally:
        args.is_training = False


def _patch_max_epochs():
    # support max_epochs
    from megatron.training import training
    train_step_origin = training.train_step

    def train_step(*args, **kwargs):
        with _training_context():
            try:
                return train_step_origin(*args, **kwargs)
            except StopIteration:
                return {}, True, True, True, 0, None, None

    training.train_step = train_step

    training.cyclic_iter = new_cyclic_iter


def _patch_megatron():
    _patch_transformer_engine()
    _patch_max_epochs()


def init_megatron_env() -> None:
    if 'MEGATRON_LM_PATH' not in os.environ:
        os.environ['MEGATRON_LM_PATH'] = git_clone_github(
            'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.12.0')
    with safe_ddp_context(hash_id='megatron-lm'):
        if not is_megatron_available():
            subprocess_run([sys.executable, '-m', 'pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']])
    sys.path.insert(0, os.environ['MEGATRON_LM_PATH'])
    _patch_megatron()