| | import argparse |
| | import json |
| | import tqdm |
| | import torch |
| | import pytorch_lightning as pl |
| | from transformers import BertTokenizer, BertForSequenceClassification |
| | from src.models.SequenceClassificationModule import SequenceClassificationModule |
| |
|
| |
|
| | LABEL = [ |
| | "Supported", |
| | "Refuted", |
| | "Not Enough Evidence", |
| | "Conflicting Evidence/Cherrypicking", |
| | ] |
| |
|
| |
|
| | class SequenceClassificationDataLoader(pl.LightningDataModule): |
| | def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False): |
| | super().__init__() |
| | self.tokenizer = tokenizer |
| | self.data_file = data_file |
| | self.batch_size = batch_size |
| | self.add_extra_nee = add_extra_nee |
| |
|
| | def tokenize_strings( |
| | self, |
| | source_sentences, |
| | max_length=400, |
| | pad_to_max_length=False, |
| | return_tensors="pt", |
| | ): |
| | encoded_dict = self.tokenizer( |
| | source_sentences, |
| | max_length=max_length, |
| | padding="max_length" if pad_to_max_length else "longest", |
| | truncation=True, |
| | return_tensors=return_tensors, |
| | ) |
| |
|
| | input_ids = encoded_dict["input_ids"] |
| | attention_masks = encoded_dict["attention_mask"] |
| |
|
| | return input_ids, attention_masks |
| |
|
| | def quadruple_to_string(self, claim, question, answer, bool_explanation=""): |
| | if bool_explanation is not None and len(bool_explanation) > 0: |
| | bool_explanation = ", because " + bool_explanation.lower().strip() |
| | else: |
| | bool_explanation = "" |
| | return ( |
| | "[CLAIM] " |
| | + claim.strip() |
| | + " [QUESTION] " |
| | + question.strip() |
| | + " " |
| | + answer.strip() |
| | + bool_explanation |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser( |
| | description="Given a claim and its 3 QA pairs as evidence, we use another pre-trained BERT model to predict the veracity label." |
| | ) |
| | parser.add_argument( |
| | "-i", |
| | "--claim_with_evidence_file", |
| | default="data_store/dev_top_3_rerank_qa.json", |
| | help="Json file with claim and top question-answer pairs as evidence.", |
| | ) |
| | parser.add_argument( |
| | "-o", |
| | "--output_file", |
| | default="data_store/dev_veracity_prediction.json", |
| | help="Json file with the veracity predictions.", |
| | ) |
| | parser.add_argument( |
| | "-ckpt", |
| | "--best_checkpoint", |
| | type=str, |
| | default="pretrained_models/bert_veracity.ckpt", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | examples = [] |
| | with open(args.claim_with_evidence_file) as f: |
| | for line in f: |
| | examples.append(json.loads(line)) |
| |
|
| | bert_model_name = "bert-base-uncased" |
| |
|
| | tokenizer = BertTokenizer.from_pretrained(bert_model_name) |
| | bert_model = BertForSequenceClassification.from_pretrained( |
| | bert_model_name, num_labels=4, problem_type="single_label_classification" |
| | ) |
| | device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| | trained_model = SequenceClassificationModule.load_from_checkpoint( |
| | args.best_checkpoint, tokenizer=tokenizer, model=bert_model |
| | ).to(device) |
| |
|
| | dataLoader = SequenceClassificationDataLoader( |
| | tokenizer=tokenizer, |
| | data_file="this_is_discontinued", |
| | batch_size=32, |
| | add_extra_nee=False, |
| | ) |
| |
|
| | predictions = [] |
| |
|
| | for example in tqdm.tqdm(examples): |
| | example_strings = [] |
| | for evidence in example["evidence"]: |
| | example_strings.append( |
| | dataLoader.quadruple_to_string( |
| | example["claim"], evidence["question"], evidence["answer"], "" |
| | ) |
| | ) |
| |
|
| | if ( |
| | len(example_strings) == 0 |
| | ): |
| | example["label"] = "Not Enough Evidence" |
| | continue |
| |
|
| | tokenized_strings, attention_mask = dataLoader.tokenize_strings(example_strings) |
| | example_support = torch.argmax( |
| | trained_model( |
| | tokenized_strings.to(device), attention_mask=attention_mask.to(device) |
| | ).logits, |
| | axis=1, |
| | ) |
| |
|
| | has_unanswerable = False |
| | has_true = False |
| | has_false = False |
| |
|
| | for v in example_support: |
| | if v == 0: |
| | has_true = True |
| | if v == 1: |
| | has_false = True |
| | if v in ( |
| | 2, |
| | 3, |
| | ): |
| | has_unanswerable = True |
| |
|
| | if has_unanswerable: |
| | answer = 2 |
| | elif has_true and not has_false: |
| | answer = 0 |
| | elif not has_true and has_false: |
| | answer = 1 |
| | else: |
| | answer = 3 |
| |
|
| | json_data = { |
| | "claim_id": example["claim_id"], |
| | "claim": example["claim"], |
| | "evidence": example["evidence"], |
| | "pred_label": LABEL[answer], |
| | } |
| | predictions.append(json_data) |
| |
|
| | with open(args.output_file, "w", encoding="utf-8") as output_file: |
| | json.dump(predictions, output_file, ensure_ascii=False, indent=4) |
| |
|