import copy import regex import string import unicodedata from typing import List from collections import Counter from ..core.logging import logger #-------------------------- # QA Metrics #-------------------------- # Normalization from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ def normalize_answer(s : str) -> str: def remove_articles(text: str) -> str: return regex.sub(r'\b(a|an|the)\b', ' ', text) def white_space_fix(text: str) -> str: return ' '.join(text.split()) def remove_punc(text: str) -> str: exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) return white_space_fix(remove_articles(remove_punc(s.lower()))) def exact_match_score(prediction : str, ground_truth : str) -> float: assert isinstance(ground_truth, str), f"ground_truth must be a string, but got {type(ground_truth)}" return float(normalize_answer(prediction) == normalize_answer(ground_truth)) def ems(prediction : str, ground_truths : List[str]) -> float: assert isinstance(ground_truths, list), f"ground_truths must be a list, but got {type(ground_truths)}" return max([exact_match_score(prediction, gt) for gt in ground_truths]) # F1 Evaluation from HotPotQA evaluation script: https://raw.githubusercontent.com/hotpotqa/hotpot/master/hotpot_evaluate_v1.py def f1_score(prediction : str, ground_truth: str) -> float: assert isinstance(ground_truth, str), f"ground_truth must be a string, but got {type(ground_truth)}" normalized_prediction = normalize_answer(prediction) normalized_ground_truth = normalize_answer(ground_truth) ZERO_METRIC = (0, 0, 0) if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: return ZERO_METRIC[0] if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: return ZERO_METRIC[0] prediction_tokens = normalized_prediction.split() ground_truth_tokens = normalized_ground_truth.split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: return ZERO_METRIC[0] precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) f1 = (2 * precision * recall) / (precision + recall) # return f1, precision, recall return f1 def _normalize(text): return unicodedata.normalize('NFD', text) class Tokenizer(object): """Base tokenizer class. Tokenizers implement tokenize, which should return a Tokens class. """ def tokenize(self, text): raise NotImplementedError def shutdown(self): pass def __del__(self): self.shutdown() class Tokens(object): """A class to represent a list of tokenized text.""" TEXT = 0 TEXT_WS = 1 SPAN = 2 POS = 3 LEMMA = 4 NER = 5 def __init__(self, data, annotators, opts=None): self.data = data self.annotators = annotators self.opts = opts or {} def __len__(self): """The number of tokens.""" return len(self.data) def slice(self, i=None, j=None): """Return a view of the list of tokens from [i, j).""" new_tokens = copy.copy(self) new_tokens.data = self.data[i:j] return new_tokens def untokenize(self): """Returns the original text (with whitespace reinserted).""" return "".join([t[self.TEXT_WS] for t in self.data]).strip() def words(self, uncased=False): """Returns a list of the text of each token Args: uncased: lower cases text """ if uncased: return [t[self.TEXT].lower() for t in self.data] else: return [t[self.TEXT] for t in self.data] def offsets(self): """Returns a list of [start, end) character offsets of each token.""" return [t[self.SPAN] for t in self.data] def pos(self): """Returns a list of part-of-speech tags of each token. Returns None if this annotation was not included. """ if "pos" not in self.annotators: return None return [t[self.POS] for t in self.data] def lemmas(self): """Returns a list of the lemmatized text of each token. Returns None if this annotation was not included. """ if "lemma" not in self.annotators: return None return [t[self.LEMMA] for t in self.data] def entities(self): """Returns a list of named-entity-recognition tags of each token. Returns None if this annotation was not included. """ if "ner" not in self.annotators: return None return [t[self.NER] for t in self.data] def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): """Returns a list of all ngrams from length 1 to n. Args: n: upper limit of ngram length uncased: lower cases text filter_fn: user function that takes in an ngram list and returns True or False to keep or not keep the ngram as_string: return the ngram as a string vs list """ def _skip(gram): if not filter_fn: return False return filter_fn(gram) words = self.words(uncased) ngrams = [ (s, e + 1) for s in range(len(words)) for e in range(s, min(s + n, len(words))) if not _skip(words[s : e + 1]) ] # Concatenate into strings if as_strings: ngrams = ["{}".format(" ".join(words[s:e])) for (s, e) in ngrams] return ngrams def entity_groups(self): """Group consecutive entity tokens with the same NER tag.""" entities = self.entities() if not entities: return None non_ent = self.opts.get("non_ent", "O") groups = [] idx = 0 while idx < len(entities): ner_tag = entities[idx] # Check for entity tag if ner_tag != non_ent: # Chomp the sequence start = idx while idx < len(entities) and entities[idx] == ner_tag: idx += 1 groups.append((self.slice(start, idx).untokenize(), ner_tag)) else: idx += 1 return groups class SimpleTokenizer(Tokenizer): ALPHA_NUM = r"[\p{L}\p{N}\p{M}]+" NON_WS = r"[^\p{Z}\p{C}]" def __init__(self, **kwargs): """ Args: annotators: None or empty set (only tokenizes). """ self._regexp = regex.compile( "(%s)|(%s)" % (self.ALPHA_NUM, self.NON_WS), flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE, ) if len(kwargs.get("annotators", {})) > 0: logger.warning( "%s only tokenizes! Skipping annotators: %s" % (type(self).__name__, kwargs.get("annotators")) ) self.annotators = set() def tokenize(self, text): data = [] matches = [m for m in self._regexp.finditer(text)] for i in range(len(matches)): # Get text token = matches[i].group() # Get whitespace span = matches[i].span() start_ws = span[0] if i + 1 < len(matches): end_ws = matches[i + 1].span()[0] else: end_ws = span[1] # Format data data.append( ( token, text[start_ws:end_ws], span, ) ) return Tokens(data, self.annotators) def regex_match(text, pattern): """Test if a regex pattern is contained within a text.""" try: pattern = regex.compile(pattern, flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE) except BaseException: return False return pattern.search(text) is not None # Acknowledgement: https://github.com/facebookresearch/DPR/blob/main/dpr/data/qa_validation.py#L175 def has_answer(answers, text, match_type="string") -> bool: """Check if the text contains an answer string. If `match_type` is string, token matching is done between the text and answer. If `match_type` is regex, we search the whole text with the regex. """ text = _normalize(text) tokenizer = SimpleTokenizer() if match_type == "string": # Answer is a list of possible strings text = tokenizer.tokenize(text).words(uncased=True) for single_answer in answers: single_answer = _normalize(single_answer) single_answer = tokenizer.tokenize(single_answer) single_answer = single_answer.words(uncased=True) for i in range(0, len(text) - len(single_answer) + 1): if single_answer == text[i : i + len(single_answer)]: return True elif match_type == "regex": # Answer is a regex for single_answer in answers: single_answer = _normalize(single_answer) if regex_match(text, single_answer): return True return False def acc_score(prediction : str, ground_truths : List[str]) -> float: assert isinstance(ground_truths, list), f"ground_truths must be a list, but got {type(ground_truths)}" return float(has_answer(answers=ground_truths, text=prediction, match_type="string"))