|
|
import argparse |
|
|
import logging |
|
|
import math |
|
|
import re |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from accelerate import Accelerator |
|
|
from accelerate.utils import set_seed |
|
|
from torch.utils.data import DataLoader |
|
|
from tqdm.auto import tqdm |
|
|
from transformers import get_scheduler, AutoTokenizer, AdamW, SchedulerType, AutoModelForSeq2SeqLM, \ |
|
|
DataCollatorWithPadding |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def get_parser(): |
|
|
parser = argparse.ArgumentParser(description="Train ELI5 seq2seq answer generation model") |
|
|
parser.add_argument( |
|
|
"--dataset_name", |
|
|
type=str, |
|
|
default="vblagoje/lfqa", |
|
|
help="The name of the dataset to use (via the datasets library).", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--per_device_train_batch_size", |
|
|
type=int, |
|
|
default=4, |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--per_device_eval_batch_size", |
|
|
type=int, |
|
|
default=4, |
|
|
help="Batch size (per device) for the evaluation dataloader.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--pretrained_model_name", |
|
|
type=str, |
|
|
default="facebook/bart-large", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--model_save_name", |
|
|
type=str, |
|
|
default="eli5_bart_model", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--learning_rate", |
|
|
type=float, |
|
|
default=2e-4, |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--weight_decay", |
|
|
type=float, |
|
|
default=0.0, |
|
|
help="Weight decay to use." |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--log_freq", |
|
|
type=int, |
|
|
default=100, |
|
|
help="Log train/validation loss every log_freq update steps" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--ignore_pad_token_for_loss", |
|
|
type=bool, |
|
|
default=True, |
|
|
help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--num_train_epochs", |
|
|
type=int, |
|
|
default=3, |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--max_train_steps", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--gradient_accumulation_steps", |
|
|
type=int, |
|
|
default=16, |
|
|
help="Number of updates steps to accumulate before performing a backward/update pass.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--pad_to_max_length", |
|
|
action="store_true", |
|
|
help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--overwrite_cache", type=bool, default=None, help="Overwrite the cached training and evaluation sets" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--max_source_length", |
|
|
type=int, |
|
|
default=1024, |
|
|
help="The maximum total input sequence length after " |
|
|
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--max_target_length", |
|
|
type=int, |
|
|
default=360, |
|
|
help="The maximum total sequence length for target text after " |
|
|
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded." |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--lr_scheduler_type", |
|
|
type=SchedulerType, |
|
|
default="linear", |
|
|
help="The scheduler type to use.", |
|
|
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--num_warmup_steps", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Number of steps for the warmup in the lr scheduler." |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--warmup_percentage", |
|
|
type=float, |
|
|
default=0.08, |
|
|
help="Number of steps for the warmup in the lr scheduler." |
|
|
) |
|
|
return parser |
|
|
|
|
|
|
|
|
def cleanup_references(text): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE) |
|
|
|
|
|
|
|
|
result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE) |
|
|
return result |
|
|
|
|
|
|
|
|
def clean_answer(text): |
|
|
result = cleanup_references(text) |
|
|
result = result.replace("\n", " ") |
|
|
result = re.sub(r"\s\s+", " ", result) |
|
|
result = re.sub(r"BULLET::::-", "", result) |
|
|
return result.strip() |
|
|
|
|
|
|
|
|
def clean_question(text): |
|
|
result = cleanup_references(text) |
|
|
result = result.replace("\n", " ") |
|
|
result = re.sub(r"\s\s+", " ", result) |
|
|
result = result.replace("[deleted]", "") |
|
|
return result.lower().strip() |
|
|
|
|
|
|
|
|
def prepare_support_docs(example): |
|
|
provenances = example["output"][-1]["provenance"] |
|
|
context = "<P> " + " <P> ".join([p["text"] for p in provenances]) |
|
|
return {"context": context} |
|
|
|
|
|
|
|
|
def preprocess_eli5(examples, **fn_kwargs): |
|
|
document_cache = fn_kwargs["document_cache"] |
|
|
training = fn_kwargs.get("training", True) |
|
|
extra_answer_threshold = fn_kwargs.get("extra_answer_threshold", 3) |
|
|
include_selftext = fn_kwargs.get("include_selftext", False) |
|
|
exclude_answer_patterns = fn_kwargs.get("exclude_answer_patterns", []) |
|
|
|
|
|
questions, contexts, answers = [], [], [] |
|
|
for q_id, question, selftext, answer in zip(examples["q_id"], examples["title"], examples["selftext"], |
|
|
examples["answers"]): |
|
|
accepted_answer_idx = [] |
|
|
if training: |
|
|
accepted_answer_idx = [idx for idx, score in enumerate(answer["score"]) if |
|
|
score > extra_answer_threshold] |
|
|
if not training or not accepted_answer_idx: |
|
|
accepted_answer_idx = [0] |
|
|
document = document_cache[q_id] |
|
|
for idx in accepted_answer_idx: |
|
|
skip_answer = any([p.search(answer["text"][idx]) for p in exclude_answer_patterns]) |
|
|
if skip_answer: |
|
|
continue |
|
|
if include_selftext: |
|
|
questions.append(clean_question(f"{question} {selftext}")) |
|
|
else: |
|
|
questions.append(clean_question(question)) |
|
|
contexts.append(document.lower().strip()) |
|
|
answers.append(clean_answer(answer["text"][idx])) |
|
|
|
|
|
return {"question": questions, "context": contexts, "answer": answers} |
|
|
|
|
|
|
|
|
def eval_qa_s2s_epoch(model, dataloader, accelerator, args): |
|
|
model.eval() |
|
|
num_eval_steps = math.ceil(len(dataloader)) |
|
|
progress_bar = tqdm(range(num_eval_steps), disable=not accelerator.is_local_main_process) |
|
|
total_loss = 0. |
|
|
with torch.no_grad(): |
|
|
for step, batch in enumerate(dataloader): |
|
|
outputs = model(**batch) |
|
|
loss = outputs.loss |
|
|
total_loss += loss.item() |
|
|
progress_bar.update(1) |
|
|
progress_bar.set_postfix(loss=round((total_loss / (step + 1)), 3)) |
|
|
return total_loss / (step + 1) |
|
|
|
|
|
|
|
|
def train(config): |
|
|
set_seed(42) |
|
|
args = config["args"] |
|
|
eli5 = load_dataset(args.dataset_name) |
|
|
|
|
|
support_docs = load_dataset("vblagoje/lfqa_support_docs") |
|
|
|
|
|
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
|
level=logging.INFO, |
|
|
) |
|
|
logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) |
|
|
logger.info(accelerator.state) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(args.pretrained_model_name) |
|
|
|
|
|
|
|
|
|
|
|
no_decay = ["bias", "LayerNorm.weight"] |
|
|
optimizer_grouped_parameters = [ |
|
|
{ |
|
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
|
"weight_decay": args.weight_decay, |
|
|
}, |
|
|
{ |
|
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
|
|
"weight_decay": 0.0, |
|
|
}, |
|
|
] |
|
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay) |
|
|
|
|
|
processed_datasets = {} |
|
|
support_docs_prepared = {} |
|
|
with accelerator.main_process_first(): |
|
|
for split in ["train", "validation"]: |
|
|
support_docs_prepared[split] = support_docs[split].map(prepare_support_docs, |
|
|
batched=False, |
|
|
cache_file_name=f"./support_docs_{split}.arrow", |
|
|
load_from_cache_file=not args.overwrite_cache, |
|
|
desc="Preparing support docs", |
|
|
) |
|
|
column_names = eli5["train"].column_names |
|
|
for split in ["train", "validation"]: |
|
|
d_cache = dict([(e["id"], e["context"]) for e in tqdm(support_docs_prepared[split], |
|
|
desc=f"Adding support docs to LFQA {split}")]) |
|
|
processed_datasets[split] = eli5[split].map(preprocess_eli5, |
|
|
batched=True, |
|
|
remove_columns=column_names, |
|
|
cache_file_name=f"./processed_datasets_{split}.arrow", |
|
|
load_from_cache_file=not args.overwrite_cache, |
|
|
desc="Preparing dataset for tokenization", |
|
|
fn_kwargs={"document_cache": d_cache, |
|
|
"training": split == "train", |
|
|
"exclude_answer_patterns": [re.compile("not sure what you"), |
|
|
re.compile("\n\n >")]} |
|
|
) |
|
|
|
|
|
padding = "max_length" if args.pad_to_max_length else False |
|
|
|
|
|
max_target_length = args.max_target_length |
|
|
|
|
|
label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id |
|
|
|
|
|
def tokenize_dataset(examples): |
|
|
inputs = ["question: {} context: {}".format(q, c) for q, c in zip(examples["question"], examples["context"])] |
|
|
targets = examples["answer"] |
|
|
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True) |
|
|
|
|
|
|
|
|
with tokenizer.as_target_tokenizer(): |
|
|
labels = tokenizer(targets, max_length=max_target_length, padding=True, truncation=True, |
|
|
return_tensors="np") |
|
|
|
|
|
model_inputs["decoder_input_ids"] = labels["input_ids"][:, :-1].tolist() |
|
|
|
|
|
labels["input_ids"] = np.where(labels["input_ids"] == tokenizer.pad_token_id, |
|
|
label_pad_token_id, labels["input_ids"]) |
|
|
|
|
|
model_inputs["labels"] = labels["input_ids"][:, 1:].tolist() |
|
|
return model_inputs |
|
|
|
|
|
tokenized_datasets = {} |
|
|
with accelerator.main_process_first(): |
|
|
for split, dataset in processed_datasets.items(): |
|
|
tokenized_datasets[split] = dataset.map( |
|
|
tokenize_dataset, |
|
|
batched=True, |
|
|
cache_file_name=f"./tokenized_dataset_{split}.arrow", |
|
|
remove_columns=dataset.column_names, |
|
|
load_from_cache_file=not args.overwrite_cache, |
|
|
desc="Running tokenizer on dataset" |
|
|
) |
|
|
|
|
|
train_dataset = tokenized_datasets["train"] |
|
|
eval_dataset = tokenized_datasets["validation"] |
|
|
train_dataset.set_format(type='torch') |
|
|
eval_dataset.set_format(type='torch') |
|
|
|
|
|
data_collator = DataCollatorWithPadding(tokenizer, "max_length") |
|
|
|
|
|
|
|
|
train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.per_device_train_batch_size, |
|
|
collate_fn=data_collator) |
|
|
eval_dataloader = DataLoader(eval_dataset, batch_size=args.per_device_eval_batch_size, collate_fn=data_collator) |
|
|
|
|
|
|
|
|
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader, |
|
|
eval_dataloader) |
|
|
|
|
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
|
|
if args.max_train_steps is None: |
|
|
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
|
|
else: |
|
|
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
|
|
|
|
|
num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps * |
|
|
args.warmup_percentage) |
|
|
scheduler = get_scheduler( |
|
|
name=args.lr_scheduler_type, |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=num_warmup_steps, |
|
|
num_training_steps=args.max_train_steps, |
|
|
) |
|
|
|
|
|
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
|
|
|
|
|
logger.info("***** Running training *****") |
|
|
logger.info(f" Num examples = {len(train_dataset)}") |
|
|
logger.info(f" Num eval examples = {len(eval_dataset)}") |
|
|
logger.info(f" Num Epochs = {args.num_train_epochs}") |
|
|
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") |
|
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
|
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
|
|
logger.info(f" Total optimization steps = {args.max_train_steps}") |
|
|
logger.info(f" Warmup steps = {num_warmup_steps}") |
|
|
logger.info(f" Logging training progress every {args.log_freq} optimization steps") |
|
|
|
|
|
|
|
|
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) |
|
|
completed_steps = 0 |
|
|
switched_train_dataloader = False |
|
|
for epoch in range(args.num_train_epochs): |
|
|
model.train() |
|
|
if epoch > 0 and not switched_train_dataloader: |
|
|
train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size, |
|
|
shuffle=True, collate_fn=data_collator) |
|
|
train_dataloader = accelerator.prepare(train_dataloader) |
|
|
switched_train_dataloader = True |
|
|
|
|
|
for step, batch in enumerate(train_dataloader): |
|
|
outputs = model(**batch) |
|
|
loss = torch.mean(outputs.loss) |
|
|
accelerator.backward(loss) |
|
|
if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)): |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
progress_bar.update(1) |
|
|
progress_bar.set_postfix(loss=round(loss.item(), 3)) |
|
|
completed_steps += 1 |
|
|
|
|
|
if completed_steps >= args.max_train_steps: |
|
|
break |
|
|
|
|
|
if step % (args.log_freq * args.gradient_accumulation_steps) == 0: |
|
|
validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args) |
|
|
model.train() |
|
|
logger.info(f"Train loss {loss.item()} , validation loss {validation_loss}") |
|
|
if args.wandb and accelerator.is_local_main_process: |
|
|
import wandb |
|
|
wandb.log({"loss": loss.item(), |
|
|
"lr": scheduler.get_last_lr()[0], |
|
|
"validation_loss": validation_loss, |
|
|
"completed_steps": completed_steps}) |
|
|
|
|
|
logger.info("Saving model {}".format(args.model_save_name)) |
|
|
accelerator.wait_for_everyone() |
|
|
unwrapped_model = accelerator.unwrap_model(model) |
|
|
accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch)) |
|
|
|
|
|
|
|
|
validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args) |
|
|
|
|
|
logger.info("Epoch: {}".format(epoch)) |
|
|
logger.info("Validation loss: {}".format(validation_loss)) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = get_parser() |
|
|
parser.add_argument( |
|
|
"--wandb", |
|
|
action="store_true", |
|
|
help="If true, use W&B logging", |
|
|
) |
|
|
main_args, _ = parser.parse_known_args() |
|
|
config = {"args": main_args} |
|
|
if main_args.wandb: |
|
|
import wandb |
|
|
wandb.init(project="Bart_ELI5") |
|
|
train(config=config) |
|
|
|
|
|
|
|
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|