| | import os |
| | import os.path as osp |
| | from mmengine.dist import master_only |
| | from .base_eval_dataset import BaseEvalDataset |
| |
|
| | from xtuner.registry import BUILDER |
| | from mmengine.logging import print_log |
| | import pandas as pd |
| | from xtuner.dataset.utils import decode_base64_to_image |
| | import numpy as np |
| | from .utils import custom_data_process |
| |
|
| |
|
| | def levenshtein_distance(s1, s2): |
| | if len(s1) > len(s2): |
| | s1, s2 = s2, s1 |
| |
|
| | distances = range(len(s1) + 1) |
| | for i2, c2 in enumerate(s2): |
| | distances_ = [i2 + 1] |
| | for i1, c1 in enumerate(s1): |
| | if c1 == c2: |
| | distances_.append(distances[i1]) |
| | else: |
| | distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) |
| | distances = distances_ |
| | return distances[-1] |
| |
|
| |
|
| | def anls_compute(groundtruth, prediction): |
| | gt_answer = ' '.join(groundtruth.strip().lower().split()) |
| | det_answer = ' '.join(prediction.strip().lower().split()) |
| | dist = levenshtein_distance(gt_answer, det_answer) |
| | length = max(len(groundtruth.upper()), len(prediction.upper())) |
| | values = 0.0 if length == 0 else float(dist) / float(length) |
| | return values |
| |
|
| |
|
| | def hit_calculate(result, dataset_name, anls_threshold=0.5): |
| | if 'DocVQA' in dataset_name or 'InfoVQA' in dataset_name: |
| | |
| | return [0.0 if 1 - np.min(x['match']) < anls_threshold else 1 - np.min(x['match']) for x in result] |
| | elif 'OCRVQA' in dataset_name: |
| | return [np.max(x['match']) for x in result] |
| | else: |
| | raise NotImplementedError(f"Dataset {dataset_name} not supported for hit calculation") |
| |
|
| |
|
| | def istype(s, type): |
| | if isinstance(s, type): |
| | return True |
| | try: |
| | return isinstance(eval(s), type) |
| | except Exception as _: |
| | return False |
| |
|
| |
|
| | class GeneralVQADataset(BaseEvalDataset): |
| | METAINFO: dict = dict(name='gvqa') |
| |
|
| | def __init__(self, data_file, image_processor, |
| | pad_image_to_square=True, |
| | anls_threshold=0.5, metainfo=None,): |
| | super().__init__(metainfo) |
| | self.anls_threshold = anls_threshold |
| | self.data_file = data_file |
| | self.df = pd.read_csv(data_file, sep='\t') |
| | self.ocr = False |
| | if 'OCR' in data_file: |
| | self.ocr = True |
| |
|
| | skip_noimg = True |
| | if skip_noimg: |
| | self.df = self.df[~pd.isna(self.df['image'])] |
| |
|
| | self.image_processor = BUILDER.build(image_processor) |
| | self.pad_image_to_square = pad_image_to_square |
| | self.name = os.path.splitext(os.path.basename(data_file))[0] |
| | self.results_xlsx_path = os.path.splitext(os.path.basename(data_file))[0] + '-results.xlsx' |
| | self.data = self.load_data_list() |
| |
|
| | def get_image(self, image): |
| | while len(image) < 16: |
| | if self.ocr: |
| | image = self.df[self.df['index'] == image]['image'].values |
| | else: |
| | image = self.df[self.df['index'] == int(image)]['image'].values |
| | assert len(image) == 1 |
| | image = image[0] |
| | image = decode_base64_to_image(image) |
| | return image |
| |
|
| | def __len__(self): |
| | return len(self.df) |
| |
|
| | def __getitem__(self, idx): |
| | data = self.data[idx] |
| | data_dict = custom_data_process(self, data) |
| | return data_dict |
| |
|
| | def load_data_list(self): |
| | data_list = [] |
| | for idx in range(len(self.df)): |
| | index = self.df.iloc[idx]['index'] |
| | image = self.df.iloc[idx]['image'] |
| | question = self.df.iloc[idx]['question'] |
| | split = self.df.iloc[idx]['split'] if 'split' in self.df.iloc[ |
| | 0].keys() else None |
| | answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[ |
| | 0].keys() else None |
| |
|
| | data = { |
| | 'img': image, |
| | 'question': question, |
| | 'answer': answer, |
| | 'index': index, |
| | 'img_id': idx |
| | } |
| | if split is not None: |
| | data['split'] = split |
| |
|
| | data_list.append(data) |
| | return data_list |
| |
|
| | @master_only |
| | def evaluate(self, results, work_dir): |
| | orig_index = [x['img_id'] for x in self.data] |
| | new_results = [] |
| | for pred_dict in results: |
| | index = pred_dict['img_id'] |
| | new_index = orig_index.index(index) |
| | filtered_rows = self.data[new_index] |
| |
|
| | cur_result = {} |
| | cur_result['question'] = filtered_rows.get('question') |
| | cur_result['split'] = filtered_rows.get('split') |
| | cur_result['prediction'] = pred_dict['prediction'] |
| | cur_result['index'] = filtered_rows.get('index') |
| | cur_result['index'] = filtered_rows.get('answer') |
| | answers = filtered_rows.get('answer') |
| | if istype(answers, list): |
| | answers = eval(answers) |
| | else: |
| | answers = [answers] |
| | if 'OCRVQA' in self.name: |
| | match = [(1.0 if (x.strip().lower() == cur_result['prediction'].strip().lower()) else 0.0) for x in |
| | answers] |
| | else: |
| | match = [anls_compute(x, cur_result['prediction']) for x in answers] |
| | cur_result['match'] = match |
| |
|
| | new_results.append(cur_result) |
| |
|
| | results_df = pd.DataFrame(new_results) |
| | with pd.ExcelWriter(osp.join(work_dir, self.results_xlsx_path), engine='openpyxl') as writer: |
| | results_df.to_excel(writer, index=False) |
| |
|
| | ret = dict() |
| | if 'split' in results_df: |
| | splits = list(set(results_df['split'])) |
| | for sp in splits: |
| | sub = [new_results[i] for i, x in enumerate(new_results) if x['split'] == sp] |
| | hit = hit_calculate(sub, self.name) |
| | ret[sp] = np.mean(hit) * 100 |
| | else: |
| | hit = hit_calculate(new_results, self.name) |
| | ret['overall'] = np.mean(hit) * 100 |
| |
|
| | print_log('============================================', 'current') |
| | print_log(ret, 'current') |
| | print_log('============================================', 'current') |
| | print_log(f'{self.name} successfully finished evaluating', 'current') |
| | return ret |
| |
|