File size: 2,399 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
from contextlib import contextmanager
from swift.llm import git_clone_github
from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run
logger = get_logger()
def _patch_transformer_engine():
try:
from transformer_engine.pytorch.attention import FusedRoPEFunc
except ImportError:
try:
import transformer_engine
transformer_engine.pytorch.attention.FusedRoPEFunc = (
transformer_engine.pytorch.dot_product_attention.rope.FusedRoPEFunc)
except (ImportError, AttributeError):
pass
def new_cyclic_iter(iter):
from megatron.training import get_args
args = get_args()
max_epochs = args.max_epochs
i = 0
while True:
if getattr(args, 'is_training', False):
if max_epochs and i >= max_epochs:
logger.info(f'Training of {i} epochs has been completed, the training has finished.')
break
logger.info(f'The training of Epoch {i} starts...')
for x in iter:
yield x
i += 1
@contextmanager
def _training_context():
from megatron.training import get_args
args = get_args()
args.is_training = True
try:
yield
finally:
args.is_training = False
def _patch_max_epochs():
# support max_epochs
from megatron.training import training
train_step_origin = training.train_step
def train_step(*args, **kwargs):
with _training_context():
try:
return train_step_origin(*args, **kwargs)
except StopIteration:
return {}, True, True, True, 0, None, None
training.train_step = train_step
training.cyclic_iter = new_cyclic_iter
def _patch_megatron():
_patch_transformer_engine()
_patch_max_epochs()
def init_megatron_env() -> None:
if 'MEGATRON_LM_PATH' not in os.environ:
os.environ['MEGATRON_LM_PATH'] = git_clone_github(
'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.12.0')
with safe_ddp_context(hash_id='megatron-lm'):
if not is_megatron_available():
subprocess_run([sys.executable, '-m', 'pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']])
sys.path.insert(0, os.environ['MEGATRON_LM_PATH'])
_patch_megatron()
|