| | import itertools |
| | import json |
| | import os |
| | import re |
| | from collections import namedtuple |
| |
|
| | import torch |
| | from tqdm import tqdm |
| |
|
| |
|
| | class InferenceSampler(torch.utils.data.sampler.Sampler): |
| |
|
| | def __init__(self, size): |
| | self._size = int(size) |
| | assert size > 0 |
| | self._rank = torch.distributed.get_rank() |
| | self._world_size = torch.distributed.get_world_size() |
| | self._local_indices = self._get_local_indices(size, self._world_size, |
| | self._rank) |
| |
|
| | @staticmethod |
| | def _get_local_indices(total_size, world_size, rank): |
| | shard_size = total_size // world_size |
| | left = total_size % world_size |
| | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] |
| |
|
| | begin = sum(shard_sizes[:rank]) |
| | end = min(sum(shard_sizes[:rank + 1]), total_size) |
| | return range(begin, end) |
| |
|
| | def __iter__(self): |
| | yield from self._local_indices |
| |
|
| | def __len__(self): |
| | return len(self._local_indices) |
| | |
| | def collate_fn_vqa(batches): |
| | ''' |
| | ''' |
| | image_paths = [_['image_path'] for _ in batches] |
| | questions = [_['question'] for _ in batches] |
| | gt_answers = [_['gt_answers'] for _ in batches] |
| | ocr_tokens = [_['ocr_tokens'] if 'ocr_tokens' in _ else None for _ in batches] |
| | question_ids = [_['question_id'] if 'question_id' in _ else None for _ in batches] |
| | question_type = [_['question_type'] if 'question_type' in _ else None for _ in batches] |
| |
|
| | return image_paths, questions, gt_answers, ocr_tokens, question_ids, question_type |
| |
|
| | def has_word(sentence, word): |
| | if word[0].isalnum(): |
| | start_pattern = r"\b" |
| | else: |
| | start_pattern = r"" |
| |
|
| | if word[-1].isalnum(): |
| | end_pattern = r"\b" |
| | else: |
| | end_pattern = r"" |
| |
|
| | pattern = start_pattern + re.escape(word) + end_pattern |
| | match = re.search(pattern, sentence) |
| | return bool(match) |
| |
|
| | def remove_special_chars(s): |
| | pattern = r"[^a-zA-Z0-9\s]" |
| | s = re.sub(pattern, "", s) |
| | return s |
| |
|
| | def levenshtein_distance(s1, s2): |
| | if len(s1) > len(s2): |
| | s1, s2 = s2, s1 |
| |
|
| | distances = range(len(s1) + 1) |
| | for i2, c2 in enumerate(s2): |
| | distances_ = [i2+1] |
| | for i1, c1 in enumerate(s1): |
| | if c1 == c2: |
| | distances_.append(distances[i1]) |
| | else: |
| | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) |
| | distances = distances_ |
| | return distances[-1] |
| |
|
| | class VQAEval: |
| | def __init__(self): |
| | self.contractions = { |
| | "aint": "ain't", |
| | "arent": "aren't", |
| | "cant": "can't", |
| | "couldve": "could've", |
| | "couldnt": "couldn't", |
| | "couldn'tve": "couldn't've", |
| | "couldnt've": "couldn't've", |
| | "didnt": "didn't", |
| | "doesnt": "doesn't", |
| | "dont": "don't", |
| | "hadnt": "hadn't", |
| | "hadnt've": "hadn't've", |
| | "hadn'tve": "hadn't've", |
| | "hasnt": "hasn't", |
| | "havent": "haven't", |
| | "hed": "he'd", |
| | "hed've": "he'd've", |
| | "he'dve": "he'd've", |
| | "hes": "he's", |
| | "howd": "how'd", |
| | "howll": "how'll", |
| | "hows": "how's", |
| | "Id've": "I'd've", |
| | "I'dve": "I'd've", |
| | "Im": "I'm", |
| | "Ive": "I've", |
| | "isnt": "isn't", |
| | "itd": "it'd", |
| | "itd've": "it'd've", |
| | "it'dve": "it'd've", |
| | "itll": "it'll", |
| | "let's": "let's", |
| | "maam": "ma'am", |
| | "mightnt": "mightn't", |
| | "mightnt've": "mightn't've", |
| | "mightn'tve": "mightn't've", |
| | "mightve": "might've", |
| | "mustnt": "mustn't", |
| | "mustve": "must've", |
| | "neednt": "needn't", |
| | "notve": "not've", |
| | "oclock": "o'clock", |
| | "oughtnt": "oughtn't", |
| | "ow's'at": "'ow's'at", |
| | "'ows'at": "'ow's'at", |
| | "'ow'sat": "'ow's'at", |
| | "shant": "shan't", |
| | "shed've": "she'd've", |
| | "she'dve": "she'd've", |
| | "she's": "she's", |
| | "shouldve": "should've", |
| | "shouldnt": "shouldn't", |
| | "shouldnt've": "shouldn't've", |
| | "shouldn'tve": "shouldn't've", |
| | "somebody'd": "somebodyd", |
| | "somebodyd've": "somebody'd've", |
| | "somebody'dve": "somebody'd've", |
| | "somebodyll": "somebody'll", |
| | "somebodys": "somebody's", |
| | "someoned": "someone'd", |
| | "someoned've": "someone'd've", |
| | "someone'dve": "someone'd've", |
| | "someonell": "someone'll", |
| | "someones": "someone's", |
| | "somethingd": "something'd", |
| | "somethingd've": "something'd've", |
| | "something'dve": "something'd've", |
| | "somethingll": "something'll", |
| | "thats": "that's", |
| | "thered": "there'd", |
| | "thered've": "there'd've", |
| | "there'dve": "there'd've", |
| | "therere": "there're", |
| | "theres": "there's", |
| | "theyd": "they'd", |
| | "theyd've": "they'd've", |
| | "they'dve": "they'd've", |
| | "theyll": "they'll", |
| | "theyre": "they're", |
| | "theyve": "they've", |
| | "twas": "'twas", |
| | "wasnt": "wasn't", |
| | "wed've": "we'd've", |
| | "we'dve": "we'd've", |
| | "weve": "we've", |
| | "werent": "weren't", |
| | "whatll": "what'll", |
| | "whatre": "what're", |
| | "whats": "what's", |
| | "whatve": "what've", |
| | "whens": "when's", |
| | "whered": "where'd", |
| | "wheres": "where's", |
| | "whereve": "where've", |
| | "whod": "who'd", |
| | "whod've": "who'd've", |
| | "who'dve": "who'd've", |
| | "wholl": "who'll", |
| | "whos": "who's", |
| | "whove": "who've", |
| | "whyll": "why'll", |
| | "whyre": "why're", |
| | "whys": "why's", |
| | "wont": "won't", |
| | "wouldve": "would've", |
| | "wouldnt": "wouldn't", |
| | "wouldnt've": "wouldn't've", |
| | "wouldn'tve": "wouldn't've", |
| | "yall": "y'all", |
| | "yall'll": "y'all'll", |
| | "y'allll": "y'all'll", |
| | "yall'd've": "y'all'd've", |
| | "y'alld've": "y'all'd've", |
| | "y'all'dve": "y'all'd've", |
| | "youd": "you'd", |
| | "youd've": "you'd've", |
| | "you'dve": "you'd've", |
| | "youll": "you'll", |
| | "youre": "you're", |
| | "youve": "you've", |
| | } |
| | self.manualMap = { |
| | "none": "0", |
| | "zero": "0", |
| | "one": "1", |
| | "two": "2", |
| | "three": "3", |
| | "four": "4", |
| | "five": "5", |
| | "six": "6", |
| | "seven": "7", |
| | "eight": "8", |
| | "nine": "9", |
| | "ten": "10", |
| | } |
| | self.articles = ["a", "an", "the"] |
| |
|
| | self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") |
| | self.commaStrip = re.compile("(\d)(\,)(\d)") |
| | self.punct = [ |
| | ";", |
| | r"/", |
| | "[", |
| | "]", |
| | '"', |
| | "{", |
| | "}", |
| | "(", |
| | ")", |
| | "=", |
| | "+", |
| | "\\", |
| | "_", |
| | "-", |
| | ">", |
| | "<", |
| | "@", |
| | "`", |
| | ",", |
| | "?", |
| | "!", |
| | ] |
| | def clean_text(self, text): |
| | text = text.replace("\n", " ").replace("\t", " ").strip() |
| | text = self.processPunctuation(text) |
| | text = self.processDigitArticle(text) |
| | return text |
| | |
| | def evaluate_vqa_human(self, answer, gt_answers): |
| | '''TextVQA, VQAv2, OKVQA, vizwiz''' |
| | answer = answer.replace("\n", " ").replace("\t", " ").strip() |
| | answer = self.processPunctuation(answer) |
| | answer = self.processDigitArticle(answer) |
| | gt_answers = [self.processPunctuation(ans) for ans in gt_answers] |
| | gt_answers = [self.processDigitArticle(ans) for ans in gt_answers] |
| |
|
| | gtAcc = [] |
| |
|
| | for idx, gtAnsDatum in enumerate(gt_answers): |
| | otherGTAns = gt_answers[:idx] + gt_answers[idx+1:] |
| |
|
| | matchingAns = [item for item in otherGTAns if answer == item] |
| | |
| | acc = min(1, float(len(matchingAns)) / 3) |
| | gtAcc.append(acc) |
| |
|
| | avgGTAcc = float(sum(gtAcc)) / len(gtAcc) if gtAcc else 0 |
| |
|
| | return avgGTAcc |
| |
|
| | def evaluate_anls(self, answer, gt_answers, threshold=0.5): |
| | '''DOcVQA, InfographicsVQA, STVQA''' |
| | answer = ' '.join(answer.strip().lower().split()) |
| | if not isinstance(gt_answers, list): |
| | gt_answers = [gt_answers] |
| | gt_answers = [' '.join(gt_answer.strip().lower().split()) for gt_answer in gt_answers] |
| |
|
| | values = [] |
| | for gt_answer in gt_answers: |
| | dist = levenshtein_distance(answer, gt_answer) |
| | length = max(len(answer), len(gt_answer)) |
| | values.append(0.0 if length == 0 else float(dist) / float(length)) |
| |
|
| | score = 1 - min(values) |
| | |
| | score = 0 if score < threshold else score |
| | |
| | return score |
| |
|
| | def processPunctuation(self, inText): |
| | outText = inText |
| | for p in self.punct: |
| | if (p + " " in inText or " " + p in inText) or ( |
| | re.search(self.commaStrip, inText) != None |
| | ): |
| | outText = outText.replace(p, "") |
| | else: |
| | outText = outText.replace(p, " ") |
| | outText = self.periodStrip.sub("", outText, re.UNICODE) |
| | return outText |
| |
|
| | def processDigitArticle(self, inText): |
| | outText = [] |
| | tempText = inText.lower().split() |
| | for word in tempText: |
| | word = self.manualMap.setdefault(word, word) |
| | if word not in self.articles: |
| | outText.append(word) |
| | else: |
| | pass |
| | for wordId, word in enumerate(outText): |
| | if word in self.contractions: |
| | outText[wordId] = self.contractions[word] |
| | outText = " ".join(outText) |
| | return outText |
| |
|
| |
|
| | def evaluate_dataset(dataset_name, answer_file_path, model_name, method = None): |
| | with open(answer_file_path, 'r', encoding='utf-8') as f: |
| | predictions = json.load(f) |
| |
|
| | eval = VQAEval() |
| | total_accuracy = 0 |
| | num = 0 |
| | Entry = namedtuple('Entry', ['text', 'bbox']) |
| |
|
| | for item in predictions: |
| | gt_answers = item['gt_answers'] |
| | answer = item['answer'] |
| | if method is not None: |
| | pass |
| | if dataset_name in ["textVQA"]: |
| | if num == 0: |
| | print(f"evaluating vqa...") |
| | accuracy = eval.evaluate_vqa_human(answer, gt_answers) |
| | elif dataset_name in ['docVQA']: |
| | if num == 0: |
| | print(f"evaluating anls...") |
| | accuracy = eval.evaluate_anls(answer, gt_answers) |
| | else: |
| | accuracy = eval.evaluate_has(answer, gt_answers) |
| | item['accuracy'] = accuracy |
| |
|
| | total_accuracy += accuracy |
| | num += 1 |
| |
|
| | average_accuracy = total_accuracy / num |
| | print(f'{dataset_name}:{average_accuracy}') |
| | |
| | answer_model_method_path = answer_file_path.replace('.json', f'_{model_name}_{method}.json') |
| | with open(answer_model_method_path, "w", encoding='utf-8') as f: |
| | json.dump(predictions, f, indent=4, ensure_ascii=False) |
| |
|
| | return average_accuracy |
| |
|
| |
|
| | def evaluate_VQA( |
| | model, |
| | dataset, |
| | model_name, |
| | dataset_name, |
| | time, |
| | batch_size=1, |
| | generate_method="interleave", |
| | answer_path='./answers', |
| | ): |
| | print(f"answer path:{answer_path}") |
| |
|
| | sampler = None |
| | if torch.distributed.is_initialized(): |
| | sampler=InferenceSampler(len(dataset)) |
| |
|
| | dataloader = torch.utils.data.DataLoader( |
| | dataset=dataset, |
| | batch_size=batch_size, |
| | sampler=sampler, |
| | collate_fn=collate_fn_vqa |
| | ) |
| | |
| | now_rank = torch.distributed.get_rank() |
| | |
| | answer_dir = os.path.join(answer_path, model_name, time) |
| | os.makedirs(answer_dir, exist_ok=True) |
| | |
| | image_list = [] |
| | for item in dataset: |
| | image_list.append(item["image_path"]) |
| | |
| | predictions = [] |
| | |
| | for batch in tqdm(dataloader, desc="Running inference"): |
| | image_paths, questions, gt_answers, ocr_tokens_list, question_ids, question_type = batch |
| |
|
| | with torch.no_grad(): |
| | if model_name != "minicpm": |
| | if model_name != "codellama": |
| | outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name) |
| | else: |
| | outputs = model.generate() |
| | elif model_name == "minicpm": |
| | if generate_method == "old": |
| | outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name) |
| | elif generate_method == "interleave": |
| | outputs = model.generate_with_interleaved(images=image_paths, questions=questions, datasetname=dataset_name) |
| | else: |
| | raise Exception(f"Wrong generate paradigm {generate_method}!") |
| | |
| | for i in range(len(outputs)): |
| | answer_dict = { |
| | 'question_id': question_ids[i], |
| | 'question': questions[i], |
| | 'answer': outputs[i], |
| | 'gt_answers': gt_answers[i], |
| | 'image_path': image_paths[i], |
| | 'model_name': model_name, |
| | 'question_type': question_type[i] |
| | } |
| | predictions.append(answer_dict) |
| | |
| | if torch.distributed.is_initialized(): |
| | torch.distributed.barrier() |
| | if torch.distributed.is_initialized(): |
| | world_size = torch.distributed.get_world_size() |
| | merged_predictions = [None for _ in range(world_size)] |
| | torch.distributed.all_gather_object(merged_predictions, predictions) |
| | predictions = [_ for _ in itertools.chain.from_iterable(merged_predictions)] |
| |
|
| | if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0: |
| | return None |
| | |
| | answer_file_path = os.path.join(answer_dir, f"{dataset_name}.json") |
| | print(f"answer_file_path:{answer_file_path}") |
| | |
| | with open(answer_file_path, "w", encoding='utf-8') as f: |
| | json.dump(predictions, f, indent=4, ensure_ascii=False) |
| |
|
| | if dataset_name in ["docVQATest"]: |
| | return -1.0 |
| |
|
| | return evaluate_dataset(answer_file_path=answer_file_path, dataset_name=dataset_name, model_name=model_name) |
| |
|