|
|
|
|
|
import os |
|
|
from functools import partial |
|
|
from typing import List, Union |
|
|
|
|
|
from datasets import Dataset as HfDataset |
|
|
|
|
|
from swift.plugin import extra_callbacks, get_loss_func, get_metric |
|
|
from swift.trainers import TrainerFactory |
|
|
from swift.utils import (append_to_jsonl, get_logger, get_model_parameter_info, is_master, plot_images, stat_array, |
|
|
use_torchacc) |
|
|
from ..argument import TrainArguments |
|
|
from ..base import SwiftPipeline |
|
|
from ..dataset import (EncodePreprocessor, GetLengthPreprocessor, IterablePackingDataset, LazyLLMDataset, |
|
|
PackingDataset, load_dataset) |
|
|
from ..infer import prepare_generation_config |
|
|
from ..model import HfConfigFactory, get_model_arch |
|
|
from ..utils import deep_getattr, dynamic_gradient_checkpointing |
|
|
from .tuner import TunerMixin |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class SwiftSft(SwiftPipeline, TunerMixin): |
|
|
args_class = TrainArguments |
|
|
args: args_class |
|
|
|
|
|
def __init__(self, args: Union[List[str], TrainArguments, None] = None) -> None: |
|
|
super().__init__(args) |
|
|
self.train_msg = {} |
|
|
self._prepare_model_tokenizer() |
|
|
self._prepare_template() |
|
|
self._prepare_callbacks() |
|
|
|
|
|
def _prepare_gradient_checkpointing(self): |
|
|
args = self.args |
|
|
HfConfigFactory.set_model_config_attr(self.model, 'use_cache', False) |
|
|
if args.gradient_checkpointing: |
|
|
self.model.supports_gradient_checkpointing = True |
|
|
dynamic_gradient_checkpointing(self.model) |
|
|
self.model.enable_input_require_grads() |
|
|
model_meta = self.model.model_meta |
|
|
model_arch = get_model_arch(model_meta.model_arch) |
|
|
if model_meta.is_multimodal and model_arch: |
|
|
for vision_tower_name in model_arch.vision_tower: |
|
|
vision_tower = deep_getattr(self.model, vision_tower_name) |
|
|
if hasattr(vision_tower, 'enable_input_require_grads'): |
|
|
try: |
|
|
vision_tower.enable_input_require_grads() |
|
|
except NotImplementedError: |
|
|
pass |
|
|
|
|
|
def _prepare_generation_config(self): |
|
|
args = self.args |
|
|
self.model.origin_generation_config = self.model.generation_config |
|
|
self.model.generation_config = prepare_generation_config(self.model.generation_config, |
|
|
args.get_request_config(), self.tokenizer) |
|
|
logger.info(f'model.generation_config: {self.model.generation_config}') |
|
|
|
|
|
def _prepare_model_tokenizer(self): |
|
|
args = self.args |
|
|
if args.sequence_parallel_size > 1: |
|
|
from swift.trainers.sequence_parallel import sequence_parallel |
|
|
sequence_parallel.init_sequence_parallel(args.sequence_parallel_size) |
|
|
self.model, self.processor = args.get_model_processor() |
|
|
|
|
|
if hasattr(self.model, 'hf_device_map'): |
|
|
logger.info(f'model.hf_device_map: {self.model.hf_device_map}') |
|
|
|
|
|
logger.info(f'model_info: {self.model.model_info}') |
|
|
|
|
|
self._prepare_generation_config() |
|
|
self._prepare_gradient_checkpointing() |
|
|
|
|
|
def _prepare_template(self) -> None: |
|
|
template = self.args.get_template(self.processor) |
|
|
if self.args.task_type == 'causal_lm': |
|
|
template.set_mode('train') |
|
|
if template.use_model: |
|
|
template.model = self.model |
|
|
self.template = template |
|
|
|
|
|
def _get_dataset(self): |
|
|
|
|
|
args = self.args |
|
|
dataset_kwargs = args.get_dataset_kwargs() |
|
|
train_dataset, val_dataset = load_dataset( |
|
|
args.dataset, split_dataset_ratio=args.split_dataset_ratio, shuffle=args.dataset_shuffle, **dataset_kwargs) |
|
|
if len(args.val_dataset) > 0: |
|
|
|
|
|
_, val_dataset = load_dataset( |
|
|
args.val_dataset, split_dataset_ratio=1.0, shuffle=args.val_dataset_shuffle, **dataset_kwargs) |
|
|
assert args.split_dataset_ratio == 0. |
|
|
logger.info(f'train_dataset: {train_dataset}') |
|
|
logger.info(f'val_dataset: {val_dataset}') |
|
|
|
|
|
return train_dataset, val_dataset |
|
|
|
|
|
def _get_loss_func(self): |
|
|
args = self.args |
|
|
loss_type = args.loss_type |
|
|
if loss_type is None and args.loss_scale != 'default': |
|
|
loss_type = 'loss_scale' |
|
|
return get_loss_func(loss_type) |
|
|
|
|
|
def _get_data_collator(self): |
|
|
args = self.args |
|
|
template = self.template |
|
|
padding_to = args.max_length if args.train_type == 'longlora' else None |
|
|
return partial(template.data_collator, padding_to=padding_to) |
|
|
|
|
|
@staticmethod |
|
|
def _save_val_dataset(output_dir: str, val_dataset): |
|
|
if is_master() and isinstance(val_dataset, HfDataset): |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
val_dataset_path = os.path.join(output_dir, 'val_dataset.jsonl') |
|
|
append_to_jsonl(val_dataset_path, val_dataset.to_list()) |
|
|
logger.info(f'The split dataset from the training set will be saved at: {val_dataset_path}.') |
|
|
|
|
|
def run(self): |
|
|
args = self.args |
|
|
|
|
|
train_dataset, val_dataset = self._get_dataset() |
|
|
train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) |
|
|
|
|
|
if args.task_type == 'seq_cls': |
|
|
args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None) |
|
|
logger.info(f'args.problem_type: {args.problem_type}') |
|
|
args.save_args() |
|
|
|
|
|
data_collator = self._get_data_collator() |
|
|
|
|
|
self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset) |
|
|
logger.info(f'model: {self.model}') |
|
|
model_parameter_info = get_model_parameter_info(self.model) |
|
|
self.train_msg['model_parameter_info'] = model_parameter_info |
|
|
logger.info(f'model_parameter_info: {model_parameter_info}') |
|
|
|
|
|
trainer_cls = TrainerFactory.get_trainer_cls(args) |
|
|
trainer = trainer_cls( |
|
|
model=self.model, |
|
|
args=self.args.training_args, |
|
|
data_collator=data_collator, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=val_dataset, |
|
|
callbacks=self.callbacks, |
|
|
template=self.template, |
|
|
**self._get_trainer_kwargs(), |
|
|
) |
|
|
return self.train(trainer) |
|
|
|
|
|
def _get_trainer_kwargs(self): |
|
|
args = self.args |
|
|
if args.metric is not None: |
|
|
compute_metrics, preprocess_logits_for_metrics = get_metric(args.metric) |
|
|
elif args.predict_with_generate: |
|
|
compute_metrics, preprocess_logits_for_metrics = get_metric('nlg') |
|
|
else: |
|
|
compute_metrics, preprocess_logits_for_metrics = get_metric('acc') |
|
|
compute_metrics = partial( |
|
|
compute_metrics, acc_strategy=args.acc_strategy, is_encoder_decoder=self.template.is_encoder_decoder) |
|
|
return { |
|
|
'compute_metrics': compute_metrics, |
|
|
'preprocess_logits_for_metrics': preprocess_logits_for_metrics, |
|
|
'compute_loss_func': self._get_loss_func() |
|
|
} |
|
|
|
|
|
def _save_trainer_state(self, trainer): |
|
|
training_args = trainer.args |
|
|
state = trainer.state |
|
|
if hasattr(state, 'last_model_checkpoint'): |
|
|
if self.args.create_checkpoint_symlink: |
|
|
last_checkpoint = os.path.join(self.args.output_dir, 'last') |
|
|
best_checkpoint = os.path.join(self.args.output_dir, 'best') |
|
|
os.symlink(state.last_model_checkpoint, last_checkpoint) |
|
|
os.symlink(state.best_model_checkpoint, best_checkpoint) |
|
|
state.last_model_checkpoint = last_checkpoint |
|
|
state.best_model_checkpoint = best_checkpoint |
|
|
else: |
|
|
state.last_model_checkpoint = None |
|
|
logger.warning('No training was carried out, which may be due to the dataset being too small ' |
|
|
'or incorrect usage of resume_from_checkpoint.') |
|
|
logger.info(f'last_model_checkpoint: {state.last_model_checkpoint}') |
|
|
logger.info(f'best_model_checkpoint: {state.best_model_checkpoint}') |
|
|
|
|
|
|
|
|
if is_master() and not use_torchacc(): |
|
|
if 'tensorboard' in training_args.report_to: |
|
|
images_dir = os.path.join(training_args.output_dir, 'images') |
|
|
logger.info(f'images_dir: {images_dir}') |
|
|
plot_images(images_dir, training_args.logging_dir, ['train/loss'], 0.9) |
|
|
if training_args.push_to_hub: |
|
|
trainer.push_to_hub() |
|
|
|
|
|
self.train_msg.update({ |
|
|
'last_model_checkpoint': state.last_model_checkpoint, |
|
|
'best_model_checkpoint': state.best_model_checkpoint, |
|
|
'best_metric': state.best_metric, |
|
|
'global_step': state.global_step, |
|
|
'log_history': state.log_history, |
|
|
'memory': trainer.max_memory, |
|
|
}) |
|
|
if is_master(): |
|
|
jsonl_path = os.path.join(training_args.output_dir, 'logging.jsonl') |
|
|
append_to_jsonl(jsonl_path, self.train_msg) |
|
|
return self.train_msg |
|
|
|
|
|
def train(self, trainer): |
|
|
logging_path = os.path.join(trainer.args.output_dir, 'logging.jsonl') |
|
|
logger.info(f'The logging file will be saved in: {logging_path}') |
|
|
try: |
|
|
trainer.train(trainer.args.resume_from_checkpoint) |
|
|
finally: |
|
|
res = self._save_trainer_state(trainer) |
|
|
return res |
|
|
|
|
|
def _prepare_callbacks(self): |
|
|
from .callback import DynamicLayerActivationCallback, TrainerAdapterCallback |
|
|
args = self.args |
|
|
callbacks = [] |
|
|
if args.lisa_activated_layers > 0: |
|
|
assert args.train_type == 'full', 'LISA only supports full parameter training.' |
|
|
lisa_callback = DynamicLayerActivationCallback( |
|
|
n_layers=args.lisa_activated_layers, |
|
|
step_interval=args.lisa_step_interval, |
|
|
model=self.model) |
|
|
lisa_callback.switch_active_layers() |
|
|
callbacks.append(lisa_callback) |
|
|
|
|
|
if args.is_adapter and args.train_type == 'adalora': |
|
|
callbacks.append(TrainerAdapterCallback(args)) |
|
|
callbacks += extra_callbacks |
|
|
self.callbacks = callbacks |
|
|
|
|
|
def _stat_dataset(self, dataset: HfDataset): |
|
|
args = self.args |
|
|
if isinstance(dataset, HfDataset): |
|
|
dataset = GetLengthPreprocessor()(dataset, num_proc=args.dataset_num_proc) |
|
|
length = dataset['length'] |
|
|
else: |
|
|
length = [] |
|
|
for row in dataset: |
|
|
length.append(max([len(row[k]) for k in row.keys() if k.endswith('input_ids')])) |
|
|
_, stat_str = stat_array(length) |
|
|
logger.info(f'Dataset Token Length: {stat_str}') |
|
|
return stat_str |
|
|
|
|
|
def _encode_dataset(self, train_dataset, val_dataset): |
|
|
template = self.template |
|
|
args = self.args |
|
|
output_dir = getattr(args, 'output_dir', None) or getattr(args, 'save') |
|
|
self._save_val_dataset(output_dir, val_dataset) |
|
|
is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo' |
|
|
predict_with_generate = getattr(args, 'predict_with_generate', False) |
|
|
if not is_grpo: |
|
|
if args.packing: |
|
|
packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset |
|
|
train_dataset = packing_dataset_cls( |
|
|
self.template, train_dataset, num_proc=args.dataset_num_proc, strict=args.strict) |
|
|
if val_dataset is not None: |
|
|
val_dataset = packing_dataset_cls( |
|
|
self.template, val_dataset, num_proc=args.dataset_num_proc, strict=args.strict) |
|
|
elif args.lazy_tokenize: |
|
|
train_dataset = LazyLLMDataset( |
|
|
train_dataset, template.encode, strict=args.strict, random_state=args.data_seed) |
|
|
if val_dataset is not None and not predict_with_generate: |
|
|
val_dataset = LazyLLMDataset( |
|
|
val_dataset, template.encode, strict=args.strict, random_state=args.data_seed) |
|
|
else: |
|
|
preprocessor = EncodePreprocessor(template=template) |
|
|
train_dataset = preprocessor(train_dataset, num_proc=args.dataset_num_proc, strict=args.strict) |
|
|
if val_dataset is not None and not predict_with_generate: |
|
|
val_dataset = preprocessor(val_dataset, num_proc=args.dataset_num_proc, strict=args.strict) |
|
|
|
|
|
if is_master(): |
|
|
inputs = train_dataset[0] if hasattr(train_dataset, '__len__') else next(iter(train_dataset)) |
|
|
template.print_inputs(inputs, tokenizer_kwargs=inputs.pop('tokenizer_kwargs', None) or {}) |
|
|
if isinstance(train_dataset, (HfDataset, PackingDataset)): |
|
|
self.train_msg['train_dataset'] = self._stat_dataset(train_dataset) |
|
|
if val_dataset is not None and not predict_with_generate: |
|
|
self.train_msg['val_dataset'] = self._stat_dataset(val_dataset) |
|
|
|
|
|
return train_dataset, val_dataset |
|
|
|
|
|
|
|
|
def sft_main(args: Union[List[str], TrainArguments, None] = None): |
|
|
return SwiftSft(args).main() |
|
|
|