| """
|
| Copyright (c) 2022, salesforce.com, inc.
|
| All rights reserved.
|
| SPDX-License-Identifier: BSD-3-Clause
|
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| """
|
|
|
|
|
|
|
| __author__ = "aagrawal"
|
|
|
|
|
|
|
| import sys
|
| import re
|
|
|
|
|
| class VQAEval:
|
| def __init__(self, vqa=None, vqaRes=None, n=2):
|
| self.n = n
|
| self.accuracy = {}
|
| self.evalQA = {}
|
| self.evalQuesType = {}
|
| self.evalAnsType = {}
|
| self.vqa = vqa
|
| self.vqaRes = vqaRes
|
| if vqa is not None:
|
| self.params = {"question_id": vqa.getQuesIds()}
|
| self.contractions = {
|
| "aint": "ain't",
|
| "arent": "aren't",
|
| "cant": "can't",
|
| "couldve": "could've",
|
| "couldnt": "couldn't",
|
| "couldn'tve": "couldn't've",
|
| "couldnt've": "couldn't've",
|
| "didnt": "didn't",
|
| "doesnt": "doesn't",
|
| "dont": "don't",
|
| "hadnt": "hadn't",
|
| "hadnt've": "hadn't've",
|
| "hadn'tve": "hadn't've",
|
| "hasnt": "hasn't",
|
| "havent": "haven't",
|
| "hed": "he'd",
|
| "hed've": "he'd've",
|
| "he'dve": "he'd've",
|
| "hes": "he's",
|
| "howd": "how'd",
|
| "howll": "how'll",
|
| "hows": "how's",
|
| "Id've": "I'd've",
|
| "I'dve": "I'd've",
|
| "Im": "I'm",
|
| "Ive": "I've",
|
| "isnt": "isn't",
|
| "itd": "it'd",
|
| "itd've": "it'd've",
|
| "it'dve": "it'd've",
|
| "itll": "it'll",
|
| "let's": "let's",
|
| "maam": "ma'am",
|
| "mightnt": "mightn't",
|
| "mightnt've": "mightn't've",
|
| "mightn'tve": "mightn't've",
|
| "mightve": "might've",
|
| "mustnt": "mustn't",
|
| "mustve": "must've",
|
| "neednt": "needn't",
|
| "notve": "not've",
|
| "oclock": "o'clock",
|
| "oughtnt": "oughtn't",
|
| "ow's'at": "'ow's'at",
|
| "'ows'at": "'ow's'at",
|
| "'ow'sat": "'ow's'at",
|
| "shant": "shan't",
|
| "shed've": "she'd've",
|
| "she'dve": "she'd've",
|
| "she's": "she's",
|
| "shouldve": "should've",
|
| "shouldnt": "shouldn't",
|
| "shouldnt've": "shouldn't've",
|
| "shouldn'tve": "shouldn't've",
|
| "somebody'd": "somebodyd",
|
| "somebodyd've": "somebody'd've",
|
| "somebody'dve": "somebody'd've",
|
| "somebodyll": "somebody'll",
|
| "somebodys": "somebody's",
|
| "someoned": "someone'd",
|
| "someoned've": "someone'd've",
|
| "someone'dve": "someone'd've",
|
| "someonell": "someone'll",
|
| "someones": "someone's",
|
| "somethingd": "something'd",
|
| "somethingd've": "something'd've",
|
| "something'dve": "something'd've",
|
| "somethingll": "something'll",
|
| "thats": "that's",
|
| "thered": "there'd",
|
| "thered've": "there'd've",
|
| "there'dve": "there'd've",
|
| "therere": "there're",
|
| "theres": "there's",
|
| "theyd": "they'd",
|
| "theyd've": "they'd've",
|
| "they'dve": "they'd've",
|
| "theyll": "they'll",
|
| "theyre": "they're",
|
| "theyve": "they've",
|
| "twas": "'twas",
|
| "wasnt": "wasn't",
|
| "wed've": "we'd've",
|
| "we'dve": "we'd've",
|
| "weve": "we've",
|
| "werent": "weren't",
|
| "whatll": "what'll",
|
| "whatre": "what're",
|
| "whats": "what's",
|
| "whatve": "what've",
|
| "whens": "when's",
|
| "whered": "where'd",
|
| "wheres": "where's",
|
| "whereve": "where've",
|
| "whod": "who'd",
|
| "whod've": "who'd've",
|
| "who'dve": "who'd've",
|
| "wholl": "who'll",
|
| "whos": "who's",
|
| "whove": "who've",
|
| "whyll": "why'll",
|
| "whyre": "why're",
|
| "whys": "why's",
|
| "wont": "won't",
|
| "wouldve": "would've",
|
| "wouldnt": "wouldn't",
|
| "wouldnt've": "wouldn't've",
|
| "wouldn'tve": "wouldn't've",
|
| "yall": "y'all",
|
| "yall'll": "y'all'll",
|
| "y'allll": "y'all'll",
|
| "yall'd've": "y'all'd've",
|
| "y'alld've": "y'all'd've",
|
| "y'all'dve": "y'all'd've",
|
| "youd": "you'd",
|
| "youd've": "you'd've",
|
| "you'dve": "you'd've",
|
| "youll": "you'll",
|
| "youre": "you're",
|
| "youve": "you've",
|
| }
|
| self.manualMap = {
|
| "none": "0",
|
| "zero": "0",
|
| "one": "1",
|
| "two": "2",
|
| "three": "3",
|
| "four": "4",
|
| "five": "5",
|
| "six": "6",
|
| "seven": "7",
|
| "eight": "8",
|
| "nine": "9",
|
| "ten": "10",
|
| }
|
| self.articles = ["a", "an", "the"]
|
|
|
| self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
|
| self.commaStrip = re.compile("(\d)(,)(\d)")
|
| self.punct = [
|
| ";",
|
| r"/",
|
| "[",
|
| "]",
|
| '"',
|
| "{",
|
| "}",
|
| "(",
|
| ")",
|
| "=",
|
| "+",
|
| "\\",
|
| "_",
|
| "-",
|
| ">",
|
| "<",
|
| "@",
|
| "`",
|
| ",",
|
| "?",
|
| "!",
|
| ]
|
|
|
| def evaluate(self, quesIds=None):
|
| if quesIds == None:
|
| quesIds = [quesId for quesId in self.params["question_id"]]
|
| gts = {}
|
| res = {}
|
| for quesId in quesIds:
|
| gts[quesId] = self.vqa.qa[quesId]
|
| res[quesId] = self.vqaRes.qa[quesId]
|
|
|
|
|
|
|
|
|
| accQA = []
|
| accQuesType = {}
|
| accAnsType = {}
|
| print("computing accuracy")
|
| step = 0
|
| for quesId in quesIds:
|
| resAns = res[quesId]["answer"]
|
| resAns = resAns.replace("\n", " ")
|
| resAns = resAns.replace("\t", " ")
|
| resAns = resAns.strip()
|
| resAns = self.processPunctuation(resAns)
|
| resAns = self.processDigitArticle(resAns)
|
| gtAcc = []
|
| gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
|
| if len(set(gtAnswers)) > 1:
|
| for ansDic in gts[quesId]["answers"]:
|
| ansDic["answer"] = self.processPunctuation(ansDic["answer"])
|
| for gtAnsDatum in gts[quesId]["answers"]:
|
| otherGTAns = [
|
| item for item in gts[quesId]["answers"] if item != gtAnsDatum
|
| ]
|
| matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
|
| acc = min(1, float(len(matchingAns)) / 3)
|
| gtAcc.append(acc)
|
| quesType = gts[quesId]["question_type"]
|
| ansType = gts[quesId]["answer_type"]
|
| avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
|
| accQA.append(avgGTAcc)
|
| if quesType not in accQuesType:
|
| accQuesType[quesType] = []
|
| accQuesType[quesType].append(avgGTAcc)
|
| if ansType not in accAnsType:
|
| accAnsType[ansType] = []
|
| accAnsType[ansType].append(avgGTAcc)
|
| self.setEvalQA(quesId, avgGTAcc)
|
| self.setEvalQuesType(quesId, quesType, avgGTAcc)
|
| self.setEvalAnsType(quesId, ansType, avgGTAcc)
|
| if step % 100 == 0:
|
| self.updateProgress(step / float(len(quesIds)))
|
| step = step + 1
|
|
|
| self.setAccuracy(accQA, accQuesType, accAnsType)
|
| print("Done computing accuracy")
|
|
|
| def processPunctuation(self, inText):
|
| outText = inText
|
| for p in self.punct:
|
| if (p + " " in inText or " " + p in inText) or (
|
| re.search(self.commaStrip, inText) != None
|
| ):
|
| outText = outText.replace(p, "")
|
| else:
|
| outText = outText.replace(p, " ")
|
| outText = self.periodStrip.sub("", outText, re.UNICODE)
|
| return outText
|
|
|
| def processDigitArticle(self, inText):
|
| outText = []
|
| tempText = inText.lower().split()
|
| for word in tempText:
|
| word = self.manualMap.setdefault(word, word)
|
| if word not in self.articles:
|
| outText.append(word)
|
| else:
|
| pass
|
| for wordId, word in enumerate(outText):
|
| if word in self.contractions:
|
| outText[wordId] = self.contractions[word]
|
| outText = " ".join(outText)
|
| return outText
|
|
|
| def setAccuracy(self, accQA, accQuesType, accAnsType):
|
| self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
|
| self.accuracy["perQuestionType"] = {
|
| quesType: round(
|
| 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
|
| self.n,
|
| )
|
| for quesType in accQuesType
|
| }
|
| self.accuracy["perAnswerType"] = {
|
| ansType: round(
|
| 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
|
| )
|
| for ansType in accAnsType
|
| }
|
|
|
| def setEvalQA(self, quesId, acc):
|
| self.evalQA[quesId] = round(100 * acc, self.n)
|
|
|
| def setEvalQuesType(self, quesId, quesType, acc):
|
| if quesType not in self.evalQuesType:
|
| self.evalQuesType[quesType] = {}
|
| self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
|
|
|
| def setEvalAnsType(self, quesId, ansType, acc):
|
| if ansType not in self.evalAnsType:
|
| self.evalAnsType[ansType] = {}
|
| self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
|
|
|
| def updateProgress(self, progress):
|
| barLength = 20
|
| status = ""
|
| if isinstance(progress, int):
|
| progress = float(progress)
|
| if not isinstance(progress, float):
|
| progress = 0
|
| status = "error: progress var must be float\r\n"
|
| if progress < 0:
|
| progress = 0
|
| status = "Halt...\r\n"
|
| if progress >= 1:
|
| progress = 1
|
| status = "Done...\r\n"
|
| block = int(round(barLength * progress))
|
| text = "\rFinshed Percent: [{0}] {1}% {2}".format(
|
| "#" * block + "-" * (barLength - block), int(progress * 100), status
|
| )
|
| sys.stdout.write(text)
|
| sys.stdout.flush()
|
|
|