|
|
|
|
|
from datasets import load_dataset |
|
|
from transformers import BartTokenizer, BartForConditionalGeneration, Trainer, TrainingArguments |
|
|
import os |
|
|
|
|
|
model_name = "facebook/bart-large-cnn" |
|
|
tokenizer = BartTokenizer.from_pretrained(model_name) |
|
|
model = BartForConditionalGeneration.from_pretrained(model_name) |
|
|
|
|
|
dataset = load_dataset("marcov/scientific_papers_arxiv_promptsource") |
|
|
|
|
|
|
|
|
dataset["train"] = dataset["train"].select(range(1000)) |
|
|
dataset["validation"] = dataset["validation"].select(range(200)) |
|
|
|
|
|
max_input_length = 1024 |
|
|
max_output_length = 200 |
|
|
|
|
|
def preprocess_function(batch): |
|
|
inputs = tokenizer(batch["article"], max_length=max_input_length, truncation=True) |
|
|
outputs = tokenizer(batch["summary"], max_length=max_output_length, truncation=True) |
|
|
batch["input_ids"] = inputs["input_ids"] |
|
|
batch["attention_mask"] = inputs["attention_mask"] |
|
|
batch["labels"] = outputs["input_ids"] |
|
|
return batch |
|
|
|
|
|
tokenized_train = dataset["train"].map(preprocess_function, batched=True, remove_columns=dataset["train"].column_names) |
|
|
tokenized_val = dataset["validation"].map(preprocess_function, batched=True, remove_columns=dataset["validation"].column_names) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="./bart-finetuned-arxiv-hub", |
|
|
evaluation_strategy="steps", |
|
|
eval_steps=500, |
|
|
save_steps=500, |
|
|
save_total_limit=2, |
|
|
learning_rate=3e-5, |
|
|
per_device_train_batch_size=2, |
|
|
per_device_eval_batch_size=2, |
|
|
num_train_epochs=3, |
|
|
weight_decay=0.01, |
|
|
fp16=False, |
|
|
logging_dir="./logs", |
|
|
logging_steps=100, |
|
|
push_to_hub=True |
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_train, |
|
|
eval_dataset=tokenized_val, |
|
|
tokenizer=tokenizer, |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
trainer.push_to_hub("username/bart-finetuned-arxiv") |
|
|
print("Fine-tuning complete.") |
|
|
|