Spaces:
Build error
Build error
zhenyundeng
commited on
Commit
·
afdeeca
1
Parent(s):
0334469
add files
Browse files- .gitattributes +3 -1
- README.md +6 -5
- app.py +390 -0
- averitec/data/all_samples.json +3 -0
- averitec/data/sample_claims.py +39 -0
- averitec/models/AveritecModule.py +312 -0
- averitec/models/DualEncoderModule.py +143 -0
- averitec/models/JustificationGenerationModule.py +193 -0
- averitec/models/NaiveSeqClassModule.py +145 -0
- averitec/models/SequenceClassificationModule.py +179 -0
- averitec/models/__pycache__/AveritecModule.cpython-38.pyc +0 -0
- averitec/models/__pycache__/DualEncoderModule.cpython-38.pyc +0 -0
- averitec/models/__pycache__/JustificationGenerationModule.cpython-38.pyc +0 -0
- averitec/models/__pycache__/SequenceClassificationModule.cpython-38.pyc +0 -0
- averitec/models/__pycache__/utils.cpython-38.pyc +0 -0
- averitec/models/utils.py +119 -0
- averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt +3 -0
- averitec/pretrained_models/bert_dual_encoder.ckpt +3 -0
- averitec/pretrained_models/bert_veracity.ckpt +3 -0
- requirements.txt +22 -0
.gitattributes
CHANGED
|
@@ -25,7 +25,6 @@
|
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
@@ -33,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 28 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.db filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
|
README.md
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 4.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: AVeriTeC
|
| 3 |
+
emoji: 🏆
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.37.2
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Created by zd302 at 17/07/2024
|
| 4 |
+
|
| 5 |
+
from fastapi import FastAPI
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
# from averitec.models.AveritecModule import Wikipediaretriever, Googleretriever, veracity_prediction, justification_generation
|
| 8 |
+
import uvicorn
|
| 9 |
+
|
| 10 |
+
app = FastAPI()
|
| 11 |
+
|
| 12 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 13 |
+
import os
|
| 14 |
+
import torch
|
| 15 |
+
import numpy as np
|
| 16 |
+
import requests
|
| 17 |
+
from rank_bm25 import BM25Okapi
|
| 18 |
+
from bs4 import BeautifulSoup
|
| 19 |
+
|
| 20 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
| 21 |
+
from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
|
| 22 |
+
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
| 23 |
+
import pytorch_lightning as pl
|
| 24 |
+
|
| 25 |
+
from averitec.models.DualEncoderModule import DualEncoderModule
|
| 26 |
+
from averitec.models.SequenceClassificationModule import SequenceClassificationModule
|
| 27 |
+
from averitec.models.JustificationGenerationModule import JustificationGenerationModule
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 30 |
+
import wikipediaapi
|
| 31 |
+
wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en')
|
| 32 |
+
|
| 33 |
+
import nltk
|
| 34 |
+
nltk.download('punkt')
|
| 35 |
+
from nltk import pos_tag, word_tokenize, sent_tokenize
|
| 36 |
+
|
| 37 |
+
import spacy
|
| 38 |
+
os.system("python -m spacy download en_core_web_sm")
|
| 39 |
+
nlp = spacy.load("en_core_web_sm")
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 42 |
+
# ---------- Load Veracity and Justification prediction model ----------
|
| 43 |
+
LABEL = [
|
| 44 |
+
"Supported",
|
| 45 |
+
"Refuted",
|
| 46 |
+
"Not Enough Evidence",
|
| 47 |
+
"Conflicting Evidence/Cherrypicking",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
# Veracity
|
| 51 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 52 |
+
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 53 |
+
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
|
| 54 |
+
veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
|
| 55 |
+
veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
|
| 56 |
+
# Justification
|
| 57 |
+
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
|
| 58 |
+
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
| 59 |
+
best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
| 60 |
+
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
|
| 63 |
+
# ----------------------------------------------------------------------------
|
| 64 |
+
class Docs:
|
| 65 |
+
def __init__(self, metadata=dict(), page_content=""):
|
| 66 |
+
self.metadata = metadata
|
| 67 |
+
self.page_content = page_content
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ------------------------------ Googleretriever -----------------------------
|
| 71 |
+
def Googleretriever():
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
return 0
|
| 75 |
+
|
| 76 |
+
# ------------------------------ Googleretriever -----------------------------
|
| 77 |
+
|
| 78 |
+
# ------------------------------ Wikipediaretriever --------------------------
|
| 79 |
+
def search_entity_wikipeida(entity):
|
| 80 |
+
find_evidence = []
|
| 81 |
+
|
| 82 |
+
page_py = wiki_wiki.page(entity)
|
| 83 |
+
if page_py.exists():
|
| 84 |
+
introduction = page_py.summary
|
| 85 |
+
find_evidence.append([str(entity), introduction])
|
| 86 |
+
|
| 87 |
+
return find_evidence
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def clean_str(p):
|
| 91 |
+
return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def find_similar_wikipedia(entity, relevant_wikipages):
|
| 95 |
+
# If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
|
| 96 |
+
ent_ = entity.replace(" ", "+")
|
| 97 |
+
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
|
| 98 |
+
response_text = requests.get(search_url).text
|
| 99 |
+
soup = BeautifulSoup(response_text, features="html.parser")
|
| 100 |
+
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
|
| 101 |
+
|
| 102 |
+
if result_divs:
|
| 103 |
+
result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
|
| 104 |
+
similar_titles = result_titles[:5]
|
| 105 |
+
|
| 106 |
+
saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
|
| 107 |
+
for _t in similar_titles:
|
| 108 |
+
if _t not in saved_titles and len(relevant_wikipages) < 5:
|
| 109 |
+
_evi = search_entity_wikipeida(_t)
|
| 110 |
+
# _evi = search_step(_t)
|
| 111 |
+
relevant_wikipages.extend(_evi)
|
| 112 |
+
|
| 113 |
+
return relevant_wikipages
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def find_evidence_from_wikipedia(claim):
|
| 117 |
+
#
|
| 118 |
+
doc = nlp(claim)
|
| 119 |
+
#
|
| 120 |
+
wikipedia_page = []
|
| 121 |
+
for ent in doc.ents:
|
| 122 |
+
relevant_wikipages = search_entity_wikipeida(ent)
|
| 123 |
+
|
| 124 |
+
if len(relevant_wikipages) < 5:
|
| 125 |
+
relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
|
| 126 |
+
|
| 127 |
+
wikipedia_page.extend(relevant_wikipages)
|
| 128 |
+
|
| 129 |
+
return wikipedia_page
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def bm25_retriever(query, corpus, topk=3):
|
| 133 |
+
bm25 = BM25Okapi(corpus)
|
| 134 |
+
#
|
| 135 |
+
query_tokens = word_tokenize(query)
|
| 136 |
+
scores = bm25.get_scores(query_tokens)
|
| 137 |
+
top_n = np.argsort(scores)[::-1][:topk]
|
| 138 |
+
top_n_scores = [scores[i] for i in top_n]
|
| 139 |
+
|
| 140 |
+
return top_n, top_n_scores
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def relevant_sentence_retrieval(query, wiki_intro, k):
|
| 144 |
+
# 1. Create corpus here
|
| 145 |
+
corpus, sentences = [], []
|
| 146 |
+
titles = []
|
| 147 |
+
for i, (title, intro) in enumerate(wiki_intro):
|
| 148 |
+
sents_in_intro = sent_tokenize(intro)
|
| 149 |
+
|
| 150 |
+
for sent in sents_in_intro:
|
| 151 |
+
corpus.append(word_tokenize(sent))
|
| 152 |
+
sentences.append(sent)
|
| 153 |
+
titles.append(title)
|
| 154 |
+
|
| 155 |
+
# ----- BM25
|
| 156 |
+
bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
|
| 157 |
+
bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
|
| 158 |
+
bm25_top_n_titles = [titles[i] for i in bm25_top_n]
|
| 159 |
+
|
| 160 |
+
return bm25_top_n_sents, bm25_top_n_titles
|
| 161 |
+
|
| 162 |
+
# ------------------------------ Wikipediaretriever -----------------------------
|
| 163 |
+
|
| 164 |
+
def Wikipediaretriever(claim):
|
| 165 |
+
# 1. extract relevant wikipedia pages from wikipedia dumps
|
| 166 |
+
wikipedia_page = find_evidence_from_wikipedia(claim)
|
| 167 |
+
|
| 168 |
+
# 2. extract relevant sentences from extracted wikipedia pages
|
| 169 |
+
sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
|
| 170 |
+
|
| 171 |
+
#
|
| 172 |
+
results = []
|
| 173 |
+
for i, (sent, title) in enumerate(zip(sents, titles)):
|
| 174 |
+
metadata = dict()
|
| 175 |
+
metadata['name'] = claim
|
| 176 |
+
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
|
| 177 |
+
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
|
| 178 |
+
metadata['short_name'] = "Evidence {}".format(i + 1)
|
| 179 |
+
metadata['page_number'] = ""
|
| 180 |
+
metadata['query'] = sent
|
| 181 |
+
metadata['title'] = title
|
| 182 |
+
metadata['evidence'] = sent
|
| 183 |
+
metadata['answer'] = ""
|
| 184 |
+
metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata['evidence']
|
| 185 |
+
page_content = f"""{metadata['page_content']}"""
|
| 186 |
+
|
| 187 |
+
results.append(Docs(metadata, page_content))
|
| 188 |
+
|
| 189 |
+
return results
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# ------------------------------ Veracity Prediction ------------------------------
|
| 193 |
+
class SequenceClassificationDataLoader(pl.LightningDataModule):
|
| 194 |
+
def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
|
| 195 |
+
super().__init__()
|
| 196 |
+
self.tokenizer = tokenizer
|
| 197 |
+
self.data_file = data_file
|
| 198 |
+
self.batch_size = batch_size
|
| 199 |
+
self.add_extra_nee = add_extra_nee
|
| 200 |
+
|
| 201 |
+
def tokenize_strings(
|
| 202 |
+
self,
|
| 203 |
+
source_sentences,
|
| 204 |
+
max_length=400,
|
| 205 |
+
pad_to_max_length=False,
|
| 206 |
+
return_tensors="pt",
|
| 207 |
+
):
|
| 208 |
+
encoded_dict = self.tokenizer(
|
| 209 |
+
source_sentences,
|
| 210 |
+
max_length=max_length,
|
| 211 |
+
padding="max_length" if pad_to_max_length else "longest",
|
| 212 |
+
truncation=True,
|
| 213 |
+
return_tensors=return_tensors,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
input_ids = encoded_dict["input_ids"]
|
| 217 |
+
attention_masks = encoded_dict["attention_mask"]
|
| 218 |
+
|
| 219 |
+
return input_ids, attention_masks
|
| 220 |
+
|
| 221 |
+
def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
|
| 222 |
+
if bool_explanation is not None and len(bool_explanation) > 0:
|
| 223 |
+
bool_explanation = ", because " + bool_explanation.lower().strip()
|
| 224 |
+
else:
|
| 225 |
+
bool_explanation = ""
|
| 226 |
+
return (
|
| 227 |
+
"[CLAIM] "
|
| 228 |
+
+ claim.strip()
|
| 229 |
+
+ " [QUESTION] "
|
| 230 |
+
+ question.strip()
|
| 231 |
+
+ " "
|
| 232 |
+
+ answer.strip()
|
| 233 |
+
+ bool_explanation
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def veracity_prediction(claim, evidence):
|
| 238 |
+
dataLoader = SequenceClassificationDataLoader(
|
| 239 |
+
tokenizer=veracity_tokenizer,
|
| 240 |
+
data_file="this_is_discontinued",
|
| 241 |
+
batch_size=32,
|
| 242 |
+
add_extra_nee=False,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
evidence_strings = []
|
| 246 |
+
for evi in evidence:
|
| 247 |
+
evidence_strings.append(dataLoader.quadruple_to_string(claim, evi.metadata["query"], evi.metadata["answer"], ""))
|
| 248 |
+
|
| 249 |
+
if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
|
| 250 |
+
pred_label = "Not Enough Evidence"
|
| 251 |
+
return pred_label
|
| 252 |
+
|
| 253 |
+
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
| 254 |
+
example_support = torch.argmax(
|
| 255 |
+
veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
|
| 256 |
+
|
| 257 |
+
has_unanswerable = False
|
| 258 |
+
has_true = False
|
| 259 |
+
has_false = False
|
| 260 |
+
|
| 261 |
+
for v in example_support:
|
| 262 |
+
if v == 0:
|
| 263 |
+
has_true = True
|
| 264 |
+
if v == 1:
|
| 265 |
+
has_false = True
|
| 266 |
+
if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
|
| 267 |
+
has_unanswerable = True
|
| 268 |
+
|
| 269 |
+
if has_unanswerable:
|
| 270 |
+
answer = 2
|
| 271 |
+
elif has_true and not has_false:
|
| 272 |
+
answer = 0
|
| 273 |
+
elif not has_true and has_false:
|
| 274 |
+
answer = 1
|
| 275 |
+
else:
|
| 276 |
+
answer = 3
|
| 277 |
+
|
| 278 |
+
pred_label = LABEL[answer]
|
| 279 |
+
|
| 280 |
+
return pred_label
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# ------------------------------ Justification Generation ------------------------------
|
| 284 |
+
def extract_claim_str(claim, evidence, verdict_label):
|
| 285 |
+
claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
|
| 286 |
+
|
| 287 |
+
for evi in evidence:
|
| 288 |
+
q_text = evi.metadata['query'].strip()
|
| 289 |
+
|
| 290 |
+
if len(q_text) == 0:
|
| 291 |
+
continue
|
| 292 |
+
|
| 293 |
+
if not q_text[-1] == "?":
|
| 294 |
+
q_text += "?"
|
| 295 |
+
|
| 296 |
+
answer_strings = []
|
| 297 |
+
answer_strings.append(evi.metadata['answer'])
|
| 298 |
+
|
| 299 |
+
claim_str += q_text
|
| 300 |
+
for a_text in answer_strings:
|
| 301 |
+
if a_text:
|
| 302 |
+
if not a_text[-1] == ".":
|
| 303 |
+
a_text += "."
|
| 304 |
+
claim_str += " " + a_text.strip()
|
| 305 |
+
|
| 306 |
+
claim_str += " "
|
| 307 |
+
|
| 308 |
+
claim_str += " [VERDICT] " + verdict_label
|
| 309 |
+
|
| 310 |
+
return claim_str
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def justification_generation(claim, evidence, verdict_label):
|
| 314 |
+
#
|
| 315 |
+
claim_str = extract_claim_str(claim, evidence, verdict_label)
|
| 316 |
+
claim_str.strip()
|
| 317 |
+
pred_justification = justification_model.generate(claim_str, device=device)
|
| 318 |
+
|
| 319 |
+
return pred_justification.strip()
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
# ---------------------------------------------------------------------------------------------------------------------
|
| 323 |
+
class Item(BaseModel):
|
| 324 |
+
claim: str
|
| 325 |
+
source: str
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
@app.get("/")
|
| 329 |
+
def greet_json():
|
| 330 |
+
return {"Hello": "World!"}
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
@app.post("/predict/")
|
| 334 |
+
def fact_checking(item: Item):
|
| 335 |
+
claim = item['claim']
|
| 336 |
+
source = item['source']
|
| 337 |
+
# claim = item.claim
|
| 338 |
+
# source = item.source
|
| 339 |
+
|
| 340 |
+
# Step1: Evidence Retrieval
|
| 341 |
+
if source == "Wikipedia":
|
| 342 |
+
evidence = Wikipediaretriever(claim)
|
| 343 |
+
elif source == "Google":
|
| 344 |
+
evidence = Googleretriever(claim)
|
| 345 |
+
|
| 346 |
+
# Step2: Veracity Prediction and Justification Generation
|
| 347 |
+
verdict_label = veracity_prediction(claim, evidence)
|
| 348 |
+
justification_label = justification_generation(claim, evidence, verdict_label)
|
| 349 |
+
|
| 350 |
+
evidence_list = []
|
| 351 |
+
for evi in evidence:
|
| 352 |
+
evidence_list.append(evi.metadata["query"])
|
| 353 |
+
|
| 354 |
+
return {"Verdict": verdict_label, "Justification": justification_label, "Evidence": evidence_list}
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if __name__ == "__main__":
|
| 358 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# if __name__ == "__main__":
|
| 362 |
+
# item = {
|
| 363 |
+
# "claim": "England won the Euro 2024.",
|
| 364 |
+
# "source": "Wikipedia",
|
| 365 |
+
# }
|
| 366 |
+
#
|
| 367 |
+
# results = fact_checking(item)
|
| 368 |
+
#
|
| 369 |
+
# print(results)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
# # -----------------------------------------------------------------------------------------
|
| 374 |
+
# import requests
|
| 375 |
+
#
|
| 376 |
+
# # 定义API URL
|
| 377 |
+
# api_url = "https://zhenyundeng-zd-api.hf.space/generate/"
|
| 378 |
+
#
|
| 379 |
+
# # 定义请求数据
|
| 380 |
+
# item = {
|
| 381 |
+
# "name": "Alice"
|
| 382 |
+
# }
|
| 383 |
+
#
|
| 384 |
+
# # 发送Get请求
|
| 385 |
+
# # response = requests.get("https://zhenyundeng-zd-api.hf.space/")
|
| 386 |
+
# # 发送POST请求
|
| 387 |
+
# response = requests.post(api_url, json=item)
|
| 388 |
+
#
|
| 389 |
+
# # 打印响应
|
| 390 |
+
# print(response.json())
|
averitec/data/all_samples.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ef79bab962c2b17d56eb2582b9919bfe8023858fa13ba20c591900857b561854
|
| 3 |
+
size 11444395
|
averitec/data/sample_claims.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Created by zd302 at 09/05/2024
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
CLAIMS_Type = {
|
| 7 |
+
"Claim": [
|
| 8 |
+
"England won the Euro 2024.",
|
| 9 |
+
"Albert Einstein works in the field of computer science.",
|
| 10 |
+
],
|
| 11 |
+
"Event/Property Claim": [
|
| 12 |
+
'Hunter Biden had no experience in Ukraine or in the energy sector when he joined the board of Burisma.',
|
| 13 |
+
"After the police shooting of Jacob Blake, Gov. Tony Evers & Lt. Gov. Mandela Barnes did not call for peace or encourage calm.",
|
| 14 |
+
"President Trump fully co-operated with the investigation into Russian interference in the 2016 U.S presidential campaign.",
|
| 15 |
+
],
|
| 16 |
+
"Causal Claim":[
|
| 17 |
+
"Anxiety levels among young teenagers dropped during the coronavirus pandemic, a study has suggested",
|
| 18 |
+
"Auto workers across Michigan could have lost their jobs if not for Barack Obama and Joe Biden",
|
| 19 |
+
],
|
| 20 |
+
"Numerical Claim":[
|
| 21 |
+
"Sweden, despite never having had lockdown, has a lower COVID-19 death rate than Spain, Italy, and the United Kingdom.",
|
| 22 |
+
"According to Harry Roque, even if 10,000 people die, 10 million COVID-19 cases in the country will not be a loss.",
|
| 23 |
+
]
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
CLAIMS_FACT_CHECKING_STRATEGY= {
|
| 27 |
+
"Written Evidence": [
|
| 28 |
+
"Pretty Little Thing's terms and conditions state that its products may contain chemicals that can cause cancer, birth defects or other reproductive harm.",
|
| 29 |
+
"Pretty Little Thing products may contain chemicals that can cause cancer, birth defects or other reproductive harm.",
|
| 30 |
+
],
|
| 31 |
+
"Numerical Comparison":[
|
| 32 |
+
"Congress party claims regarding shortfall in Government earnings",
|
| 33 |
+
"On average, one person dies by suicide every 22 hours in West Virginia, United States.",
|
| 34 |
+
],
|
| 35 |
+
"Consultation":[
|
| 36 |
+
"Your reaction to an optical illusion is an indication of your state of mind.",
|
| 37 |
+
"The last time people created a Hollywood blacklist, people ended up killing themselves. They were accused, and they lost their right to work.",
|
| 38 |
+
]
|
| 39 |
+
}
|
averitec/models/AveritecModule.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Created by zd302 at 17/07/2024
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import requests
|
| 8 |
+
from rank_bm25 import BM25Okapi
|
| 9 |
+
from bs4 import BeautifulSoup
|
| 10 |
+
|
| 11 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
| 12 |
+
from transformers import BloomTokenizerFast, BloomForCausalLM, BertTokenizer, BertForSequenceClassification
|
| 13 |
+
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
| 14 |
+
import pytorch_lightning as pl
|
| 15 |
+
|
| 16 |
+
from averitec.models.DualEncoderModule import DualEncoderModule
|
| 17 |
+
from averitec.models.SequenceClassificationModule import SequenceClassificationModule
|
| 18 |
+
from averitec.models.JustificationGenerationModule import JustificationGenerationModule
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
import wikipediaapi
|
| 22 |
+
wiki_wiki = wikipediaapi.Wikipedia('AVeriTeC (zd302@cam.ac.uk)', 'en')
|
| 23 |
+
import os
|
| 24 |
+
import nltk
|
| 25 |
+
nltk.download('punkt')
|
| 26 |
+
from nltk import pos_tag, word_tokenize, sent_tokenize
|
| 27 |
+
|
| 28 |
+
import spacy
|
| 29 |
+
os.system("python -m spacy download en_core_web_sm")
|
| 30 |
+
nlp = spacy.load("en_core_web_sm")
|
| 31 |
+
|
| 32 |
+
# ---------- Load Veracity and Justification prediction model ----------
|
| 33 |
+
LABEL = [
|
| 34 |
+
"Supported",
|
| 35 |
+
"Refuted",
|
| 36 |
+
"Not Enough Evidence",
|
| 37 |
+
"Conflicting Evidence/Cherrypicking",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
# Veracity
|
| 41 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 42 |
+
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
| 43 |
+
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
|
| 44 |
+
veracity_model = SequenceClassificationModule.load_from_checkpoint("averitec/pretrained_models/bert_veracity.ckpt",
|
| 45 |
+
tokenizer=veracity_tokenizer, model=bert_model).to(device)
|
| 46 |
+
# Justification
|
| 47 |
+
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
|
| 48 |
+
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
| 49 |
+
best_checkpoint = 'averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
| 50 |
+
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ----------------------------------------------------------------------------
|
| 55 |
+
class Docs:
|
| 56 |
+
def __init__(self, metadata=dict(), page_content=""):
|
| 57 |
+
self.metadata = metadata
|
| 58 |
+
self.page_content = page_content
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ------------------------------ Googleretriever -----------------------------
|
| 62 |
+
def Googleretriever():
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
return 0
|
| 66 |
+
|
| 67 |
+
# ------------------------------ Googleretriever -----------------------------
|
| 68 |
+
|
| 69 |
+
# ------------------------------ Wikipediaretriever --------------------------
|
| 70 |
+
def search_entity_wikipeida(entity):
|
| 71 |
+
find_evidence = []
|
| 72 |
+
|
| 73 |
+
page_py = wiki_wiki.page(entity)
|
| 74 |
+
if page_py.exists():
|
| 75 |
+
introduction = page_py.summary
|
| 76 |
+
find_evidence.append([str(entity), introduction])
|
| 77 |
+
|
| 78 |
+
return find_evidence
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def clean_str(p):
|
| 82 |
+
return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def find_similar_wikipedia(entity, relevant_wikipages):
|
| 86 |
+
# If the relevant wikipeida page of the entity is less than 5, find similar wikipedia pages.
|
| 87 |
+
ent_ = entity.replace(" ", "+")
|
| 88 |
+
search_url = f"https://en.wikipedia.org/w/index.php?search={ent_}&title=Special:Search&profile=advanced&fulltext=1&ns0=1"
|
| 89 |
+
response_text = requests.get(search_url).text
|
| 90 |
+
soup = BeautifulSoup(response_text, features="html.parser")
|
| 91 |
+
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
|
| 92 |
+
|
| 93 |
+
if result_divs:
|
| 94 |
+
result_titles = [clean_str(div.get_text().strip()) for div in result_divs]
|
| 95 |
+
similar_titles = result_titles[:5]
|
| 96 |
+
|
| 97 |
+
saved_titles = [ent[0] for ent in relevant_wikipages] if relevant_wikipages else relevant_wikipages
|
| 98 |
+
for _t in similar_titles:
|
| 99 |
+
if _t not in saved_titles and len(relevant_wikipages) < 5:
|
| 100 |
+
_evi = search_entity_wikipeida(_t)
|
| 101 |
+
# _evi = search_step(_t)
|
| 102 |
+
relevant_wikipages.extend(_evi)
|
| 103 |
+
|
| 104 |
+
return relevant_wikipages
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def find_evidence_from_wikipedia(claim):
|
| 108 |
+
#
|
| 109 |
+
doc = nlp(claim)
|
| 110 |
+
#
|
| 111 |
+
wikipedia_page = []
|
| 112 |
+
for ent in doc.ents:
|
| 113 |
+
relevant_wikipages = search_entity_wikipeida(ent)
|
| 114 |
+
|
| 115 |
+
if len(relevant_wikipages) < 5:
|
| 116 |
+
relevant_wikipages = find_similar_wikipedia(str(ent), relevant_wikipages)
|
| 117 |
+
|
| 118 |
+
wikipedia_page.extend(relevant_wikipages)
|
| 119 |
+
|
| 120 |
+
return wikipedia_page
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def bm25_retriever(query, corpus, topk=3):
|
| 124 |
+
bm25 = BM25Okapi(corpus)
|
| 125 |
+
#
|
| 126 |
+
query_tokens = word_tokenize(query)
|
| 127 |
+
scores = bm25.get_scores(query_tokens)
|
| 128 |
+
top_n = np.argsort(scores)[::-1][:topk]
|
| 129 |
+
top_n_scores = [scores[i] for i in top_n]
|
| 130 |
+
|
| 131 |
+
return top_n, top_n_scores
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def relevant_sentence_retrieval(query, wiki_intro, k):
|
| 135 |
+
# 1. Create corpus here
|
| 136 |
+
corpus, sentences = [], []
|
| 137 |
+
titles = []
|
| 138 |
+
for i, (title, intro) in enumerate(wiki_intro):
|
| 139 |
+
sents_in_intro = sent_tokenize(intro)
|
| 140 |
+
|
| 141 |
+
for sent in sents_in_intro:
|
| 142 |
+
corpus.append(word_tokenize(sent))
|
| 143 |
+
sentences.append(sent)
|
| 144 |
+
titles.append(title)
|
| 145 |
+
|
| 146 |
+
# ----- BM25
|
| 147 |
+
bm25_top_n, bm25_top_n_scores = bm25_retriever(query, corpus, topk=k)
|
| 148 |
+
bm25_top_n_sents = [sentences[i] for i in bm25_top_n]
|
| 149 |
+
bm25_top_n_titles = [titles[i] for i in bm25_top_n]
|
| 150 |
+
|
| 151 |
+
return bm25_top_n_sents, bm25_top_n_titles
|
| 152 |
+
|
| 153 |
+
# ------------------------------ Wikipediaretriever -----------------------------
|
| 154 |
+
|
| 155 |
+
def Wikipediaretriever(claim):
|
| 156 |
+
# 1. extract relevant wikipedia pages from wikipedia dumps
|
| 157 |
+
wikipedia_page = find_evidence_from_wikipedia(claim)
|
| 158 |
+
|
| 159 |
+
# 2. extract relevant sentences from extracted wikipedia pages
|
| 160 |
+
sents, titles = relevant_sentence_retrieval(claim, wikipedia_page, k=3)
|
| 161 |
+
|
| 162 |
+
#
|
| 163 |
+
results = []
|
| 164 |
+
for i, (sent, title) in enumerate(zip(sents, titles)):
|
| 165 |
+
metadata = dict()
|
| 166 |
+
metadata['name'] = claim
|
| 167 |
+
metadata['url'] = "https://en.wikipedia.org/wiki/" + "_".join(title.split())
|
| 168 |
+
metadata['cached_source_url'] = "https://en.wikipedia.org/wiki/" + "_".join(title)
|
| 169 |
+
metadata['short_name'] = "Evidence {}".format(i + 1)
|
| 170 |
+
metadata['page_number'] = ""
|
| 171 |
+
metadata['query'] = sent
|
| 172 |
+
metadata['title'] = title
|
| 173 |
+
metadata['evidence'] = sent
|
| 174 |
+
metadata['answer'] = ""
|
| 175 |
+
metadata['page_content'] = "<b>Title</b>: " + str(metadata['title']) + "<br>" + "<b>Evidence</b>: " + metadata['evidence']
|
| 176 |
+
page_content = f"""{metadata['page_content']}"""
|
| 177 |
+
|
| 178 |
+
results.append(Docs(metadata, page_content))
|
| 179 |
+
|
| 180 |
+
return results
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# ------------------------------ Veracity Prediction ------------------------------
|
| 184 |
+
class SequenceClassificationDataLoader(pl.LightningDataModule):
|
| 185 |
+
def __init__(self, tokenizer, data_file, batch_size, add_extra_nee=False):
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.tokenizer = tokenizer
|
| 188 |
+
self.data_file = data_file
|
| 189 |
+
self.batch_size = batch_size
|
| 190 |
+
self.add_extra_nee = add_extra_nee
|
| 191 |
+
|
| 192 |
+
def tokenize_strings(
|
| 193 |
+
self,
|
| 194 |
+
source_sentences,
|
| 195 |
+
max_length=400,
|
| 196 |
+
pad_to_max_length=False,
|
| 197 |
+
return_tensors="pt",
|
| 198 |
+
):
|
| 199 |
+
encoded_dict = self.tokenizer(
|
| 200 |
+
source_sentences,
|
| 201 |
+
max_length=max_length,
|
| 202 |
+
padding="max_length" if pad_to_max_length else "longest",
|
| 203 |
+
truncation=True,
|
| 204 |
+
return_tensors=return_tensors,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
input_ids = encoded_dict["input_ids"]
|
| 208 |
+
attention_masks = encoded_dict["attention_mask"]
|
| 209 |
+
|
| 210 |
+
return input_ids, attention_masks
|
| 211 |
+
|
| 212 |
+
def quadruple_to_string(self, claim, question, answer, bool_explanation=""):
|
| 213 |
+
if bool_explanation is not None and len(bool_explanation) > 0:
|
| 214 |
+
bool_explanation = ", because " + bool_explanation.lower().strip()
|
| 215 |
+
else:
|
| 216 |
+
bool_explanation = ""
|
| 217 |
+
return (
|
| 218 |
+
"[CLAIM] "
|
| 219 |
+
+ claim.strip()
|
| 220 |
+
+ " [QUESTION] "
|
| 221 |
+
+ question.strip()
|
| 222 |
+
+ " "
|
| 223 |
+
+ answer.strip()
|
| 224 |
+
+ bool_explanation
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def veracity_prediction(claim, evidence):
|
| 229 |
+
dataLoader = SequenceClassificationDataLoader(
|
| 230 |
+
tokenizer=veracity_tokenizer,
|
| 231 |
+
data_file="this_is_discontinued",
|
| 232 |
+
batch_size=32,
|
| 233 |
+
add_extra_nee=False,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
evidence_strings = []
|
| 237 |
+
for evi in evidence:
|
| 238 |
+
evidence_strings.append(dataLoader.quadruple_to_string(claim, evi.metadata["query"], evi.metadata["answer"], ""))
|
| 239 |
+
|
| 240 |
+
if len(evidence_strings) == 0: # If we found no evidence e.g. because google returned 0 pages, just output NEI.
|
| 241 |
+
pred_label = "Not Enough Evidence"
|
| 242 |
+
return pred_label
|
| 243 |
+
|
| 244 |
+
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
| 245 |
+
example_support = torch.argmax(
|
| 246 |
+
veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
|
| 247 |
+
|
| 248 |
+
has_unanswerable = False
|
| 249 |
+
has_true = False
|
| 250 |
+
has_false = False
|
| 251 |
+
|
| 252 |
+
for v in example_support:
|
| 253 |
+
if v == 0:
|
| 254 |
+
has_true = True
|
| 255 |
+
if v == 1:
|
| 256 |
+
has_false = True
|
| 257 |
+
if v in (2, 3,): # TODO another hack -- we cant have different labels for train and test so we do this
|
| 258 |
+
has_unanswerable = True
|
| 259 |
+
|
| 260 |
+
if has_unanswerable:
|
| 261 |
+
answer = 2
|
| 262 |
+
elif has_true and not has_false:
|
| 263 |
+
answer = 0
|
| 264 |
+
elif not has_true and has_false:
|
| 265 |
+
answer = 1
|
| 266 |
+
else:
|
| 267 |
+
answer = 3
|
| 268 |
+
|
| 269 |
+
pred_label = LABEL[answer]
|
| 270 |
+
|
| 271 |
+
return pred_label
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# ------------------------------ Justification Generation ------------------------------
|
| 275 |
+
def extract_claim_str(claim, evidence, verdict_label):
|
| 276 |
+
claim_str = "[CLAIM] " + claim + " [EVIDENCE] "
|
| 277 |
+
|
| 278 |
+
for evi in evidence:
|
| 279 |
+
q_text = evi.metadata['query'].strip()
|
| 280 |
+
|
| 281 |
+
if len(q_text) == 0:
|
| 282 |
+
continue
|
| 283 |
+
|
| 284 |
+
if not q_text[-1] == "?":
|
| 285 |
+
q_text += "?"
|
| 286 |
+
|
| 287 |
+
answer_strings = []
|
| 288 |
+
answer_strings.append(evi.metadata['answer'])
|
| 289 |
+
|
| 290 |
+
claim_str += q_text
|
| 291 |
+
for a_text in answer_strings:
|
| 292 |
+
if a_text:
|
| 293 |
+
if not a_text[-1] == ".":
|
| 294 |
+
a_text += "."
|
| 295 |
+
claim_str += " " + a_text.strip()
|
| 296 |
+
|
| 297 |
+
claim_str += " "
|
| 298 |
+
|
| 299 |
+
claim_str += " [VERDICT] " + verdict_label
|
| 300 |
+
|
| 301 |
+
return claim_str
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def justification_generation(claim, evidence, verdict_label):
|
| 305 |
+
#
|
| 306 |
+
claim_str = extract_claim_str(claim, evidence, verdict_label)
|
| 307 |
+
claim_str.strip()
|
| 308 |
+
pred_justification = justification_model.generate(claim_str, device=device)
|
| 309 |
+
|
| 310 |
+
return pred_justification.strip()
|
| 311 |
+
|
| 312 |
+
|
averitec/models/DualEncoderModule.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
import torch
|
| 3 |
+
from transformers.optimization import AdamW
|
| 4 |
+
import torchmetrics
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DualEncoderModule(pl.LightningModule):
|
| 8 |
+
|
| 9 |
+
def __init__(self, tokenizer, model, learning_rate=1e-3):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.tokenizer = tokenizer
|
| 12 |
+
self.model = model
|
| 13 |
+
self.learning_rate = learning_rate
|
| 14 |
+
|
| 15 |
+
self.train_acc = torchmetrics.Accuracy(
|
| 16 |
+
task="multiclass", num_classes=model.num_labels
|
| 17 |
+
)
|
| 18 |
+
self.val_acc = torchmetrics.Accuracy(
|
| 19 |
+
task="multiclass", num_classes=model.num_labels
|
| 20 |
+
)
|
| 21 |
+
self.test_acc = torchmetrics.Accuracy(
|
| 22 |
+
task="multiclass", num_classes=model.num_labels
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def forward(self, input_ids, **kwargs):
|
| 26 |
+
return self.model(input_ids, **kwargs)
|
| 27 |
+
|
| 28 |
+
def configure_optimizers(self):
|
| 29 |
+
optimizer = AdamW(self.parameters(), lr=self.learning_rate)
|
| 30 |
+
return optimizer
|
| 31 |
+
|
| 32 |
+
def training_step(self, batch, batch_idx):
|
| 33 |
+
pos_ids, pos_mask, neg_ids, neg_mask = batch
|
| 34 |
+
|
| 35 |
+
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
|
| 36 |
+
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
|
| 37 |
+
|
| 38 |
+
pos_outputs = self(
|
| 39 |
+
pos_ids,
|
| 40 |
+
attention_mask=pos_mask,
|
| 41 |
+
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
|
| 42 |
+
pos_ids.get_device()
|
| 43 |
+
),
|
| 44 |
+
)
|
| 45 |
+
neg_outputs = self(
|
| 46 |
+
neg_ids,
|
| 47 |
+
attention_mask=neg_mask,
|
| 48 |
+
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
|
| 49 |
+
neg_ids.get_device()
|
| 50 |
+
),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
loss_scale = 1.0
|
| 54 |
+
loss = pos_outputs.loss + loss_scale * neg_outputs.loss
|
| 55 |
+
|
| 56 |
+
pos_logits = pos_outputs.logits
|
| 57 |
+
pos_preds = torch.argmax(pos_logits, axis=1)
|
| 58 |
+
self.train_acc(
|
| 59 |
+
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
neg_logits = neg_outputs.logits
|
| 63 |
+
neg_preds = torch.argmax(neg_logits, axis=1)
|
| 64 |
+
self.train_acc(
|
| 65 |
+
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return {"loss": loss}
|
| 69 |
+
|
| 70 |
+
def validation_step(self, batch, batch_idx):
|
| 71 |
+
pos_ids, pos_mask, neg_ids, neg_mask = batch
|
| 72 |
+
|
| 73 |
+
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
|
| 74 |
+
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
|
| 75 |
+
|
| 76 |
+
pos_outputs = self(
|
| 77 |
+
pos_ids,
|
| 78 |
+
attention_mask=pos_mask,
|
| 79 |
+
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
|
| 80 |
+
pos_ids.get_device()
|
| 81 |
+
),
|
| 82 |
+
)
|
| 83 |
+
neg_outputs = self(
|
| 84 |
+
neg_ids,
|
| 85 |
+
attention_mask=neg_mask,
|
| 86 |
+
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
|
| 87 |
+
neg_ids.get_device()
|
| 88 |
+
),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
loss_scale = 1.0
|
| 92 |
+
loss = pos_outputs.loss + loss_scale * neg_outputs.loss
|
| 93 |
+
|
| 94 |
+
pos_logits = pos_outputs.logits
|
| 95 |
+
pos_preds = torch.argmax(pos_logits, axis=1)
|
| 96 |
+
self.val_acc(
|
| 97 |
+
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
neg_logits = neg_outputs.logits
|
| 101 |
+
neg_preds = torch.argmax(neg_logits, axis=1)
|
| 102 |
+
self.val_acc(
|
| 103 |
+
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.log("val_acc", self.val_acc)
|
| 107 |
+
|
| 108 |
+
return {"loss": loss}
|
| 109 |
+
|
| 110 |
+
def test_step(self, batch, batch_idx):
|
| 111 |
+
pos_ids, pos_mask, neg_ids, neg_mask = batch
|
| 112 |
+
|
| 113 |
+
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
|
| 114 |
+
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
|
| 115 |
+
|
| 116 |
+
pos_outputs = self(
|
| 117 |
+
pos_ids,
|
| 118 |
+
attention_mask=pos_mask,
|
| 119 |
+
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
|
| 120 |
+
pos_ids.get_device()
|
| 121 |
+
),
|
| 122 |
+
)
|
| 123 |
+
neg_outputs = self(
|
| 124 |
+
neg_ids,
|
| 125 |
+
attention_mask=neg_mask,
|
| 126 |
+
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
|
| 127 |
+
neg_ids.get_device()
|
| 128 |
+
),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
pos_logits = pos_outputs.logits
|
| 132 |
+
pos_preds = torch.argmax(pos_logits, axis=1)
|
| 133 |
+
self.test_acc(
|
| 134 |
+
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
neg_logits = neg_outputs.logits
|
| 138 |
+
neg_preds = torch.argmax(neg_logits, axis=1)
|
| 139 |
+
self.test_acc(
|
| 140 |
+
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self.log("test_acc", self.test_acc)
|
averitec/models/JustificationGenerationModule.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import datasets
|
| 5 |
+
from transformers import MaxLengthCriteria, StoppingCriteriaList
|
| 6 |
+
from transformers.optimization import AdamW
|
| 7 |
+
import itertools
|
| 8 |
+
from averitec.models.utils import count_stats, f1_metric, pairwise_meteor
|
| 9 |
+
from torchmetrics.text.rouge import ROUGEScore
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchmetrics
|
| 12 |
+
from torchmetrics.classification import F1Score
|
| 13 |
+
|
| 14 |
+
def freeze_params(model):
|
| 15 |
+
for layer in model.parameters():
|
| 16 |
+
layer.requires_grade = False
|
| 17 |
+
|
| 18 |
+
class JustificationGenerationModule(pl.LightningModule):
|
| 19 |
+
|
| 20 |
+
def __init__(self, tokenizer, model, learning_rate=1e-3, gen_num_beams=2, gen_max_length=100, should_pad_gen=True):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.tokenizer = tokenizer
|
| 23 |
+
self.model = model
|
| 24 |
+
self.learning_rate = learning_rate
|
| 25 |
+
|
| 26 |
+
self.gen_num_beams = gen_num_beams
|
| 27 |
+
self.gen_max_length = gen_max_length
|
| 28 |
+
self.should_pad_gen = should_pad_gen
|
| 29 |
+
|
| 30 |
+
#self.metrics = datasets.load_metric('meteor')
|
| 31 |
+
|
| 32 |
+
freeze_params(self.model.get_encoder())
|
| 33 |
+
self.freeze_embeds()
|
| 34 |
+
|
| 35 |
+
def freeze_embeds(self):
|
| 36 |
+
''' freeze the positional embedding parameters of the model; adapted from finetune.py '''
|
| 37 |
+
freeze_params(self.model.model.shared)
|
| 38 |
+
for d in [self.model.model.encoder, self.model.model.decoder]:
|
| 39 |
+
freeze_params(d.embed_positions)
|
| 40 |
+
freeze_params(d.embed_tokens)
|
| 41 |
+
|
| 42 |
+
# Do a forward pass through the model
|
| 43 |
+
def forward(self, input_ids, **kwargs):
|
| 44 |
+
return self.model(input_ids, **kwargs)
|
| 45 |
+
|
| 46 |
+
def configure_optimizers(self):
|
| 47 |
+
optimizer = AdamW(self.parameters(), lr = self.learning_rate)
|
| 48 |
+
return optimizer
|
| 49 |
+
|
| 50 |
+
def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
|
| 51 |
+
"""
|
| 52 |
+
Shift input ids one token to the right.
|
| 53 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/bart/modeling_bart.py.
|
| 54 |
+
"""
|
| 55 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 56 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
| 57 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
| 58 |
+
|
| 59 |
+
if pad_token_id is None:
|
| 60 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
| 61 |
+
# replace possible -100 values in labels by `pad_token_id`
|
| 62 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
| 63 |
+
|
| 64 |
+
return shifted_input_ids
|
| 65 |
+
|
| 66 |
+
def run_model(self, batch):
|
| 67 |
+
src_ids, src_mask, tgt_ids = batch[0], batch[1], batch[2]
|
| 68 |
+
|
| 69 |
+
decoder_input_ids = self.shift_tokens_right(
|
| 70 |
+
tgt_ids, self.tokenizer.pad_token_id, self.tokenizer.pad_token_id # BART uses the EOS token to start generation as well. Might have to change for other models.
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
| 74 |
+
return outputs
|
| 75 |
+
|
| 76 |
+
def compute_loss(self, batch):
|
| 77 |
+
tgt_ids = batch[2]
|
| 78 |
+
logits = self.run_model(batch)[0]
|
| 79 |
+
|
| 80 |
+
cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
|
| 81 |
+
loss = cross_entropy(logits.view(-1, logits.shape[-1]), tgt_ids.view(-1))
|
| 82 |
+
|
| 83 |
+
return loss
|
| 84 |
+
|
| 85 |
+
def training_step(self, batch, batch_idx):
|
| 86 |
+
loss = self.compute_loss(batch)
|
| 87 |
+
|
| 88 |
+
self.log("train_loss", loss, on_epoch=True)
|
| 89 |
+
|
| 90 |
+
return {'loss':loss}
|
| 91 |
+
|
| 92 |
+
def validation_step(self, batch, batch_idx):
|
| 93 |
+
preds, loss, tgts = self.generate_and_compute_loss_and_tgts(batch)
|
| 94 |
+
if self.should_pad_gen:
|
| 95 |
+
preds = F.pad(preds, pad=(0, self.gen_max_length - preds.shape[1]), value=self.tokenizer.pad_token_id)
|
| 96 |
+
|
| 97 |
+
self.log('val_loss', loss, prog_bar=True, sync_dist=True)
|
| 98 |
+
|
| 99 |
+
return {'loss': loss, 'pred': preds, 'target': tgts}
|
| 100 |
+
|
| 101 |
+
def test_step(self, batch, batch_idx):
|
| 102 |
+
test_preds, test_loss, test_tgts = self.generate_and_compute_loss_and_tgts(batch)
|
| 103 |
+
if self.should_pad_gen:
|
| 104 |
+
test_preds = F.pad(test_preds, pad=(0, self.gen_max_length - test_preds.shape[1]), value=self.tokenizer.pad_token_id)
|
| 105 |
+
|
| 106 |
+
self.log('test_loss', test_loss, prog_bar=True, sync_dist=True)
|
| 107 |
+
|
| 108 |
+
return {'loss': test_loss, 'pred': test_preds, 'target': test_tgts}
|
| 109 |
+
|
| 110 |
+
def test_epoch_end(self, outputs):
|
| 111 |
+
self.handle_end_of_epoch_scoring(outputs, "test")
|
| 112 |
+
|
| 113 |
+
def validation_epoch_end(self, outputs):
|
| 114 |
+
self.handle_end_of_epoch_scoring(outputs, "val")
|
| 115 |
+
|
| 116 |
+
def handle_end_of_epoch_scoring(self, outputs, prefix):
|
| 117 |
+
gen = {}
|
| 118 |
+
tgt = {}
|
| 119 |
+
rouge = ROUGEScore()
|
| 120 |
+
rouge_metric = lambda x, y: rouge(x,y)["rougeL_precision"]
|
| 121 |
+
for out in outputs:
|
| 122 |
+
preds = out['pred']
|
| 123 |
+
tgts = out['target']
|
| 124 |
+
|
| 125 |
+
preds = self.do_batch_detokenize(preds)
|
| 126 |
+
tgts = self.do_batch_detokenize(tgts)
|
| 127 |
+
|
| 128 |
+
for pred, t in zip(preds, tgts):
|
| 129 |
+
rouge_d = rouge_metric(pred, t)
|
| 130 |
+
self.log(prefix+"_rouge", rouge_d)
|
| 131 |
+
|
| 132 |
+
meteor_d = pairwise_meteor(pred, t)
|
| 133 |
+
self.log(prefix+"_meteor", meteor_d)
|
| 134 |
+
|
| 135 |
+
def generate_and_compute_loss_and_tgts(self, batch):
|
| 136 |
+
src_ids = batch[0]
|
| 137 |
+
loss = self.compute_loss(batch)
|
| 138 |
+
pred_ids, _ = self.generate_for_batch(src_ids)
|
| 139 |
+
|
| 140 |
+
tgt_ids = batch[2]
|
| 141 |
+
|
| 142 |
+
return pred_ids, loss, tgt_ids
|
| 143 |
+
|
| 144 |
+
def do_batch_detokenize(self, batch):
|
| 145 |
+
tokens = self.tokenizer.batch_decode(
|
| 146 |
+
batch,
|
| 147 |
+
skip_special_tokens=True,
|
| 148 |
+
clean_up_tokenization_spaces=True
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Huggingface skipping of special tokens doesn't work for all models, so we do it manually as well for safety:
|
| 152 |
+
tokens = [p.replace("<pad>", "") for p in tokens]
|
| 153 |
+
tokens = [p.replace("<s>", "") for p in tokens]
|
| 154 |
+
tokens = [p.replace("</s>", "") for p in tokens]
|
| 155 |
+
|
| 156 |
+
return [t for t in tokens if t != ""]
|
| 157 |
+
|
| 158 |
+
def generate_for_batch(self, batch):
|
| 159 |
+
generated_ids = self.model.generate(
|
| 160 |
+
batch,
|
| 161 |
+
decoder_start_token_id = self.tokenizer.pad_token_id,
|
| 162 |
+
num_beams = self.gen_num_beams,
|
| 163 |
+
max_length = self.gen_max_length
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
generated_tokens = self.tokenizer.batch_decode(
|
| 167 |
+
generated_ids,
|
| 168 |
+
skip_special_tokens=True,
|
| 169 |
+
clean_up_tokenization_spaces=True
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return generated_ids, generated_tokens
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def generate(self, text, max_input_length=512, device=None):
|
| 176 |
+
encoded_dict = self.tokenizer(
|
| 177 |
+
[text],
|
| 178 |
+
max_length=max_input_length,
|
| 179 |
+
padding="longest",
|
| 180 |
+
truncation=True,
|
| 181 |
+
return_tensors="pt",
|
| 182 |
+
add_prefix_space = True
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
input_ids = encoded_dict['input_ids']
|
| 186 |
+
|
| 187 |
+
if device is not None:
|
| 188 |
+
input_ids = input_ids.to(device)
|
| 189 |
+
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
_, generated_tokens = self.generate_for_batch(input_ids)
|
| 192 |
+
|
| 193 |
+
return generated_tokens[0]
|
averitec/models/NaiveSeqClassModule.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import datasets
|
| 5 |
+
from transformers import MaxLengthCriteria, StoppingCriteriaList
|
| 6 |
+
from transformers.optimization import AdamW
|
| 7 |
+
import itertools
|
| 8 |
+
from utils import count_stats, f1_metric, pairwise_meteor
|
| 9 |
+
from torchmetrics.text.rouge import ROUGEScore
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchmetrics
|
| 12 |
+
from torchmetrics.classification import F1Score
|
| 13 |
+
|
| 14 |
+
class NaiveSeqClassModule(pl.LightningModule):
|
| 15 |
+
# Instantiate the model
|
| 16 |
+
def __init__(self, tokenizer, model, use_question_stance_approach=True, learning_rate=1e-3):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.tokenizer = tokenizer
|
| 19 |
+
self.model = model
|
| 20 |
+
self.learning_rate = learning_rate
|
| 21 |
+
|
| 22 |
+
self.train_acc = torchmetrics.Accuracy()
|
| 23 |
+
self.val_acc = torchmetrics.Accuracy()
|
| 24 |
+
self.test_acc = torchmetrics.Accuracy()
|
| 25 |
+
|
| 26 |
+
self.train_f1 = F1Score(num_classes=4, average="macro")
|
| 27 |
+
self.val_f1 = F1Score(num_classes=4, average=None)
|
| 28 |
+
self.test_f1 = F1Score(num_classes=4, average=None)
|
| 29 |
+
|
| 30 |
+
self.use_question_stance_approach = use_question_stance_approach
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Do a forward pass through the model
|
| 34 |
+
def forward(self, input_ids, **kwargs):
|
| 35 |
+
return self.model(input_ids, **kwargs)
|
| 36 |
+
|
| 37 |
+
def configure_optimizers(self):
|
| 38 |
+
optimizer = AdamW(self.parameters(), lr = self.learning_rate)
|
| 39 |
+
return optimizer
|
| 40 |
+
|
| 41 |
+
def training_step(self, batch, batch_idx):
|
| 42 |
+
x, x_mask, y = batch
|
| 43 |
+
|
| 44 |
+
outputs = self(x, attention_mask=x_mask, labels=y)
|
| 45 |
+
logits = outputs.logits
|
| 46 |
+
loss = outputs.loss
|
| 47 |
+
|
| 48 |
+
#cross_entropy = torch.nn.CrossEntropyLoss()
|
| 49 |
+
#loss = cross_entropy(logits, y)
|
| 50 |
+
|
| 51 |
+
preds = torch.argmax(logits, axis=1)
|
| 52 |
+
|
| 53 |
+
self.train_acc(preds.cpu(), y.cpu())
|
| 54 |
+
self.train_f1(preds.cpu(), y.cpu())
|
| 55 |
+
|
| 56 |
+
self.log("train_loss", loss)
|
| 57 |
+
|
| 58 |
+
return {'loss': loss}
|
| 59 |
+
|
| 60 |
+
def training_epoch_end(self, outs):
|
| 61 |
+
self.log('train_acc_epoch', self.train_acc)
|
| 62 |
+
self.log('train_f1_epoch', self.train_f1)
|
| 63 |
+
|
| 64 |
+
def validation_step(self, batch, batch_idx):
|
| 65 |
+
x, x_mask, y = batch
|
| 66 |
+
|
| 67 |
+
outputs = self(x, attention_mask=x_mask, labels=y)
|
| 68 |
+
logits = outputs.logits
|
| 69 |
+
loss = outputs.loss
|
| 70 |
+
|
| 71 |
+
preds = torch.argmax(logits, axis=1)
|
| 72 |
+
|
| 73 |
+
if not self.use_question_stance_approach:
|
| 74 |
+
self.val_acc(preds, y)
|
| 75 |
+
self.log('val_acc_step', self.val_acc)
|
| 76 |
+
|
| 77 |
+
self.val_f1(preds, y)
|
| 78 |
+
self.log("val_loss", loss)
|
| 79 |
+
|
| 80 |
+
return {'val_loss':loss, "src": x, "pred": preds, "target": y}
|
| 81 |
+
|
| 82 |
+
def validation_epoch_end(self, outs):
|
| 83 |
+
if self.use_question_stance_approach:
|
| 84 |
+
self.handle_end_of_epoch_scoring(outs, self.val_acc, self.val_f1)
|
| 85 |
+
|
| 86 |
+
self.log('val_acc_epoch', self.val_acc)
|
| 87 |
+
|
| 88 |
+
f1 = self.val_f1.compute()
|
| 89 |
+
self.val_f1.reset()
|
| 90 |
+
|
| 91 |
+
self.log('val_f1_epoch', torch.mean(f1))
|
| 92 |
+
|
| 93 |
+
class_names = ["supported", "refuted", "nei", "conflicting"]
|
| 94 |
+
for i, c_name in enumerate(class_names):
|
| 95 |
+
self.log("val_f1_" + c_name, f1[i])
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test_step(self, batch, batch_idx):
|
| 99 |
+
x, x_mask, y = batch
|
| 100 |
+
|
| 101 |
+
outputs = self(x, attention_mask=x_mask)
|
| 102 |
+
logits = outputs.logits
|
| 103 |
+
|
| 104 |
+
preds = torch.argmax(logits, axis=1)
|
| 105 |
+
|
| 106 |
+
if not self.use_question_stance_approach:
|
| 107 |
+
self.test_acc(preds, y)
|
| 108 |
+
self.log('test_acc_step', self.test_acc)
|
| 109 |
+
self.test_f1(preds, y)
|
| 110 |
+
|
| 111 |
+
return {"src": x, "pred": preds, "target": y}
|
| 112 |
+
|
| 113 |
+
def test_epoch_end(self, outs):
|
| 114 |
+
if self.use_question_stance_approach:
|
| 115 |
+
self.handle_end_of_epoch_scoring(outs, self.test_acc, self.test_f1)
|
| 116 |
+
|
| 117 |
+
self.log('test_acc_epoch', self.test_acc)
|
| 118 |
+
|
| 119 |
+
f1 = self.test_f1.compute()
|
| 120 |
+
self.test_f1.reset()
|
| 121 |
+
self.log('test_f1_epoch', torch.mean(f1))
|
| 122 |
+
|
| 123 |
+
class_names = ["supported", "refuted", "nei", "conflicting"]
|
| 124 |
+
for i, c_name in enumerate(class_names):
|
| 125 |
+
self.log("test_f1_" + c_name, f1[i])
|
| 126 |
+
|
| 127 |
+
def handle_end_of_epoch_scoring(self, outputs, acc_scorer, f1_scorer):
|
| 128 |
+
gold_labels = {}
|
| 129 |
+
question_support = {}
|
| 130 |
+
for out in outputs:
|
| 131 |
+
srcs = out['src']
|
| 132 |
+
preds = out['pred']
|
| 133 |
+
tgts = out['target']
|
| 134 |
+
|
| 135 |
+
tokens = self.tokenizer.batch_decode(
|
| 136 |
+
srcs,
|
| 137 |
+
skip_special_tokens=True,
|
| 138 |
+
clean_up_tokenization_spaces=True
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
for src, pred, tgt in zip(tokens, preds, tgts):
|
| 142 |
+
acc_scorer(torch.as_tensor([pred]).to("cuda:0"), torch.as_tensor([tgt]).to("cuda:0"))
|
| 143 |
+
f1_scorer(torch.as_tensor([pred]).to("cuda:0"), torch.as_tensor([tgt]).to("cuda:0"))
|
| 144 |
+
|
| 145 |
+
|
averitec/models/SequenceClassificationModule.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import datasets
|
| 5 |
+
from transformers import MaxLengthCriteria, StoppingCriteriaList
|
| 6 |
+
from transformers.optimization import AdamW
|
| 7 |
+
import itertools
|
| 8 |
+
# from utils import count_stats, f1_metric, pairwise_meteor
|
| 9 |
+
from torchmetrics.text.rouge import ROUGEScore
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torchmetrics
|
| 12 |
+
from torchmetrics.classification import F1Score
|
| 13 |
+
|
| 14 |
+
class SequenceClassificationModule(pl.LightningModule):
|
| 15 |
+
# Instantiate the model
|
| 16 |
+
def __init__(self, tokenizer, model, use_question_stance_approach=True, learning_rate=1e-3):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.tokenizer = tokenizer
|
| 19 |
+
self.model = model
|
| 20 |
+
self.learning_rate = learning_rate
|
| 21 |
+
|
| 22 |
+
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)
|
| 23 |
+
self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)
|
| 24 |
+
self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)
|
| 25 |
+
|
| 26 |
+
self.train_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average="macro")
|
| 27 |
+
self.val_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average=None)
|
| 28 |
+
self.test_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average=None)
|
| 29 |
+
# self.train_acc = torchmetrics.Accuracy()
|
| 30 |
+
# self.val_acc = torchmetrics.Accuracy()
|
| 31 |
+
# self.test_acc = torchmetrics.Accuracy()
|
| 32 |
+
|
| 33 |
+
# self.train_f1 = F1Score(num_classes=4, average="macro")
|
| 34 |
+
# self.val_f1 = F1Score(num_classes=4, average=None)
|
| 35 |
+
# self.test_f1 = F1Score(num_classes=4, average=None)
|
| 36 |
+
|
| 37 |
+
self.use_question_stance_approach = use_question_stance_approach
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Do a forward pass through the model
|
| 41 |
+
def forward(self, input_ids, **kwargs):
|
| 42 |
+
return self.model(input_ids, **kwargs)
|
| 43 |
+
|
| 44 |
+
def configure_optimizers(self):
|
| 45 |
+
optimizer = AdamW(self.parameters(), lr = self.learning_rate)
|
| 46 |
+
return optimizer
|
| 47 |
+
|
| 48 |
+
def training_step(self, batch, batch_idx):
|
| 49 |
+
x, x_mask, y = batch
|
| 50 |
+
|
| 51 |
+
outputs = self(x, attention_mask=x_mask, labels=y)
|
| 52 |
+
logits = outputs.logits
|
| 53 |
+
loss = outputs.loss
|
| 54 |
+
|
| 55 |
+
#cross_entropy = torch.nn.CrossEntropyLoss()
|
| 56 |
+
#loss = cross_entropy(logits, y)
|
| 57 |
+
|
| 58 |
+
preds = torch.argmax(logits, axis=1)
|
| 59 |
+
|
| 60 |
+
self.log("train_loss", loss)
|
| 61 |
+
|
| 62 |
+
return {'loss': loss}
|
| 63 |
+
|
| 64 |
+
def validation_step(self, batch, batch_idx):
|
| 65 |
+
x, x_mask, y = batch
|
| 66 |
+
|
| 67 |
+
outputs = self(x, attention_mask=x_mask, labels=y)
|
| 68 |
+
logits = outputs.logits
|
| 69 |
+
loss = outputs.loss
|
| 70 |
+
|
| 71 |
+
preds = torch.argmax(logits, axis=1)
|
| 72 |
+
|
| 73 |
+
if not self.use_question_stance_approach:
|
| 74 |
+
self.val_acc(preds, y)
|
| 75 |
+
self.log('val_acc_step', self.val_acc)
|
| 76 |
+
|
| 77 |
+
self.val_f1(preds, y)
|
| 78 |
+
self.log("val_loss", loss)
|
| 79 |
+
|
| 80 |
+
return {'val_loss':loss, "src": x, "pred": preds, "target": y}
|
| 81 |
+
|
| 82 |
+
def validation_epoch_end(self, outs):
|
| 83 |
+
if self.use_question_stance_approach:
|
| 84 |
+
self.handle_end_of_epoch_scoring(outs, self.val_acc, self.val_f1)
|
| 85 |
+
|
| 86 |
+
self.log('val_acc_epoch', self.val_acc)
|
| 87 |
+
|
| 88 |
+
f1 = self.val_f1.compute()
|
| 89 |
+
self.val_f1.reset()
|
| 90 |
+
|
| 91 |
+
self.log('val_f1_epoch', torch.mean(f1))
|
| 92 |
+
|
| 93 |
+
class_names = ["supported", "refuted", "nei", "conflicting"]
|
| 94 |
+
for i, c_name in enumerate(class_names):
|
| 95 |
+
self.log("val_f1_" + c_name, f1[i])
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def test_step(self, batch, batch_idx):
|
| 99 |
+
x, x_mask, y = batch
|
| 100 |
+
|
| 101 |
+
outputs = self(x, attention_mask=x_mask)
|
| 102 |
+
logits = outputs.logits
|
| 103 |
+
|
| 104 |
+
preds = torch.argmax(logits, axis=1)
|
| 105 |
+
|
| 106 |
+
if not self.use_question_stance_approach:
|
| 107 |
+
self.test_acc(preds, y)
|
| 108 |
+
self.log('test_acc_step', self.test_acc)
|
| 109 |
+
self.test_f1(preds, y)
|
| 110 |
+
|
| 111 |
+
return {"src": x, "pred": preds, "target": y}
|
| 112 |
+
|
| 113 |
+
def test_epoch_end(self, outs):
|
| 114 |
+
if self.use_question_stance_approach:
|
| 115 |
+
self.handle_end_of_epoch_scoring(outs, self.test_acc, self.test_f1)
|
| 116 |
+
|
| 117 |
+
self.log('test_acc_epoch', self.test_acc)
|
| 118 |
+
|
| 119 |
+
f1 = self.test_f1.compute()
|
| 120 |
+
self.test_f1.reset()
|
| 121 |
+
self.log('test_f1_epoch', torch.mean(f1))
|
| 122 |
+
|
| 123 |
+
class_names = ["supported", "refuted", "nei", "conflicting"]
|
| 124 |
+
for i, c_name in enumerate(class_names):
|
| 125 |
+
self.log("test_f1_" + c_name, f1[i])
|
| 126 |
+
|
| 127 |
+
def handle_end_of_epoch_scoring(self, outputs, acc_scorer, f1_scorer):
|
| 128 |
+
gold_labels = {}
|
| 129 |
+
question_support = {}
|
| 130 |
+
for out in outputs:
|
| 131 |
+
srcs = out['src']
|
| 132 |
+
preds = out['pred']
|
| 133 |
+
tgts = out['target']
|
| 134 |
+
|
| 135 |
+
tokens = self.tokenizer.batch_decode(
|
| 136 |
+
srcs,
|
| 137 |
+
skip_special_tokens=True,
|
| 138 |
+
clean_up_tokenization_spaces=True
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
for src, pred, tgt in zip(tokens, preds, tgts):
|
| 142 |
+
claim_id = src.split("[ question ]")[0]
|
| 143 |
+
|
| 144 |
+
if claim_id not in gold_labels:
|
| 145 |
+
gold_labels[claim_id] = tgt
|
| 146 |
+
question_support[claim_id] = []
|
| 147 |
+
|
| 148 |
+
question_support[claim_id].append(pred)
|
| 149 |
+
|
| 150 |
+
for k,gold_label in gold_labels.items():
|
| 151 |
+
support = question_support[k]
|
| 152 |
+
|
| 153 |
+
has_unansw = False
|
| 154 |
+
has_true = False
|
| 155 |
+
has_false = False
|
| 156 |
+
|
| 157 |
+
for v in support:
|
| 158 |
+
if v == 0:
|
| 159 |
+
has_true = True
|
| 160 |
+
if v == 1:
|
| 161 |
+
has_false = True
|
| 162 |
+
if v == 2 or v == 3: # TODO very ugly hack -- we cant have different numbers of labels for train and test so we do this
|
| 163 |
+
has_unansw = True
|
| 164 |
+
|
| 165 |
+
if has_unansw:
|
| 166 |
+
answer = 2
|
| 167 |
+
elif has_true and not has_false:
|
| 168 |
+
answer = 0
|
| 169 |
+
elif has_false and not has_true:
|
| 170 |
+
answer = 1
|
| 171 |
+
elif has_true and has_false:
|
| 172 |
+
answer = 3
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# TODO this is very hacky and wont work if the device is literally anything other than cuda:0
|
| 176 |
+
acc_scorer(torch.as_tensor([answer]).to("cuda:0"), torch.as_tensor([gold_label]).to("cuda:0"))
|
| 177 |
+
f1_scorer(torch.as_tensor([answer]).to("cuda:0"), torch.as_tensor([gold_label]).to("cuda:0"))
|
| 178 |
+
|
| 179 |
+
|
averitec/models/__pycache__/AveritecModule.cpython-38.pyc
ADDED
|
Binary file (8.75 kB). View file
|
|
|
averitec/models/__pycache__/DualEncoderModule.cpython-38.pyc
ADDED
|
Binary file (3.28 kB). View file
|
|
|
averitec/models/__pycache__/JustificationGenerationModule.cpython-38.pyc
ADDED
|
Binary file (7.56 kB). View file
|
|
|
averitec/models/__pycache__/SequenceClassificationModule.cpython-38.pyc
ADDED
|
Binary file (4.8 kB). View file
|
|
|
averitec/models/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (4.12 kB). View file
|
|
|
averitec/models/utils.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import nltk
|
| 3 |
+
from nltk import word_tokenize
|
| 4 |
+
import numpy as np
|
| 5 |
+
from leven import levenshtein
|
| 6 |
+
from sklearn.cluster import DBSCAN, dbscan
|
| 7 |
+
|
| 8 |
+
def delete_if_exists(filepath):
|
| 9 |
+
if os.path.exists(filepath):
|
| 10 |
+
os.remove(filepath)
|
| 11 |
+
|
| 12 |
+
def pairwise_meteor(candidate, reference): # Todo this is not thread safe, no idea how to make it so
|
| 13 |
+
return nltk.translate.meteor_score.single_meteor_score(word_tokenize(reference), word_tokenize(candidate))
|
| 14 |
+
|
| 15 |
+
def count_stats(candidate_dict, reference_dict):
|
| 16 |
+
count_match = [0 for _ in candidate_dict]
|
| 17 |
+
count_diff = [0 for _ in candidate_dict]
|
| 18 |
+
|
| 19 |
+
for i, k in enumerate(candidate_dict.keys()):
|
| 20 |
+
pred_parts = candidate_dict[k]
|
| 21 |
+
tgt_parts = reference_dict[k]
|
| 22 |
+
|
| 23 |
+
if len(pred_parts) == len(tgt_parts):
|
| 24 |
+
count_match[i] = 1
|
| 25 |
+
|
| 26 |
+
count_diff[i] = abs(len(pred_parts) - len(tgt_parts))
|
| 27 |
+
|
| 28 |
+
count_match_score = np.mean(count_match)
|
| 29 |
+
count_diff_score = np.mean(count_diff)
|
| 30 |
+
|
| 31 |
+
return {
|
| 32 |
+
"count_match_score": count_match_score,
|
| 33 |
+
"count_diff_score": count_diff_score
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def f1_metric(candidate_dict, reference_dict, pairwise_metric):
|
| 37 |
+
all_best_p = [0 for _ in candidate_dict]
|
| 38 |
+
all_best_t = [0 for _ in candidate_dict]
|
| 39 |
+
p_unnorm = []
|
| 40 |
+
|
| 41 |
+
for i, k in enumerate(candidate_dict.keys()):
|
| 42 |
+
pred_parts = candidate_dict[k]
|
| 43 |
+
tgt_parts = reference_dict[k]
|
| 44 |
+
|
| 45 |
+
best_p_score = [0 for _ in pred_parts]
|
| 46 |
+
best_t_score = [0 for _ in tgt_parts]
|
| 47 |
+
|
| 48 |
+
for p_idx in range(len(pred_parts)):
|
| 49 |
+
for t_idx in range(len(tgt_parts)):
|
| 50 |
+
#meteor_score = pairwise_meteor(pred_parts[p_idx], tgt_parts[t_idx])
|
| 51 |
+
metric_score = pairwise_metric(pred_parts[p_idx], tgt_parts[t_idx])
|
| 52 |
+
|
| 53 |
+
if metric_score > best_p_score[p_idx]:
|
| 54 |
+
best_p_score[p_idx] = metric_score
|
| 55 |
+
|
| 56 |
+
if metric_score > best_t_score[t_idx]:
|
| 57 |
+
best_t_score[t_idx] = metric_score
|
| 58 |
+
|
| 59 |
+
all_best_p[i] = np.mean(best_p_score) if len(best_p_score) > 0 else 1.0
|
| 60 |
+
all_best_t[i] = np.mean(best_t_score) if len(best_t_score) > 0 else 1.0
|
| 61 |
+
|
| 62 |
+
p_unnorm.extend(best_p_score)
|
| 63 |
+
|
| 64 |
+
p_score = np.mean(all_best_p)
|
| 65 |
+
r_score = np.mean(all_best_t)
|
| 66 |
+
avg_score = (p_score + r_score) / 2
|
| 67 |
+
f1_score = 2 * p_score * r_score / (p_score + r_score + 1e-8)
|
| 68 |
+
|
| 69 |
+
p_unnorm_score = np.mean(p_unnorm)
|
| 70 |
+
|
| 71 |
+
return {
|
| 72 |
+
"p": p_score,
|
| 73 |
+
"r": r_score,
|
| 74 |
+
"avg": avg_score,
|
| 75 |
+
"f1": f1_score,
|
| 76 |
+
"p_unnorm": p_unnorm_score,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def edit_distance_dbscan(data):
|
| 80 |
+
# Inspired by https://scikit-learn.org/stable/faq.html#how-do-i-deal-with-string-data-or-trees-graphs
|
| 81 |
+
def lev_metric(x, y):
|
| 82 |
+
i, j = int(x[0]), int(y[0])
|
| 83 |
+
return levenshtein(data[i], data[j])
|
| 84 |
+
|
| 85 |
+
X = np.arange(len(data)).reshape(-1, 1)
|
| 86 |
+
|
| 87 |
+
clustering = dbscan(X, metric=lev_metric, eps=20, min_samples=2, algorithm='brute')
|
| 88 |
+
return clustering
|
| 89 |
+
|
| 90 |
+
def compute_all_pairwise_edit_distances(data):
|
| 91 |
+
X = np.empty((len(data), len(data)))
|
| 92 |
+
|
| 93 |
+
for i in range(len(data)):
|
| 94 |
+
for j in range(len(data)):
|
| 95 |
+
X[i][j] = levenshtein(data[i], data[j])
|
| 96 |
+
|
| 97 |
+
return X
|
| 98 |
+
|
| 99 |
+
def compute_all_pairwise_scores(src_data, tgt_data, metric):
|
| 100 |
+
X = np.empty((len(src_data), len(tgt_data)))
|
| 101 |
+
|
| 102 |
+
for i in range(len(src_data)):
|
| 103 |
+
for j in range(len(tgt_data)):
|
| 104 |
+
X[i][j] = (metric(src_data[i], tgt_data[j]))
|
| 105 |
+
|
| 106 |
+
return X
|
| 107 |
+
|
| 108 |
+
def compute_all_pairwise_meteor_scores(data):
|
| 109 |
+
X = np.empty((len(data), len(data)))
|
| 110 |
+
|
| 111 |
+
for i in range(len(data)):
|
| 112 |
+
for j in range(len(data)):
|
| 113 |
+
X[i][j] = (pairwise_meteor(data[i], data[j]) + pairwise_meteor(data[j], data[i])) / 2
|
| 114 |
+
|
| 115 |
+
return X
|
| 116 |
+
|
| 117 |
+
def edit_distance_custom(data, X, eps=0.5, min_samples=3):
|
| 118 |
+
clustering = DBSCAN(metric="precomputed", eps=eps, min_samples=min_samples).fit(X)
|
| 119 |
+
return clustering.labels_
|
averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e4b7bf02daaf10b3443f4f2cbe79c3c9f10c453dfdf818a4d14e44b2b4311cf4
|
| 3 |
+
size 4876206567
|
averitec/pretrained_models/bert_dual_encoder.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fee6737f655f4f1dfb46cc1bb812b5eaf9a72cfc0b69d4e5c05cde27ea7b6051
|
| 3 |
+
size 1314015751
|
averitec/pretrained_models/bert_veracity.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ddb8132a28ceff149904dd3ad3c3edd3e5f0c7de0169819207104a80e425c9a
|
| 3 |
+
size 1314034311
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
nltk
|
| 3 |
+
rank_bm25
|
| 4 |
+
accelerate
|
| 5 |
+
trafilatura
|
| 6 |
+
spacy
|
| 7 |
+
pytorch_lightning
|
| 8 |
+
transformers==4.29.2
|
| 9 |
+
datasets
|
| 10 |
+
leven
|
| 11 |
+
scikit-learn
|
| 12 |
+
pexpect
|
| 13 |
+
elasticsearch
|
| 14 |
+
torch
|
| 15 |
+
huggingface_hub
|
| 16 |
+
google-api-python-client
|
| 17 |
+
wikipedia-api
|
| 18 |
+
beautifulsoup4
|
| 19 |
+
azure-storage-file-share
|
| 20 |
+
azure-storage-blob
|
| 21 |
+
bm25s
|
| 22 |
+
PyStemmer
|