Spaces:
Configuration error
Configuration error
| import itertools | |
| import json | |
| import os | |
| import re | |
| from collections import namedtuple | |
| import torch | |
| from tqdm import tqdm | |
| class InferenceSampler(torch.utils.data.sampler.Sampler): | |
| def __init__(self, size): | |
| self._size = int(size) | |
| assert size > 0 | |
| self._rank = torch.distributed.get_rank() | |
| self._world_size = torch.distributed.get_world_size() | |
| self._local_indices = self._get_local_indices(size, self._world_size, | |
| self._rank) | |
| def _get_local_indices(total_size, world_size, rank): | |
| shard_size = total_size // world_size | |
| left = total_size % world_size | |
| shard_sizes = [shard_size + int(r < left) for r in range(world_size)] | |
| begin = sum(shard_sizes[:rank]) | |
| end = min(sum(shard_sizes[:rank + 1]), total_size) | |
| return range(begin, end) | |
| def __iter__(self): | |
| yield from self._local_indices | |
| def __len__(self): | |
| return len(self._local_indices) | |
| def collate_fn_vqa(batches): | |
| ''' | |
| ''' | |
| image_paths = [_['image_path'] for _ in batches] | |
| questions = [_['question'] for _ in batches] | |
| gt_answers = [_['gt_answers'] for _ in batches] | |
| ocr_tokens = [_['ocr_tokens'] if 'ocr_tokens' in _ else None for _ in batches] | |
| question_ids = [_['question_id'] if 'question_id' in _ else None for _ in batches] | |
| question_type = [_['question_type'] if 'question_type' in _ else None for _ in batches] | |
| return image_paths, questions, gt_answers, ocr_tokens, question_ids, question_type | |
| def has_word(sentence, word): | |
| if word[0].isalnum(): | |
| start_pattern = r"\b" | |
| else: | |
| start_pattern = r"" | |
| if word[-1].isalnum(): | |
| end_pattern = r"\b" | |
| else: | |
| end_pattern = r"" | |
| pattern = start_pattern + re.escape(word) + end_pattern | |
| match = re.search(pattern, sentence) | |
| return bool(match) | |
| def remove_special_chars(s): | |
| pattern = r"[^a-zA-Z0-9\s]" | |
| s = re.sub(pattern, "", s) | |
| return s | |
| def levenshtein_distance(s1, s2): | |
| if len(s1) > len(s2): | |
| s1, s2 = s2, s1 | |
| distances = range(len(s1) + 1) | |
| for i2, c2 in enumerate(s2): | |
| distances_ = [i2+1] | |
| for i1, c1 in enumerate(s1): | |
| if c1 == c2: | |
| distances_.append(distances[i1]) | |
| else: | |
| distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) | |
| distances = distances_ | |
| return distances[-1] | |
| class VQAEval: | |
| def __init__(self): | |
| self.contractions = { | |
| "aint": "ain't", | |
| "arent": "aren't", | |
| "cant": "can't", | |
| "couldve": "could've", | |
| "couldnt": "couldn't", | |
| "couldn'tve": "couldn't've", | |
| "couldnt've": "couldn't've", | |
| "didnt": "didn't", | |
| "doesnt": "doesn't", | |
| "dont": "don't", | |
| "hadnt": "hadn't", | |
| "hadnt've": "hadn't've", | |
| "hadn'tve": "hadn't've", | |
| "hasnt": "hasn't", | |
| "havent": "haven't", | |
| "hed": "he'd", | |
| "hed've": "he'd've", | |
| "he'dve": "he'd've", | |
| "hes": "he's", | |
| "howd": "how'd", | |
| "howll": "how'll", | |
| "hows": "how's", | |
| "Id've": "I'd've", | |
| "I'dve": "I'd've", | |
| "Im": "I'm", | |
| "Ive": "I've", | |
| "isnt": "isn't", | |
| "itd": "it'd", | |
| "itd've": "it'd've", | |
| "it'dve": "it'd've", | |
| "itll": "it'll", | |
| "let's": "let's", | |
| "maam": "ma'am", | |
| "mightnt": "mightn't", | |
| "mightnt've": "mightn't've", | |
| "mightn'tve": "mightn't've", | |
| "mightve": "might've", | |
| "mustnt": "mustn't", | |
| "mustve": "must've", | |
| "neednt": "needn't", | |
| "notve": "not've", | |
| "oclock": "o'clock", | |
| "oughtnt": "oughtn't", | |
| "ow's'at": "'ow's'at", | |
| "'ows'at": "'ow's'at", | |
| "'ow'sat": "'ow's'at", | |
| "shant": "shan't", | |
| "shed've": "she'd've", | |
| "she'dve": "she'd've", | |
| "she's": "she's", | |
| "shouldve": "should've", | |
| "shouldnt": "shouldn't", | |
| "shouldnt've": "shouldn't've", | |
| "shouldn'tve": "shouldn't've", | |
| "somebody'd": "somebodyd", | |
| "somebodyd've": "somebody'd've", | |
| "somebody'dve": "somebody'd've", | |
| "somebodyll": "somebody'll", | |
| "somebodys": "somebody's", | |
| "someoned": "someone'd", | |
| "someoned've": "someone'd've", | |
| "someone'dve": "someone'd've", | |
| "someonell": "someone'll", | |
| "someones": "someone's", | |
| "somethingd": "something'd", | |
| "somethingd've": "something'd've", | |
| "something'dve": "something'd've", | |
| "somethingll": "something'll", | |
| "thats": "that's", | |
| "thered": "there'd", | |
| "thered've": "there'd've", | |
| "there'dve": "there'd've", | |
| "therere": "there're", | |
| "theres": "there's", | |
| "theyd": "they'd", | |
| "theyd've": "they'd've", | |
| "they'dve": "they'd've", | |
| "theyll": "they'll", | |
| "theyre": "they're", | |
| "theyve": "they've", | |
| "twas": "'twas", | |
| "wasnt": "wasn't", | |
| "wed've": "we'd've", | |
| "we'dve": "we'd've", | |
| "weve": "we've", | |
| "werent": "weren't", | |
| "whatll": "what'll", | |
| "whatre": "what're", | |
| "whats": "what's", | |
| "whatve": "what've", | |
| "whens": "when's", | |
| "whered": "where'd", | |
| "wheres": "where's", | |
| "whereve": "where've", | |
| "whod": "who'd", | |
| "whod've": "who'd've", | |
| "who'dve": "who'd've", | |
| "wholl": "who'll", | |
| "whos": "who's", | |
| "whove": "who've", | |
| "whyll": "why'll", | |
| "whyre": "why're", | |
| "whys": "why's", | |
| "wont": "won't", | |
| "wouldve": "would've", | |
| "wouldnt": "wouldn't", | |
| "wouldnt've": "wouldn't've", | |
| "wouldn'tve": "wouldn't've", | |
| "yall": "y'all", | |
| "yall'll": "y'all'll", | |
| "y'allll": "y'all'll", | |
| "yall'd've": "y'all'd've", | |
| "y'alld've": "y'all'd've", | |
| "y'all'dve": "y'all'd've", | |
| "youd": "you'd", | |
| "youd've": "you'd've", | |
| "you'dve": "you'd've", | |
| "youll": "you'll", | |
| "youre": "you're", | |
| "youve": "you've", | |
| } | |
| self.manualMap = { | |
| "none": "0", | |
| "zero": "0", | |
| "one": "1", | |
| "two": "2", | |
| "three": "3", | |
| "four": "4", | |
| "five": "5", | |
| "six": "6", | |
| "seven": "7", | |
| "eight": "8", | |
| "nine": "9", | |
| "ten": "10", | |
| } | |
| self.articles = ["a", "an", "the"] | |
| self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") | |
| self.commaStrip = re.compile("(\d)(\,)(\d)") | |
| self.punct = [ | |
| ";", | |
| r"/", | |
| "[", | |
| "]", | |
| '"', | |
| "{", | |
| "}", | |
| "(", | |
| ")", | |
| "=", | |
| "+", | |
| "\\", | |
| "_", | |
| "-", | |
| ">", | |
| "<", | |
| "@", | |
| "`", | |
| ",", | |
| "?", | |
| "!", | |
| ] | |
| def clean_text(self, text): | |
| text = text.replace("\n", " ").replace("\t", " ").strip() | |
| text = self.processPunctuation(text) | |
| text = self.processDigitArticle(text) | |
| return text | |
| def evaluate_vqa_human(self, answer, gt_answers): | |
| '''TextVQA, VQAv2, OKVQA, vizwiz''' | |
| answer = answer.replace("\n", " ").replace("\t", " ").strip() | |
| answer = self.processPunctuation(answer) | |
| answer = self.processDigitArticle(answer) | |
| gt_answers = [self.processPunctuation(ans) for ans in gt_answers] | |
| gt_answers = [self.processDigitArticle(ans) for ans in gt_answers] | |
| gtAcc = [] | |
| for idx, gtAnsDatum in enumerate(gt_answers): | |
| otherGTAns = gt_answers[:idx] + gt_answers[idx+1:] | |
| matchingAns = [item for item in otherGTAns if answer == item] | |
| acc = min(1, float(len(matchingAns)) / 3) | |
| gtAcc.append(acc) | |
| avgGTAcc = float(sum(gtAcc)) / len(gtAcc) if gtAcc else 0 | |
| return avgGTAcc | |
| def evaluate_anls(self, answer, gt_answers, threshold=0.5): | |
| '''DOcVQA, InfographicsVQA, STVQA''' | |
| answer = ' '.join(answer.strip().lower().split()) | |
| if not isinstance(gt_answers, list): | |
| gt_answers = [gt_answers] | |
| gt_answers = [' '.join(gt_answer.strip().lower().split()) for gt_answer in gt_answers] | |
| values = [] | |
| for gt_answer in gt_answers: | |
| dist = levenshtein_distance(answer, gt_answer) | |
| length = max(len(answer), len(gt_answer)) | |
| values.append(0.0 if length == 0 else float(dist) / float(length)) | |
| score = 1 - min(values) | |
| score = 0 if score < threshold else score | |
| return score | |
| def processPunctuation(self, inText): | |
| outText = inText | |
| for p in self.punct: | |
| if (p + " " in inText or " " + p in inText) or ( | |
| re.search(self.commaStrip, inText) != None | |
| ): | |
| outText = outText.replace(p, "") | |
| else: | |
| outText = outText.replace(p, " ") | |
| outText = self.periodStrip.sub("", outText, re.UNICODE) | |
| return outText | |
| def processDigitArticle(self, inText): | |
| outText = [] | |
| tempText = inText.lower().split() | |
| for word in tempText: | |
| word = self.manualMap.setdefault(word, word) | |
| if word not in self.articles: | |
| outText.append(word) | |
| else: | |
| pass | |
| for wordId, word in enumerate(outText): | |
| if word in self.contractions: | |
| outText[wordId] = self.contractions[word] | |
| outText = " ".join(outText) | |
| return outText | |
| def evaluate_dataset(dataset_name, answer_file_path, model_name, method = None): | |
| with open(answer_file_path, 'r', encoding='utf-8') as f: | |
| predictions = json.load(f) | |
| eval = VQAEval() | |
| total_accuracy = 0 | |
| num = 0 | |
| Entry = namedtuple('Entry', ['text', 'bbox']) | |
| for item in predictions: | |
| gt_answers = item['gt_answers'] | |
| answer = item['answer'] | |
| if method is not None: | |
| pass | |
| if dataset_name in ["textVQA"]: | |
| if num == 0: | |
| print(f"evaluating vqa...") | |
| accuracy = eval.evaluate_vqa_human(answer, gt_answers) | |
| elif dataset_name in ['docVQA']: | |
| if num == 0: | |
| print(f"evaluating anls...") | |
| accuracy = eval.evaluate_anls(answer, gt_answers) | |
| else: | |
| accuracy = eval.evaluate_has(answer, gt_answers) | |
| item['accuracy'] = accuracy | |
| total_accuracy += accuracy | |
| num += 1 | |
| average_accuracy = total_accuracy / num | |
| print(f'{dataset_name}:{average_accuracy}') | |
| answer_model_method_path = answer_file_path.replace('.json', f'_{model_name}_{method}.json') | |
| with open(answer_model_method_path, "w", encoding='utf-8') as f: | |
| json.dump(predictions, f, indent=4, ensure_ascii=False) | |
| return average_accuracy | |
| def evaluate_VQA( | |
| model, | |
| dataset, | |
| model_name, | |
| dataset_name, | |
| time, | |
| batch_size=1, | |
| generate_method="interleave", | |
| answer_path='./answers', | |
| ): | |
| print(f"answer path:{answer_path}") | |
| sampler = None | |
| if torch.distributed.is_initialized(): | |
| sampler=InferenceSampler(len(dataset)) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset=dataset, | |
| batch_size=batch_size, | |
| sampler=sampler, | |
| collate_fn=collate_fn_vqa | |
| ) | |
| now_rank = torch.distributed.get_rank() | |
| answer_dir = os.path.join(answer_path, model_name, time) | |
| os.makedirs(answer_dir, exist_ok=True) | |
| image_list = [] | |
| for item in dataset: | |
| image_list.append(item["image_path"]) | |
| predictions = [] | |
| for batch in tqdm(dataloader, desc="Running inference"): | |
| image_paths, questions, gt_answers, ocr_tokens_list, question_ids, question_type = batch | |
| with torch.no_grad(): | |
| if model_name != "minicpm": | |
| if model_name != "codellama": | |
| outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name) | |
| else: | |
| outputs = model.generate() | |
| elif model_name == "minicpm": | |
| if generate_method == "old": | |
| outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name) | |
| elif generate_method == "interleave": | |
| outputs = model.generate_with_interleaved(images=image_paths, questions=questions, datasetname=dataset_name) | |
| else: | |
| raise Exception(f"Wrong generate paradigm {generate_method}!") | |
| for i in range(len(outputs)): | |
| answer_dict = { | |
| 'question_id': question_ids[i], | |
| 'question': questions[i], | |
| 'answer': outputs[i], | |
| 'gt_answers': gt_answers[i], | |
| 'image_path': image_paths[i], | |
| 'model_name': model_name, | |
| 'question_type': question_type[i] | |
| } | |
| predictions.append(answer_dict) | |
| if torch.distributed.is_initialized(): | |
| torch.distributed.barrier() | |
| if torch.distributed.is_initialized(): | |
| world_size = torch.distributed.get_world_size() | |
| merged_predictions = [None for _ in range(world_size)] | |
| torch.distributed.all_gather_object(merged_predictions, predictions) | |
| predictions = [_ for _ in itertools.chain.from_iterable(merged_predictions)] | |
| if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0: | |
| return None | |
| answer_file_path = os.path.join(answer_dir, f"{dataset_name}.json") | |
| print(f"answer_file_path:{answer_file_path}") | |
| with open(answer_file_path, "w", encoding='utf-8') as f: | |
| json.dump(predictions, f, indent=4, ensure_ascii=False) | |
| if dataset_name in ["docVQATest"]: | |
| return -1.0 | |
| return evaluate_dataset(answer_file_path=answer_file_path, dataset_name=dataset_name, model_name=model_name) | |