| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import functools |
| | from collections.abc import Sequence |
| | from copy import deepcopy |
| | from typing import TYPE_CHECKING, Any |
| |
|
| | from ...data import ( |
| | SFTDataCollatorWith4DAttentionMask, |
| | get_dataset, |
| | get_template_and_fix_tokenizer, |
| | ) |
| | from ...data.collator import ( |
| | PairwiseDataCollatorWithPadding, |
| | ) |
| | from ...extras.constants import IGNORE_INDEX, MCA_SUPPORTED_MODELS |
| | from ...extras.logging import get_logger |
| | from ...extras.misc import calculate_tps |
| | from ...extras.packages import is_mcore_adapter_available |
| | from ...extras.ploting import plot_loss |
| | from ...model import load_tokenizer |
| | from ..callbacks import SaveProcessorCallback |
| |
|
| |
|
| | if not is_mcore_adapter_available(): |
| | raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.") |
| |
|
| | from mcore_adapter.models import AutoConfig, AutoModel |
| | from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer |
| | from mcore_adapter.trainer import McaTrainer |
| | from mcore_adapter.trainer.dpo_config import DPOConfig |
| | from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments |
| |
|
| |
|
| | if TYPE_CHECKING: |
| | from transformers import DataCollatorForSeq2Seq, TrainerCallback |
| |
|
| | from ...hparams import DataArguments, FinetuningArguments, ModelArguments |
| |
|
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | def _data_collator_wrapper(data_collator: Any): |
| | @functools.wraps(data_collator) |
| | def wrapper(features: Sequence[dict[str, Any]]): |
| | labels_key = [k for k in features[0].keys() if k.endswith("labels")] |
| | input_ids_key = [k for k in features[0].keys() if k.endswith("input_ids")] |
| | for feature in features: |
| | if len(labels_key) == 0: |
| | feature["labels"] = deepcopy(feature["input_ids"])[1:] |
| | for k in labels_key: |
| | feature[k] = feature[k][1:] |
| | for k in input_ids_key: |
| | feature[k] = feature[k][:-1] |
| | for k in ["attention_mask", "position_ids"]: |
| | if k in feature: |
| | feature[k] = feature[k][:-1] |
| | return data_collator(features) |
| |
|
| | return wrapper |
| |
|
| |
|
| | def _check_model_support(model_args: ModelArguments): |
| | from transformers import AutoConfig as HfAutoConfig |
| |
|
| | config = HfAutoConfig.from_pretrained( |
| | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code |
| | ) |
| | if config.model_type not in MCA_SUPPORTED_MODELS: |
| | raise ValueError(f"Model {config.model_type} is not supported by MCA.") |
| |
|
| |
|
| | def run_pt( |
| | model_args: ModelArguments, |
| | data_args: DataArguments, |
| | training_args: McaSeq2SeqTrainingArguments, |
| | finetuning_args: FinetuningArguments, |
| | callbacks: list[TrainerCallback] | None = None, |
| | ): |
| | tokenizer_module = load_tokenizer(model_args) |
| | tokenizer = tokenizer_module["tokenizer"] |
| | template = get_template_and_fix_tokenizer(tokenizer, data_args) |
| |
|
| | |
| | data_args.cutoff_len += 1 |
| | dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module) |
| | data_args.cutoff_len -= 1 |
| |
|
| | _check_model_support(model_args) |
| | model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) |
| |
|
| | from transformers import DataCollatorForSeq2Seq |
| |
|
| | data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq( |
| | tokenizer=tokenizer, |
| | pad_to_multiple_of=8, |
| | label_pad_token_id=IGNORE_INDEX, |
| | ) |
| | data_collator = _data_collator_wrapper(data_collator) |
| |
|
| | trainer = McaTrainer( |
| | model=model, |
| | args=training_args, |
| | tokenizer=tokenizer, |
| | data_collator=data_collator, |
| | callbacks=callbacks, |
| | **dataset_module, |
| | ) |
| |
|
| | if "processor" in tokenizer_module and tokenizer_module["processor"] is not None: |
| | trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"])) |
| |
|
| | if training_args.do_train: |
| | train_result = trainer.train(training_args.resume_from_checkpoint) |
| | trainer.save_model() |
| | trainer.log_metrics("train", train_result.metrics) |
| | trainer.save_metrics("train", train_result.metrics) |
| | trainer.save_state() |
| | if trainer.is_world_process_zero() and finetuning_args.plot_loss: |
| | keys = ["loss"] |
| | if isinstance(dataset_module.get("eval_dataset"), dict): |
| | keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] |
| | else: |
| | keys += ["eval_loss"] |
| | plot_loss(training_args.output_dir, keys=keys) |
| |
|
| |
|
| | def run_sft( |
| | model_args: ModelArguments, |
| | data_args: DataArguments, |
| | training_args: McaSeq2SeqTrainingArguments, |
| | finetuning_args: FinetuningArguments, |
| | callbacks: list[TrainerCallback] | None = None, |
| | ): |
| | |
| | |
| | data_args.neat_packing = training_args.sequence_packing = data_args.neat_packing or training_args.sequence_packing |
| | data_args.packing = data_args.neat_packing or data_args.packing |
| |
|
| | tokenizer_module = load_tokenizer(model_args) |
| | tokenizer = tokenizer_module["tokenizer"] |
| | template = get_template_and_fix_tokenizer(tokenizer, data_args) |
| |
|
| | |
| | data_args.cutoff_len += 1 |
| | dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) |
| | data_args.cutoff_len -= 1 |
| |
|
| | _check_model_support(model_args) |
| | model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) |
| |
|
| | |
| | if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]: |
| | params_to_freeze = [] |
| | if finetuning_args.freeze_vision_tower: |
| | params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"]) |
| |
|
| | if finetuning_args.freeze_multi_modal_projector: |
| | params_to_freeze.extend(["multi_modal_projector"]) |
| |
|
| | if finetuning_args.freeze_language_model: |
| | params_to_freeze.extend(["embedding", "decoder", "output_layer"]) |
| |
|
| | if params_to_freeze: |
| | for name, p in model.named_parameters(): |
| | if any(name.startswith(k) for k in params_to_freeze): |
| | p.requires_grad_(False) |
| |
|
| | pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 |
| | data_collator = SFTDataCollatorWith4DAttentionMask( |
| | template=template, |
| | padding="max_length" if pad_to_max else "longest", |
| | max_length=data_args.cutoff_len if pad_to_max else None, |
| | pad_to_multiple_of=64, |
| | label_pad_token_id=IGNORE_INDEX, |
| | **tokenizer_module, |
| | ) |
| | data_collator = _data_collator_wrapper(data_collator) |
| |
|
| | trainer = McaTrainer( |
| | model=model, |
| | args=training_args, |
| | tokenizer=tokenizer, |
| | data_collator=data_collator, |
| | callbacks=callbacks, |
| | **dataset_module, |
| | ) |
| |
|
| | if "processor" in tokenizer_module and tokenizer_module["processor"] is not None: |
| | trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"])) |
| |
|
| | train_result = trainer.train(training_args.resume_from_checkpoint) |
| | trainer.save_model() |
| | trainer.log_metrics("train", train_result.metrics) |
| | trainer.save_metrics("train", train_result.metrics) |
| | trainer.save_state() |
| | if trainer.is_world_process_zero() and finetuning_args.plot_loss: |
| | keys = ["loss"] |
| | if isinstance(dataset_module.get("eval_dataset"), dict): |
| | keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] |
| | else: |
| | keys += ["eval_loss"] |
| | plot_loss(training_args.output_dir, keys=keys) |
| |
|
| |
|
| | def run_dpo( |
| | model_args: ModelArguments, |
| | data_args: DataArguments, |
| | training_args: McaSeq2SeqTrainingArguments, |
| | finetuning_args: FinetuningArguments, |
| | callbacks: list[TrainerCallback] | None = None, |
| | ): |
| | tokenizer_module = load_tokenizer(model_args) |
| | tokenizer = tokenizer_module["tokenizer"] |
| | template = get_template_and_fix_tokenizer(tokenizer, data_args) |
| |
|
| | _check_model_support(model_args) |
| | model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) |
| |
|
| | if finetuning_args.use_ref_model: |
| | ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args) |
| | ref_model = AutoModel.from_config(ref_config) |
| | ref_model.load_state_dict(model.state_dict()) |
| | else: |
| | ref_model = None |
| |
|
| | |
| | data_args.cutoff_len += 1 |
| | dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module) |
| | data_args.cutoff_len -= 1 |
| |
|
| | pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 |
| | dpo_config = DPOConfig( |
| | beta=finetuning_args.pref_beta, |
| | pref_loss=finetuning_args.pref_loss, |
| | label_smoothing=finetuning_args.dpo_label_smoothing, |
| | ) |
| | data_collator = PairwiseDataCollatorWithPadding( |
| | template=template, |
| | pad_to_multiple_of=64, |
| | padding="max_length" if pad_to_max else "longest", |
| | max_length=data_args.cutoff_len if pad_to_max else None, |
| | label_pad_token_id=IGNORE_INDEX, |
| | **tokenizer_module, |
| | ) |
| | data_collator = _data_collator_wrapper(data_collator) |
| |
|
| | trainer = McaDPOTrainer( |
| | model=model, |
| | ref_model=ref_model, |
| | args=training_args, |
| | train_config=dpo_config, |
| | tokenizer=tokenizer, |
| | data_collator=data_collator, |
| | callbacks=callbacks, |
| | **dataset_module, |
| | ) |
| |
|
| | if "processor" in tokenizer_module and tokenizer_module["processor"] is not None: |
| | trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"])) |
| |
|
| | train_result = trainer.train(training_args.resume_from_checkpoint) |
| | trainer.save_model() |
| | if finetuning_args.include_effective_tokens_per_second: |
| | train_result.metrics["effective_tokens_per_sec"] = calculate_tps( |
| | dataset_module["train_dataset"], train_result.metrics, stage="rm" |
| | ) |
| |
|
| | trainer.log_metrics("train", train_result.metrics) |
| | trainer.save_metrics("train", train_result.metrics) |
| | trainer.save_state() |
| | if trainer.is_world_process_zero() and finetuning_args.plot_loss: |
| | keys = ["loss", "rewards/accuracies"] |
| | if isinstance(dataset_module.get("eval_dataset"), dict): |
| | keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] |
| | else: |
| | keys += ["eval_loss"] |
| |
|
| | plot_loss(training_args.output_dir, keys=keys) |
| |
|