| | import gzip |
| | import random |
| |
|
| | from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, AdamW |
| | import sys |
| | import torch |
| | import transformers |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.cuda.amp import autocast |
| | import tqdm |
| | from datetime import datetime |
| | from shutil import copyfile |
| | import os |
| | |
| |
|
| | import gzip |
| | from collections import defaultdict |
| | import logging |
| | import tqdm |
| | import numpy as np |
| | import sys |
| | import pytrec_eval |
| | from sentence_transformers import SentenceTransformer, util, CrossEncoder |
| | import torch |
| |
|
| |
|
| | model_name = sys.argv[1] |
| | max_length = 350 |
| |
|
| | |
| | queries_filepath = 'msmarco-data/trec2019/msmarco-test2019-queries.tsv.gz' |
| | queries_eval = {} |
| | with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn: |
| | for line in fIn: |
| | qid, query = line.strip().split("\t")[0:2] |
| | queries_eval[qid] = query |
| |
|
| | rel = defaultdict(lambda: defaultdict(int)) |
| |
|
| | with open('msmarco-data/trec2019/2019qrels-pass.txt') as fIn: |
| | for line in fIn: |
| | qid, _, pid, score = line.strip().split() |
| | score = int(score) |
| | if score > 0: |
| | rel[qid][pid] = score |
| |
|
| | relevant_qid = [] |
| | for qid in queries_eval: |
| | if len(rel[qid]) > 0: |
| | relevant_qid.append(qid) |
| |
|
| | |
| | passage_cand = {} |
| |
|
| | with gzip.open('msmarco-data/trec2019/msmarco-passagetest2019-top1000.tsv.gz', 'rt', encoding='utf8') as fIn: |
| | for line in fIn: |
| | qid, pid, query, passage = line.strip().split("\t") |
| | if qid not in passage_cand: |
| | passage_cand[qid] = [] |
| |
|
| | passage_cand[qid].append([pid, passage]) |
| |
|
| |
|
| |
|
| | def eval_modal(model_path): |
| | run = {} |
| | model = CrossEncoder(model_path, max_length=512) |
| |
|
| | for qid in relevant_qid: |
| | query = queries_eval[qid] |
| |
|
| | cand = passage_cand[qid] |
| | pids = [c[0] for c in cand] |
| | corpus_sentences = [c[1] for c in cand] |
| |
|
| | |
| | cross_inp = [[query, sent] for sent in corpus_sentences] |
| | if model.config.num_labels > 1: |
| | cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist() |
| | else: |
| | cross_scores = model.predict(cross_inp, activation_fct=torch.nn.Identity()).tolist() |
| |
|
| | cross_scores_sparse = {} |
| | for idx, pid in enumerate(pids): |
| | cross_scores_sparse[pid] = cross_scores[idx] |
| |
|
| | sparse_scores = cross_scores_sparse |
| | run[qid] = {} |
| | for pid in sparse_scores: |
| | run[qid][pid] = float(sparse_scores[pid]) |
| |
|
| | evaluator = pytrec_eval.RelevanceEvaluator(rel, {'ndcg_cut.10'}) |
| | scores = evaluator.evaluate(run) |
| | scores_mean = np.mean([ele["ndcg_cut_10"] for ele in scores.values()]) |
| |
|
| | print("NDCG@10: {:.2f}".format(scores_mean * 100)) |
| | return scores_mean |
| |
|
| | |
| |
|
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | config = AutoConfig.from_pretrained(model_name) |
| | config.num_labels = 1 |
| | model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config) |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
|
| |
|
| | |
| | if len(sys.argv) > 2: |
| | num_layers = int(sys.argv[2]) |
| | if num_layers == 6: |
| | layers_to_keep = [0, 2, 4, 6, 8, 10] |
| | elif num_layers == 4: |
| | layers_to_keep = [1, 4, 7, 10] |
| | elif num_layers == 2: |
| | layers_to_keep = [3, 7] |
| | else: |
| | print("Unknown number of layers to keep:", num_layers) |
| | exit() |
| | |
| | print("Reduce model to {} layers".format(len(layers_to_keep))) |
| | new_layers = torch.nn.ModuleList([layer_module for i, layer_module in enumerate(model.bert.encoder.layer) if i in layers_to_keep]) |
| | model.bert.encoder.layer = new_layers |
| | model.bert.config.num_hidden_layers = len(layers_to_keep) |
| | model_name += "_L-{}".format(len(layers_to_keep)) |
| |
|
| |
|
| |
|
| |
|
| | |
| |
|
| | queries = {} |
| | corpus = {} |
| |
|
| | output_save_path = 'output/train_cross-encoder_mse-{}-{}'.format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) |
| | output_save_path_latest = output_save_path+"-latest" |
| | tokenizer.save_pretrained(output_save_path) |
| | tokenizer.save_pretrained(output_save_path_latest) |
| |
|
| |
|
| | |
| | train_script_path = os.path.join(output_save_path, 'train_script.py') |
| | copyfile(__file__, train_script_path) |
| | with open(train_script_path, 'a') as fOut: |
| | fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
| |
|
| |
|
| | |
| | train_script_path = os.path.join(output_save_path_latest, 'train_script.py') |
| | copyfile(__file__, train_script_path) |
| | with open(train_script_path, 'a') as fOut: |
| | fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
| |
|
| |
|
| |
|
| | |
| | class MultilingualDataset(Dataset): |
| | def __init__(self): |
| | self.examples = defaultdict(lambda: defaultdict(list)) |
| |
|
| | def add(self, lang, filepath): |
| | open_method = gzip.open if filepath.endswith('.gz') else open |
| | with open_method(filepath, 'rt') as fIn: |
| | for line in fIn: |
| | pid, passage = line.strip().split("\t") |
| | self.examples[pid][lang].append(passage) |
| |
|
| |
|
| | def __len__(self): |
| | return len(self.examples) |
| |
|
| | def __getitem__(self, item): |
| | all_examples = self.examples[item] |
| | lang_examples = random.choice(list(all_examples.values())) |
| | return random.choice(lang_examples) |
| |
|
| |
|
| | train_corpus = MultilingualDataset() |
| | train_corpus.add('en', 'msmarco-data/collection.tsv') |
| | train_corpus.add('de', 'msmarco-data/de/collection.de.opus-mt.tsv.gz') |
| | train_corpus.add('de', 'msmarco-data/de/collection.de.wmt19.tsv.gz') |
| |
|
| |
|
| | train_queries = MultilingualDataset() |
| | train_queries.add('en', 'msmarco-data/queries.train.tsv') |
| | train_queries.add('de', 'msmarco-data/de/queries.train.de.opus-mt.tsv.gz') |
| | train_queries.add('de', 'msmarco-data/de/queries.train.de.wmt19.tsv.gz') |
| |
|
| | |
| | class MSEDataset(Dataset): |
| | def __init__(self, filepath): |
| | super().__init__() |
| |
|
| | self.examples = [] |
| | with open(filepath) as fIn: |
| | for line in fIn: |
| | pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t") |
| | self.examples.append([qid, pid1, pid2, float(pos_score)-float(neg_score)]) |
| |
|
| | def __len__(self): |
| | return len(self.examples) |
| |
|
| | def __getitem__(self, item): |
| | return self.examples[item] |
| |
|
| | train_batch_size = 16 |
| | train_dataset = MSEDataset('msmarco-data/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv') |
| | train_dataloader = DataLoader(train_dataset, drop_last=True, shuffle=True, batch_size=train_batch_size) |
| |
|
| |
|
| | |
| |
|
| | weight_decay = 0.01 |
| | max_grad_norm = 1 |
| | param_optimizer = list(model.named_parameters()) |
| |
|
| | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
| | optimizer_grouped_parameters = [ |
| | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay}, |
| | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} |
| | ] |
| |
|
| | optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5) |
| | scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=len(train_dataloader)) |
| | scaler = torch.cuda.amp.GradScaler() |
| |
|
| | loss_fct = torch.nn.MSELoss() |
| | |
| | model.to(device) |
| |
|
| | auto_save = 10000 |
| | best_ndcg_score = 0 |
| | for step_idx, batch in tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader)): |
| | batch_queries = [train_queries[qid] for qid in batch[0]] |
| | batch_pos = [train_corpus[cid] for cid in batch[1]] |
| | batch_neg = [train_corpus[cid] for cid in batch[2]] |
| | scores = batch[3].float().to(device) |
| |
|
| | with autocast(): |
| | inp_pos = tokenizer(batch_queries, batch_pos, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device) |
| | pred_pos = model(**inp_pos).logits.squeeze() |
| |
|
| | inp_neg = tokenizer(batch_queries, batch_neg, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device) |
| | pred_neg = model(**inp_neg).logits.squeeze() |
| |
|
| | pred_diff = pred_pos - pred_neg |
| | loss_value = loss_fct(pred_diff, scores) |
| |
|
| | |
| | scaler.scale(loss_value).backward() |
| | scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) |
| | scaler.step(optimizer) |
| | scaler.update() |
| |
|
| | optimizer.zero_grad() |
| | scheduler.step() |
| |
|
| | if (step_idx+1) % auto_save == 0: |
| | print("Step:", step_idx+1) |
| | model.save_pretrained(output_save_path_latest) |
| | ndcg_score = eval_modal(output_save_path_latest) |
| |
|
| | if ndcg_score >= best_ndcg_score: |
| | best_ndcg_score = ndcg_score |
| | print("Save to:", output_save_path) |
| | model.save_pretrained(output_save_path) |
| |
|
| | model.save_pretrained(output_save_path) |
| |
|
| |
|
| | |
| | |