theachyuttiwari commited on
Commit
29b1daa
·
1 Parent(s): 5bb64ce

Upload run_seq2seq_no_trainer.py

Browse files
Files changed (1) hide show
  1. run_seq2seq_no_trainer.py +446 -0
run_seq2seq_no_trainer.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import re
5
+
6
+ import numpy as np
7
+ import torch
8
+ from accelerate import Accelerator
9
+ from accelerate.utils import set_seed
10
+ from torch.utils.data import DataLoader
11
+ from tqdm.auto import tqdm
12
+ from transformers import get_scheduler, AutoTokenizer, AdamW, SchedulerType, AutoModelForSeq2SeqLM, \
13
+ DataCollatorWithPadding
14
+
15
+ from datasets import load_dataset
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def get_parser():
21
+ parser = argparse.ArgumentParser(description="Train ELI5 seq2seq answer generation model")
22
+ parser.add_argument(
23
+ "--dataset_name",
24
+ type=str,
25
+ default="vblagoje/lfqa",
26
+ help="The name of the dataset to use (via the datasets library).",
27
+ )
28
+
29
+ parser.add_argument(
30
+ "--per_device_train_batch_size",
31
+ type=int,
32
+ default=4,
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--per_device_eval_batch_size",
37
+ type=int,
38
+ default=4,
39
+ help="Batch size (per device) for the evaluation dataloader.",
40
+ )
41
+
42
+ parser.add_argument(
43
+ "--pretrained_model_name",
44
+ type=str,
45
+ default="facebook/bart-large",
46
+ )
47
+
48
+ parser.add_argument(
49
+ "--model_save_name",
50
+ type=str,
51
+ default="eli5_bart_model",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--learning_rate",
56
+ type=float,
57
+ default=2e-4,
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--weight_decay",
62
+ type=float,
63
+ default=0.0,
64
+ help="Weight decay to use."
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--log_freq",
69
+ type=int,
70
+ default=100,
71
+ help="Log train/validation loss every log_freq update steps"
72
+ )
73
+
74
+ parser.add_argument(
75
+ "--ignore_pad_token_for_loss",
76
+ type=bool,
77
+ default=True,
78
+ help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
79
+ )
80
+
81
+ parser.add_argument(
82
+ "--num_train_epochs",
83
+ type=int,
84
+ default=3,
85
+ )
86
+
87
+ parser.add_argument(
88
+ "--max_train_steps",
89
+ type=int,
90
+ default=None,
91
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--gradient_accumulation_steps",
96
+ type=int,
97
+ default=16,
98
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--pad_to_max_length",
103
+ action="store_true",
104
+ help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--overwrite_cache", type=bool, default=None, help="Overwrite the cached training and evaluation sets"
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--max_source_length",
113
+ type=int,
114
+ default=1024,
115
+ help="The maximum total input sequence length after "
116
+ "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
117
+ )
118
+
119
+ parser.add_argument(
120
+ "--max_target_length",
121
+ type=int,
122
+ default=360,
123
+ help="The maximum total sequence length for target text after "
124
+ "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--lr_scheduler_type",
129
+ type=SchedulerType,
130
+ default="linear", # this is linear with warmup
131
+ help="The scheduler type to use.",
132
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
133
+ )
134
+
135
+ parser.add_argument(
136
+ "--num_warmup_steps",
137
+ type=int,
138
+ default=None,
139
+ help="Number of steps for the warmup in the lr scheduler."
140
+ )
141
+
142
+ parser.add_argument(
143
+ "--warmup_percentage",
144
+ type=float,
145
+ default=0.08,
146
+ help="Number of steps for the warmup in the lr scheduler."
147
+ )
148
+ return parser
149
+
150
+
151
+ def cleanup_references(text):
152
+ # URL reference where we need to remove both the link text and URL
153
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal
154
+ # views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)).
155
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery.
156
+ result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE)
157
+
158
+ # URL reference where we need to preserve link text but remove URL
159
+ # At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South.
160
+ # At the outbreak of the Civil War, Leyburn left his church and joined the South.
161
+ result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE)
162
+
163
+ # lastly remove just dangling _URL_[0-9]_ URL references
164
+ result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE)
165
+ return result
166
+
167
+
168
+ def clean_answer(text):
169
+ result = cleanup_references(text)
170
+ result = result.replace("\n", " ")
171
+ result = re.sub(r"\s\s+", " ", result)
172
+ result = re.sub(r"BULLET::::-", "", result)
173
+ return result.strip()
174
+
175
+
176
+ def clean_question(text):
177
+ result = cleanup_references(text)
178
+ result = result.replace("\n", " ")
179
+ result = re.sub(r"\s\s+", " ", result)
180
+ result = result.replace("[deleted]", "")
181
+ return result.lower().strip()
182
+
183
+
184
+ def prepare_support_docs(example):
185
+ provenances = example["output"][-1]["provenance"]
186
+ context = "<P> " + " <P> ".join([p["text"] for p in provenances])
187
+ return {"context": context}
188
+
189
+
190
+ def preprocess_eli5(examples, **fn_kwargs):
191
+ document_cache = fn_kwargs["document_cache"]
192
+ training = fn_kwargs.get("training", True)
193
+ extra_answer_threshold = fn_kwargs.get("extra_answer_threshold", 3)
194
+ include_selftext = fn_kwargs.get("include_selftext", False)
195
+ exclude_answer_patterns = fn_kwargs.get("exclude_answer_patterns", [])
196
+
197
+ questions, contexts, answers = [], [], []
198
+ for q_id, question, selftext, answer in zip(examples["q_id"], examples["title"], examples["selftext"],
199
+ examples["answers"]):
200
+ accepted_answer_idx = []
201
+ if training:
202
+ accepted_answer_idx = [idx for idx, score in enumerate(answer["score"]) if
203
+ score > extra_answer_threshold]
204
+ if not training or not accepted_answer_idx:
205
+ accepted_answer_idx = [0]
206
+ document = document_cache[q_id]
207
+ for idx in accepted_answer_idx:
208
+ skip_answer = any([p.search(answer["text"][idx]) for p in exclude_answer_patterns])
209
+ if skip_answer:
210
+ continue
211
+ if include_selftext:
212
+ questions.append(clean_question(f"{question} {selftext}"))
213
+ else:
214
+ questions.append(clean_question(question))
215
+ contexts.append(document.lower().strip())
216
+ answers.append(clean_answer(answer["text"][idx]))
217
+
218
+ return {"question": questions, "context": contexts, "answer": answers}
219
+
220
+
221
+ def eval_qa_s2s_epoch(model, dataloader, accelerator, args):
222
+ model.eval()
223
+ num_eval_steps = math.ceil(len(dataloader))
224
+ progress_bar = tqdm(range(num_eval_steps), disable=not accelerator.is_local_main_process)
225
+ total_loss = 0.
226
+ with torch.no_grad():
227
+ for step, batch in enumerate(dataloader):
228
+ outputs = model(**batch)
229
+ loss = outputs.loss
230
+ total_loss += loss.item()
231
+ progress_bar.update(1)
232
+ progress_bar.set_postfix(loss=round((total_loss / (step + 1)), 3))
233
+ return total_loss / (step + 1)
234
+
235
+
236
+ def train(config):
237
+ set_seed(42)
238
+ args = config["args"]
239
+ eli5 = load_dataset(args.dataset_name)
240
+
241
+ support_docs = load_dataset("vblagoje/lfqa_support_docs")
242
+
243
+ # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
244
+ accelerator = Accelerator()
245
+ # Make one log on every process with the configuration for debugging.
246
+ logging.basicConfig(
247
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
248
+ datefmt="%m/%d/%Y %H:%M:%S",
249
+ level=logging.INFO,
250
+ )
251
+ logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
252
+ logger.info(accelerator.state)
253
+
254
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name)
255
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.pretrained_model_name)
256
+
257
+ # Optimizer
258
+ # Split weights in two groups, one with weight decay and the other not.
259
+ no_decay = ["bias", "LayerNorm.weight"]
260
+ optimizer_grouped_parameters = [
261
+ {
262
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
263
+ "weight_decay": args.weight_decay,
264
+ },
265
+ {
266
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
267
+ "weight_decay": 0.0,
268
+ },
269
+ ]
270
+ optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
271
+
272
+ processed_datasets = {}
273
+ support_docs_prepared = {}
274
+ with accelerator.main_process_first():
275
+ for split in ["train", "validation"]:
276
+ support_docs_prepared[split] = support_docs[split].map(prepare_support_docs,
277
+ batched=False,
278
+ cache_file_name=f"./support_docs_{split}.arrow",
279
+ load_from_cache_file=not args.overwrite_cache,
280
+ desc="Preparing support docs",
281
+ )
282
+ column_names = eli5["train"].column_names
283
+ for split in ["train", "validation"]:
284
+ d_cache = dict([(e["id"], e["context"]) for e in tqdm(support_docs_prepared[split],
285
+ desc=f"Adding support docs to LFQA {split}")])
286
+ processed_datasets[split] = eli5[split].map(preprocess_eli5,
287
+ batched=True,
288
+ remove_columns=column_names,
289
+ cache_file_name=f"./processed_datasets_{split}.arrow",
290
+ load_from_cache_file=not args.overwrite_cache,
291
+ desc="Preparing dataset for tokenization",
292
+ fn_kwargs={"document_cache": d_cache,
293
+ "training": split == "train",
294
+ "exclude_answer_patterns": [re.compile("not sure what you"),
295
+ re.compile("\n\n >")]}
296
+ )
297
+
298
+ padding = "max_length" if args.pad_to_max_length else False
299
+ # Temporarily set max_target_length for training.
300
+ max_target_length = args.max_target_length
301
+
302
+ label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
303
+
304
+ def tokenize_dataset(examples):
305
+ inputs = ["question: {} context: {}".format(q, c) for q, c in zip(examples["question"], examples["context"])]
306
+ targets = examples["answer"]
307
+ model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)
308
+
309
+ # Setup the tokenizer for targets
310
+ with tokenizer.as_target_tokenizer():
311
+ labels = tokenizer(targets, max_length=max_target_length, padding=True, truncation=True,
312
+ return_tensors="np")
313
+
314
+ model_inputs["decoder_input_ids"] = labels["input_ids"][:, :-1].tolist()
315
+ # replace pad_token_id with label_pad_token_id to avoid loss calculation on those tokens
316
+ labels["input_ids"] = np.where(labels["input_ids"] == tokenizer.pad_token_id,
317
+ label_pad_token_id, labels["input_ids"])
318
+
319
+ model_inputs["labels"] = labels["input_ids"][:, 1:].tolist()
320
+ return model_inputs
321
+
322
+ tokenized_datasets = {}
323
+ with accelerator.main_process_first():
324
+ for split, dataset in processed_datasets.items():
325
+ tokenized_datasets[split] = dataset.map(
326
+ tokenize_dataset,
327
+ batched=True,
328
+ cache_file_name=f"./tokenized_dataset_{split}.arrow",
329
+ remove_columns=dataset.column_names,
330
+ load_from_cache_file=not args.overwrite_cache,
331
+ desc="Running tokenizer on dataset"
332
+ )
333
+
334
+ train_dataset = tokenized_datasets["train"]
335
+ eval_dataset = tokenized_datasets["validation"]
336
+ train_dataset.set_format(type='torch')
337
+ eval_dataset.set_format(type='torch')
338
+
339
+ data_collator = DataCollatorWithPadding(tokenizer, "max_length")
340
+
341
+ # first epoch we don't shuffle
342
+ train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=args.per_device_train_batch_size,
343
+ collate_fn=data_collator)
344
+ eval_dataloader = DataLoader(eval_dataset, batch_size=args.per_device_eval_batch_size, collate_fn=data_collator)
345
+
346
+ # train the model
347
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader,
348
+ eval_dataloader)
349
+ # Scheduler and math around the number of training steps.
350
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
351
+ if args.max_train_steps is None:
352
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
353
+ else:
354
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
355
+
356
+ num_warmup_steps = args.num_warmup_steps if args.num_warmup_steps else math.ceil(args.max_train_steps *
357
+ args.warmup_percentage)
358
+ scheduler = get_scheduler(
359
+ name=args.lr_scheduler_type,
360
+ optimizer=optimizer,
361
+ num_warmup_steps=num_warmup_steps,
362
+ num_training_steps=args.max_train_steps,
363
+ )
364
+ # Train!
365
+ total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
366
+
367
+ logger.info("***** Running training *****")
368
+ logger.info(f" Num examples = {len(train_dataset)}")
369
+ logger.info(f" Num eval examples = {len(eval_dataset)}")
370
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
371
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
372
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
373
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
374
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
375
+ logger.info(f" Warmup steps = {num_warmup_steps}")
376
+ logger.info(f" Logging training progress every {args.log_freq} optimization steps")
377
+
378
+ # Only show the progress bar once on each machine.
379
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
380
+ completed_steps = 0
381
+ switched_train_dataloader = False
382
+ for epoch in range(args.num_train_epochs):
383
+ model.train()
384
+ if epoch > 0 and not switched_train_dataloader:
385
+ train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size,
386
+ shuffle=True, collate_fn=data_collator)
387
+ train_dataloader = accelerator.prepare(train_dataloader)
388
+ switched_train_dataloader = True
389
+
390
+ for step, batch in enumerate(train_dataloader):
391
+ outputs = model(**batch)
392
+ loss = torch.mean(outputs.loss)
393
+ accelerator.backward(loss)
394
+ if ((step + 1) % args.gradient_accumulation_steps == 0) or (step + 1 == len(train_dataloader)):
395
+ optimizer.step()
396
+ scheduler.step()
397
+ optimizer.zero_grad()
398
+ progress_bar.update(1)
399
+ progress_bar.set_postfix(loss=round(loss.item(), 3))
400
+ completed_steps += 1
401
+
402
+ if completed_steps >= args.max_train_steps:
403
+ break
404
+
405
+ if step % (args.log_freq * args.gradient_accumulation_steps) == 0:
406
+ validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args)
407
+ model.train()
408
+ logger.info(f"Train loss {loss.item()} , validation loss {validation_loss}")
409
+ if args.wandb and accelerator.is_local_main_process:
410
+ import wandb
411
+ wandb.log({"loss": loss.item(),
412
+ "lr": scheduler.get_last_lr()[0],
413
+ "validation_loss": validation_loss,
414
+ "completed_steps": completed_steps})
415
+
416
+ logger.info("Saving model {}".format(args.model_save_name))
417
+ accelerator.wait_for_everyone()
418
+ unwrapped_model = accelerator.unwrap_model(model)
419
+ accelerator.save(unwrapped_model.state_dict(), "{}_{}.bin".format(args.model_save_name, epoch))
420
+
421
+ # Calculating the validation loss over epoch
422
+ validation_loss = eval_qa_s2s_epoch(model, eval_dataloader, accelerator, args)
423
+
424
+ logger.info("Epoch: {}".format(epoch))
425
+ logger.info("Validation loss: {}".format(validation_loss))
426
+
427
+
428
+ def main():
429
+ parser = get_parser()
430
+ parser.add_argument(
431
+ "--wandb",
432
+ action="store_true",
433
+ help="If true, use W&B logging",
434
+ )
435
+ main_args, _ = parser.parse_known_args()
436
+ config = {"args": main_args}
437
+ if main_args.wandb:
438
+ import wandb
439
+ wandb.init(project="Bart_ELI5")
440
+ train(config=config)
441
+
442
+
443
+ main()
444
+
445
+
446
+