| from torch.utils.data import DataLoader |
| from sentence_transformers import LoggingHandler |
| from sentence_transformers.cross_encoder import CrossEncoder |
| from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator |
| from sentence_transformers import InputExample |
| import logging |
| from datetime import datetime |
| import gzip |
| import sys |
| import numpy as np |
| import os |
| from shutil import copyfile |
| import csv |
| import tqdm |
|
|
| |
| logging.basicConfig(format='%(asctime)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| level=logging.INFO, |
| handlers=[LoggingHandler()]) |
| |
|
|
|
|
| |
| model_name = sys.argv[1] |
| train_batch_size = 32 |
| num_epochs = 1 |
| model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
| |
| model = CrossEncoder(model_name, num_labels=1, max_length=512) |
|
|
|
|
| |
| os.makedirs(model_save_path, exist_ok=True) |
|
|
| train_script_path = os.path.join(model_save_path, 'train_script.py') |
| copyfile(__file__, train_script_path) |
| with open(train_script_path, 'a') as fOut: |
| fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
|
|
|
|
| corpus = {} |
| queries = {} |
|
|
| |
| with gzip.open('../data/collection.tsv.gz', 'rt') as fIn: |
| for line in fIn: |
| pid, passage = line.strip().split("\t") |
| corpus[pid] = passage |
|
|
| with open('../data/queries.train.tsv', 'r') as fIn: |
| for line in fIn: |
| qid, query = line.strip().split("\t") |
| queries[qid] = query |
|
|
|
|
|
|
| pos_neg_ration = (4+1) |
| cnt = 0 |
| train_samples = [] |
| dev_samples = {} |
|
|
| num_dev_queries = 125 |
| num_max_dev_negatives = 200 |
|
|
| with gzip.open('../data/qidpidtriples.rnd-shuf.train-eval.tsv.gz', 'rt') as fIn: |
| for line in fIn: |
| qid, pos_id, neg_id = line.strip().split() |
|
|
| if qid not in dev_samples and len(dev_samples) < num_dev_queries: |
| dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()} |
|
|
| if qid in dev_samples: |
| dev_samples[qid]['positive'].add(corpus[pos_id]) |
|
|
| if len(dev_samples[qid]['negative']) < num_max_dev_negatives: |
| dev_samples[qid]['negative'].add(corpus[neg_id]) |
|
|
| with gzip.open('../data/qidpidtriples.rnd-shuf.train.tsv.gz', 'rt') as fIn: |
| for line in tqdm.tqdm(fIn, unit_scale=True): |
| cnt += 1 |
| qid, pos_id, neg_id = line.strip().split() |
| query = queries[qid] |
| if (cnt % pos_neg_ration) == 0: |
| passage = corpus[pos_id] |
| label = 1 |
| else: |
| passage = corpus[neg_id] |
| label = 0 |
|
|
| train_samples.append(InputExample(texts=[query, passage], label=label)) |
|
|
| if len(train_samples) >= 2e7: |
| break |
|
|
|
|
|
|
| train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) |
|
|
| |
|
|
| class CERerankingEvaluator: |
| def __init__(self, samples, mrr_at_k: int = 10, name: str = ''): |
| self.samples = samples |
| self.name = name |
| self.mrr_at_k = mrr_at_k |
|
|
| if isinstance(self.samples, dict): |
| self.samples = list(self.samples.values()) |
|
|
| self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv" |
| self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)] |
|
|
| def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float: |
| if epoch != -1: |
| if steps == -1: |
| out_txt = " after epoch {}:".format(epoch) |
| else: |
| out_txt = " in epoch {} after {} steps:".format(epoch, steps) |
| else: |
| out_txt = ":" |
|
|
| logging.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt) |
|
|
| all_mrr_scores = [] |
| num_queries = 0 |
| num_positives = [] |
| num_negatives = [] |
| for instance in self.samples: |
| query = instance['query'] |
| positive = list(instance['positive']) |
| negative = list(instance['negative']) |
| docs = positive + negative |
| is_relevant = [True]*len(positive) + [False]*len(negative) |
|
|
| if len(positive) == 0 or len(negative) == 0: |
| continue |
|
|
| num_queries += 1 |
| num_positives.append(len(positive)) |
| num_negatives.append(len(negative)) |
|
|
| model_input = [[query, doc] for doc in docs] |
| pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False) |
| pred_scores_argsort = np.argsort(-pred_scores) |
|
|
| mrr_score = 0 |
| for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]): |
| if is_relevant[index]: |
| mrr_score = 1 / (rank+1) |
|
|
| all_mrr_scores.append(mrr_score) |
|
|
| mean_mrr = np.mean(all_mrr_scores) |
| logging.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives))) |
| logging.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100)) |
|
|
| if output_path is not None: |
| csv_path = os.path.join(output_path, self.csv_file) |
| output_file_exists = os.path.isfile(csv_path) |
| with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f: |
| writer = csv.writer(f) |
| if not output_file_exists: |
| writer.writerow(self.csv_headers) |
|
|
| writer.writerow([epoch, steps, mean_mrr]) |
|
|
| return mean_mrr |
|
|
|
|
| evaluator = CERerankingEvaluator(dev_samples) |
|
|
| |
| warmup_steps = 5000 |
| logging.info("Warmup-steps: {}".format(warmup_steps)) |
|
|
|
|
| |
| model.fit(train_dataloader=train_dataloader, |
| evaluator=evaluator, |
| epochs=num_epochs, |
| evaluation_steps=5000, |
| warmup_steps=warmup_steps, |
| output_path=model_save_path, |
| use_amp=True) |
|
|
| |
| model.save(model_save_path+'-latest') |
|
|
| |
| |