|
|
|
|
|
import os |
|
|
import sys |
|
|
from contextlib import contextmanager |
|
|
|
|
|
from swift.llm import git_clone_github |
|
|
from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
def _patch_transformer_engine(): |
|
|
try: |
|
|
from transformer_engine.pytorch.attention import FusedRoPEFunc |
|
|
except ImportError: |
|
|
try: |
|
|
import transformer_engine |
|
|
transformer_engine.pytorch.attention.FusedRoPEFunc = ( |
|
|
transformer_engine.pytorch.dot_product_attention.rope.FusedRoPEFunc) |
|
|
except (ImportError, AttributeError): |
|
|
pass |
|
|
|
|
|
|
|
|
def new_cyclic_iter(iter): |
|
|
from megatron.training import get_args |
|
|
args = get_args() |
|
|
max_epochs = args.max_epochs |
|
|
i = 0 |
|
|
while True: |
|
|
if getattr(args, 'is_training', False): |
|
|
if max_epochs and i >= max_epochs: |
|
|
logger.info(f'Training of {i} epochs has been completed, the training has finished.') |
|
|
break |
|
|
logger.info(f'The training of Epoch {i} starts...') |
|
|
for x in iter: |
|
|
yield x |
|
|
i += 1 |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def _training_context(): |
|
|
from megatron.training import get_args |
|
|
args = get_args() |
|
|
args.is_training = True |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
args.is_training = False |
|
|
|
|
|
|
|
|
def _patch_max_epochs(): |
|
|
|
|
|
from megatron.training import training |
|
|
train_step_origin = training.train_step |
|
|
|
|
|
def train_step(*args, **kwargs): |
|
|
with _training_context(): |
|
|
try: |
|
|
return train_step_origin(*args, **kwargs) |
|
|
except StopIteration: |
|
|
return {}, True, True, True, 0, None, None |
|
|
|
|
|
training.train_step = train_step |
|
|
|
|
|
training.cyclic_iter = new_cyclic_iter |
|
|
|
|
|
|
|
|
def _patch_megatron(): |
|
|
_patch_transformer_engine() |
|
|
_patch_max_epochs() |
|
|
|
|
|
|
|
|
def init_megatron_env() -> None: |
|
|
if 'MEGATRON_LM_PATH' not in os.environ: |
|
|
os.environ['MEGATRON_LM_PATH'] = git_clone_github( |
|
|
'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.12.0') |
|
|
with safe_ddp_context(hash_id='megatron-lm'): |
|
|
if not is_megatron_available(): |
|
|
subprocess_run([sys.executable, '-m', 'pip', 'install', '-e', os.environ['MEGATRON_LM_PATH']]) |
|
|
sys.path.insert(0, os.environ['MEGATRON_LM_PATH']) |
|
|
_patch_megatron() |
|
|
|