Liyan06
commited on
Commit
·
2d158d3
1
Parent(s):
93e9112
add entity highlight
Browse files- handler.py +22 -3
handler.py
CHANGED
|
@@ -3,6 +3,16 @@ from web_retrieval import *
|
|
| 3 |
from nltk.tokenize import sent_tokenize
|
| 4 |
import evaluate
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
|
| 8 |
'''
|
|
@@ -19,7 +29,13 @@ def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
|
|
| 19 |
ranked_docs, scores = zip(*ranked_doc_score)
|
| 20 |
|
| 21 |
return ranked_docs, scores
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
class EndpointHandler():
|
| 25 |
def __init__(self, path="./"):
|
|
@@ -30,6 +46,7 @@ class EndpointHandler():
|
|
| 30 |
def __call__(self, data):
|
| 31 |
|
| 32 |
claim = data['inputs']['claims'][0]
|
|
|
|
| 33 |
|
| 34 |
# Using user-provided document to do fact-checking
|
| 35 |
if len(data['inputs']['docs']) == 1 and data['inputs']['docs'][0] != '':
|
|
@@ -48,7 +65,8 @@ class EndpointHandler():
|
|
| 48 |
outputs = {
|
| 49 |
'ranked_docs': ranked_docs,
|
| 50 |
'scores': scores,
|
| 51 |
-
'span_to_highlight': span_to_highlight
|
|
|
|
| 52 |
}
|
| 53 |
|
| 54 |
else:
|
|
@@ -69,7 +87,8 @@ class EndpointHandler():
|
|
| 69 |
'ranked_docs': ranked_docs,
|
| 70 |
'scores': scores,
|
| 71 |
'ranked_urls': ranked_urls,
|
| 72 |
-
'span_to_highlight': span_to_highlight
|
|
|
|
| 73 |
}
|
| 74 |
|
| 75 |
return outputs
|
|
|
|
| 3 |
from nltk.tokenize import sent_tokenize
|
| 4 |
import evaluate
|
| 5 |
|
| 6 |
+
import spacy
|
| 7 |
+
from spacy.cli import download
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
nlp = spacy.load("en_core_web_lg")
|
| 11 |
+
except:
|
| 12 |
+
# If loading fails, download the model
|
| 13 |
+
download("en_core_web_lg")
|
| 14 |
+
nlp = spacy.load("en_core_web_lg")
|
| 15 |
+
|
| 16 |
|
| 17 |
def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
|
| 18 |
'''
|
|
|
|
| 29 |
ranked_docs, scores = zip(*ranked_doc_score)
|
| 30 |
|
| 31 |
return ranked_docs, scores
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def extract_entities(text):
|
| 35 |
+
text = nlp(text)
|
| 36 |
+
ents = list({ent.text for ent in text.ents})
|
| 37 |
+
return ents
|
| 38 |
+
|
| 39 |
|
| 40 |
class EndpointHandler():
|
| 41 |
def __init__(self, path="./"):
|
|
|
|
| 46 |
def __call__(self, data):
|
| 47 |
|
| 48 |
claim = data['inputs']['claims'][0]
|
| 49 |
+
ents = extract_entities(claim)
|
| 50 |
|
| 51 |
# Using user-provided document to do fact-checking
|
| 52 |
if len(data['inputs']['docs']) == 1 and data['inputs']['docs'][0] != '':
|
|
|
|
| 65 |
outputs = {
|
| 66 |
'ranked_docs': ranked_docs,
|
| 67 |
'scores': scores,
|
| 68 |
+
'span_to_highlight': span_to_highlight,
|
| 69 |
+
'entities': ents
|
| 70 |
}
|
| 71 |
|
| 72 |
else:
|
|
|
|
| 87 |
'ranked_docs': ranked_docs,
|
| 88 |
'scores': scores,
|
| 89 |
'ranked_urls': ranked_urls,
|
| 90 |
+
'span_to_highlight': span_to_highlight,
|
| 91 |
+
'entities': ents
|
| 92 |
}
|
| 93 |
|
| 94 |
return outputs
|