| import argparse | |
| import json | |
| import torch | |
| import tqdm | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| from src.models.DualEncoderModule import DualEncoderModule | |
| def triple_to_string(x): | |
| return " </s> ".join([item.strip() for item in x]) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Rerank the QA paris and keep top 3 QA paris as evidence using a pre-trained BERT model." | |
| ) | |
| parser.add_argument( | |
| "-i", | |
| "--top_k_qa_file", | |
| default="data_store/dev_top_k_qa.json", | |
| help="Json file with claim and top k generated question-answer pairs.", | |
| ) | |
| parser.add_argument( | |
| "-o", | |
| "--output_file", | |
| default="data_store/dev_top_3_rerank_qa.json", | |
| help="Json file with the top3 reranked questions.", | |
| ) | |
| parser.add_argument( | |
| "-ckpt", | |
| "--best_checkpoint", | |
| type=str, | |
| default="pretrained_models/bert_dual_encoder.ckpt", | |
| ) | |
| parser.add_argument( | |
| "--top_n", | |
| type=int, | |
| default=3, | |
| help="top_n question answer pairs as evidence to keep.", | |
| ) | |
| args = parser.parse_args() | |
| examples = [] | |
| with open(args.top_k_qa_file) as f: | |
| for line in f: | |
| examples.append(json.loads(line)) | |
| bert_model_name = "bert-base-uncased" | |
| tokenizer = BertTokenizer.from_pretrained(bert_model_name) | |
| bert_model = BertForSequenceClassification.from_pretrained( | |
| bert_model_name, num_labels=2, problem_type="single_label_classification" | |
| ) | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| trained_model = DualEncoderModule.load_from_checkpoint( | |
| args.best_checkpoint, tokenizer=tokenizer, model=bert_model | |
| ).to(device) | |
| with open(args.output_file, "w", encoding="utf-8") as output_file: | |
| for example in tqdm.tqdm(examples): | |
| strs_to_score = [] | |
| values = [] | |
| bm25_qau = example["bm25_qau"] if "bm25_qau" in example else [] | |
| claim = example["claim"] | |
| for question, answer, url in bm25_qau: | |
| str_to_score = triple_to_string([claim, question, answer]) | |
| strs_to_score.append(str_to_score) | |
| values.append([question, answer, url]) | |
| if len(bm25_qau) > 0: | |
| encoded_dict = tokenizer( | |
| strs_to_score, | |
| max_length=512, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| ).to(device) | |
| input_ids = encoded_dict["input_ids"] | |
| attention_masks = encoded_dict["attention_mask"] | |
| scores = torch.softmax( | |
| trained_model(input_ids, attention_mask=attention_masks).logits, | |
| axis=-1, | |
| )[:, 1] | |
| top_n = torch.argsort(scores, descending=True)[: args.top_n] | |
| evidence = [ | |
| { | |
| "question": values[i][0], | |
| "answer": values[i][1], | |
| "url": values[i][2], | |
| } | |
| for i in top_n | |
| ] | |
| else: | |
| evidence = [] | |
| json_data = { | |
| "claim_id": example["claim_id"], | |
| "claim": claim, | |
| "evidence": evidence, | |
| } | |
| output_file.write(json.dumps(json_data, ensure_ascii=False) + "\n") | |
| output_file.flush() | |