|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
|
import datasets |
|
|
import torch |
|
|
import torch.distributed |
|
|
import transformers |
|
|
from accelerate.logging import get_logger |
|
|
from transformers import AutoTokenizer |
|
|
from trl import SFTTrainer |
|
|
|
|
|
import modelopt.torch.opt as mto |
|
|
from modelopt.torch.distill.plugins.huggingface import KDTrainer, LMLogitsLoss |
|
|
|
|
|
logger = get_logger(__name__, log_level="INFO") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelArguments: |
|
|
teacher_name_or_path: str | None = None |
|
|
student_name_or_path: str | None = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingArguments(transformers.TrainingArguments): |
|
|
do_train: bool = True |
|
|
do_eval: bool = True |
|
|
save_strategy: str = "no" |
|
|
max_length: int = 1024 |
|
|
optim: str = "adamw_torch" |
|
|
learning_rate: float = 1e-5 |
|
|
lr_scheduler_type: str = "cosine" |
|
|
dataloader_drop_last: bool = True |
|
|
dataset_num_proc: int = 8 |
|
|
bf16: bool = True |
|
|
|
|
|
|
|
|
|
|
|
def _format_smoltalk_chat_template(sample, tokenizer): |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "user", "content": sample["query"]}, |
|
|
{"role": "assistant", "content": sample["answer"]}, |
|
|
] |
|
|
return tokenizer.apply_chat_template(messages, tokenize=False) |
|
|
|
|
|
|
|
|
class KDSFTTrainer(KDTrainer, SFTTrainer): |
|
|
pass |
|
|
|
|
|
|
|
|
def train(): |
|
|
parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) |
|
|
model_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
|
|
|
|
|
|
mto.enable_huggingface_checkpointing() |
|
|
|
|
|
|
|
|
total_batch_size = 64 |
|
|
num_accum_steps = total_batch_size / ( |
|
|
training_args.per_device_train_batch_size * torch.distributed.get_world_size() |
|
|
) |
|
|
if not num_accum_steps.is_integer(): |
|
|
raise ValueError( |
|
|
f"`per_device_train_batch_size` * `world_size` must be a factor of {total_batch_size}" |
|
|
) |
|
|
training_args.gradient_accumulation_steps = int(num_accum_steps) |
|
|
logger.info( |
|
|
f"Using {int(num_accum_steps)} grad accumulation steps for effective batchsize of {total_batch_size}." |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Loading dataset...") |
|
|
dset = datasets.load_dataset("ReactiveAI/smol-smoltalk-Interaction-SFT", split="train") |
|
|
dset_splits = dset.train_test_split(train_size=12800, test_size=1280, seed=420) |
|
|
dset_train, dset_eval = dset_splits["train"], dset_splits["test"] |
|
|
logger.info("Dataset loaded.") |
|
|
|
|
|
|
|
|
logger.info("Loading tokenizer...") |
|
|
model_path = model_args.teacher_name_or_path or model_args.student_name_or_path |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.padding_side = "right" |
|
|
logger.info("Tokenizer loaded.") |
|
|
|
|
|
|
|
|
logger.info("Loading student model...") |
|
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
|
model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None |
|
|
) |
|
|
logger.info("Student loaded.") |
|
|
logger.info("Loading teacher model...") |
|
|
teacher_model = transformers.AutoModelForCausalLM.from_pretrained( |
|
|
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None |
|
|
) |
|
|
|
|
|
|
|
|
kd_config = { |
|
|
"teacher_model": teacher_model, |
|
|
"criterion": LMLogitsLoss(), |
|
|
} |
|
|
|
|
|
|
|
|
model.generation_config.temperature = None |
|
|
model.generation_config.top_p = None |
|
|
|
|
|
|
|
|
trainer = KDSFTTrainer( |
|
|
model, |
|
|
training_args, |
|
|
distill_config=kd_config, |
|
|
train_dataset=dset_train, |
|
|
eval_dataset=dset_eval, |
|
|
formatting_func=lambda sample: _format_smoltalk_chat_template(sample, tokenizer), |
|
|
processing_class=tokenizer, |
|
|
) |
|
|
|
|
|
|
|
|
if training_args.do_train: |
|
|
logger.info("Beginning training...") |
|
|
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) |
|
|
logger.info("Training done.") |
|
|
|
|
|
|
|
|
if training_args.do_eval: |
|
|
logger.info("Evaluating...") |
|
|
eval_results = trainer.evaluate() |
|
|
logger.info(eval_results) |
|
|
logger.info("Evaluation complete.") |
|
|
|
|
|
|
|
|
logger.info("Saving checkpoint...") |
|
|
trainer.save_state() |
|
|
trainer.save_model(trainer.args.output_dir) |
|
|
logger.info("Checkpoint saved.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|