| from datasets import load_dataset |
| from transformers import DataCollatorForLanguageModeling |
| from transformers import Trainer, TrainingArguments |
| import os |
| import torch |
|
|
|
|
|
|
| def main(): |
| |
| local_rank = int(os.environ['LOCAL_RANK']) |
| rank = int(os.environ['RANK']) |
| world_size = int(os.environ['WORLD_SIZE']) |
|
|
| torch.distributed.init_process_group("nccl") |
| print(f"Local Rank = {local_rank}/{world_size}") |
|
|
|
|
|
|
| |
| dataset = load_dataset('json', data_files='../../data/m500_clean.jsonl', split='train') |
| |
| |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
| model_name = "FacebookAI/roberta-base" |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForCausalLM.from_pretrained(model_name) |
| |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| def tokenize_function(examples): |
| return tokenizer(examples["text"], truncation=True, max_length=512) |
| |
| tokenized_dataset = dataset.map(tokenize_function, batched=True) |
| |
| |
| split_dataset = tokenized_dataset.train_test_split(test_size=0.1) |
| |
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=tokenizer, mlm=False |
| ) |
| |
| |
| training_args = TrainingArguments( |
| output_dir="./results", |
| overwrite_output_dir=True, |
| num_train_epochs=3, |
| per_device_train_batch_size=4, |
| per_device_eval_batch_size=4, |
| dataloader_num_workers=8, |
| eval_steps=500, |
| save_steps=1000, |
| warmup_steps=500, |
| prediction_loss_only=True, |
| logging_dir="./logs", |
| logging_steps=100, |
| learning_rate=5e-5, |
| fp16=True, |
| ) |
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=split_dataset["train"], |
| eval_dataset=split_dataset["test"], |
| data_collator=data_collator, |
| ) |
| |
| |
| trainer.train() |
|
|
| torch.distributed.destroy_process_group() |
| |
| |
| model.save_pretrained("./fine_tuned_model") |
| tokenizer.save_pretrained("./fine_tuned_model") |
|
|
| if __name__ == "__main__": |
| main() |
|
|