# Copyright (c) Alibaba, Inc. and its affiliates. import os from functools import partial from typing import List, Union from datasets import Dataset as HfDataset from swift.plugin import extra_callbacks, get_loss_func, get_metric from swift.trainers import TrainerFactory from swift.utils import (append_to_jsonl, get_logger, get_model_parameter_info, is_master, plot_images, stat_array, use_torchacc) from ..argument import TrainArguments from ..base import SwiftPipeline from ..dataset import (EncodePreprocessor, GetLengthPreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, load_dataset) from ..infer import prepare_generation_config from ..model import HfConfigFactory, get_model_arch from ..utils import deep_getattr, dynamic_gradient_checkpointing from .tuner import TunerMixin logger = get_logger() class SwiftSft(SwiftPipeline, TunerMixin): args_class = TrainArguments args: args_class def __init__(self, args: Union[List[str], TrainArguments, None] = None) -> None: super().__init__(args) self.train_msg = {} self._prepare_model_tokenizer() self._prepare_template() self._prepare_callbacks() def _prepare_gradient_checkpointing(self): args = self.args HfConfigFactory.set_model_config_attr(self.model, 'use_cache', False) if args.gradient_checkpointing: self.model.supports_gradient_checkpointing = True dynamic_gradient_checkpointing(self.model) self.model.enable_input_require_grads() model_meta = self.model.model_meta model_arch = get_model_arch(model_meta.model_arch) if model_meta.is_multimodal and model_arch: for vision_tower_name in model_arch.vision_tower: vision_tower = deep_getattr(self.model, vision_tower_name) if hasattr(vision_tower, 'enable_input_require_grads'): try: vision_tower.enable_input_require_grads() except NotImplementedError: pass def _prepare_generation_config(self): args = self.args self.model.origin_generation_config = self.model.generation_config self.model.generation_config = prepare_generation_config(self.model.generation_config, args.get_request_config(), self.tokenizer) logger.info(f'model.generation_config: {self.model.generation_config}') def _prepare_model_tokenizer(self): args = self.args if args.sequence_parallel_size > 1: from swift.trainers.sequence_parallel import sequence_parallel sequence_parallel.init_sequence_parallel(args.sequence_parallel_size) self.model, self.processor = args.get_model_processor() if hasattr(self.model, 'hf_device_map'): logger.info(f'model.hf_device_map: {self.model.hf_device_map}') logger.info(f'model_info: {self.model.model_info}') self._prepare_generation_config() self._prepare_gradient_checkpointing() def _prepare_template(self) -> None: template = self.args.get_template(self.processor) if self.args.task_type == 'causal_lm': template.set_mode('train') if template.use_model: template.model = self.model self.template = template def _get_dataset(self): # The random shuffling of the training set occurs in the dataloader of the trainer. args = self.args dataset_kwargs = args.get_dataset_kwargs() train_dataset, val_dataset = load_dataset( args.dataset, split_dataset_ratio=args.split_dataset_ratio, shuffle=args.dataset_shuffle, **dataset_kwargs) if len(args.val_dataset) > 0: # Loading val dataset _, val_dataset = load_dataset( args.val_dataset, split_dataset_ratio=1.0, shuffle=args.val_dataset_shuffle, **dataset_kwargs) assert args.split_dataset_ratio == 0. logger.info(f'train_dataset: {train_dataset}') logger.info(f'val_dataset: {val_dataset}') return train_dataset, val_dataset def _get_loss_func(self): args = self.args loss_type = args.loss_type if loss_type is None and args.loss_scale != 'default': loss_type = 'loss_scale' return get_loss_func(loss_type) def _get_data_collator(self): args = self.args template = self.template padding_to = args.max_length if args.train_type == 'longlora' else None return partial(template.data_collator, padding_to=padding_to) @staticmethod def _save_val_dataset(output_dir: str, val_dataset): if is_master() and isinstance(val_dataset, HfDataset): os.makedirs(output_dir, exist_ok=True) val_dataset_path = os.path.join(output_dir, 'val_dataset.jsonl') append_to_jsonl(val_dataset_path, val_dataset.to_list()) logger.info(f'The split dataset from the training set will be saved at: {val_dataset_path}.') def run(self): args = self.args train_dataset, val_dataset = self._get_dataset() train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) if args.task_type == 'seq_cls': args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None) logger.info(f'args.problem_type: {args.problem_type}') args.save_args() data_collator = self._get_data_collator() # Some tuners require train_dataset and data_collator for preparation: LoRA-GA self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset) logger.info(f'model: {self.model}') model_parameter_info = get_model_parameter_info(self.model) self.train_msg['model_parameter_info'] = model_parameter_info logger.info(f'model_parameter_info: {model_parameter_info}') trainer_cls = TrainerFactory.get_trainer_cls(args) trainer = trainer_cls( model=self.model, args=self.args.training_args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=val_dataset, callbacks=self.callbacks, template=self.template, **self._get_trainer_kwargs(), ) return self.train(trainer) def _get_trainer_kwargs(self): args = self.args if args.metric is not None: compute_metrics, preprocess_logits_for_metrics = get_metric(args.metric) elif args.predict_with_generate: compute_metrics, preprocess_logits_for_metrics = get_metric('nlg') else: compute_metrics, preprocess_logits_for_metrics = get_metric('acc') compute_metrics = partial( compute_metrics, acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder) return { 'compute_metrics': compute_metrics, 'preprocess_logits_for_metrics': preprocess_logits_for_metrics, 'compute_loss_func': self._get_loss_func() } def _save_trainer_state(self, trainer): training_args = trainer.args state = trainer.state if hasattr(state, 'last_model_checkpoint'): if self.args.create_checkpoint_symlink: last_checkpoint = os.path.join(self.args.output_dir, 'last') best_checkpoint = os.path.join(self.args.output_dir, 'best') os.symlink(state.last_model_checkpoint, last_checkpoint) os.symlink(state.best_model_checkpoint, best_checkpoint) state.last_model_checkpoint = last_checkpoint state.best_model_checkpoint = best_checkpoint else: state.last_model_checkpoint = None logger.warning('No training was carried out, which may be due to the dataset being too small ' 'or incorrect usage of resume_from_checkpoint.') logger.info(f'last_model_checkpoint: {state.last_model_checkpoint}') logger.info(f'best_model_checkpoint: {state.best_model_checkpoint}') # Visualization if is_master() and not use_torchacc(): if 'tensorboard' in training_args.report_to: images_dir = os.path.join(training_args.output_dir, 'images') logger.info(f'images_dir: {images_dir}') plot_images(images_dir, training_args.logging_dir, ['train/loss'], 0.9) if training_args.push_to_hub: trainer.push_to_hub() self.train_msg.update({ 'last_model_checkpoint': state.last_model_checkpoint, 'best_model_checkpoint': state.best_model_checkpoint, 'best_metric': state.best_metric, 'global_step': state.global_step, 'log_history': state.log_history, 'memory': trainer.max_memory, }) if is_master(): jsonl_path = os.path.join(training_args.output_dir, 'logging.jsonl') append_to_jsonl(jsonl_path, self.train_msg) return self.train_msg def train(self, trainer): logging_path = os.path.join(trainer.args.output_dir, 'logging.jsonl') logger.info(f'The logging file will be saved in: {logging_path}') try: trainer.train(trainer.args.resume_from_checkpoint) finally: res = self._save_trainer_state(trainer) return res def _prepare_callbacks(self): from .callback import DynamicLayerActivationCallback, TrainerAdapterCallback args = self.args callbacks = [] if args.lisa_activated_layers > 0: assert args.train_type == 'full', 'LISA only supports full parameter training.' lisa_callback = DynamicLayerActivationCallback( n_layers=args.lisa_activated_layers, # Number of layers to activate step_interval=args.lisa_step_interval, # Step interval to update active layers model=self.model) lisa_callback.switch_active_layers() # Make trainable parameters printing a correct value callbacks.append(lisa_callback) if args.is_adapter and args.train_type == 'adalora': callbacks.append(TrainerAdapterCallback(args)) callbacks += extra_callbacks self.callbacks = callbacks def _stat_dataset(self, dataset: HfDataset): args = self.args if isinstance(dataset, HfDataset): dataset = GetLengthPreprocessor()(dataset, num_proc=args.dataset_num_proc) length = dataset['length'] else: length = [] for row in dataset: length.append(max([len(row[k]) for k in row.keys() if k.endswith('input_ids')])) _, stat_str = stat_array(length) logger.info(f'Dataset Token Length: {stat_str}') return stat_str def _encode_dataset(self, train_dataset, val_dataset): template = self.template args = self.args output_dir = getattr(args, 'output_dir', None) or getattr(args, 'save') self._save_val_dataset(output_dir, val_dataset) is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' predict_with_generate = getattr(args, 'predict_with_generate', False) if not is_grpo: if args.packing: packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset train_dataset = packing_dataset_cls( self.template, train_dataset, num_proc=args.dataset_num_proc, strict=args.strict) if val_dataset is not None: val_dataset = packing_dataset_cls( self.template, val_dataset, num_proc=args.dataset_num_proc, strict=args.strict) elif args.lazy_tokenize: train_dataset = LazyLLMDataset( train_dataset, template.encode, strict=args.strict, random_state=args.data_seed) if val_dataset is not None and not predict_with_generate: val_dataset = LazyLLMDataset( val_dataset, template.encode, strict=args.strict, random_state=args.data_seed) else: preprocessor = EncodePreprocessor(template=template) train_dataset = preprocessor(train_dataset, num_proc=args.dataset_num_proc, strict=args.strict) if val_dataset is not None and not predict_with_generate: val_dataset = preprocessor(val_dataset, num_proc=args.dataset_num_proc, strict=args.strict) if is_master(): inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset)) template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {}) if isinstance(train_dataset, (HfDataset, PackingDataset)): self.train_msg['train_dataset'] = self._stat_dataset(train_dataset) if val_dataset is not None and not predict_with_generate: self.train_msg['val_dataset'] = self._stat_dataset(val_dataset) return train_dataset, val_dataset def sft_main(args: Union[List[str], TrainArguments, None] = None): return SwiftSft(args).main()