| | import sys |
| | import json |
| | from torch.utils.data import DataLoader |
| | from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample |
| | import logging |
| | from datetime import datetime |
| | import gzip |
| | import os |
| | import tarfile |
| | import tqdm |
| | from torch.utils.data import Dataset |
| | import random |
| | from shutil import copyfile |
| | import pickle |
| | import argparse |
| |
|
| | |
| | logging.basicConfig(format='%(asctime)s - %(message)s', |
| | datefmt='%Y-%m-%d %H:%M:%S', |
| | level=logging.INFO, |
| | handlers=[LoggingHandler()]) |
| | |
| |
|
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--train_batch_size", default=64, type=int) |
| | parser.add_argument("--max_seq_length", default=250, type=int) |
| | parser.add_argument("--model_name", default="nicoladecao/msmarco-word2vec256000-distilbert-base-uncased") |
| | parser.add_argument("--max_passages", default=0, type=int) |
| | parser.add_argument("--epochs", default=30, type=int) |
| | parser.add_argument("--pooling", default="mean") |
| | parser.add_argument("--negs_to_use", default=None, help="From which systems should negatives be used? Multiple systems seperated by comma. None = all") |
| | parser.add_argument("--warmup_steps", default=1000, type=int) |
| | parser.add_argument("--lr", default=2e-5, type=float) |
| | parser.add_argument("--num_negs_per_system", default=5, type=int) |
| | parser.add_argument("--use_all_queries", default=False, action="store_true") |
| | args = parser.parse_args() |
| |
|
| | logging.info(str(args)) |
| |
|
| |
|
| |
|
| | |
| | train_batch_size = args.train_batch_size |
| | model_name = args.model_name |
| | max_passages = args.max_passages |
| | max_seq_length = args.max_seq_length |
| |
|
| | num_negs_per_system = args.num_negs_per_system |
| | num_epochs = args.epochs |
| |
|
| | |
| |
|
| | logging.info("Create new SBERT model") |
| | word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length) |
| | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling) |
| | model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) |
| |
|
| | |
| | word_embedding_model.auto_model.embeddings.requires_grad = False |
| |
|
| | model_save_path = f'output/train_bi-encoder-margin_mse-word2vec-{model_name.replace("/", "-")}-batch_size_{train_batch_size}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}' |
| |
|
| |
|
| | |
| | os.makedirs(model_save_path, exist_ok=True) |
| |
|
| | train_script_path = os.path.join(model_save_path, 'train_script.py') |
| | copyfile(__file__, train_script_path) |
| | with open(train_script_path, 'a') as fOut: |
| | fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
| |
|
| |
|
| | |
| | data_folder = 'msmarco-data' |
| |
|
| | |
| | corpus = {} |
| | collection_filepath = os.path.join(data_folder, 'collection.tsv') |
| | if not os.path.exists(collection_filepath): |
| | tar_filepath = os.path.join(data_folder, 'collection.tar.gz') |
| | if not os.path.exists(tar_filepath): |
| | logging.info("Download collection.tar.gz") |
| | util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath) |
| |
|
| | with tarfile.open(tar_filepath, "r:gz") as tar: |
| | tar.extractall(path=data_folder) |
| |
|
| | logging.info("Read corpus: collection.tsv") |
| | with open(collection_filepath, 'r', encoding='utf8') as fIn: |
| | for line in fIn: |
| | pid, passage = line.strip().split("\t") |
| | pid = int(pid) |
| | corpus[pid] = passage |
| |
|
| |
|
| | |
| | queries = {} |
| | queries_filepath = os.path.join(data_folder, 'queries.train.tsv') |
| | if not os.path.exists(queries_filepath): |
| | tar_filepath = os.path.join(data_folder, 'queries.tar.gz') |
| | if not os.path.exists(tar_filepath): |
| | logging.info("Download queries.tar.gz") |
| | util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath) |
| |
|
| | with tarfile.open(tar_filepath, "r:gz") as tar: |
| | tar.extractall(path=data_folder) |
| |
|
| |
|
| | with open(queries_filepath, 'r', encoding='utf8') as fIn: |
| | for line in fIn: |
| | qid, query = line.strip().split("\t") |
| | qid = int(qid) |
| | queries[qid] = query |
| |
|
| |
|
| | |
| | |
| | ce_scores_file = os.path.join(data_folder, 'cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz') |
| | if not os.path.exists(ce_scores_file): |
| | logging.info("Download cross-encoder scores file") |
| | util.http_get('https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz', ce_scores_file) |
| |
|
| | logging.info("Load CrossEncoder scores dict") |
| | with gzip.open(ce_scores_file, 'rb') as fIn: |
| | ce_scores = pickle.load(fIn) |
| |
|
| | |
| | hard_negatives_filepath = os.path.join(data_folder, 'msmarco-hard-negatives.jsonl.gz') |
| | if not os.path.exists(hard_negatives_filepath): |
| | logging.info("Download cross-encoder scores file") |
| | util.http_get('https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/msmarco-hard-negatives.jsonl.gz', hard_negatives_filepath) |
| |
|
| |
|
| | logging.info("Read hard negatives train file") |
| | train_queries = {} |
| | negs_to_use = None |
| | with gzip.open(hard_negatives_filepath, 'rt') as fIn: |
| | for line in tqdm.tqdm(fIn): |
| | if max_passages > 0 and len(train_queries) >= max_passages: |
| | break |
| | data = json.loads(line) |
| |
|
| | |
| | pos_pids = data['pos'] |
| |
|
| | |
| | neg_pids = set() |
| | if negs_to_use is None: |
| | if args.negs_to_use is not None: |
| | negs_to_use = args.negs_to_use.split(",") |
| | else: |
| | negs_to_use = list(data['neg'].keys()) |
| | logging.info("Using negatives from the following systems: {}".format(", ".join(negs_to_use))) |
| |
|
| | for system_name in negs_to_use: |
| | if system_name not in data['neg']: |
| | continue |
| |
|
| | system_negs = data['neg'][system_name] |
| | negs_added = 0 |
| | for pid in system_negs: |
| | if pid not in neg_pids: |
| | neg_pids.add(pid) |
| | negs_added += 1 |
| | if negs_added >= num_negs_per_system: |
| | break |
| |
|
| | if args.use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0): |
| | train_queries[data['qid']] = {'qid': data['qid'], 'query': queries[data['qid']], 'pos': pos_pids, 'neg': neg_pids} |
| |
|
| | logging.info("Train queries: {}".format(len(train_queries))) |
| |
|
| | |
| | |
| | class MSMARCODataset(Dataset): |
| | def __init__(self, queries, corpus, ce_scores): |
| | self.queries = queries |
| | self.queries_ids = list(queries.keys()) |
| | self.corpus = corpus |
| | self.ce_scores = ce_scores |
| |
|
| | for qid in self.queries: |
| | self.queries[qid]['pos'] = list(self.queries[qid]['pos']) |
| | self.queries[qid]['neg'] = list(self.queries[qid]['neg']) |
| | random.shuffle(self.queries[qid]['neg']) |
| |
|
| | def __getitem__(self, item): |
| | query = self.queries[self.queries_ids[item]] |
| | query_text = query['query'] |
| | qid = query['qid'] |
| |
|
| | if len(query['pos']) > 0: |
| | pos_id = query['pos'].pop(0) |
| | pos_text = self.corpus[pos_id] |
| | query['pos'].append(pos_id) |
| | else: |
| | pos_id = query['neg'].pop(0) |
| | pos_text = self.corpus[pos_id] |
| | query['neg'].append(pos_id) |
| |
|
| | |
| | neg_id = query['neg'].pop(0) |
| | neg_text = self.corpus[neg_id] |
| | query['neg'].append(neg_id) |
| |
|
| | pos_score = self.ce_scores[qid][pos_id] |
| | neg_score = self.ce_scores[qid][neg_id] |
| |
|
| | return InputExample(texts=[query_text, pos_text, neg_text], label=pos_score-neg_score) |
| |
|
| | def __len__(self): |
| | return len(self.queries) |
| |
|
| | |
| | train_dataset = MSMARCODataset(queries=train_queries, corpus=corpus, ce_scores=ce_scores) |
| | train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, drop_last=True) |
| | train_loss = losses.MarginMSELoss(model=model) |
| |
|
| | |
| | model.fit(train_objectives=[(train_dataloader, train_loss)], |
| | epochs=num_epochs, |
| | warmup_steps=args.warmup_steps, |
| | use_amp=True, |
| | checkpoint_path=model_save_path, |
| | checkpoint_save_steps=10000, |
| | optimizer_params = {'lr': args.lr}, |
| | ) |
| |
|
| | |
| | model.save(model_save_path) |
| |
|
| | |
| | |