| | import argparse |
| | import logging |
| | from torch.utils.data import Dataset, IterableDataset |
| | import gzip |
| | import json |
| | from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments |
| | import sys |
| | from datetime import datetime |
| | import torch |
| | import random |
| | from shutil import copyfile |
| | import os |
| | import wandb |
| | import random |
| | import re |
| | from datasets import load_dataset |
| | import tqdm |
| |
|
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S", |
| | handlers=[logging.StreamHandler(sys.stdout)], |
| | ) |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--lang", required=True) |
| | parser.add_argument("--model_name", default="google/mt5-base") |
| | parser.add_argument("--epochs", default=4, type=int) |
| | parser.add_argument("--batch_size", default=32, type=int) |
| | parser.add_argument("--max_source_length", default=320, type=int) |
| | parser.add_argument("--max_target_length", default=64, type=int) |
| | parser.add_argument("--eval_size", default=1000, type=int) |
| | |
| | args = parser.parse_args() |
| |
|
| | wandb.init(project="doc2query", name=f"{args.lang}-{args.model_name}") |
| |
|
| |
|
| |
|
| |
|
| |
|
| | def main(): |
| | |
| | queries = {} |
| | for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{args.lang}')['train']): |
| | queries[row['id']] = row['text'] |
| |
|
| | """ |
| | collection = {} |
| | for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']): |
| | collection[row['id']] = row['text'] |
| | """ |
| | collection = load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection'] |
| |
|
| | train_pairs = [] |
| | eval_pairs = [] |
| |
|
| |
|
| | with open('qrels.train.tsv') as fIn: |
| | for line in fIn: |
| | qid, _, did, _ = line.strip().split("\t") |
| |
|
| | qid = int(qid) |
| | did = int(did) |
| |
|
| | assert did == collection[did]['id'] |
| | text = collection[did]['text'] |
| |
|
| | pair = (queries[qid], text) |
| | if len(eval_pairs) < args.eval_size: |
| | eval_pairs.append(pair) |
| | else: |
| | train_pairs.append(pair) |
| |
|
| | |
| | print(f"Train pairs: {len(train_pairs)}") |
| |
|
| |
|
| | |
| | model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name) |
| | tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
| |
|
| | save_steps = 1000 |
| |
|
| | output_dir = 'output/'+args.lang+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
| | print("Output dir:", output_dir) |
| |
|
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | train_script_path = os.path.join(output_dir, 'train_script.py') |
| | copyfile(__file__, train_script_path) |
| | with open(train_script_path, 'a') as fOut: |
| | fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
| |
|
| | |
| |
|
| | training_args = Seq2SeqTrainingArguments( |
| | output_dir=output_dir, |
| | bf16=True, |
| | per_device_train_batch_size=args.batch_size, |
| | evaluation_strategy="steps", |
| | save_steps=save_steps, |
| | logging_steps=100, |
| | eval_steps=save_steps, |
| | warmup_steps=1000, |
| | save_total_limit=1, |
| | num_train_epochs=args.epochs, |
| | report_to="wandb", |
| | ) |
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | print("Input:", train_pairs[0][1]) |
| | print("Target:", train_pairs[0][0]) |
| |
|
| | print("Input:", eval_pairs[0][1]) |
| | print("Target:", eval_pairs[0][0]) |
| |
|
| |
|
| | def data_collator(examples): |
| | targets = [row[0] for row in examples] |
| | inputs = [row[1] for row in examples] |
| | label_pad_token_id = -100 |
| |
|
| | model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None) |
| |
|
| | |
| | with tokenizer.as_target_tokenizer(): |
| | labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None) |
| |
|
| | |
| | labels["input_ids"] = [ |
| | [(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"] |
| | ] |
| |
|
| |
|
| | model_inputs["labels"] = torch.tensor(labels["input_ids"]) |
| | return model_inputs |
| |
|
| | |
| | trainer = Seq2SeqTrainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_pairs, |
| | eval_dataset=eval_pairs, |
| | tokenizer=tokenizer, |
| | data_collator=data_collator |
| | ) |
| |
|
| | |
| | train_result = trainer.train() |
| | trainer.save_model() |
| | |
| | |
| | if __name__ == "__main__": |
| | main() |
| |
|
| | |
| | |