Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
from contextlib import contextmanager
from swift.llm import git_clone_github
from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run
logger = get_logger()
def _patch_transformer_engine():
try:
from transformer_engine.pytorch.attention import FusedRoPEFunc
except ImportError:
try:
import transformer_engine
transformer_engine.pytorch.attention.FusedRoPEFunc = (
transformer_engine.pytorch.dot_product_attention.rope.FusedRoPEFunc)
except (ImportError, AttributeError):
pass
def new_cyclic_iter(iter):
from megatron.training import get_args
args = get_args()
max_epochs = args.max_epochs
i = 0
while True:
if getattr(args, 'is_training', False):
if max_epochs and i >= max_epochs:
logger.info(f'Training of {i} epochs has been completed, the training has finished.')
break
logger.info(f'The training of Epoch {i} starts...')
for x in iter:
yield x
i += 1
@contextmanager
def _training_context():
from megatron.training import get_args
args = get_args()
args.is_training = True
try:
yield
finally:
args.is_training = False
def _patch_max_epochs():
# support max_epochs
from megatron.training import training
train_step_origin = training.train_step
def train_step(*args, **kwargs):
with _training_context():
try:
return train_step_origin(*args, **kwargs)
except StopIteration:
return {}, True, True, True, 0, None, None
training.train_step = train_step
training.cyclic_iter = new_cyclic_iter
def _patch_megatron():
_patch_transformer_engine()
_patch_max_epochs()
def init_megatron_env() -> None:
if 'MEGATRON_LM_PATH' not in os.environ:
os.environ['MEGATRON_LM_PATH'] = git_clone_github(
'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.12.0')
with safe_ddp_context(hash_id='megatron-lm'):
if not is_megatron_available():
subprocess_run([sys.executable, '-m', 'pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']])
sys.path.insert(0, os.environ['MEGATRON_LM_PATH'])
_patch_megatron()