| |
|
| | import sys
|
| | import json
|
| | from torch.utils.data import DataLoader
|
| | from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample
|
| | import logging
|
| | from datetime import datetime
|
| | import gzip
|
| | import os
|
| | import tarfile
|
| | from collections import defaultdict
|
| | from torch.utils.data import IterableDataset
|
| | import tqdm
|
| | from torch.utils.data import Dataset
|
| | import random
|
| | from shutil import copyfile
|
| |
|
| | import argparse
|
| |
|
| | parser = argparse.ArgumentParser()
|
| | parser.add_argument("--train_batch_size", default=64, type=int)
|
| | parser.add_argument("--max_seq_length", default=300, type=int)
|
| | parser.add_argument("--model_name", required=True)
|
| | parser.add_argument("--max_passages", default=0, type=int)
|
| | parser.add_argument("--epochs", default=10, type=int)
|
| | parser.add_argument("--pooling", default="cls")
|
| | parser.add_argument("--negs_to_use", default=None, help="From which systems should negatives be used? Multiple systems seperated by comma. None = all")
|
| | parser.add_argument("--warmup_steps", default=1000, type=int)
|
| | parser.add_argument("--lr", default=2e-5, type=float)
|
| | parser.add_argument("--name", default='')
|
| | parser.add_argument("--num_negs_per_system", default=5, type=int)
|
| | parser.add_argument("--use_pre_trained_model", default=False, action="store_true")
|
| | parser.add_argument("--use_all_queries", default=False, action="store_true")
|
| | args = parser.parse_args()
|
| |
|
| | print(args)
|
| |
|
| |
|
| | logging.basicConfig(format='%(asctime)s - %(message)s',
|
| | datefmt='%Y-%m-%d %H:%M:%S',
|
| | level=logging.INFO,
|
| | handlers=[LoggingHandler()])
|
| |
|
| |
|
| |
|
| | train_batch_size = args.train_batch_size
|
| | model_name = args.model_name
|
| | max_passages = args.max_passages
|
| | max_seq_length = args.max_seq_length
|
| |
|
| | num_negs_per_system = args.num_negs_per_system
|
| | num_epochs = args.epochs
|
| |
|
| |
|
| | if args.use_pre_trained_model:
|
| | print("use pretrained SBERT model")
|
| | model = SentenceTransformer(model_name)
|
| | model.max_seq_length = max_seq_length
|
| | else:
|
| | print("Create new SBERT model")
|
| | word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
|
| | pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), args.pooling)
|
| | model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
|
| |
|
| | model_save_path = f'output/train_bi-encoder-margin_mse_en-{args.name}-{model_name.replace("/", "-")}-batch_size_{train_batch_size}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
|
| |
|
| |
|
| |
|
| | os.makedirs(model_save_path, exist_ok=True)
|
| |
|
| | train_script_path = os.path.join(model_save_path, 'train_script.py')
|
| | copyfile(__file__, train_script_path)
|
| | with open(train_script_path, 'a') as fOut:
|
| | fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
| |
|
| |
|
| |
|
| | data_folder = 'msmarco-data'
|
| |
|
| |
|
| | corpus = {}
|
| | collection_filepath = os.path.join(data_folder, 'collection.tsv')
|
| | if not os.path.exists(collection_filepath):
|
| | tar_filepath = os.path.join(data_folder, 'collection.tar.gz')
|
| | if not os.path.exists(tar_filepath):
|
| | logging.info("Download collection.tar.gz")
|
| | util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath)
|
| |
|
| | with tarfile.open(tar_filepath, "r:gz") as tar:
|
| | tar.extractall(path=data_folder)
|
| |
|
| | logging.info("Read corpus: collection.tsv")
|
| | with open(collection_filepath, 'r', encoding='utf8') as fIn:
|
| | for line in fIn:
|
| | pid, passage = line.strip().split("\t")
|
| | corpus[pid] = passage
|
| |
|
| |
|
| |
|
| | queries = {}
|
| | queries_filepath = os.path.join(data_folder, 'queries.train.tsv')
|
| | if not os.path.exists(queries_filepath):
|
| | tar_filepath = os.path.join(data_folder, 'queries.tar.gz')
|
| | if not os.path.exists(tar_filepath):
|
| | logging.info("Download queries.tar.gz")
|
| | util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath)
|
| |
|
| | with tarfile.open(tar_filepath, "r:gz") as tar:
|
| | tar.extractall(path=data_folder)
|
| |
|
| |
|
| | with open(queries_filepath, 'r', encoding='utf8') as fIn:
|
| | for line in fIn:
|
| | qid, query = line.strip().split("\t")
|
| | queries[qid] = query
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | train_filepath = '/home/msmarco/data/hard-negatives/msmarco-hard-negatives-v6.jsonl.gz'
|
| |
|
| |
|
| | logging.info("Read train dataset")
|
| | train_queries = {}
|
| | ce_scores = {}
|
| | negs_to_use = None
|
| | with gzip.open(train_filepath, 'rt') as fIn:
|
| | for line in tqdm.tqdm(fIn):
|
| | if max_passages > 0 and len(train_queries) >= max_passages:
|
| | break
|
| |
|
| | data = json.loads(line)
|
| |
|
| | if data['qid'] not in ce_scores:
|
| | ce_scores[data['qid']] = {}
|
| |
|
| |
|
| | for item in data['pos'] :
|
| | ce_scores[data['qid']][item['pid']] = item['ce-score']
|
| |
|
| |
|
| | pos_pids = [item['pid'] for item in data['pos']]
|
| |
|
| |
|
| | neg_pids = set()
|
| | if negs_to_use is None:
|
| | if args.negs_to_use is not None:
|
| | negs_to_use = args.negs_to_use.split(",")
|
| | else:
|
| | negs_to_use = list(data['neg'].keys())
|
| | print("Using negatives from the following systems:", negs_to_use)
|
| |
|
| | for system_name in negs_to_use:
|
| | if system_name not in data['neg']:
|
| | continue
|
| |
|
| | system_negs = data['neg'][system_name]
|
| |
|
| | negs_added = 0
|
| | for item in system_negs:
|
| |
|
| | ce_scores[data['qid']][item['pid']] = item['ce-score']
|
| |
|
| | pid = item['pid']
|
| | if pid not in neg_pids:
|
| | neg_pids.add(pid)
|
| | negs_added += 1
|
| | if negs_added >= num_negs_per_system:
|
| | break
|
| |
|
| | if args.use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0):
|
| | train_queries[data['qid']] = {'qid': data['qid'], 'query': queries[data['qid']], 'pos': pos_pids, 'neg': neg_pids}
|
| |
|
| | logging.info("Train queries: {}".format(len(train_queries)))
|
| |
|
| |
|
| |
|
| | class MSMARCODataset(Dataset):
|
| | def __init__(self, queries, corpus, ce_scores):
|
| | self.queries = queries
|
| | self.queries_ids = list(queries.keys())
|
| | self.corpus = corpus
|
| | self.ce_scores = ce_scores
|
| |
|
| | for qid in self.queries:
|
| | self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
|
| | self.queries[qid]['neg'] = list(self.queries[qid]['neg'])
|
| | random.shuffle(self.queries[qid]['neg'])
|
| |
|
| | def __getitem__(self, item):
|
| | query = self.queries[self.queries_ids[item]]
|
| | query_text = query['query']
|
| | qid = query['qid']
|
| |
|
| | if len(query['pos']) > 0:
|
| | pos_id = query['pos'].pop(0)
|
| | pos_text = self.corpus[pos_id]
|
| | query['pos'].append(pos_id)
|
| | else:
|
| | pos_id = query['neg'].pop(0)
|
| | pos_text = self.corpus[pos_id]
|
| | query['neg'].append(pos_id)
|
| |
|
| |
|
| | neg_id = query['neg'].pop(0)
|
| | neg_text = self.corpus[neg_id]
|
| | query['neg'].append(neg_id)
|
| |
|
| | pos_score = self.ce_scores[qid][pos_id]
|
| | neg_score = self.ce_scores[qid][neg_id]
|
| |
|
| | return InputExample(texts=[query_text, pos_text, neg_text], label=pos_score-neg_score)
|
| |
|
| | def __len__(self):
|
| | return len(self.queries)
|
| |
|
| |
|
| | train_dataset = MSMARCODataset(queries=train_queries, corpus=corpus, ce_scores=ce_scores)
|
| | train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, drop_last=True)
|
| | train_loss = losses.MarginMSELoss(model=model)
|
| |
|
| |
|
| | model.fit(train_objectives=[(train_dataloader, train_loss)],
|
| | epochs=num_epochs,
|
| | warmup_steps=args.warmup_steps,
|
| | use_amp=True,
|
| | checkpoint_path=model_save_path,
|
| | checkpoint_save_steps=10000,
|
| | checkpoint_save_total_limit = 0,
|
| | optimizer_params = {'lr': args.lr},
|
| | )
|
| |
|
| |
|
| | model.save(model_save_path)
|
| |
|
| |
|
| | |
| | |