Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import logging | |
| import json | |
| import os | |
| import torch | |
| from tqdm import tqdm | |
| from lavis.common.utils import is_convertible_to_int | |
| import lavis.common.dist_utils as dist_utils | |
| from lavis.common.registry import registry | |
| from lavis.common.vqa_tools.vqa import VQA | |
| from lavis.common.vqa_tools.vqa_eval import VQAEval | |
| from lavis.tasks.base_task import BaseTask | |
| class VQATask(BaseTask): | |
| def __init__( | |
| self, | |
| num_beams, | |
| max_len, | |
| min_len, | |
| evaluate, | |
| num_ans_candidates, | |
| inference_method="rank", | |
| prompt="", | |
| sample_id_key = "", | |
| ques_files=dict(), | |
| anno_files=dict(), | |
| valid_splits=['val'] | |
| ): | |
| super().__init__() | |
| self.num_beams = num_beams | |
| self.max_len = max_len | |
| self.min_len = min_len | |
| self.evaluate = evaluate | |
| self.inference_method = inference_method | |
| self.num_ans_candidates = num_ans_candidates | |
| self.prompt = prompt | |
| self.answer_list = None | |
| self.ques_files = ques_files | |
| self.anno_files = anno_files | |
| # generalize to non coco data | |
| self.sample_id_key = sample_id_key | |
| self.valid_splits = valid_splits | |
| def setup_task(cls, cfg): | |
| run_cfg = cfg.run_cfg | |
| num_beams = run_cfg.get("num_beams", 3) | |
| max_len = run_cfg.get("max_len", 10) | |
| min_len = run_cfg.get("min_len", 1) | |
| evaluate = run_cfg.get("evaluate", False) | |
| inference_method = run_cfg.get("inference_method", "rank") | |
| num_ans_candidates = run_cfg.get("num_ans_candidates", 128) | |
| prompt = run_cfg.get("prompt", "") | |
| # generalize to non coco data | |
| sample_id_key = run_cfg.get("sample_id_key", "instance_id") | |
| ques_files = run_cfg.get("ques_files", dict()) | |
| anno_files = run_cfg.get("anno_files", dict()) | |
| valid_splits = run_cfg.get("valid_splits", ["val"]) | |
| return cls( | |
| num_beams=num_beams, | |
| max_len=max_len, | |
| min_len=min_len, | |
| evaluate=evaluate, | |
| num_ans_candidates=num_ans_candidates, | |
| inference_method=inference_method, | |
| prompt=prompt, | |
| sample_id_key = sample_id_key, | |
| ques_files=ques_files, | |
| anno_files=anno_files, | |
| valid_splits=valid_splits | |
| ) | |
| def build_datasets(self, cfg): | |
| datasets = super().build_datasets(cfg) | |
| # get question file, annotation file and anwser list in COCO format | |
| for ds_name, dataset in datasets.items(): | |
| for split in self.valid_splits: | |
| if split not in dataset: | |
| print(f"Split {split} not found in {ds_name}.") | |
| if ( | |
| hasattr(dataset[split], "coco_fmt_qust_file") | |
| and dataset[split].coco_fmt_qust_file is not None | |
| ): | |
| self.ques_files[split] = dataset[split].coco_fmt_qust_file | |
| self.anno_files[split] = dataset[split].coco_fmt_anno_file | |
| else: | |
| if split not in self.ques_files: # precomputed and passed in task builder | |
| self.ques_files[split] = os.path.join(registry.get_path("cache_root"),f'{ds_name}_gt', f'{ds_name}_{split}_questions.json') | |
| self.anno_files[split] = os.path.join(registry.get_path("cache_root"), f'{ds_name}_gt', f'{ds_name}_{split}_annotations.json') | |
| if dist_utils.get_rank() == 0: | |
| os.makedirs(os.path.join(registry.get_path("cache_root"),f'{ds_name}_gt'), exist_ok=True) | |
| try: | |
| convert_to_coco_gt(dataset, self.ques_files[split], self.anno_files[split], split, self.sample_id_key) | |
| except: | |
| pass # tasks like vizwiz with no gt answer | |
| try: | |
| self.answer_list = dataset[split].answer_list | |
| except AttributeError: | |
| # if answer_list is not provided, then set it to None | |
| pass | |
| if len(self.ques_files) > 0: | |
| assert len(self.ques_files) == len( | |
| self.anno_files | |
| ), "Only support one split for evaluation." | |
| return datasets | |
| def valid_step(self, model, samples): | |
| answers = model.predict_answers( | |
| samples=samples, | |
| answer_list=self.answer_list, | |
| inference_method=self.inference_method, | |
| num_beams=self.num_beams, | |
| max_len=self.max_len, | |
| min_len=self.min_len, | |
| num_ans_candidates=self.num_ans_candidates, | |
| prompt=self.prompt, | |
| ) | |
| pred_qa_pairs = [] | |
| question_id = samples["question_id"] | |
| for answer, ques_id in zip(answers, question_id): | |
| ques_id = int(ques_id.item()) if isinstance(ques_id, torch.Tensor) else ques_id | |
| if ques_id != int and is_convertible_to_int(ques_id): | |
| ques_id = int(ques_id) | |
| pred_qa_pairs.append({"question_id": ques_id, "answer": answer}) | |
| return pred_qa_pairs | |
| def after_evaluation(self, val_result, split_name, **kwargs): | |
| result_file = self.save_result( | |
| val_result, | |
| result_dir=registry.get_path("result_dir"), | |
| filename=f"{split_name}_vqa_result", | |
| remove_duplicate="question_id", | |
| ) | |
| metrics = self._report_metrics(result_file=result_file, split=split_name) | |
| return metrics | |
| def _report_metrics(self, result_file, split): | |
| """ | |
| Use official VQA evaluation script to report metrics. | |
| """ | |
| metrics = {} | |
| if split in self.ques_files and split in self.anno_files: | |
| vqa = VQA(self.anno_files[split], self.ques_files[split]) | |
| vqa_result = vqa.loadRes( | |
| resFile=result_file, quesFile=self.ques_files[split] | |
| ) | |
| # create vqaEval object by taking vqa and vqaRes | |
| # n is precision of accuracy (number of places after decimal), default is 2 | |
| vqa_scorer = VQAEval(vqa, vqa_result, n=2) | |
| logging.info("Start VQA evaluation.") | |
| vqa_scorer.evaluate() | |
| # print accuracies | |
| overall_acc = vqa_scorer.accuracy["overall"] | |
| metrics["agg_metrics"] = overall_acc | |
| logging.info("Overall Accuracy is: %.02f\n" % overall_acc) | |
| logging.info("Per Answer Type Accuracy is the following:") | |
| for ans_type in vqa_scorer.accuracy["perAnswerType"]: | |
| logging.info( | |
| "%s : %.02f" | |
| % (ans_type, vqa_scorer.accuracy["perAnswerType"][ans_type]) | |
| ) | |
| metrics[ans_type] = vqa_scorer.accuracy["perAnswerType"][ans_type] | |
| with open( | |
| os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
| ) as f: | |
| f.write(json.dumps(metrics) + "\n") | |
| return metrics | |
| def convert_to_coco_gt(data, outpath_questions, outpath_annotations, split, sample_id_key): | |
| if split not in data: | |
| return | |
| questions_data = {'info':"", 'task_type':"", 'data_type':"", 'license':"", 'data_subtype':"", 'questions':[]} | |
| annotations_data = {'info':"", 'task_type':"", 'data_type':"", 'license':"", 'data_subtype':"", 'annotations':[]} | |
| print("Generating ground truth annotations...") | |
| for ann in tqdm(data[split]): | |
| if ann == None: | |
| continue | |
| # if ann[sample_id_key] not in img_ids: | |
| # continue | |
| ques_id = ann["question_id"] | |
| ques_id = int(ques_id.item()) if isinstance(ques_id, torch.Tensor) else ques_id | |
| if ques_id != int and is_convertible_to_int(ques_id): | |
| ques_id = int(ques_id) | |
| questions_data["questions"].append({"question": ann["text_input"], "image_id": ann[sample_id_key], "question_id": ques_id}) | |
| annotations_data["annotations"].append({ | |
| "question_type": "" if "question_type" not in ann else ann["question_type"], | |
| "multiple_choice_answer": ann["answers"][0] if isinstance(ann["answers"], list) else ann["answers"], | |
| "answers": [{"answer":ans, "answer_id":i} for i,ans in enumerate(ann["answers"])] if isinstance(ann["answers"], list) else [{"answer":ann["answers"], "answer_id":0}], | |
| "image_id": ann[sample_id_key], | |
| "question_id": ques_id, | |
| "answer_type": "" if "answer_type" not in ann else ann["answer_type"], | |
| }) | |
| json.dump(questions_data, open(outpath_questions, 'w')) | |
| print(f"Saved questions data at {outpath_questions}") | |
| json.dump(annotations_data, open(outpath_annotations, 'w')) | |
| print(f"Saved annotation data at {outpath_annotations}") | |
| class AOKVQATask(VQATask): | |
| def valid_step(self, model, samples): | |
| answers = model.predict_answers( | |
| samples=samples, | |
| answer_list=self.answer_list, | |
| inference_method=self.inference_method, | |
| num_beams=self.num_beams, | |
| max_len=self.max_len, | |
| min_len=self.min_len, | |
| num_ans_candidates=self.num_ans_candidates, | |
| ) | |
| pred_qa_pairs = [] | |
| question_id = samples["question_id"] | |
| gt_answers = samples["direct_answers"] | |
| for pred_answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): | |
| pred_qa_pairs.append( | |
| {"question_id": ques_id, "pred_ans": pred_answer, "gt_ans": gt_answer} | |
| ) | |
| return pred_qa_pairs | |
| def _report_metrics(self, result_file, split): | |
| """ | |
| Implementing accuracy computation for AOKVQA, see | |
| https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py#L45 for details. | |
| """ | |
| # TODO add evaluation for multi-choice | |
| results = json.load(open(result_file, "r")) | |
| acc = [] | |
| for res in results: | |
| if res["gt_ans"] is None: | |
| # prepare test results for leaderboard evaluation | |
| self._save_result_leaderboard(results) | |
| return | |
| pred = res["pred_ans"] | |
| gt_ans = res["gt_ans"] | |
| num_match = sum([pred == gt for gt in gt_ans]) | |
| vqa_acc = min(1.0, num_match / 3.0) | |
| acc.append(vqa_acc) | |
| accuracy = sum(acc) / len(acc) * 100 | |
| metrics = {"agg_metrics": accuracy, "acc": accuracy} | |
| with open( | |
| os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
| ) as f: | |
| f.write(json.dumps(metrics) + "\n") | |
| logging.info(metrics) | |
| return metrics | |
| def _save_result_leaderboard(self, results): | |
| """ | |
| Saving the results in the format required for leaderboard evaluation. | |
| [TODO] add support for multi-choice. | |
| """ | |
| result_leaderboard = dict() | |
| for res in results: | |
| result_leaderboard[res["question_id"]] = { | |
| "direct_answer": res["pred_ans"], | |
| "multiple_choice": "", | |
| } | |
| result_file = registry.get_path("result_dir") + "_leaderboard.json" | |
| with open(result_file, "w") as f: | |
| json.dump(result_leaderboard, f) | |
| logging.info(f"Saved results for leaderboard evaluation at {result_file}") | |
| class GQATask(VQATask): | |
| def valid_step(self, model, samples): | |
| answers = model.predict_answers( | |
| samples=samples, | |
| answer_list=self.answer_list, | |
| inference_method=self.inference_method, | |
| num_beams=self.num_beams, | |
| max_len=self.max_len, | |
| min_len=self.min_len, | |
| num_ans_candidates=self.num_ans_candidates, | |
| prompt=self.prompt, | |
| ) | |
| pred_qa_pairs = [] | |
| question_id = samples["question_id"] | |
| gt_answers = samples["answer"] | |
| for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): | |
| ques_id = int(ques_id.item()) if isinstance(ques_id, torch.Tensor) else ques_id | |
| pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer}) | |
| return pred_qa_pairs | |
| def build_datasets(self, cfg): | |
| datasets = BaseTask.build_datasets(self,cfg) | |
| # get question file, annotation file and anwser list in COCO format | |
| for ds_name, dataset in datasets.items(): | |
| for split in dataset: | |
| if ( | |
| hasattr(dataset[split], "coco_fmt_qust_file") | |
| and dataset[split].coco_fmt_qust_file is not None | |
| ): | |
| self.ques_files[split] = dataset[split].coco_fmt_qust_file | |
| self.anno_files[split] = dataset[split].coco_fmt_anno_file | |
| if len(self.ques_files) > 0: | |
| assert len(self.ques_files) == len( | |
| self.anno_files | |
| ), "Only support one split for evaluation." | |
| return datasets | |
| def _report_metrics(self, result_file, split): | |
| """ | |
| TODO: add other evaluation metrics for GQA | |
| """ | |
| results = json.load(open(result_file, "r")) | |
| acc = [] | |
| vqa_tool = VQAEval() | |
| for res in results: | |
| if res["gt_ans"] is None: | |
| # prepare test results for leaderboard evaluation | |
| self._save_result_leaderboard(results) | |
| return | |
| gt_ans = res["gt_ans"] | |
| pred = res["pred_ans"] | |
| # if self.inference_method == "generate": | |
| pred = vqa_tool.processPunctuation(pred) | |
| pred = vqa_tool.processDigitArticle(pred) | |
| # added to ensure that the ground truth format of answers is as expected for non-gqa but similar tasks | |
| gt_ans = vqa_tool.processPunctuation(gt_ans) | |
| gt_ans = vqa_tool.processDigitArticle(gt_ans) | |
| vqa_acc = 1 if pred == gt_ans else 0 | |
| acc.append(vqa_acc) | |
| accuracy = sum(acc) / len(acc) * 100 | |
| metrics = {"agg_metrics": accuracy, "acc": accuracy} | |
| with open( | |
| os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
| ) as f: | |
| f.write(json.dumps(metrics) + "\n") | |
| logging.info(metrics) | |
| return metrics | |
| class DisCRNTask(VQATask): | |
| def valid_step(self, model, samples): | |
| answers = model.predict_answers( | |
| samples=samples, | |
| answer_list=self.answer_list, | |
| inference_method=self.inference_method, | |
| num_beams=self.num_beams, | |
| max_len=self.max_len, | |
| min_len=self.min_len, | |
| num_ans_candidates=self.num_ans_candidates, | |
| prompt=self.prompt, | |
| ) | |
| if answers == None: # corrupt videos | |
| return [] | |
| pred_qa_pairs = [] | |
| question_id = samples["question_id"] | |
| gt_answers = samples["answer"] | |
| for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers): | |
| ques_id = int(ques_id.item()) if isinstance(ques_id, torch.Tensor) else ques_id | |
| pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer}) | |
| return pred_qa_pairs | |
| def build_datasets(self, cfg): | |
| datasets = BaseTask.build_datasets(self, cfg) | |
| return datasets | |
| def _report_metrics(self, result_file, split): | |
| results = json.load(open(result_file, "r")) | |
| acc = [] | |
| vqa_tool = VQAEval() | |
| for res in results: | |
| gt_ans = res["gt_ans"] | |
| pred = res["pred_ans"] | |
| # gt_ans = [vqa_tool.processPunctuation(g) for g in gt_ans] | |
| # gt_ans = [vqa_tool.processDigitArticle(g) for g in gt_ans] | |
| # if self.inference_method == "generate": | |
| pred = vqa_tool.processPunctuation(pred) | |
| pred = vqa_tool.processDigitArticle(pred) | |
| tokenized_pred = pred.strip().split(" ") | |
| for ans in gt_ans: | |
| if ans in tokenized_pred: | |
| pred = ans | |
| break | |
| vqa_acc = 1 if pred in gt_ans else 0 | |
| acc.append(vqa_acc) | |
| accuracy = sum(acc) / len(acc) * 100 | |
| metrics = {"agg_metrics": accuracy, "acc": accuracy} | |
| with open( | |
| os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
| ) as f: | |
| f.write(json.dumps(metrics) + "\n") | |
| logging.info(metrics) | |
| return metrics |