File size: 2,615 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import List, Union
from megatron.core.enums import ModelType
from megatron.training import pretrain
from swift.llm.train import SwiftSft
from swift.utils import get_logger, is_master, plot_images
from ..argument import MegatronTrainArguments
from ..utils import patch_megatron_tokenizer
from .patcher import patch_megatron_data_collator, patch_training_log
from .utils import build_streaming_dataloader, forward_step, get_swift_datasets_provider
logger = get_logger()
class MegatronSft(SwiftSft):
args_class = MegatronTrainArguments
args: args_class
def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None) -> None:
self.train_msg = {}
super(SwiftSft, self).__init__(args)
args = self.args
_, self.processor = args.get_model_processor(load_model=False)
patch_megatron_tokenizer(self.processor)
args.init_model_args(self.processor.model_info.config)
self._prepare_template()
self.template.use_megatron = True
args.save_args(args.save)
def run(self):
args = self.args
train_dataset, val_dataset = self._get_dataset()
train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset)
data_collator = self.template.data_collator
if args.streaming:
train_dataset = build_streaming_dataloader(args, train_dataset, data_collator)
if val_dataset is not None:
val_dataset = build_streaming_dataloader(args, val_dataset, data_collator)
datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset)
datasets_provider.is_distributed = True
logging_path = os.path.join(args.save, 'logging.jsonl')
logger.info(f'The logging file will be saved in: {logging_path}')
try:
with patch_training_log(), patch_megatron_data_collator(data_collator):
pretrain(
datasets_provider,
args.megatron_model_meta.model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults=args.extra_args)
finally:
# Visualization
if is_master():
images_dir = os.path.join(args.save, 'images')
logger.info(f'images_dir: {images_dir}')
plot_images(images_dir, args.tensorboard_dir)
def megatron_sft_main(args: Union[List[str], MegatronTrainArguments, None] = None):
return MegatronSft(args).main()
|