|
|
import copy |
|
|
import regex |
|
|
import string |
|
|
import unicodedata |
|
|
from typing import List |
|
|
from collections import Counter |
|
|
from ..core.logging import logger |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_answer(s : str) -> str: |
|
|
def remove_articles(text: str) -> str: |
|
|
return regex.sub(r'\b(a|an|the)\b', ' ', text) |
|
|
def white_space_fix(text: str) -> str: |
|
|
return ' '.join(text.split()) |
|
|
def remove_punc(text: str) -> str: |
|
|
exclude = set(string.punctuation) |
|
|
return ''.join(ch for ch in text if ch not in exclude) |
|
|
return white_space_fix(remove_articles(remove_punc(s.lower()))) |
|
|
|
|
|
def exact_match_score(prediction : str, ground_truth : str) -> float: |
|
|
assert isinstance(ground_truth, str), f"ground_truth must be a string, but got {type(ground_truth)}" |
|
|
return float(normalize_answer(prediction) == normalize_answer(ground_truth)) |
|
|
|
|
|
def ems(prediction : str, ground_truths : List[str]) -> float: |
|
|
assert isinstance(ground_truths, list), f"ground_truths must be a list, but got {type(ground_truths)}" |
|
|
return max([exact_match_score(prediction, gt) for gt in ground_truths]) |
|
|
|
|
|
|
|
|
def f1_score(prediction : str, ground_truth: str) -> float: |
|
|
assert isinstance(ground_truth, str), f"ground_truth must be a string, but got {type(ground_truth)}" |
|
|
normalized_prediction = normalize_answer(prediction) |
|
|
normalized_ground_truth = normalize_answer(ground_truth) |
|
|
ZERO_METRIC = (0, 0, 0) |
|
|
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: |
|
|
return ZERO_METRIC[0] |
|
|
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: |
|
|
return ZERO_METRIC[0] |
|
|
prediction_tokens = normalized_prediction.split() |
|
|
ground_truth_tokens = normalized_ground_truth.split() |
|
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) |
|
|
num_same = sum(common.values()) |
|
|
if num_same == 0: |
|
|
return ZERO_METRIC[0] |
|
|
precision = 1.0 * num_same / len(prediction_tokens) |
|
|
recall = 1.0 * num_same / len(ground_truth_tokens) |
|
|
f1 = (2 * precision * recall) / (precision + recall) |
|
|
|
|
|
return f1 |
|
|
|
|
|
|
|
|
def _normalize(text): |
|
|
return unicodedata.normalize('NFD', text) |
|
|
|
|
|
class Tokenizer(object): |
|
|
"""Base tokenizer class. |
|
|
Tokenizers implement tokenize, which should return a Tokens class. |
|
|
""" |
|
|
def tokenize(self, text): |
|
|
raise NotImplementedError |
|
|
|
|
|
def shutdown(self): |
|
|
pass |
|
|
|
|
|
def __del__(self): |
|
|
self.shutdown() |
|
|
|
|
|
|
|
|
class Tokens(object): |
|
|
"""A class to represent a list of tokenized text.""" |
|
|
|
|
|
TEXT = 0 |
|
|
TEXT_WS = 1 |
|
|
SPAN = 2 |
|
|
POS = 3 |
|
|
LEMMA = 4 |
|
|
NER = 5 |
|
|
|
|
|
def __init__(self, data, annotators, opts=None): |
|
|
self.data = data |
|
|
self.annotators = annotators |
|
|
self.opts = opts or {} |
|
|
|
|
|
def __len__(self): |
|
|
"""The number of tokens.""" |
|
|
return len(self.data) |
|
|
|
|
|
def slice(self, i=None, j=None): |
|
|
"""Return a view of the list of tokens from [i, j).""" |
|
|
new_tokens = copy.copy(self) |
|
|
new_tokens.data = self.data[i:j] |
|
|
return new_tokens |
|
|
|
|
|
def untokenize(self): |
|
|
"""Returns the original text (with whitespace reinserted).""" |
|
|
return "".join([t[self.TEXT_WS] for t in self.data]).strip() |
|
|
|
|
|
def words(self, uncased=False): |
|
|
"""Returns a list of the text of each token |
|
|
|
|
|
Args: |
|
|
uncased: lower cases text |
|
|
""" |
|
|
if uncased: |
|
|
return [t[self.TEXT].lower() for t in self.data] |
|
|
else: |
|
|
return [t[self.TEXT] for t in self.data] |
|
|
|
|
|
def offsets(self): |
|
|
"""Returns a list of [start, end) character offsets of each token.""" |
|
|
return [t[self.SPAN] for t in self.data] |
|
|
|
|
|
def pos(self): |
|
|
"""Returns a list of part-of-speech tags of each token. |
|
|
Returns None if this annotation was not included. |
|
|
""" |
|
|
if "pos" not in self.annotators: |
|
|
return None |
|
|
return [t[self.POS] for t in self.data] |
|
|
|
|
|
def lemmas(self): |
|
|
"""Returns a list of the lemmatized text of each token. |
|
|
Returns None if this annotation was not included. |
|
|
""" |
|
|
if "lemma" not in self.annotators: |
|
|
return None |
|
|
return [t[self.LEMMA] for t in self.data] |
|
|
|
|
|
def entities(self): |
|
|
"""Returns a list of named-entity-recognition tags of each token. |
|
|
Returns None if this annotation was not included. |
|
|
""" |
|
|
if "ner" not in self.annotators: |
|
|
return None |
|
|
return [t[self.NER] for t in self.data] |
|
|
|
|
|
def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): |
|
|
"""Returns a list of all ngrams from length 1 to n. |
|
|
|
|
|
Args: |
|
|
n: upper limit of ngram length |
|
|
uncased: lower cases text |
|
|
filter_fn: user function that takes in an ngram list and returns |
|
|
True or False to keep or not keep the ngram |
|
|
as_string: return the ngram as a string vs list |
|
|
""" |
|
|
|
|
|
def _skip(gram): |
|
|
if not filter_fn: |
|
|
return False |
|
|
return filter_fn(gram) |
|
|
|
|
|
words = self.words(uncased) |
|
|
ngrams = [ |
|
|
(s, e + 1) |
|
|
for s in range(len(words)) |
|
|
for e in range(s, min(s + n, len(words))) |
|
|
if not _skip(words[s : e + 1]) |
|
|
] |
|
|
|
|
|
|
|
|
if as_strings: |
|
|
ngrams = ["{}".format(" ".join(words[s:e])) for (s, e) in ngrams] |
|
|
|
|
|
return ngrams |
|
|
|
|
|
def entity_groups(self): |
|
|
"""Group consecutive entity tokens with the same NER tag.""" |
|
|
entities = self.entities() |
|
|
if not entities: |
|
|
return None |
|
|
non_ent = self.opts.get("non_ent", "O") |
|
|
groups = [] |
|
|
idx = 0 |
|
|
while idx < len(entities): |
|
|
ner_tag = entities[idx] |
|
|
|
|
|
if ner_tag != non_ent: |
|
|
|
|
|
start = idx |
|
|
while idx < len(entities) and entities[idx] == ner_tag: |
|
|
idx += 1 |
|
|
groups.append((self.slice(start, idx).untokenize(), ner_tag)) |
|
|
else: |
|
|
idx += 1 |
|
|
return groups |
|
|
|
|
|
|
|
|
class SimpleTokenizer(Tokenizer): |
|
|
ALPHA_NUM = r"[\p{L}\p{N}\p{M}]+" |
|
|
NON_WS = r"[^\p{Z}\p{C}]" |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
""" |
|
|
Args: |
|
|
annotators: None or empty set (only tokenizes). |
|
|
""" |
|
|
self._regexp = regex.compile( |
|
|
"(%s)|(%s)" % (self.ALPHA_NUM, self.NON_WS), |
|
|
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE, |
|
|
) |
|
|
if len(kwargs.get("annotators", {})) > 0: |
|
|
logger.warning( |
|
|
"%s only tokenizes! Skipping annotators: %s" % (type(self).__name__, kwargs.get("annotators")) |
|
|
) |
|
|
self.annotators = set() |
|
|
|
|
|
def tokenize(self, text): |
|
|
data = [] |
|
|
matches = [m for m in self._regexp.finditer(text)] |
|
|
for i in range(len(matches)): |
|
|
|
|
|
token = matches[i].group() |
|
|
|
|
|
|
|
|
span = matches[i].span() |
|
|
start_ws = span[0] |
|
|
if i + 1 < len(matches): |
|
|
end_ws = matches[i + 1].span()[0] |
|
|
else: |
|
|
end_ws = span[1] |
|
|
|
|
|
|
|
|
data.append( |
|
|
( |
|
|
token, |
|
|
text[start_ws:end_ws], |
|
|
span, |
|
|
) |
|
|
) |
|
|
return Tokens(data, self.annotators) |
|
|
|
|
|
|
|
|
def regex_match(text, pattern): |
|
|
"""Test if a regex pattern is contained within a text.""" |
|
|
try: |
|
|
pattern = regex.compile(pattern, flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE) |
|
|
except BaseException: |
|
|
return False |
|
|
return pattern.search(text) is not None |
|
|
|
|
|
|
|
|
|
|
|
def has_answer(answers, text, match_type="string") -> bool: |
|
|
|
|
|
"""Check if the text contains an answer string. |
|
|
If `match_type` is string, token matching is done between the text and answer. |
|
|
If `match_type` is regex, we search the whole text with the regex. |
|
|
""" |
|
|
|
|
|
text = _normalize(text) |
|
|
|
|
|
tokenizer = SimpleTokenizer() |
|
|
|
|
|
if match_type == "string": |
|
|
|
|
|
text = tokenizer.tokenize(text).words(uncased=True) |
|
|
|
|
|
for single_answer in answers: |
|
|
single_answer = _normalize(single_answer) |
|
|
single_answer = tokenizer.tokenize(single_answer) |
|
|
single_answer = single_answer.words(uncased=True) |
|
|
|
|
|
for i in range(0, len(text) - len(single_answer) + 1): |
|
|
if single_answer == text[i : i + len(single_answer)]: |
|
|
return True |
|
|
|
|
|
elif match_type == "regex": |
|
|
|
|
|
for single_answer in answers: |
|
|
single_answer = _normalize(single_answer) |
|
|
if regex_match(text, single_answer): |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def acc_score(prediction : str, ground_truths : List[str]) -> float: |
|
|
assert isinstance(ground_truths, list), f"ground_truths must be a list, but got {type(ground_truths)}" |
|
|
return float(has_answer(answers=ground_truths, text=prediction, match_type="string")) |