|
|
| import re
|
|
|
| from tqdm import tqdm
|
|
|
|
|
| class EvalAIAnswerProcessor:
|
| """
|
| Processes an answer similar to Eval AI
|
| copied from
|
| https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
|
| """
|
|
|
| CONTRACTIONS = {
|
| "aint": "ain't",
|
| "arent": "aren't",
|
| "cant": "can't",
|
| "couldve": "could've",
|
| "couldnt": "couldn't",
|
| "couldn'tve": "couldn't've",
|
| "couldnt've": "couldn't've",
|
| "didnt": "didn't",
|
| "doesnt": "doesn't",
|
| "dont": "don't",
|
| "hadnt": "hadn't",
|
| "hadnt've": "hadn't've",
|
| "hadn'tve": "hadn't've",
|
| "hasnt": "hasn't",
|
| "havent": "haven't",
|
| "hed": "he'd",
|
| "hed've": "he'd've",
|
| "he'dve": "he'd've",
|
| "hes": "he's",
|
| "howd": "how'd",
|
| "howll": "how'll",
|
| "hows": "how's",
|
| "Id've": "I'd've",
|
| "I'dve": "I'd've",
|
| "Im": "I'm",
|
| "Ive": "I've",
|
| "isnt": "isn't",
|
| "itd": "it'd",
|
| "itd've": "it'd've",
|
| "it'dve": "it'd've",
|
| "itll": "it'll",
|
| "let's": "let's",
|
| "maam": "ma'am",
|
| "mightnt": "mightn't",
|
| "mightnt've": "mightn't've",
|
| "mightn'tve": "mightn't've",
|
| "mightve": "might've",
|
| "mustnt": "mustn't",
|
| "mustve": "must've",
|
| "neednt": "needn't",
|
| "notve": "not've",
|
| "oclock": "o'clock",
|
| "oughtnt": "oughtn't",
|
| "ow's'at": "'ow's'at",
|
| "'ows'at": "'ow's'at",
|
| "'ow'sat": "'ow's'at",
|
| "shant": "shan't",
|
| "shed've": "she'd've",
|
| "she'dve": "she'd've",
|
| "she's": "she's",
|
| "shouldve": "should've",
|
| "shouldnt": "shouldn't",
|
| "shouldnt've": "shouldn't've",
|
| "shouldn'tve": "shouldn't've",
|
| "somebody'd": "somebodyd",
|
| "somebodyd've": "somebody'd've",
|
| "somebody'dve": "somebody'd've",
|
| "somebodyll": "somebody'll",
|
| "somebodys": "somebody's",
|
| "someoned": "someone'd",
|
| "someoned've": "someone'd've",
|
| "someone'dve": "someone'd've",
|
| "someonell": "someone'll",
|
| "someones": "someone's",
|
| "somethingd": "something'd",
|
| "somethingd've": "something'd've",
|
| "something'dve": "something'd've",
|
| "somethingll": "something'll",
|
| "thats": "that's",
|
| "thered": "there'd",
|
| "thered've": "there'd've",
|
| "there'dve": "there'd've",
|
| "therere": "there're",
|
| "theres": "there's",
|
| "theyd": "they'd",
|
| "theyd've": "they'd've",
|
| "they'dve": "they'd've",
|
| "theyll": "they'll",
|
| "theyre": "they're",
|
| "theyve": "they've",
|
| "twas": "'twas",
|
| "wasnt": "wasn't",
|
| "wed've": "we'd've",
|
| "we'dve": "we'd've",
|
| "weve": "we've",
|
| "werent": "weren't",
|
| "whatll": "what'll",
|
| "whatre": "what're",
|
| "whats": "what's",
|
| "whatve": "what've",
|
| "whens": "when's",
|
| "whered": "where'd",
|
| "wheres": "where's",
|
| "whereve": "where've",
|
| "whod": "who'd",
|
| "whod've": "who'd've",
|
| "who'dve": "who'd've",
|
| "wholl": "who'll",
|
| "whos": "who's",
|
| "whove": "who've",
|
| "whyll": "why'll",
|
| "whyre": "why're",
|
| "whys": "why's",
|
| "wont": "won't",
|
| "wouldve": "would've",
|
| "wouldnt": "wouldn't",
|
| "wouldnt've": "wouldn't've",
|
| "wouldn'tve": "wouldn't've",
|
| "yall": "y'all",
|
| "yall'll": "y'all'll",
|
| "y'allll": "y'all'll",
|
| "yall'd've": "y'all'd've",
|
| "y'alld've": "y'all'd've",
|
| "y'all'dve": "y'all'd've",
|
| "youd": "you'd",
|
| "youd've": "you'd've",
|
| "you'dve": "you'd've",
|
| "youll": "you'll",
|
| "youre": "you're",
|
| "youve": "you've",
|
| }
|
|
|
| NUMBER_MAP = {
|
| "none": "0",
|
| "zero": "0",
|
| "one": "1",
|
| "two": "2",
|
| "three": "3",
|
| "four": "4",
|
| "five": "5",
|
| "six": "6",
|
| "seven": "7",
|
| "eight": "8",
|
| "nine": "9",
|
| "ten": "10",
|
| }
|
| ARTICLES = ["a", "an", "the"]
|
| PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
| COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
|
| PUNCTUATIONS = [
|
| ";",
|
| r"/",
|
| "[",
|
| "]",
|
| '"',
|
| "{",
|
| "}",
|
| "(",
|
| ")",
|
| "=",
|
| "+",
|
| "\\",
|
| "_",
|
| "-",
|
| ">",
|
| "<",
|
| "@",
|
| "`",
|
| ",",
|
| "?",
|
| "!",
|
| ]
|
|
|
| def __init__(self, *args, **kwargs):
|
| pass
|
|
|
| def word_tokenize(self, word):
|
| word = word.lower()
|
| word = word.replace(",", "").replace("?", "").replace("'s", " 's")
|
| return word.strip()
|
|
|
| def process_punctuation(self, in_text):
|
| out_text = in_text
|
| for p in self.PUNCTUATIONS:
|
| if (p + " " in in_text or " " + p in in_text) or (
|
| re.search(self.COMMA_STRIP, in_text) is not None
|
| ):
|
| out_text = out_text.replace(p, "")
|
| else:
|
| out_text = out_text.replace(p, " ")
|
| out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
|
| return out_text
|
|
|
| def process_digit_article(self, in_text):
|
| out_text = []
|
| temp_text = in_text.lower().split()
|
| for word in temp_text:
|
| word = self.NUMBER_MAP.setdefault(word, word)
|
| if word not in self.ARTICLES:
|
| out_text.append(word)
|
| else:
|
| pass
|
| for word_id, word in enumerate(out_text):
|
| if word in self.CONTRACTIONS:
|
| out_text[word_id] = self.CONTRACTIONS[word]
|
| out_text = " ".join(out_text)
|
| return out_text
|
|
|
| def __call__(self, item):
|
| item = self.word_tokenize(item)
|
| item = item.replace("\n", " ").replace("\t", " ").strip()
|
| item = self.process_punctuation(item)
|
| item = self.process_digit_article(item)
|
| return item
|
|
|
|
|
| class TextVQAAccuracyEvaluator:
|
| def __init__(self):
|
| self.answer_processor = EvalAIAnswerProcessor()
|
|
|
| def _compute_answer_scores(self, raw_answers):
|
| """
|
| compute the accuracy (soft score) of human answers
|
| """
|
| answers = [self.answer_processor(a) for a in raw_answers]
|
| assert len(answers) == 10
|
| gt_answers = list(enumerate(answers))
|
| unique_answers = set(answers)
|
| unique_answer_scores = {}
|
|
|
| for unique_answer in unique_answers:
|
| accs = []
|
| for gt_answer in gt_answers:
|
| other_answers = [item for item in gt_answers if item != gt_answer]
|
| matching_answers = [
|
| item for item in other_answers if item[1] == unique_answer
|
| ]
|
| acc = min(1, float(len(matching_answers)) / 3)
|
| accs.append(acc)
|
| unique_answer_scores[unique_answer] = sum(accs) / len(accs)
|
|
|
| return unique_answer_scores
|
|
|
| def eval_pred_list(self, pred_list):
|
| pred_scores = []
|
| for entry in tqdm(pred_list):
|
| pred_answer = self.answer_processor(entry["pred_answer"])
|
| unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
|
| score = unique_answer_scores.get(pred_answer, 0.0)
|
| pred_scores.append(score)
|
|
|
| accuracy = sum(pred_scores) / len(pred_scores)
|
| return accuracy
|
|
|
|
|
| class STVQAAccuracyEvaluator:
|
| def __init__(self):
|
| self.answer_processor = EvalAIAnswerProcessor()
|
|
|
| def eval_pred_list(self, pred_list):
|
| pred_scores = []
|
| for entry in pred_list:
|
| pred_answer = self.answer_processor(entry["pred_answer"])
|
| gts = [self.answer_processor(a) for a in entry["gt_answers"]]
|
| score = 1.0 if pred_answer in gts else 0.0
|
| pred_scores.append(score)
|
|
|
| accuracy = sum(pred_scores) / len(pred_scores)
|
| return accuracy
|
|
|
|
|
| class STVQAANLSEvaluator:
|
| def __init__(self):
|
| import editdistance
|
|
|
| self.get_edit_distance = editdistance.eval
|
|
|
| def get_anls(self, s1, s2):
|
| s1 = s1.lower().strip()
|
| s2 = s2.lower().strip()
|
| iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
|
| anls = iou if iou >= 0.5 else 0.0
|
| return anls
|
|
|
| def eval_pred_list(self, pred_list):
|
| pred_scores = []
|
| for entry in pred_list:
|
| anls = max(
|
| self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
|
| )
|
| pred_scores.append(anls)
|
|
|
| accuracy = sum(pred_scores) / len(pred_scores)
|
| return accuracy
|
|
|
|
|
| class TextCapsBleu4Evaluator:
|
| def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
| try:
|
| from pycocoevalcap.bleu.bleu import Bleu
|
| from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
| except ModuleNotFoundError:
|
| print(
|
| "Please install pycocoevalcap module using "
|
| "pip install git+https://github.com/ronghanghu/coco-caption.git@python23"
|
| )
|
| raise
|
|
|
| self.tokenizer = PTBTokenizer()
|
| self.scorer = Bleu(4)
|
|
|
| def eval_pred_list(self, pred_list):
|
|
|
| gts = {}
|
| res = {}
|
| for idx, entry in enumerate(pred_list):
|
| gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
|
| res[idx] = [{"caption": entry["pred_answer"]}]
|
|
|
| gts = self.tokenizer.tokenize(gts)
|
| res = self.tokenizer.tokenize(res)
|
| score, _ = self.scorer.compute_score(gts, res)
|
|
|
| bleu4 = score[3]
|
| return bleu4
|
|
|