Add files using upload-large-folder tool
Browse files- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/utils/ccocr_evaluator/ocr_evaluator.py +106 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/__init__.py +4 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/file.py +344 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/log.py +47 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/misc.py +291 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/vlm.py +179 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/__init__.py +7 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/matching_util.py +69 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/mp_util.py +72 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/result_transfer.py +97 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/__init__.py +6 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/base.py +198 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/minicpm_v.py +727 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/README.md +3 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/datasets/__init__.py +0 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/datasets/vqa_dataset.py +116 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/eval.py +106 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/eval_utils/cal_metric.py +40 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/requirements.txt +49 -0
- r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/transform_docvqatest_for_submission.py +16 -0
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/utils/ccocr_evaluator/ocr_evaluator.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
from collections import Counter
|
| 6 |
+
|
| 7 |
+
# local import
|
| 8 |
+
from .common import BaseMetric
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def token_normalize(token_text, is_lower=False, is_alphanum_only=False):
|
| 12 |
+
"""
|
| 13 |
+
"""
|
| 14 |
+
if is_lower:
|
| 15 |
+
token_text = token_text.lower()
|
| 16 |
+
if is_alphanum_only:
|
| 17 |
+
token_text = re.sub('[^A-Za-z0-9]+', '', token_text)
|
| 18 |
+
return token_text
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def text_normalize_and_tokenize(text, is_keep_blank=True, is_lower=True, is_alphanum_only=False):
|
| 22 |
+
text = text.replace("\t", " ").replace("\n", " ").replace("###", "").replace("***", "")
|
| 23 |
+
text = re.sub(r'\s+', ' ', text)
|
| 24 |
+
if not is_keep_blank:
|
| 25 |
+
text = text.replace(" ", "")
|
| 26 |
+
text_tokens = text.split(" ") if is_keep_blank else list(text)
|
| 27 |
+
text_token_normalized = [token_normalize(t, is_lower, is_alphanum_only) for t in text_tokens]
|
| 28 |
+
text_token_normalized = [x for x in text_token_normalized if len(x) > 0]
|
| 29 |
+
return text_token_normalized
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def evaluate_single_sample(gts, preds):
|
| 33 |
+
right_num = 0
|
| 34 |
+
gt_counter_info = dict(Counter(gts))
|
| 35 |
+
pdt_counter_info = dict(Counter(preds))
|
| 36 |
+
for gt_token, gt_count in gt_counter_info.items():
|
| 37 |
+
pred_count = pdt_counter_info.get(gt_token, 0)
|
| 38 |
+
right_num += min(gt_count, pred_count)
|
| 39 |
+
return right_num
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def calculate_metrics(response_info, gt_info, is_verbose=False):
|
| 43 |
+
"""
|
| 44 |
+
"""
|
| 45 |
+
macro_recall_list, macro_precision_list, macro_f1_list = [], [], []
|
| 46 |
+
total_gt_num, total_pred_num, total_right_num = 0, 0, 0
|
| 47 |
+
for file_name, fullbox_gts in gt_info.items():
|
| 48 |
+
fullbox_preds = response_info.get(file_name, [])
|
| 49 |
+
right_num = evaluate_single_sample(fullbox_gts, fullbox_preds)
|
| 50 |
+
total_right_num += right_num
|
| 51 |
+
total_gt_num += len(fullbox_gts)
|
| 52 |
+
total_pred_num += len(fullbox_preds)
|
| 53 |
+
|
| 54 |
+
macro_recall = right_num / (len(fullbox_gts) + 1e-9)
|
| 55 |
+
macro_precision = right_num / (len(fullbox_preds) + 1e-9)
|
| 56 |
+
macro_f1 = 2 * macro_recall * macro_precision / (macro_recall + macro_precision + 1e-9)
|
| 57 |
+
macro_recall_list.append(macro_recall)
|
| 58 |
+
macro_precision_list.append(macro_precision)
|
| 59 |
+
macro_f1_list.append(macro_f1)
|
| 60 |
+
|
| 61 |
+
# marco
|
| 62 |
+
final_macro_recall = sum(macro_recall_list) / (len(macro_recall_list) + 1e-9)
|
| 63 |
+
final_macro_precision = sum(macro_precision_list) / (len(macro_precision_list) + 1e-9)
|
| 64 |
+
final_macro_f1 = sum(macro_f1_list) / (len(macro_f1_list) + 1e-9)
|
| 65 |
+
|
| 66 |
+
# micro
|
| 67 |
+
recall_acc = total_right_num / (total_gt_num + 1e-9)
|
| 68 |
+
preci_acc = total_right_num / (total_pred_num + 1e-9)
|
| 69 |
+
hmean = 2 * recall_acc * preci_acc / (recall_acc + preci_acc + 1e-9)
|
| 70 |
+
vbs_eval_result = {
|
| 71 |
+
'macro_recall': final_macro_recall, 'macro_precision': final_macro_precision, 'macro_f1_score': final_macro_f1,
|
| 72 |
+
'micro_recall': recall_acc, 'micro_precision': preci_acc, 'mirco_f1_score': hmean
|
| 73 |
+
}
|
| 74 |
+
eval_result = vbs_eval_result if is_verbose else {'macro_f1_score': final_macro_f1, 'mirco_f1_score': hmean}
|
| 75 |
+
return eval_result
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class OcrEvaluator(BaseMetric):
|
| 79 |
+
def response_post_func(self, response_text, **kwargs):
|
| 80 |
+
return response_text
|
| 81 |
+
|
| 82 |
+
def evaluate(self, response_info, gt_info, **kwargs):
|
| 83 |
+
# hard code here
|
| 84 |
+
dataset_name = kwargs['dataset']
|
| 85 |
+
is_word_level, is_lower, is_alphanum_only = True, True, False
|
| 86 |
+
if dataset_name in ["Arabic", "Japanese", "Korean"] or "zh" in dataset_name:
|
| 87 |
+
is_word_level = False
|
| 88 |
+
if "multi_scene_ocr" in self.group_name and is_word_level:
|
| 89 |
+
is_alphanum_only = True
|
| 90 |
+
eval_config = {"word_level": is_word_level, "alphanum_only": is_alphanum_only, "lowercase": is_lower}
|
| 91 |
+
|
| 92 |
+
image_pdt_info, image_gt_info = {}, {}
|
| 93 |
+
for file_name, gt_src in gt_info.items():
|
| 94 |
+
pred_src = response_info.get(file_name, "")
|
| 95 |
+
pdt_token_list = text_normalize_and_tokenize(
|
| 96 |
+
str(pred_src).strip(), is_word_level, is_lower, is_alphanum_only)
|
| 97 |
+
gt_token_list = text_normalize_and_tokenize(
|
| 98 |
+
str(gt_src).strip(), is_word_level, is_lower, is_alphanum_only)
|
| 99 |
+
image_pdt_info[file_name] = pdt_token_list
|
| 100 |
+
image_gt_info[file_name] = gt_token_list
|
| 101 |
+
eval_result = calculate_metrics(image_pdt_info, image_gt_info, is_verbose=False)
|
| 102 |
+
return {"summary": eval_result, "metric_config": eval_config}
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == '__main__':
|
| 106 |
+
pass
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .file import *
|
| 2 |
+
from .vlm import *
|
| 3 |
+
from .misc import *
|
| 4 |
+
from .log import *
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/file.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pickle
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import os
|
| 5 |
+
import csv
|
| 6 |
+
import hashlib
|
| 7 |
+
import os.path as osp
|
| 8 |
+
import time
|
| 9 |
+
import numpy as np
|
| 10 |
+
import validators
|
| 11 |
+
import mimetypes
|
| 12 |
+
import multiprocessing as mp
|
| 13 |
+
from .misc import toliststr
|
| 14 |
+
from .vlm import decode_base64_to_image_file
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def decode_img_omni(tup):
|
| 18 |
+
root, im, p = tup
|
| 19 |
+
images = toliststr(im)
|
| 20 |
+
paths = toliststr(p)
|
| 21 |
+
if len(images) > 1 and len(paths) == 1:
|
| 22 |
+
paths = [osp.splitext(p)[0] + f'_{i}' + osp.splitext(p)[1] for i in range(len(images))]
|
| 23 |
+
|
| 24 |
+
assert len(images) == len(paths)
|
| 25 |
+
paths = [osp.join(root, p) for p in paths]
|
| 26 |
+
for p, im in zip(paths, images):
|
| 27 |
+
if osp.exists(p):
|
| 28 |
+
continue
|
| 29 |
+
if isinstance(im, str) and len(im) > 64:
|
| 30 |
+
decode_base64_to_image_file(im, p)
|
| 31 |
+
return paths
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def localize_df(data, dname, nproc=32):
|
| 35 |
+
assert 'image' in data
|
| 36 |
+
indices = list(data['index'])
|
| 37 |
+
indices_str = [str(x) for x in indices]
|
| 38 |
+
images = list(data['image'])
|
| 39 |
+
image_map = {x: y for x, y in zip(indices_str, images)}
|
| 40 |
+
|
| 41 |
+
root = LMUDataRoot()
|
| 42 |
+
root = osp.join(root, 'images', dname)
|
| 43 |
+
os.makedirs(root, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
if 'image_path' in data:
|
| 46 |
+
img_paths = list(data['image_path'])
|
| 47 |
+
else:
|
| 48 |
+
img_paths = []
|
| 49 |
+
for i in indices_str:
|
| 50 |
+
if len(image_map[i]) <= 64:
|
| 51 |
+
idx = image_map[i]
|
| 52 |
+
assert idx in image_map and len(image_map[idx]) > 64
|
| 53 |
+
img_paths.append(f'{idx}.jpg')
|
| 54 |
+
else:
|
| 55 |
+
img_paths.append(f'{i}.jpg')
|
| 56 |
+
|
| 57 |
+
tups = [(root, im, p) for p, im in zip(img_paths, images)]
|
| 58 |
+
|
| 59 |
+
pool = mp.Pool(32)
|
| 60 |
+
ret = pool.map(decode_img_omni, tups)
|
| 61 |
+
pool.close()
|
| 62 |
+
data.pop('image')
|
| 63 |
+
if 'image_path' not in data:
|
| 64 |
+
data['image_path'] = [x[0] if len(x) == 1 else x for x in ret]
|
| 65 |
+
return data
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def LMUDataRoot():
|
| 69 |
+
if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']):
|
| 70 |
+
return os.environ['LMUData']
|
| 71 |
+
home = osp.expanduser('~')
|
| 72 |
+
root = osp.join(home, 'LMUData')
|
| 73 |
+
os.makedirs(root, exist_ok=True)
|
| 74 |
+
return root
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def HFCacheRoot():
|
| 78 |
+
cache_list = ['HUGGINGFACE_HUB_CACHE', 'HF_HOME']
|
| 79 |
+
for cache_name in cache_list:
|
| 80 |
+
if cache_name in os.environ and osp.exists(os.environ[cache_name]):
|
| 81 |
+
if os.environ[cache_name].split('/')[-1] == 'hub':
|
| 82 |
+
return os.environ[cache_name]
|
| 83 |
+
else:
|
| 84 |
+
return osp.join(os.environ[cache_name], 'hub')
|
| 85 |
+
home = osp.expanduser('~')
|
| 86 |
+
root = osp.join(home, '.cache', 'huggingface', 'hub')
|
| 87 |
+
os.makedirs(root, exist_ok=True)
|
| 88 |
+
return root
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def MMBenchOfficialServer(dataset_name):
|
| 92 |
+
root = LMUDataRoot()
|
| 93 |
+
|
| 94 |
+
if dataset_name in ['MMBench', 'MMBench_V11', 'MMBench_CN', 'MMBench_CN_V11']:
|
| 95 |
+
ans_file = f'{root}/{dataset_name}.tsv'
|
| 96 |
+
if osp.exists(ans_file):
|
| 97 |
+
data = load(ans_file)
|
| 98 |
+
if 'answer' in data and sum([pd.isna(x) for x in data['answer']]) == 0:
|
| 99 |
+
return True
|
| 100 |
+
|
| 101 |
+
if dataset_name in ['MMBench_TEST_EN', 'MMBench_TEST_CN', 'MMBench_TEST_EN_V11', 'MMBench_TEST_CN_V11']:
|
| 102 |
+
ans_file1 = f'{root}/{dataset_name}.tsv'
|
| 103 |
+
mapp = {
|
| 104 |
+
'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_CN': 'MMBench_CN',
|
| 105 |
+
'MMBench_TEST_EN_V11': 'MMBench_V11', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11',
|
| 106 |
+
}
|
| 107 |
+
ans_file2 = f'{root}/{mapp[dataset_name]}.tsv'
|
| 108 |
+
for f in [ans_file1, ans_file2]:
|
| 109 |
+
if osp.exists(f):
|
| 110 |
+
data = load(f)
|
| 111 |
+
if 'answer' in data and sum([pd.isna(x) for x in data['answer']]) == 0:
|
| 112 |
+
return True
|
| 113 |
+
return False
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class NumpyEncoder(json.JSONEncoder):
|
| 117 |
+
def default(self, obj):
|
| 118 |
+
if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
|
| 119 |
+
np.int16, np.int32, np.int64, np.uint8,
|
| 120 |
+
np.uint16, np.uint32, np.uint64)):
|
| 121 |
+
return int(obj)
|
| 122 |
+
elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
|
| 123 |
+
return float(obj)
|
| 124 |
+
elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
|
| 125 |
+
return {'real': obj.real, 'imag': obj.imag}
|
| 126 |
+
elif isinstance(obj, (np.ndarray,)):
|
| 127 |
+
return obj.tolist()
|
| 128 |
+
elif isinstance(obj, (np.bool_)):
|
| 129 |
+
return bool(obj)
|
| 130 |
+
elif isinstance(obj, (np.void)):
|
| 131 |
+
return None
|
| 132 |
+
return json.JSONEncoder.default(self, obj)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# LOAD & DUMP
|
| 136 |
+
def dump(data, f, **kwargs):
|
| 137 |
+
def dump_pkl(data, pth, **kwargs):
|
| 138 |
+
pickle.dump(data, open(pth, 'wb'))
|
| 139 |
+
|
| 140 |
+
def dump_json(data, pth, **kwargs):
|
| 141 |
+
json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder)
|
| 142 |
+
|
| 143 |
+
def dump_jsonl(data, f, **kwargs):
|
| 144 |
+
lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data]
|
| 145 |
+
with open(f, 'w', encoding='utf8') as fout:
|
| 146 |
+
fout.write('\n'.join(lines))
|
| 147 |
+
|
| 148 |
+
def dump_xlsx(data, f, **kwargs):
|
| 149 |
+
data.to_excel(f, index=False, engine='xlsxwriter')
|
| 150 |
+
|
| 151 |
+
def dump_csv(data, f, quoting=csv.QUOTE_ALL):
|
| 152 |
+
data.to_csv(f, index=False, encoding='utf-8', quoting=quoting)
|
| 153 |
+
|
| 154 |
+
def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
|
| 155 |
+
data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting)
|
| 156 |
+
|
| 157 |
+
handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv)
|
| 158 |
+
suffix = f.split('.')[-1]
|
| 159 |
+
return handlers[suffix](data, f, **kwargs)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def load(f, fmt=None):
|
| 163 |
+
def load_pkl(pth):
|
| 164 |
+
return pickle.load(open(pth, 'rb'))
|
| 165 |
+
|
| 166 |
+
def load_json(pth):
|
| 167 |
+
return json.load(open(pth, 'r', encoding='utf-8'))
|
| 168 |
+
|
| 169 |
+
def load_jsonl(f):
|
| 170 |
+
lines = open(f, encoding='utf-8').readlines()
|
| 171 |
+
lines = [x.strip() for x in lines]
|
| 172 |
+
if lines[-1] == '':
|
| 173 |
+
lines = lines[:-1]
|
| 174 |
+
data = [json.loads(x) for x in lines]
|
| 175 |
+
return data
|
| 176 |
+
|
| 177 |
+
def load_xlsx(f):
|
| 178 |
+
return pd.read_excel(f)
|
| 179 |
+
|
| 180 |
+
def load_csv(f):
|
| 181 |
+
return pd.read_csv(f)
|
| 182 |
+
|
| 183 |
+
def load_tsv(f):
|
| 184 |
+
return pd.read_csv(f, sep='\t')
|
| 185 |
+
|
| 186 |
+
handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
|
| 187 |
+
if fmt is not None:
|
| 188 |
+
return handlers[fmt](f)
|
| 189 |
+
|
| 190 |
+
suffix = f.split('.')[-1]
|
| 191 |
+
return handlers[suffix](f)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def download_file(url, filename=None):
|
| 195 |
+
import urllib.request
|
| 196 |
+
from tqdm import tqdm
|
| 197 |
+
|
| 198 |
+
class DownloadProgressBar(tqdm):
|
| 199 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
| 200 |
+
if tsize is not None:
|
| 201 |
+
self.total = tsize
|
| 202 |
+
self.update(b * bsize - self.n)
|
| 203 |
+
|
| 204 |
+
if filename is None:
|
| 205 |
+
filename = url.split('/')[-1]
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t:
|
| 209 |
+
urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
|
| 210 |
+
except Exception as e:
|
| 211 |
+
import logging
|
| 212 |
+
logging.warning(f'{type(e)}: {e}')
|
| 213 |
+
# Handle Failed Downloads from huggingface.co
|
| 214 |
+
if 'huggingface.co' in url:
|
| 215 |
+
url_new = url.replace('huggingface.co', 'hf-mirror.com')
|
| 216 |
+
try:
|
| 217 |
+
download_file(url_new, filename)
|
| 218 |
+
return filename
|
| 219 |
+
except Exception as e:
|
| 220 |
+
logging.warning(f'{type(e)}: {e}')
|
| 221 |
+
raise Exception(f'Failed to download {url}')
|
| 222 |
+
else:
|
| 223 |
+
raise Exception(f'Failed to download {url}')
|
| 224 |
+
|
| 225 |
+
return filename
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def ls(dirname='.', match=[], mode='all', level=1):
|
| 229 |
+
if isinstance(level, str):
|
| 230 |
+
assert '+' in level
|
| 231 |
+
level = int(level[:-1])
|
| 232 |
+
res = []
|
| 233 |
+
for i in range(1, level + 1):
|
| 234 |
+
res.extend(ls(dirname, match=match, mode='file', level=i))
|
| 235 |
+
return res
|
| 236 |
+
|
| 237 |
+
if dirname == '.':
|
| 238 |
+
ans = os.listdir(dirname)
|
| 239 |
+
else:
|
| 240 |
+
ans = [osp.join(dirname, x) for x in os.listdir(dirname)]
|
| 241 |
+
assert mode in ['all', 'dir', 'file']
|
| 242 |
+
assert level >= 1 and isinstance(level, int)
|
| 243 |
+
if level == 1:
|
| 244 |
+
if isinstance(match, str):
|
| 245 |
+
match = [match]
|
| 246 |
+
for m in match:
|
| 247 |
+
if len(m) == 0:
|
| 248 |
+
continue
|
| 249 |
+
if m[0] != '!':
|
| 250 |
+
ans = [x for x in ans if m in x]
|
| 251 |
+
else:
|
| 252 |
+
ans = [x for x in ans if m[1:] not in x]
|
| 253 |
+
if mode == 'dir':
|
| 254 |
+
ans = [x for x in ans if osp.isdir(x)]
|
| 255 |
+
elif mode == 'file':
|
| 256 |
+
ans = [x for x in ans if not osp.isdir(x)]
|
| 257 |
+
return ans
|
| 258 |
+
else:
|
| 259 |
+
dirs = [x for x in ans if osp.isdir(x)]
|
| 260 |
+
res = []
|
| 261 |
+
for d in dirs:
|
| 262 |
+
res.extend(ls(d, match=match, mode=mode, level=level - 1))
|
| 263 |
+
return res
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def mrlines(fname, sp='\n'):
|
| 267 |
+
f = open(fname).read().split(sp)
|
| 268 |
+
while f != [] and f[-1] == '':
|
| 269 |
+
f = f[:-1]
|
| 270 |
+
return f
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def mwlines(lines, fname):
|
| 274 |
+
with open(fname, 'w') as fout:
|
| 275 |
+
fout.write('\n'.join(lines))
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def md5(s):
|
| 279 |
+
hash = hashlib.new('md5')
|
| 280 |
+
if osp.exists(s):
|
| 281 |
+
with open(s, 'rb') as f:
|
| 282 |
+
for chunk in iter(lambda: f.read(2**20), b''):
|
| 283 |
+
hash.update(chunk)
|
| 284 |
+
else:
|
| 285 |
+
hash.update(s.encode('utf-8'))
|
| 286 |
+
return str(hash.hexdigest())
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def last_modified(pth):
|
| 290 |
+
stamp = osp.getmtime(pth)
|
| 291 |
+
m_ti = time.ctime(stamp)
|
| 292 |
+
t_obj = time.strptime(m_ti)
|
| 293 |
+
t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:]
|
| 294 |
+
return t
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def parse_file(s):
|
| 298 |
+
if osp.exists(s) and s != '.':
|
| 299 |
+
assert osp.isfile(s)
|
| 300 |
+
suffix = osp.splitext(s)[1].lower()
|
| 301 |
+
mime = mimetypes.types_map.get(suffix, 'unknown')
|
| 302 |
+
return (mime, s)
|
| 303 |
+
elif s.startswith('data:image/'):
|
| 304 |
+
# To be compatible with OPENAI base64 format
|
| 305 |
+
content = s[11:]
|
| 306 |
+
mime = content.split(';')[0]
|
| 307 |
+
content = ';'.join(content.split(';')[1:])
|
| 308 |
+
dname = osp.join(LMUDataRoot(), 'files')
|
| 309 |
+
assert content.startswith('base64,')
|
| 310 |
+
b64 = content[7:]
|
| 311 |
+
os.makedirs(dname, exist_ok=True)
|
| 312 |
+
tgt = osp.join(dname, md5(b64) + '.png')
|
| 313 |
+
decode_base64_to_image_file(b64, tgt)
|
| 314 |
+
return parse_file(tgt)
|
| 315 |
+
elif validators.url(s):
|
| 316 |
+
suffix = osp.splitext(s)[1].lower()
|
| 317 |
+
if suffix in mimetypes.types_map:
|
| 318 |
+
mime = mimetypes.types_map[suffix]
|
| 319 |
+
dname = osp.join(LMUDataRoot(), 'files')
|
| 320 |
+
os.makedirs(dname, exist_ok=True)
|
| 321 |
+
tgt = osp.join(dname, md5(s) + suffix)
|
| 322 |
+
download_file(s, tgt)
|
| 323 |
+
return (mime, tgt)
|
| 324 |
+
else:
|
| 325 |
+
return ('url', s)
|
| 326 |
+
else:
|
| 327 |
+
return (None, s)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def file_size(f, unit='GB'):
|
| 331 |
+
stats = os.stat(f)
|
| 332 |
+
div_map = {
|
| 333 |
+
'GB': 2 ** 30,
|
| 334 |
+
'MB': 2 ** 20,
|
| 335 |
+
'KB': 2 ** 10,
|
| 336 |
+
}
|
| 337 |
+
return stats.st_size / div_map[unit]
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def parquet_to_tsv(file_path):
|
| 341 |
+
data = pd.read_parquet(file_path)
|
| 342 |
+
pth = '/'.join(file_path.split('/')[:-1])
|
| 343 |
+
data_name = file_path.split('/')[-1].split('.')[0]
|
| 344 |
+
data.to_csv(osp.join(pth, f'{data_name}.tsv'), sep='\t', index=False)
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/log.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
logging.basicConfig(
|
| 3 |
+
format='[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s',
|
| 4 |
+
datefmt='%Y-%m-%d %H:%M:%S')
|
| 5 |
+
|
| 6 |
+
logger_initialized = {}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
|
| 10 |
+
logger = logging.getLogger(name)
|
| 11 |
+
if name in logger_initialized:
|
| 12 |
+
return logger
|
| 13 |
+
|
| 14 |
+
for logger_name in logger_initialized:
|
| 15 |
+
if name.startswith(logger_name):
|
| 16 |
+
return logger
|
| 17 |
+
|
| 18 |
+
stream_handler = logging.StreamHandler()
|
| 19 |
+
handlers = [stream_handler]
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
if dist.is_available() and dist.is_initialized():
|
| 24 |
+
rank = dist.get_rank()
|
| 25 |
+
else:
|
| 26 |
+
rank = 0
|
| 27 |
+
except ImportError:
|
| 28 |
+
rank = 0
|
| 29 |
+
|
| 30 |
+
if rank == 0 and log_file is not None:
|
| 31 |
+
file_handler = logging.FileHandler(log_file, file_mode)
|
| 32 |
+
handlers.append(file_handler)
|
| 33 |
+
|
| 34 |
+
formatter = logging.Formatter(
|
| 35 |
+
'[%(asctime)s] %(levelname)s - %(name)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s')
|
| 36 |
+
for handler in handlers:
|
| 37 |
+
handler.setFormatter(formatter)
|
| 38 |
+
handler.setLevel(log_level)
|
| 39 |
+
logger.addHandler(handler)
|
| 40 |
+
|
| 41 |
+
if rank == 0:
|
| 42 |
+
logger.setLevel(log_level)
|
| 43 |
+
else:
|
| 44 |
+
logger.setLevel(logging.ERROR)
|
| 45 |
+
|
| 46 |
+
logger_initialized[name] = True
|
| 47 |
+
return logger
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/misc.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F401, F403
|
| 2 |
+
import abc
|
| 3 |
+
import argparse
|
| 4 |
+
import csv
|
| 5 |
+
import multiprocessing as mp
|
| 6 |
+
import os
|
| 7 |
+
import os.path as osp
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import copy as cp
|
| 10 |
+
import random as rd
|
| 11 |
+
import requests
|
| 12 |
+
import shutil
|
| 13 |
+
import subprocess
|
| 14 |
+
import warnings
|
| 15 |
+
import pandas as pd
|
| 16 |
+
from collections import OrderedDict, defaultdict
|
| 17 |
+
from multiprocessing import Pool, current_process
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import datetime
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
from tabulate import tabulate
|
| 22 |
+
from json import JSONDecoder
|
| 23 |
+
from huggingface_hub import scan_cache_dir
|
| 24 |
+
from huggingface_hub.utils._cache_manager import _scan_cached_repo
|
| 25 |
+
from sty import fg, bg, ef, rs
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def modelscope_flag_set():
|
| 29 |
+
return os.environ.get('VLMEVALKIT_USE_MODELSCOPE', None) in ['1', 'True']
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def process_punctuation(inText):
|
| 33 |
+
import re
|
| 34 |
+
outText = inText
|
| 35 |
+
punct = [
|
| 36 |
+
';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
|
| 37 |
+
'>', '<', '@', '`', ',', '?', '!'
|
| 38 |
+
]
|
| 39 |
+
commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
|
| 40 |
+
periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
|
| 41 |
+
for p in punct:
|
| 42 |
+
if (p + ' ' in inText or ' ' + p in inText) or (re.search(
|
| 43 |
+
commaStrip, inText) is not None):
|
| 44 |
+
outText = outText.replace(p, '')
|
| 45 |
+
else:
|
| 46 |
+
outText = outText.replace(p, ' ')
|
| 47 |
+
outText = periodStrip.sub('', outText, re.UNICODE)
|
| 48 |
+
return outText
|
| 49 |
+
|
| 50 |
+
def h2r(value):
|
| 51 |
+
if value[0] == '#':
|
| 52 |
+
value = value[1:]
|
| 53 |
+
assert len(value) == 6
|
| 54 |
+
return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2))
|
| 55 |
+
|
| 56 |
+
def r2h(rgb):
|
| 57 |
+
return '#%02x%02x%02x' % rgb
|
| 58 |
+
|
| 59 |
+
def colored(s, color):
|
| 60 |
+
if isinstance(color, str):
|
| 61 |
+
if hasattr(fg, color):
|
| 62 |
+
return getattr(fg, color) + s + fg.rs
|
| 63 |
+
color = h2r(color)
|
| 64 |
+
return fg(*color) + s + fg.rs
|
| 65 |
+
|
| 66 |
+
def istype(s, type):
|
| 67 |
+
if isinstance(s, type):
|
| 68 |
+
return True
|
| 69 |
+
try:
|
| 70 |
+
return isinstance(eval(s), type)
|
| 71 |
+
except Exception as _:
|
| 72 |
+
return False
|
| 73 |
+
|
| 74 |
+
def bincount(lst):
|
| 75 |
+
bins = defaultdict(lambda: 0)
|
| 76 |
+
for item in lst:
|
| 77 |
+
bins[item] += 1
|
| 78 |
+
return bins
|
| 79 |
+
|
| 80 |
+
def get_cache_path(repo_id, branch='main', repo_type='datasets'):
|
| 81 |
+
try:
|
| 82 |
+
if modelscope_flag_set():
|
| 83 |
+
from modelscope.hub.file_download import create_temporary_directory_and_cache
|
| 84 |
+
if repo_type == 'datasets':
|
| 85 |
+
repo_type = 'dataset'
|
| 86 |
+
_, cache = create_temporary_directory_and_cache(model_id=repo_id, repo_type=repo_type)
|
| 87 |
+
cache_path = cache.get_root_location()
|
| 88 |
+
return cache_path
|
| 89 |
+
else:
|
| 90 |
+
from .file import HFCacheRoot
|
| 91 |
+
cache_path = HFCacheRoot()
|
| 92 |
+
org, repo_name = repo_id.split('/')
|
| 93 |
+
repo_path = Path(osp.join(cache_path, f'{repo_type}--{org}--{repo_name}/'))
|
| 94 |
+
hf_cache_info = _scan_cached_repo(repo_path=repo_path)
|
| 95 |
+
revs = {r.refs: r for r in hf_cache_info.revisions}
|
| 96 |
+
if branch is not None:
|
| 97 |
+
revs = {refs: r for refs, r in revs.items() if branch in refs}
|
| 98 |
+
rev2keep = max(revs.values(), key=lambda r: r.last_modified)
|
| 99 |
+
return str(rev2keep.snapshot_path)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
import logging
|
| 102 |
+
logging.warning(f'{type(e)}: {e}')
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
def proxy_set(s):
|
| 106 |
+
import os
|
| 107 |
+
for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']:
|
| 108 |
+
os.environ[key] = s
|
| 109 |
+
|
| 110 |
+
def get_rank_and_world_size():
|
| 111 |
+
rank = int(os.environ.get('RANK', 0))
|
| 112 |
+
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
| 113 |
+
return rank, world_size
|
| 114 |
+
|
| 115 |
+
def splitlen(s, sym='/'):
|
| 116 |
+
return len(s.split(sym))
|
| 117 |
+
|
| 118 |
+
def listinstr(lst, s):
|
| 119 |
+
assert isinstance(lst, list)
|
| 120 |
+
for item in lst:
|
| 121 |
+
if item in s:
|
| 122 |
+
return True
|
| 123 |
+
return False
|
| 124 |
+
|
| 125 |
+
def d2df(D):
|
| 126 |
+
return pd.DataFrame({x: [D[x]] for x in D})
|
| 127 |
+
|
| 128 |
+
def cn_string(s):
|
| 129 |
+
import re
|
| 130 |
+
if re.search(u'[\u4e00-\u9fff]', s):
|
| 131 |
+
return True
|
| 132 |
+
return False
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
import decord
|
| 136 |
+
except ImportError:
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
def timestr(granularity='second'):
|
| 140 |
+
s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
|
| 141 |
+
assert granularity in ['second', 'minute', 'hour', 'day']
|
| 142 |
+
if granularity == 'second':
|
| 143 |
+
return s
|
| 144 |
+
elif granularity == 'minute':
|
| 145 |
+
return s[:-2]
|
| 146 |
+
elif granularity == 'hour':
|
| 147 |
+
return s[:-4]
|
| 148 |
+
elif granularity == 'day':
|
| 149 |
+
return s[:-6]
|
| 150 |
+
|
| 151 |
+
def _minimal_ext_cmd(cmd, cwd=None):
|
| 152 |
+
env = {}
|
| 153 |
+
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
|
| 154 |
+
v = os.environ.get(k)
|
| 155 |
+
if v is not None:
|
| 156 |
+
env[k] = v
|
| 157 |
+
env['LANGUAGE'] = 'C'
|
| 158 |
+
env['LANG'] = 'C'
|
| 159 |
+
env['LC_ALL'] = 'C'
|
| 160 |
+
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env, cwd=cwd).communicate()[0]
|
| 161 |
+
return out
|
| 162 |
+
|
| 163 |
+
def githash(fallback='unknown', digits=8):
|
| 164 |
+
if digits is not None and not isinstance(digits, int):
|
| 165 |
+
raise TypeError('digits must be None or an integer')
|
| 166 |
+
try:
|
| 167 |
+
import vlmeval
|
| 168 |
+
except ImportError as e:
|
| 169 |
+
import logging
|
| 170 |
+
logging.error(f'ImportError: {str(e)}')
|
| 171 |
+
return fallback
|
| 172 |
+
try:
|
| 173 |
+
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'], cwd=vlmeval.__path__[0])
|
| 174 |
+
sha = out.strip().decode('ascii')
|
| 175 |
+
if digits is not None:
|
| 176 |
+
sha = sha[:digits]
|
| 177 |
+
except OSError:
|
| 178 |
+
sha = fallback
|
| 179 |
+
return sha
|
| 180 |
+
|
| 181 |
+
def dict_merge(dct, merge_dct):
|
| 182 |
+
for k, _ in merge_dct.items():
|
| 183 |
+
if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa
|
| 184 |
+
dict_merge(dct[k], merge_dct[k])
|
| 185 |
+
else:
|
| 186 |
+
dct[k] = merge_dct[k]
|
| 187 |
+
|
| 188 |
+
def youtube_dl(idx):
|
| 189 |
+
cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4'
|
| 190 |
+
os.system(cmd)
|
| 191 |
+
|
| 192 |
+
def run_command(cmd):
|
| 193 |
+
if isinstance(cmd, str):
|
| 194 |
+
cmd = cmd.split()
|
| 195 |
+
return subprocess.check_output(cmd).decode()
|
| 196 |
+
|
| 197 |
+
def load_env():
|
| 198 |
+
import logging
|
| 199 |
+
logging.basicConfig(
|
| 200 |
+
format='[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s',
|
| 201 |
+
datefmt='%Y-%m-%d %H:%M:%S')
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
import vlmeval
|
| 205 |
+
except ImportError:
|
| 206 |
+
logging.error('VLMEval is not installed. Failed to import environment variables from .env file. ')
|
| 207 |
+
return
|
| 208 |
+
pth = osp.realpath(vlmeval.__path__[0])
|
| 209 |
+
pth = osp.join(pth, '../.env')
|
| 210 |
+
pth = osp.realpath(pth)
|
| 211 |
+
if not osp.exists(pth):
|
| 212 |
+
logging.error(f'Did not detect the .env file at {pth}, failed to load. ')
|
| 213 |
+
return
|
| 214 |
+
|
| 215 |
+
from dotenv import dotenv_values
|
| 216 |
+
values = dotenv_values(pth)
|
| 217 |
+
for k, v in values.items():
|
| 218 |
+
if v is not None and len(v):
|
| 219 |
+
os.environ[k] = v
|
| 220 |
+
logging.info(f'API Keys successfully loaded from {pth}')
|
| 221 |
+
|
| 222 |
+
def pip_install_robust(package):
|
| 223 |
+
import sys
|
| 224 |
+
retry = 3
|
| 225 |
+
while retry > 0:
|
| 226 |
+
try:
|
| 227 |
+
package_base = package.split('=')[0]
|
| 228 |
+
module = __import__(package)
|
| 229 |
+
return True
|
| 230 |
+
except ImportError:
|
| 231 |
+
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
|
| 232 |
+
retry -= 1
|
| 233 |
+
return False
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def version_cmp(v1, v2, op='eq'):
|
| 237 |
+
from packaging import version
|
| 238 |
+
import operator
|
| 239 |
+
op_func = getattr(operator, op)
|
| 240 |
+
return op_func(version.parse(v1), version.parse(v2))
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def toliststr(s):
|
| 244 |
+
if isinstance(s, str) and (s[0] == '[') and (s[-1] == ']'):
|
| 245 |
+
return [str(x) for x in eval(s)]
|
| 246 |
+
elif isinstance(s, str):
|
| 247 |
+
return [s]
|
| 248 |
+
elif isinstance(s, list):
|
| 249 |
+
return [str(x) for x in s]
|
| 250 |
+
raise NotImplementedError
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def extract_json_objects(text, decoder=JSONDecoder()):
|
| 254 |
+
pos = 0
|
| 255 |
+
while True:
|
| 256 |
+
match = text.find('{', pos)
|
| 257 |
+
if match == -1: break
|
| 258 |
+
try:
|
| 259 |
+
result, index = decoder.raw_decode(text[match:])
|
| 260 |
+
yield result
|
| 261 |
+
pos = match + index
|
| 262 |
+
except ValueError:
|
| 263 |
+
pos = match + 1
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def get_gpu_memory():
|
| 267 |
+
import subprocess
|
| 268 |
+
try:
|
| 269 |
+
command = "nvidia-smi --query-gpu=memory.free --format=csv"
|
| 270 |
+
memory_free_info = subprocess.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
|
| 271 |
+
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
|
| 272 |
+
return memory_free_values
|
| 273 |
+
except Exception as e:
|
| 274 |
+
print(f'{type(e)}: {str(e)}')
|
| 275 |
+
return []
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def auto_split_flag():
|
| 279 |
+
flag = os.environ.get('AUTO_SPLIT', '0')
|
| 280 |
+
if flag == '1':
|
| 281 |
+
return True
|
| 282 |
+
_, world_size = get_rank_and_world_size()
|
| 283 |
+
try:
|
| 284 |
+
import torch
|
| 285 |
+
device_count = torch.cuda.device_count()
|
| 286 |
+
if device_count > world_size and device_count % world_size == 0:
|
| 287 |
+
return True
|
| 288 |
+
else:
|
| 289 |
+
return False
|
| 290 |
+
except:
|
| 291 |
+
return False
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/vlm.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import string
|
| 6 |
+
from uuid import uuid4
|
| 7 |
+
import os.path as osp
|
| 8 |
+
import base64
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
Image.MAX_IMAGE_PIXELS = 1e9
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def rescale_img(img, tgt=None):
|
| 16 |
+
assert isinstance(tgt, tuple) and -1 in tgt
|
| 17 |
+
w, h = img.size
|
| 18 |
+
if tgt[0] != -1:
|
| 19 |
+
new_w, new_h = tgt[0], int(tgt[0] / w * h)
|
| 20 |
+
elif tgt[1] != -1:
|
| 21 |
+
new_w, new_h = int(tgt[1] / h * w), tgt[1]
|
| 22 |
+
img = img.resize((new_w, new_h))
|
| 23 |
+
return img
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def concat_images_vlmeval(images, target_size=-1, mode='h', return_image=False):
|
| 27 |
+
from .file import md5
|
| 28 |
+
|
| 29 |
+
ims = [Image.open(im) for im in images]
|
| 30 |
+
if target_size != -1:
|
| 31 |
+
ims = [
|
| 32 |
+
rescale_img(im, (-1, target_size) if mode == 'h' else (target_size, -1))
|
| 33 |
+
for im in ims
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
ws, hs = [x.width for x in ims], [x.height for x in ims]
|
| 37 |
+
if mode == 'h':
|
| 38 |
+
new_w, new_h = sum(ws), max(hs)
|
| 39 |
+
dst = Image.new('RGB', (new_w, new_h))
|
| 40 |
+
for i, im in enumerate(ims):
|
| 41 |
+
dst.paste(im, (sum(ws[:i]), 0))
|
| 42 |
+
elif mode == 'v':
|
| 43 |
+
new_w, new_h = max(ws), sum(hs)
|
| 44 |
+
dst = Image.new('RGB', (new_w, new_h))
|
| 45 |
+
for i, im in enumerate(ims):
|
| 46 |
+
dst.paste(im, (sum(ws[:i], 0)))
|
| 47 |
+
if return_image:
|
| 48 |
+
return dst
|
| 49 |
+
else:
|
| 50 |
+
_str = '\n'.join(images)
|
| 51 |
+
str_md5 = md5(_str)
|
| 52 |
+
tgt = osp.join('/tmp', str_md5 + '.jpg')
|
| 53 |
+
dst.save(tgt)
|
| 54 |
+
return tgt
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def mmqa_display(question, target_size=512):
|
| 58 |
+
question = {k.lower(): v for k, v in question.items()}
|
| 59 |
+
keys = list(question.keys())
|
| 60 |
+
keys = [k for k in keys if k not in ['index', 'image']]
|
| 61 |
+
|
| 62 |
+
images = question['image']
|
| 63 |
+
if isinstance(images, str):
|
| 64 |
+
images = [images]
|
| 65 |
+
|
| 66 |
+
idx = question.pop('index', 'XXX')
|
| 67 |
+
print(f'INDEX: {idx}')
|
| 68 |
+
|
| 69 |
+
for im in images:
|
| 70 |
+
image = decode_base64_to_image(im, target_size=target_size)
|
| 71 |
+
display(image) # noqa: F821
|
| 72 |
+
|
| 73 |
+
for k in keys:
|
| 74 |
+
try:
|
| 75 |
+
if not pd.isna(question[k]):
|
| 76 |
+
print(f'{k.upper()}. {question[k]}')
|
| 77 |
+
except ValueError:
|
| 78 |
+
if False in pd.isna(question[k]):
|
| 79 |
+
print(f'{k.upper()}. {question[k]}')
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def encode_image_to_base64(img, target_size=-1, fmt='JPEG'):
|
| 83 |
+
# if target_size == -1, will not do resizing
|
| 84 |
+
# else, will set the max_size ot (target_size, target_size)
|
| 85 |
+
if img.mode in ('RGBA', 'P'):
|
| 86 |
+
img = img.convert('RGB')
|
| 87 |
+
if target_size > 0:
|
| 88 |
+
img.thumbnail((target_size, target_size))
|
| 89 |
+
img_buffer = io.BytesIO()
|
| 90 |
+
img.save(img_buffer, format=fmt)
|
| 91 |
+
image_data = img_buffer.getvalue()
|
| 92 |
+
ret = base64.b64encode(image_data).decode('utf-8')
|
| 93 |
+
return ret
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def encode_image_file_to_base64(image_path, target_size=-1):
|
| 97 |
+
image = Image.open(image_path)
|
| 98 |
+
return encode_image_to_base64(image, target_size=target_size)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def decode_base64_to_image(base64_string, target_size=-1):
|
| 102 |
+
image_data = base64.b64decode(base64_string)
|
| 103 |
+
image = Image.open(io.BytesIO(image_data))
|
| 104 |
+
if image.mode in ('RGBA', 'P'):
|
| 105 |
+
image = image.convert('RGB')
|
| 106 |
+
if target_size > 0:
|
| 107 |
+
image.thumbnail((target_size, target_size))
|
| 108 |
+
return image
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def decode_base64_to_image_file(base64_string, image_path, target_size=-1):
|
| 112 |
+
image = decode_base64_to_image(base64_string, target_size=target_size)
|
| 113 |
+
image.save(image_path)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def build_option_str(option_dict):
|
| 117 |
+
s = 'There are several options: \n'
|
| 118 |
+
for c, content in option_dict.items():
|
| 119 |
+
if not pd.isna(content):
|
| 120 |
+
s += f'{c}. {content}\n'
|
| 121 |
+
return s
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def isimg(s):
|
| 125 |
+
return osp.exists(s) or s.startswith('http')
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def read_ok(img_path):
|
| 129 |
+
if not osp.exists(img_path):
|
| 130 |
+
return False
|
| 131 |
+
try:
|
| 132 |
+
im = Image.open(img_path)
|
| 133 |
+
assert im.size[0] > 0 and im.size[1] > 0
|
| 134 |
+
return True
|
| 135 |
+
except:
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def gpt_key_set():
|
| 140 |
+
openai_key = os.environ.get('OPENAI_API_KEY', None)
|
| 141 |
+
return isinstance(openai_key, str) and openai_key.startswith('sk-')
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def apiok(wrapper):
|
| 145 |
+
s = wrapper.generate('Hello!')
|
| 146 |
+
return wrapper.fail_msg not in s
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def circular_pred(df, extract_func=None):
|
| 150 |
+
if extract_func is None:
|
| 151 |
+
extract_func = lambda x: x # noqa: E731
|
| 152 |
+
df = df.sort_values('index')
|
| 153 |
+
from vlmeval.utils import can_infer_option
|
| 154 |
+
|
| 155 |
+
shift = int(1e6)
|
| 156 |
+
|
| 157 |
+
choices = [extract_func(x) for x in df['prediction']]
|
| 158 |
+
pred_map = {i: c for i, c in zip(df['index'], choices)}
|
| 159 |
+
flag_map = {i: True for i in pred_map if i < 1e6}
|
| 160 |
+
valid_map = {i: True for i in pred_map if i < 1e6}
|
| 161 |
+
for i in df['index']:
|
| 162 |
+
if i >= shift and pred_map[i] and pred_map[i - shift]:
|
| 163 |
+
if pred_map[i] not in list(
|
| 164 |
+
string.ascii_uppercase
|
| 165 |
+
) or pred_map[ # noqa: W504
|
| 166 |
+
i - shift
|
| 167 |
+
] not in list(
|
| 168 |
+
string.ascii_uppercase
|
| 169 |
+
):
|
| 170 |
+
|
| 171 |
+
valid_map[i % shift] = False
|
| 172 |
+
continue
|
| 173 |
+
if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1:
|
| 174 |
+
continue
|
| 175 |
+
else:
|
| 176 |
+
flag_map[i % shift] = False
|
| 177 |
+
flag_map = {k: v for k, v in flag_map.items() if valid_map[k]}
|
| 178 |
+
flags = list(flag_map.values())
|
| 179 |
+
return np.mean(flags)
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .matching_util import can_infer, can_infer_option, can_infer_text
|
| 2 |
+
from .mp_util import track_progress_rich
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
|
| 7 |
+
]
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/matching_util.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import string
|
| 2 |
+
import copy as cp
|
| 3 |
+
import os
|
| 4 |
+
from ..smp import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def can_infer_option(answer, choices):
|
| 8 |
+
verbose = os.environ.get('VERBOSE', 0)
|
| 9 |
+
# Choices is a dictionary
|
| 10 |
+
if 'Failed to obtain answer via API' in answer:
|
| 11 |
+
return False
|
| 12 |
+
|
| 13 |
+
reject_to_answer = [
|
| 14 |
+
"Sorry, I can't help with images of people yet.",
|
| 15 |
+
"I can't process this file.",
|
| 16 |
+
"I'm sorry, but without the image provided",
|
| 17 |
+
'Cannot determine the answer'
|
| 18 |
+
]
|
| 19 |
+
for err in reject_to_answer:
|
| 20 |
+
if err in answer:
|
| 21 |
+
return 'Z'
|
| 22 |
+
|
| 23 |
+
def count_choice(splits, choices, prefix='', suffix=''):
|
| 24 |
+
cnt = 0
|
| 25 |
+
for c in choices:
|
| 26 |
+
if prefix + c + suffix in splits:
|
| 27 |
+
cnt += 1
|
| 28 |
+
return cnt
|
| 29 |
+
|
| 30 |
+
answer_mod = cp.copy(answer)
|
| 31 |
+
chars = '.()[],:;!*#{}'
|
| 32 |
+
for c in chars:
|
| 33 |
+
answer_mod = answer_mod.replace(c, ' ')
|
| 34 |
+
|
| 35 |
+
splits = [x.strip() for x in answer_mod.split()]
|
| 36 |
+
count = count_choice(splits, choices)
|
| 37 |
+
|
| 38 |
+
if count == 1:
|
| 39 |
+
for ch in choices:
|
| 40 |
+
if 'A' in splits and len(splits) > 3 and verbose:
|
| 41 |
+
logger = get_logger('Evaluation')
|
| 42 |
+
logger.info(f'A might be a quantifier in the string: {answer}.')
|
| 43 |
+
return False
|
| 44 |
+
if ch in splits:
|
| 45 |
+
return ch
|
| 46 |
+
elif count == 0 and count_choice(splits, {'Z', ''}) == 1:
|
| 47 |
+
return 'Z'
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def can_infer_text(answer, choices):
|
| 52 |
+
answer = answer.lower()
|
| 53 |
+
assert isinstance(choices, dict)
|
| 54 |
+
for k in choices:
|
| 55 |
+
assert k in string.ascii_uppercase
|
| 56 |
+
choices[k] = str(choices[k]).lower()
|
| 57 |
+
cands = []
|
| 58 |
+
for k in choices:
|
| 59 |
+
if choices[k] in answer:
|
| 60 |
+
cands.append(k)
|
| 61 |
+
if len(cands) == 1:
|
| 62 |
+
return cands[0]
|
| 63 |
+
return False
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def can_infer(answer, choices):
|
| 67 |
+
answer = str(answer)
|
| 68 |
+
copt = can_infer_option(answer, choices)
|
| 69 |
+
return copt if copt else can_infer_text(answer, choices)
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/mp_util.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from multiprocessing import Pool
|
| 2 |
+
import os
|
| 3 |
+
from typing import Callable, Iterable, Sized
|
| 4 |
+
|
| 5 |
+
from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
|
| 6 |
+
TaskProgressColumn, TextColumn, TimeRemainingColumn)
|
| 7 |
+
from rich.text import Text
|
| 8 |
+
import os.path as osp
|
| 9 |
+
import time
|
| 10 |
+
import portalocker
|
| 11 |
+
from ..smp import load, dump
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def track_progress_rich(
|
| 15 |
+
func: Callable,
|
| 16 |
+
tasks: Iterable = tuple(),
|
| 17 |
+
nproc: int = 1,
|
| 18 |
+
save=None,
|
| 19 |
+
keys=None,
|
| 20 |
+
**kwargs) -> list:
|
| 21 |
+
|
| 22 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
if save is not None:
|
| 25 |
+
assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ''
|
| 26 |
+
if not osp.exists(save):
|
| 27 |
+
dump({}, save)
|
| 28 |
+
if keys is not None:
|
| 29 |
+
assert len(keys) == len(tasks)
|
| 30 |
+
if not callable(func):
|
| 31 |
+
raise TypeError('func must be a callable object')
|
| 32 |
+
if not isinstance(tasks, Iterable):
|
| 33 |
+
raise TypeError(
|
| 34 |
+
f'tasks must be an iterable object, but got {type(tasks)}')
|
| 35 |
+
assert nproc > 0, 'nproc must be a positive number'
|
| 36 |
+
res = load(save) if save is not None else {}
|
| 37 |
+
results = [None for _ in range(len(tasks))]
|
| 38 |
+
|
| 39 |
+
with ThreadPoolExecutor(max_workers=nproc) as executor:
|
| 40 |
+
futures = []
|
| 41 |
+
|
| 42 |
+
for inputs in tasks:
|
| 43 |
+
if not isinstance(inputs, (tuple, list, dict)):
|
| 44 |
+
inputs = (inputs, )
|
| 45 |
+
if isinstance(inputs, dict):
|
| 46 |
+
future = executor.submit(func, **inputs)
|
| 47 |
+
else:
|
| 48 |
+
future = executor.submit(func, *inputs)
|
| 49 |
+
futures.append(future)
|
| 50 |
+
|
| 51 |
+
unfinished = set(range(len(tasks)))
|
| 52 |
+
pbar = tqdm(total=len(unfinished))
|
| 53 |
+
while len(unfinished):
|
| 54 |
+
new_finished = set()
|
| 55 |
+
for idx in unfinished:
|
| 56 |
+
if futures[idx].done():
|
| 57 |
+
results[idx] = futures[idx].result()
|
| 58 |
+
new_finished.add(idx)
|
| 59 |
+
if keys is not None:
|
| 60 |
+
res[keys[idx]] = results[idx]
|
| 61 |
+
if len(new_finished):
|
| 62 |
+
if save is not None:
|
| 63 |
+
dump(res, save)
|
| 64 |
+
pbar.update(len(new_finished))
|
| 65 |
+
for k in new_finished:
|
| 66 |
+
unfinished.remove(k)
|
| 67 |
+
time.sleep(0.1)
|
| 68 |
+
pbar.close()
|
| 69 |
+
|
| 70 |
+
if save is not None:
|
| 71 |
+
dump(res, save)
|
| 72 |
+
return results
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/result_transfer.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..smp import *
|
| 2 |
+
from ..dataset.utils.judge_util import build_judge
|
| 3 |
+
from ..dataset.utils.multiple_choice import extract_answer_from_item
|
| 4 |
+
from .matching_util import can_infer
|
| 5 |
+
from .mp_util import track_progress_rich
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def MMMU_result_transfer(result_path):
|
| 9 |
+
res = {}
|
| 10 |
+
result_data = load(result_path)
|
| 11 |
+
mcq = result_data['A'].notna()
|
| 12 |
+
lt = len(result_data)
|
| 13 |
+
for i in range(lt):
|
| 14 |
+
line = result_data.iloc[i]
|
| 15 |
+
if mcq[i]:
|
| 16 |
+
options = {
|
| 17 |
+
cand: line[cand]
|
| 18 |
+
for cand in string.ascii_uppercase
|
| 19 |
+
if cand in line and not pd.isna(line[cand])
|
| 20 |
+
}
|
| 21 |
+
prediction = line['prediction']
|
| 22 |
+
infer_prediction = can_infer(prediction, options)
|
| 23 |
+
res[line['id']] = infer_prediction
|
| 24 |
+
else:
|
| 25 |
+
res[line['id']] = line['prediction']
|
| 26 |
+
result_json = result_path.replace('.xlsx', '.json')
|
| 27 |
+
dump(res, result_json)
|
| 28 |
+
return result_json
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def MMTBench_result_transfer(eval_file, dataset='default', **judge_kwargs):
|
| 32 |
+
logger = get_logger('Evaluation')
|
| 33 |
+
nproc = judge_kwargs.pop('nproc', 4)
|
| 34 |
+
|
| 35 |
+
rd.seed(2680)
|
| 36 |
+
suffix = eval_file.split('.')[-1]
|
| 37 |
+
model = judge_kwargs['model']
|
| 38 |
+
assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
|
| 39 |
+
name_str_map = {
|
| 40 |
+
'chatgpt-0125': 'openai',
|
| 41 |
+
'gpt-4-0125': 'gpt4'
|
| 42 |
+
}
|
| 43 |
+
name_str = name_str_map[model] if model in name_str_map else model
|
| 44 |
+
|
| 45 |
+
if model == 'exact_matching':
|
| 46 |
+
model = None
|
| 47 |
+
elif gpt_key_set():
|
| 48 |
+
model = build_judge(**judge_kwargs)
|
| 49 |
+
if not model.working():
|
| 50 |
+
logger.error('The OPENAI API is not working properly, will use exact matching for evaluation')
|
| 51 |
+
model = None
|
| 52 |
+
else:
|
| 53 |
+
logger.error('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
|
| 54 |
+
model = None
|
| 55 |
+
|
| 56 |
+
logger.info(f'Evaluating {eval_file}')
|
| 57 |
+
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_option.pkl')
|
| 58 |
+
result = {}
|
| 59 |
+
if osp.exists(result_file):
|
| 60 |
+
result = load(result_file)
|
| 61 |
+
|
| 62 |
+
data = load(eval_file)
|
| 63 |
+
assert 'index' in data, 'Essentail columns missing in the eval_file.'
|
| 64 |
+
|
| 65 |
+
data = data.sort_values(by='index')
|
| 66 |
+
data['prediction'] = [str(x) for x in data['prediction']]
|
| 67 |
+
for k in data.keys():
|
| 68 |
+
data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
|
| 69 |
+
|
| 70 |
+
idx2lines = {data.iloc[i]['index']: data.iloc[i] for i in range(len(data))}
|
| 71 |
+
idx2lines = {k: v for k, v in idx2lines.items() if k not in result}
|
| 72 |
+
|
| 73 |
+
indices = list(idx2lines.keys())
|
| 74 |
+
lines = [idx2lines[i] for i in indices]
|
| 75 |
+
tups = [(model, line) for line in lines]
|
| 76 |
+
res = track_progress_rich(
|
| 77 |
+
extract_answer_from_item,
|
| 78 |
+
tups,
|
| 79 |
+
nproc=nproc,
|
| 80 |
+
chunksize=nproc,
|
| 81 |
+
save=result_file,
|
| 82 |
+
keys=indices)
|
| 83 |
+
|
| 84 |
+
for i, r in zip(indices, res):
|
| 85 |
+
if i in result:
|
| 86 |
+
assert result[i]['opt'] == r['opt'] and result[i]['log'] == r['log']
|
| 87 |
+
else:
|
| 88 |
+
result[i] = r
|
| 89 |
+
|
| 90 |
+
indices = list(data['index'])
|
| 91 |
+
data['opt'] = [result[i]['opt'] for i in data['index']]
|
| 92 |
+
data['log'] = [result[i]['log'] for i in data['index']]
|
| 93 |
+
|
| 94 |
+
# load split
|
| 95 |
+
output_path = eval_file.replace(f'.{suffix}', f'_{name_str}_submission.tsv')
|
| 96 |
+
dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_submission.tsv'))
|
| 97 |
+
return output_path
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
torch.set_grad_enabled(False)
|
| 4 |
+
torch.manual_seed(1234)
|
| 5 |
+
from .base import BaseModel
|
| 6 |
+
from .minicpm_v import MiniCPM_V, MiniCPM_Llama3_V, MiniCPM_V_2_6, MiniCPM_o_2_6
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/base.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..smp import *
|
| 2 |
+
from ..dataset import img_root_map, DATASET_TYPE
|
| 3 |
+
from abc import abstractmethod
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BaseModel:
|
| 7 |
+
|
| 8 |
+
INTERLEAVE = False
|
| 9 |
+
allowed_types = ['text', 'image', 'video']
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.dump_image_func = None
|
| 13 |
+
|
| 14 |
+
def use_custom_prompt(self, dataset):
|
| 15 |
+
"""Whether to use custom prompt for the given dataset.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
dataset (str): The name of the dataset.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
bool: Whether to use custom prompt. If True, will call `build_prompt` of the VLM to build the prompt.
|
| 22 |
+
Default to False.
|
| 23 |
+
"""
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def build_prompt(self, line, dataset):
|
| 28 |
+
"""Build custom prompts for a specific dataset. Called only if `use_custom_prompt` returns True.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
line (line of pd.DataFrame): The raw input line.
|
| 32 |
+
dataset (str): The name of the dataset.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
str: The built message.
|
| 36 |
+
"""
|
| 37 |
+
raise NotImplementedError
|
| 38 |
+
|
| 39 |
+
def set_dump_image(self, dump_image_func):
|
| 40 |
+
self.dump_image_func = dump_image_func
|
| 41 |
+
|
| 42 |
+
def dump_image(self, line, dataset):
|
| 43 |
+
return self.dump_image_func(line)
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def generate_inner(self, message, dataset=None):
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
|
| 49 |
+
def check_content(self, msgs):
|
| 50 |
+
"""Check the content type of the input. Four types are allowed: str, dict, liststr, listdict.
|
| 51 |
+
"""
|
| 52 |
+
if isinstance(msgs, str):
|
| 53 |
+
return 'str'
|
| 54 |
+
if isinstance(msgs, dict):
|
| 55 |
+
return 'dict'
|
| 56 |
+
if isinstance(msgs, list):
|
| 57 |
+
types = [self.check_content(m) for m in msgs]
|
| 58 |
+
if all(t == 'str' for t in types):
|
| 59 |
+
return 'liststr'
|
| 60 |
+
if all(t == 'dict' for t in types):
|
| 61 |
+
return 'listdict'
|
| 62 |
+
return 'unknown'
|
| 63 |
+
|
| 64 |
+
def preproc_content(self, inputs):
|
| 65 |
+
"""Convert the raw input messages to a list of dicts.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
inputs: raw input messages.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
list(dict): The preprocessed input messages. Will return None if failed to preprocess the input.
|
| 72 |
+
"""
|
| 73 |
+
if self.check_content(inputs) == 'str':
|
| 74 |
+
return [dict(type='text', value=inputs)]
|
| 75 |
+
elif self.check_content(inputs) == 'dict':
|
| 76 |
+
assert 'type' in inputs and 'value' in inputs
|
| 77 |
+
return [inputs]
|
| 78 |
+
elif self.check_content(inputs) == 'liststr':
|
| 79 |
+
res = []
|
| 80 |
+
for s in inputs:
|
| 81 |
+
mime, pth = parse_file(s)
|
| 82 |
+
if mime is None or mime == 'unknown':
|
| 83 |
+
res.append(dict(type='text', value=s))
|
| 84 |
+
else:
|
| 85 |
+
res.append(dict(type=mime.split('/')[0], value=pth))
|
| 86 |
+
return res
|
| 87 |
+
elif self.check_content(inputs) == 'listdict':
|
| 88 |
+
for item in inputs:
|
| 89 |
+
assert 'type' in item and 'value' in item
|
| 90 |
+
mime, s = parse_file(item['value'])
|
| 91 |
+
if mime is None:
|
| 92 |
+
assert item['type'] == 'text'
|
| 93 |
+
else:
|
| 94 |
+
assert mime.split('/')[0] == item['type']
|
| 95 |
+
item['value'] = s
|
| 96 |
+
return inputs
|
| 97 |
+
else:
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
def generate(self, message, dataset=None):
|
| 101 |
+
"""Generate the output message.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
message (list[dict]): The input message.
|
| 105 |
+
dataset (str, optional): The name of the dataset. Defaults to None.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
str: The generated message.
|
| 109 |
+
"""
|
| 110 |
+
assert self.check_content(message) in ['str', 'dict', 'liststr', 'listdict'], f'Invalid input type: {message}'
|
| 111 |
+
message = self.preproc_content(message)
|
| 112 |
+
assert message is not None and self.check_content(message) == 'listdict'
|
| 113 |
+
for item in message:
|
| 114 |
+
assert item['type'] in self.allowed_types, f'Invalid input type: {item["type"]}'
|
| 115 |
+
return self.generate_inner(message, dataset)
|
| 116 |
+
|
| 117 |
+
def chat(self, messages, dataset=None):
|
| 118 |
+
"""The main function for multi-turn chatting. Will call `chat_inner` with the preprocessed input messages."""
|
| 119 |
+
assert hasattr(self, 'chat_inner'), 'The API model should has the `chat_inner` method. '
|
| 120 |
+
for msg in messages:
|
| 121 |
+
assert isinstance(msg, dict) and 'role' in msg and 'content' in msg, msg
|
| 122 |
+
assert self.check_content(msg['content']) in ['str', 'dict', 'liststr', 'listdict'], msg
|
| 123 |
+
msg['content'] = self.preproc_content(msg['content'])
|
| 124 |
+
|
| 125 |
+
while len(messages):
|
| 126 |
+
try:
|
| 127 |
+
return self.chat_inner(messages, dataset=dataset)
|
| 128 |
+
except Exception as e:
|
| 129 |
+
logging.info(f'{type(e)}: {e}')
|
| 130 |
+
messages = messages[1:]
|
| 131 |
+
while len(messages) and messages[0]['role'] != 'user':
|
| 132 |
+
messages = messages[1:]
|
| 133 |
+
continue
|
| 134 |
+
return 'Chat Mode: Failed with all possible conversation turns.'
|
| 135 |
+
|
| 136 |
+
def message_to_promptimg(self, message, dataset=None):
|
| 137 |
+
assert not self.INTERLEAVE
|
| 138 |
+
model_name = self.__class__.__name__
|
| 139 |
+
warnings.warn(
|
| 140 |
+
f'Model {model_name} does not support interleaved input. '
|
| 141 |
+
'Will use the first image and aggregated texts as prompt. ')
|
| 142 |
+
num_images = len([x for x in message if x['type'] == 'image'])
|
| 143 |
+
if num_images == 0:
|
| 144 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
| 145 |
+
image = None
|
| 146 |
+
else:
|
| 147 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
| 148 |
+
images = [x['value'] for x in message if x['type'] == 'image']
|
| 149 |
+
if 'BLINK' == dataset:
|
| 150 |
+
image = concat_images_vlmeval(images, target_size=512)
|
| 151 |
+
else:
|
| 152 |
+
image = images[0]
|
| 153 |
+
return prompt, image
|
| 154 |
+
|
| 155 |
+
def message_to_promptvideo(self, message):
|
| 156 |
+
if self.VIDEO_LLM:
|
| 157 |
+
num_videos = len([x for x in message if x['type'] == 'video'])
|
| 158 |
+
if num_videos == 0:
|
| 159 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
| 160 |
+
video = None
|
| 161 |
+
else:
|
| 162 |
+
prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
|
| 163 |
+
video = [x['value'] for x in message if x['type'] == 'video'][0]
|
| 164 |
+
return prompt, video
|
| 165 |
+
else:
|
| 166 |
+
logging.critical('Model does not support video input.')
|
| 167 |
+
raise NotImplementedError
|
| 168 |
+
|
| 169 |
+
def message_to_promptvideo_withrole(self, message, dataset=None):
|
| 170 |
+
if self.VIDEO_LLM:
|
| 171 |
+
system, user, assistant, video_list = '', '', '', []
|
| 172 |
+
for msg in message:
|
| 173 |
+
if msg['type'] == 'text':
|
| 174 |
+
if 'role' in msg and msg['role'] == 'system':
|
| 175 |
+
system += msg['value']
|
| 176 |
+
elif 'role' in msg and msg['role'] == 'assistant':
|
| 177 |
+
assistant += msg['value']
|
| 178 |
+
else:
|
| 179 |
+
user += msg['value']
|
| 180 |
+
elif msg['type'] == 'video':
|
| 181 |
+
video_list.append(msg['value'])
|
| 182 |
+
question = {
|
| 183 |
+
'system': system,
|
| 184 |
+
'user': user,
|
| 185 |
+
'assistant': assistant
|
| 186 |
+
}
|
| 187 |
+
if assistant == '':
|
| 188 |
+
if listinstr(['MCQ'], DATASET_TYPE(dataset)):
|
| 189 |
+
question['assistant'] = 'Best Option: ('
|
| 190 |
+
else:
|
| 191 |
+
del question['assistant']
|
| 192 |
+
if len(video_list) > 1:
|
| 193 |
+
print('VLMEvalKit only support single video as input, take first video as input')
|
| 194 |
+
video = video_list[0]
|
| 195 |
+
return question, video
|
| 196 |
+
else:
|
| 197 |
+
logging.critical('Model does not support video input.')
|
| 198 |
+
raise NotImplementedError
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/minicpm_v.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from transformers import AutoModel, AutoTokenizer
|
| 7 |
+
|
| 8 |
+
from .base import BaseModel
|
| 9 |
+
from ..smp import *
|
| 10 |
+
from ..dataset import DATASET_TYPE, DATASET_MODALITY
|
| 11 |
+
|
| 12 |
+
import re
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MiniCPM_V(BaseModel):
|
| 16 |
+
|
| 17 |
+
INSTALL_REQ = False
|
| 18 |
+
INTERLEAVE = False
|
| 19 |
+
|
| 20 |
+
def __init__(self, model_path='openbmb/MiniCPM-V', **kwargs):
|
| 21 |
+
assert model_path is not None
|
| 22 |
+
self.model_path = model_path
|
| 23 |
+
print(f'load from {self.model_path}')
|
| 24 |
+
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
|
| 25 |
+
self.model = self.model.to(dtype=torch.bfloat16)
|
| 26 |
+
self.model.eval().cuda()
|
| 27 |
+
self.kwargs = kwargs
|
| 28 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
| 29 |
+
torch.cuda.empty_cache()
|
| 30 |
+
self.num_beams = 3
|
| 31 |
+
|
| 32 |
+
def use_custom_prompt(self, dataset):
|
| 33 |
+
assert dataset is not None
|
| 34 |
+
if listinstr(['MMDU', 'MME-RealWorld', 'MME-RealWorld-CN'], dataset):
|
| 35 |
+
# For Multi-Turn we don't have custom prompt
|
| 36 |
+
return False
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
def build_prompt(self, line, dataset=None):
|
| 40 |
+
assert dataset is None or isinstance(dataset, str)
|
| 41 |
+
assert self.use_custom_prompt(dataset)
|
| 42 |
+
tgt_path = self.dump_image(line, dataset)
|
| 43 |
+
|
| 44 |
+
question = line['question']
|
| 45 |
+
options = {
|
| 46 |
+
cand: line[cand]
|
| 47 |
+
for cand in string.ascii_uppercase
|
| 48 |
+
if cand in line and not pd.isna(line[cand])
|
| 49 |
+
}
|
| 50 |
+
options_prompt = 'Options:\n'
|
| 51 |
+
for key, item in options.items():
|
| 52 |
+
options_prompt += f'{key}. {item}\n'
|
| 53 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
| 54 |
+
prompt = ''
|
| 55 |
+
if hint is not None:
|
| 56 |
+
prompt += f'Hint: {hint}\n'
|
| 57 |
+
prompt += f'{question}\n'
|
| 58 |
+
if len(options):
|
| 59 |
+
prompt += options_prompt
|
| 60 |
+
prompt = 'Study the image carefully and pick the option associated with the correct answer. \
|
| 61 |
+
Focus solely on selecting the option and avoid including any other content.\n' + prompt
|
| 62 |
+
message = [dict(type='text', value=prompt)]
|
| 63 |
+
message.extend([dict(type='image', value=p) for p in tgt_path])
|
| 64 |
+
|
| 65 |
+
return message
|
| 66 |
+
|
| 67 |
+
def generate_inner(self, message, dataset=None):
|
| 68 |
+
prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
|
| 69 |
+
image = Image.open(image_path).convert('RGB')
|
| 70 |
+
msgs = [{'role': 'user', 'content': prompt}]
|
| 71 |
+
if DATASET_TYPE(dataset) == 'MCQ':
|
| 72 |
+
max_new_tokens = 20
|
| 73 |
+
elif DATASET_TYPE(dataset) == 'Y/N':
|
| 74 |
+
max_new_tokens = 100
|
| 75 |
+
else:
|
| 76 |
+
max_new_tokens = 1024
|
| 77 |
+
|
| 78 |
+
default_kwargs = dict(
|
| 79 |
+
max_new_tokens=max_new_tokens,
|
| 80 |
+
sampling=False,
|
| 81 |
+
num_beams=self.num_beams
|
| 82 |
+
)
|
| 83 |
+
default_kwargs.update(self.kwargs)
|
| 84 |
+
res, _, _ = self.model.chat(
|
| 85 |
+
image=image,
|
| 86 |
+
msgs=msgs,
|
| 87 |
+
context=None,
|
| 88 |
+
tokenizer=self.tokenizer,
|
| 89 |
+
**default_kwargs
|
| 90 |
+
)
|
| 91 |
+
return res
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class MiniCPM_Llama3_V(BaseModel):
|
| 95 |
+
|
| 96 |
+
INSTALL_REQ = False
|
| 97 |
+
INTERLEAVE = True
|
| 98 |
+
|
| 99 |
+
def __init__(self, model_path='openbmb/MiniCPM-Llama3-V-2_5', **kwargs):
|
| 100 |
+
assert model_path is not None
|
| 101 |
+
self.model_path = model_path
|
| 102 |
+
print(f'load from {self.model_path}')
|
| 103 |
+
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
|
| 104 |
+
self.model = self.model.to(dtype=torch.float16)
|
| 105 |
+
self.model.eval().cuda()
|
| 106 |
+
self.kwargs = kwargs
|
| 107 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
| 108 |
+
torch.cuda.empty_cache()
|
| 109 |
+
self.num_beams = 3
|
| 110 |
+
self.options_system_prompt = ('Carefully read the following question and select the letter corresponding '
|
| 111 |
+
'to the correct answer. Highlight the applicable choices without giving '
|
| 112 |
+
'explanations.')
|
| 113 |
+
self.wo_options_system_prompt = 'Carefully read the following question Answer the question directly.'
|
| 114 |
+
self.detail_system_prompt = 'Answer this question in detail.'
|
| 115 |
+
self.vqa_prompt = 'Answer the question using a single word or phrase.'
|
| 116 |
+
|
| 117 |
+
def use_custom_prompt(self, dataset):
|
| 118 |
+
if listinstr(['MCQ', 'VQA'], DATASET_TYPE(dataset)):
|
| 119 |
+
return True
|
| 120 |
+
elif dataset is not None and listinstr(['HallusionBench'], dataset):
|
| 121 |
+
return True
|
| 122 |
+
return False
|
| 123 |
+
|
| 124 |
+
def build_prompt(self, line, dataset=None):
|
| 125 |
+
if isinstance(line, int):
|
| 126 |
+
line = self.data.iloc[line]
|
| 127 |
+
|
| 128 |
+
tgt_path = self.dump_image(line, dataset)
|
| 129 |
+
system_prompt = ''
|
| 130 |
+
|
| 131 |
+
question = line['question']
|
| 132 |
+
if DATASET_TYPE(dataset) == 'MCQ':
|
| 133 |
+
options = {
|
| 134 |
+
cand: line[cand]
|
| 135 |
+
for cand in string.ascii_uppercase
|
| 136 |
+
if cand in line and not pd.isna(line[cand])
|
| 137 |
+
}
|
| 138 |
+
options_prompt = 'Options:\n'
|
| 139 |
+
for key, item in options.items():
|
| 140 |
+
options_prompt += f'{key}. {item}\n'
|
| 141 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
| 142 |
+
prompt = ''
|
| 143 |
+
if hint is not None:
|
| 144 |
+
prompt += f'Hint: {hint}\n'
|
| 145 |
+
prompt += f'Question: {question}\n'
|
| 146 |
+
if len(options):
|
| 147 |
+
prompt += options_prompt
|
| 148 |
+
system_prompt = self.options_system_prompt + '\nPlease just indicate your choice.'
|
| 149 |
+
else:
|
| 150 |
+
system_prompt = self.wo_options_system_prompt
|
| 151 |
+
if 'MMMU' in dataset: # Corner Case
|
| 152 |
+
prompt = system_prompt + '\n' + prompt
|
| 153 |
+
system_prompt = ''
|
| 154 |
+
elif dataset is not None and listinstr(['HallusionBench'], dataset):
|
| 155 |
+
question = line['question'] + ' Yes or No?'
|
| 156 |
+
prompt = question
|
| 157 |
+
elif dataset is not None and listinstr(['MME'], dataset):
|
| 158 |
+
question = line['question'] + ' Yes or No?'
|
| 159 |
+
prompt = question
|
| 160 |
+
elif dataset is not None and listinstr(['OCRBench'], dataset):
|
| 161 |
+
system_prompt = self.vqa_prompt
|
| 162 |
+
question = line['question']
|
| 163 |
+
prompt = question
|
| 164 |
+
elif DATASET_TYPE(dataset) == 'VQA':
|
| 165 |
+
if listinstr(['LLaVABench', 'MMLongBench_DOC'], dataset):
|
| 166 |
+
system_prompt = ''
|
| 167 |
+
prompt = question
|
| 168 |
+
elif listinstr(['MMVet'], dataset):
|
| 169 |
+
system_prompt = self.detail_system_prompt
|
| 170 |
+
prompt = question
|
| 171 |
+
else:
|
| 172 |
+
system_prompt = self.vqa_prompt
|
| 173 |
+
prompt = question
|
| 174 |
+
|
| 175 |
+
msgs = []
|
| 176 |
+
if system_prompt:
|
| 177 |
+
msgs.append(dict(type='text', value=system_prompt))
|
| 178 |
+
if isinstance(tgt_path, list):
|
| 179 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
| 180 |
+
else:
|
| 181 |
+
msgs = [dict(type='image', value=tgt_path)]
|
| 182 |
+
msgs.append(dict(type='text', value=prompt))
|
| 183 |
+
return msgs
|
| 184 |
+
|
| 185 |
+
def generate_inner(self, message, dataset=None):
|
| 186 |
+
if DATASET_TYPE(dataset) == 'MCQ':
|
| 187 |
+
max_new_tokens = 200
|
| 188 |
+
elif DATASET_TYPE(dataset) == 'Y/N':
|
| 189 |
+
max_new_tokens = 3
|
| 190 |
+
else:
|
| 191 |
+
max_new_tokens = 1024
|
| 192 |
+
|
| 193 |
+
default_kwargs = dict(
|
| 194 |
+
max_new_tokens=max_new_tokens,
|
| 195 |
+
sampling=False,
|
| 196 |
+
num_beams=self.num_beams,
|
| 197 |
+
)
|
| 198 |
+
default_kwargs.update(self.kwargs)
|
| 199 |
+
|
| 200 |
+
content = []
|
| 201 |
+
for x in message:
|
| 202 |
+
if x['type'] == 'text':
|
| 203 |
+
content.append(x['value'])
|
| 204 |
+
elif x['type'] == 'image':
|
| 205 |
+
image = Image.open(x['value']).convert('RGB')
|
| 206 |
+
content.append(image)
|
| 207 |
+
msgs = [{'role': 'user', 'content': content}]
|
| 208 |
+
|
| 209 |
+
res = self.model.chat(
|
| 210 |
+
msgs=msgs,
|
| 211 |
+
context=None,
|
| 212 |
+
image=None,
|
| 213 |
+
tokenizer=self.tokenizer,
|
| 214 |
+
**default_kwargs
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if isinstance(res, tuple) and len(res) > 0:
|
| 218 |
+
res = res[0]
|
| 219 |
+
return res
|
| 220 |
+
|
| 221 |
+
def chat_inner(self, message, dataset=None):
|
| 222 |
+
max_new_tokens = 1024
|
| 223 |
+
|
| 224 |
+
default_kwargs = dict(
|
| 225 |
+
max_new_tokens=max_new_tokens,
|
| 226 |
+
sampling=False,
|
| 227 |
+
num_beams=self.num_beams,
|
| 228 |
+
)
|
| 229 |
+
default_kwargs.update(self.kwargs)
|
| 230 |
+
|
| 231 |
+
msgs = []
|
| 232 |
+
for msg in message:
|
| 233 |
+
content = []
|
| 234 |
+
if len(msg['content']) == 1 and msg['content'][0]['type'] == 'text':
|
| 235 |
+
msg_new = {'role': msg['role'], 'content': msg['content'][0]['value']}
|
| 236 |
+
msgs.append(msg_new)
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
for x in msg['content']:
|
| 240 |
+
if x['type'] == 'text':
|
| 241 |
+
content.append(x['value'])
|
| 242 |
+
elif x['type'] == 'image':
|
| 243 |
+
image = Image.open(x['value']).convert('RGB')
|
| 244 |
+
content.append(image)
|
| 245 |
+
msg_new = {'role': msg['role'], 'content': content}
|
| 246 |
+
msgs.append(msg_new)
|
| 247 |
+
|
| 248 |
+
res = self.model.chat(
|
| 249 |
+
msgs=msgs,
|
| 250 |
+
context=None,
|
| 251 |
+
image=None,
|
| 252 |
+
tokenizer=self.tokenizer,
|
| 253 |
+
**default_kwargs)
|
| 254 |
+
|
| 255 |
+
if isinstance(res, tuple) and len(res) > 0:
|
| 256 |
+
res = res[0]
|
| 257 |
+
return res
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class MiniCPM_V_2_6(BaseModel):
|
| 261 |
+
INSTALL_REQ = False
|
| 262 |
+
INTERLEAVE = True
|
| 263 |
+
|
| 264 |
+
def __init__(self, model_path='openbmb/MiniCPM-V-2_6', **kwargs):
|
| 265 |
+
random.seed(0)
|
| 266 |
+
np.random.seed(0)
|
| 267 |
+
torch.manual_seed(0)
|
| 268 |
+
torch.cuda.manual_seed_all(0)
|
| 269 |
+
|
| 270 |
+
assert model_path is not None
|
| 271 |
+
self.model_path = model_path
|
| 272 |
+
print(f'load from path {self.model_path}')
|
| 273 |
+
self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
|
| 274 |
+
self.model = self.model.to(dtype=torch.bfloat16)
|
| 275 |
+
self.model.eval().cuda()
|
| 276 |
+
|
| 277 |
+
self.kwargs = kwargs
|
| 278 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
| 279 |
+
torch.cuda.empty_cache()
|
| 280 |
+
self.num_beams = 3
|
| 281 |
+
|
| 282 |
+
self.options_suffix_prompt = '''\nAnswer with the option's letter from the given choices directly.'''
|
| 283 |
+
self.wo_options_system_prompt = 'Carefully read the following question Answer the question directly.'
|
| 284 |
+
self.detail_system_prompt = 'Answer this question in detail.'
|
| 285 |
+
self.vqa_prompt = 'Answer the question using a single word or phrase.'
|
| 286 |
+
|
| 287 |
+
self.multi_choice_cot_prompt = ('''Carefully read the following multichoice question, solve it step '''
|
| 288 |
+
'''by step and finally pick the option associated with the correct '''
|
| 289 |
+
'''answer in the format of "Answer: selected option\n\n''')
|
| 290 |
+
self.short_ans_cot_prompt = ('''Read the following question carefully, solve it step by step, and '''
|
| 291 |
+
'''then output the final answer in the format of "Answer: single number '''
|
| 292 |
+
'''or single word or phrase".\n\n''')
|
| 293 |
+
|
| 294 |
+
def use_custom_prompt(self, dataset=None):
|
| 295 |
+
if dataset is None:
|
| 296 |
+
return False
|
| 297 |
+
if DATASET_TYPE(dataset) in ['MCQ', 'VQA', 'Y/N']:
|
| 298 |
+
return True
|
| 299 |
+
return False
|
| 300 |
+
|
| 301 |
+
def use_cot(self, dataset=None):
|
| 302 |
+
if dataset is None:
|
| 303 |
+
return False
|
| 304 |
+
if listinstr(['MMMU', 'HallusionBench', 'OCRBench', 'ChartQA'], dataset):
|
| 305 |
+
return True
|
| 306 |
+
elif listinstr(['MathVista', 'MMVet', 'MMBench', 'MMStar', 'AI2D', 'RealWorldQA',
|
| 307 |
+
'POPE', 'ScienceQA', 'TextVQA', 'DocVQA'], dataset):
|
| 308 |
+
return False
|
| 309 |
+
else:
|
| 310 |
+
return False
|
| 311 |
+
|
| 312 |
+
def use_upsize(self, dataset=None):
|
| 313 |
+
if dataset is None:
|
| 314 |
+
return False
|
| 315 |
+
if listinstr(['MMVet', 'MMBench', 'MMStar', 'AI2D', 'OCRBench'], dataset):
|
| 316 |
+
return True
|
| 317 |
+
else:
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
def build_prompt(self, line, dataset=None):
|
| 321 |
+
if isinstance(line, int):
|
| 322 |
+
line = self.data.iloc[line]
|
| 323 |
+
|
| 324 |
+
tgt_path = self.dump_image(line, dataset)
|
| 325 |
+
system_prompt, prompt = '', ''
|
| 326 |
+
|
| 327 |
+
question = line['question']
|
| 328 |
+
|
| 329 |
+
if not self.use_cot(dataset):
|
| 330 |
+
if DATASET_TYPE(dataset) == 'MCQ':
|
| 331 |
+
options = {
|
| 332 |
+
cand: line[cand]
|
| 333 |
+
for cand in string.ascii_uppercase
|
| 334 |
+
if cand in line and not pd.isna(line[cand])
|
| 335 |
+
}
|
| 336 |
+
options_prompt = 'Options:\n'
|
| 337 |
+
for key, item in options.items():
|
| 338 |
+
options_prompt += f'{key}. {item}\n'
|
| 339 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
| 340 |
+
if hint is not None:
|
| 341 |
+
prompt += f'Hint: {hint}\n'
|
| 342 |
+
prompt += f'Question: {question}\n'
|
| 343 |
+
if len(options):
|
| 344 |
+
prompt += options_prompt
|
| 345 |
+
prompt += self.options_suffix_prompt
|
| 346 |
+
else:
|
| 347 |
+
system_prompt = self.wo_options_system_prompt
|
| 348 |
+
|
| 349 |
+
if 'MMMU' in dataset:
|
| 350 |
+
if len(system_prompt) > 0:
|
| 351 |
+
prompt = system_prompt + '\n' + prompt
|
| 352 |
+
system_prompt = ''
|
| 353 |
+
elif dataset is not None and listinstr(['HallusionBench'], dataset):
|
| 354 |
+
question += ' Yes or No?'
|
| 355 |
+
prompt = question
|
| 356 |
+
elif dataset is not None and listinstr(['OCRBench'], dataset):
|
| 357 |
+
system_prompt = self.vqa_prompt
|
| 358 |
+
prompt = question
|
| 359 |
+
elif DATASET_TYPE(dataset) == 'VQA':
|
| 360 |
+
if listinstr(['LLaVABench'], dataset):
|
| 361 |
+
system_prompt = ''
|
| 362 |
+
elif listinstr(['MMVet'], dataset):
|
| 363 |
+
system_prompt = self.detail_system_prompt
|
| 364 |
+
else:
|
| 365 |
+
system_prompt = self.vqa_prompt
|
| 366 |
+
prompt = question
|
| 367 |
+
else:
|
| 368 |
+
prompt = question
|
| 369 |
+
else:
|
| 370 |
+
has_options = True
|
| 371 |
+
if DATASET_TYPE(dataset) == 'MCQ':
|
| 372 |
+
options = {
|
| 373 |
+
cand: line[cand]
|
| 374 |
+
for cand in string.ascii_uppercase
|
| 375 |
+
if cand in line and not pd.isna(line[cand])
|
| 376 |
+
}
|
| 377 |
+
options_prompt = ''
|
| 378 |
+
for key, item in options.items():
|
| 379 |
+
options_prompt += f'{key}. {item}\n'
|
| 380 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
| 381 |
+
if hint is not None:
|
| 382 |
+
prompt += f'Hint: {hint}\n'
|
| 383 |
+
prompt += f'{question}\n'
|
| 384 |
+
|
| 385 |
+
if len(options):
|
| 386 |
+
prompt += options_prompt
|
| 387 |
+
else:
|
| 388 |
+
has_options = False
|
| 389 |
+
|
| 390 |
+
if 'MMMU' in dataset:
|
| 391 |
+
if len(system_prompt) > 0:
|
| 392 |
+
prompt = system_prompt + '\n' + prompt
|
| 393 |
+
system_prompt = ''
|
| 394 |
+
else:
|
| 395 |
+
prompt = question
|
| 396 |
+
|
| 397 |
+
if DATASET_TYPE(dataset) in ['MCQ', 'Y/N', 'VQA']:
|
| 398 |
+
if DATASET_TYPE(dataset) == 'MCQ':
|
| 399 |
+
if has_options:
|
| 400 |
+
prompt = self.multi_choice_cot_prompt + prompt
|
| 401 |
+
else:
|
| 402 |
+
prompt = self.short_ans_cot_prompt + prompt
|
| 403 |
+
elif DATASET_TYPE(dataset) == 'Y/N':
|
| 404 |
+
prompt = self.short_ans_cot_prompt + prompt
|
| 405 |
+
else:
|
| 406 |
+
prompt = self.short_ans_cot_prompt + prompt
|
| 407 |
+
|
| 408 |
+
msgs = []
|
| 409 |
+
if system_prompt:
|
| 410 |
+
msgs.append(dict(type='text', value=system_prompt))
|
| 411 |
+
if isinstance(tgt_path, list):
|
| 412 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
| 413 |
+
else:
|
| 414 |
+
msgs = [dict(type='image', value=tgt_path)]
|
| 415 |
+
msgs.append(dict(type='text', value=prompt))
|
| 416 |
+
|
| 417 |
+
return msgs
|
| 418 |
+
|
| 419 |
+
def generate_inner(self, message, dataset=None):
|
| 420 |
+
if DATASET_MODALITY(dataset) == 'VIDEO':
|
| 421 |
+
max_slice_nums = 1
|
| 422 |
+
use_image_id = False
|
| 423 |
+
max_inp_length = 2048 * 10
|
| 424 |
+
else:
|
| 425 |
+
max_slice_nums = None
|
| 426 |
+
use_image_id = True
|
| 427 |
+
max_inp_length = 8192
|
| 428 |
+
|
| 429 |
+
max_new_tokens = 2048
|
| 430 |
+
default_kwargs = dict(
|
| 431 |
+
max_new_tokens=max_new_tokens,
|
| 432 |
+
sampling=False,
|
| 433 |
+
num_beams=self.num_beams,
|
| 434 |
+
)
|
| 435 |
+
default_kwargs.update(self.kwargs)
|
| 436 |
+
|
| 437 |
+
content = []
|
| 438 |
+
|
| 439 |
+
for x in message:
|
| 440 |
+
if x['type'] == 'text':
|
| 441 |
+
content.append(x['value'])
|
| 442 |
+
elif x['type'] == 'image':
|
| 443 |
+
image = Image.open(x['value']).convert('RGB')
|
| 444 |
+
if not self.use_upsize(dataset):
|
| 445 |
+
content.append(image)
|
| 446 |
+
else:
|
| 447 |
+
img_width, img_height = image.width, image.height
|
| 448 |
+
if (img_width * img_height) >= (1344 * 1344):
|
| 449 |
+
content.append(image)
|
| 450 |
+
else:
|
| 451 |
+
ratio = math.sqrt((1344 * 1344) / (img_width * img_height))
|
| 452 |
+
max_img_width = int(img_width * ratio)
|
| 453 |
+
new_img_width = random.randint(img_width, max_img_width)
|
| 454 |
+
new_img_height = int(new_img_width / img_width * img_height)
|
| 455 |
+
resized_image = image.resize((new_img_width, new_img_height))
|
| 456 |
+
content.append(resized_image)
|
| 457 |
+
msgs = [{'role': 'user', 'content': content}]
|
| 458 |
+
|
| 459 |
+
res = self.model.chat(
|
| 460 |
+
image=None,
|
| 461 |
+
msgs=msgs,
|
| 462 |
+
context=None,
|
| 463 |
+
tokenizer=self.tokenizer,
|
| 464 |
+
max_inp_length=max_inp_length,
|
| 465 |
+
use_image_id=use_image_id,
|
| 466 |
+
max_slice_nums=max_slice_nums,
|
| 467 |
+
**default_kwargs
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if isinstance(res, tuple) and len(res) > 0:
|
| 471 |
+
res = res[0]
|
| 472 |
+
|
| 473 |
+
return res
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class MiniCPM_o_2_6(BaseModel):
|
| 477 |
+
INSTALL_REQ = False
|
| 478 |
+
INTERLEAVE = True
|
| 479 |
+
|
| 480 |
+
def __init__(self, model_path='openbmb/MiniCPM-o-2_6', **kwargs):
|
| 481 |
+
random.seed(0)
|
| 482 |
+
np.random.seed(0)
|
| 483 |
+
torch.manual_seed(0)
|
| 484 |
+
torch.cuda.manual_seed_all(0)
|
| 485 |
+
|
| 486 |
+
assert model_path is not None
|
| 487 |
+
self.model_path = model_path
|
| 488 |
+
print(f'load from path {self.model_path}')
|
| 489 |
+
self.model = AutoModel.from_pretrained(
|
| 490 |
+
self.model_path,
|
| 491 |
+
trust_remote_code=True,
|
| 492 |
+
attn_implementation='sdpa',
|
| 493 |
+
torch_dtype=torch.bfloat16,
|
| 494 |
+
init_vision=True,
|
| 495 |
+
init_audio=False,
|
| 496 |
+
init_tts=False
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
self.model.eval().cuda()
|
| 500 |
+
|
| 501 |
+
self.kwargs = kwargs
|
| 502 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
| 503 |
+
torch.cuda.empty_cache()
|
| 504 |
+
|
| 505 |
+
num_beams = int(os.getenv("NUM_BEAMS", "3"))
|
| 506 |
+
self.num_beams = 3 if self.model_path == 'openbmb/MiniCPM-o-2_6' else num_beams
|
| 507 |
+
|
| 508 |
+
repetition_penalty = float(os.getenv("PENALTY", "1.2"))
|
| 509 |
+
self.repetition_penalty = repetition_penalty
|
| 510 |
+
|
| 511 |
+
self.options_suffix_prompt = '''\nAnswer with the option's letter from the given choices directly.'''
|
| 512 |
+
self.wo_options_system_prompt = 'Carefully read the following question Answer the question directly.'
|
| 513 |
+
self.detail_system_prompt = 'Answer this question in detail.'
|
| 514 |
+
self.vqa_prompt = 'Answer the question using a single word or phrase.'
|
| 515 |
+
|
| 516 |
+
self.multi_choice_cot_prompt = ('''Carefully read the following multichoice question, solve it step '''
|
| 517 |
+
'''by step and finally pick the option associated with the correct '''
|
| 518 |
+
'''answer in the format of "Answer: selected option\n\n''')
|
| 519 |
+
self.short_ans_cot_prompt = ('''Read the following question carefully, solve it step by step, and '''
|
| 520 |
+
'''then output the final answer in the format of "Answer: single number '''
|
| 521 |
+
'''or single word or phrase".\n\n''')
|
| 522 |
+
|
| 523 |
+
def use_custom_prompt(self, dataset=None):
|
| 524 |
+
if dataset is None:
|
| 525 |
+
return False
|
| 526 |
+
if listinstr(['MCQ', 'VQA', 'Y/N'], DATASET_TYPE(dataset)):
|
| 527 |
+
return True
|
| 528 |
+
return False
|
| 529 |
+
|
| 530 |
+
def use_cot(self, dataset=None):
|
| 531 |
+
if dataset is None:
|
| 532 |
+
return False
|
| 533 |
+
if listinstr(['MMMU', 'MathVista', 'OCRBench', 'ChartQA', 'MathVision', 'MathVerse_MINI_Vision_Only'], dataset):
|
| 534 |
+
return True
|
| 535 |
+
elif listinstr(['MMVet', 'MMBench', 'MMStar', 'HallusionBench', 'AI2D', 'RealWorldQA',
|
| 536 |
+
'POPE', 'ScienceQA', 'TextVQA', 'DocVQA'], dataset):
|
| 537 |
+
return False
|
| 538 |
+
else:
|
| 539 |
+
return False
|
| 540 |
+
|
| 541 |
+
def use_upsize(self, dataset=None):
|
| 542 |
+
if dataset is None:
|
| 543 |
+
return False
|
| 544 |
+
if listinstr(['MathVista', 'MMBench_TEST_CN', 'MMStar', 'AI2D', 'OCRBench', 'DynaMath'], dataset):
|
| 545 |
+
return True
|
| 546 |
+
else:
|
| 547 |
+
return False
|
| 548 |
+
|
| 549 |
+
def build_prompt(self, line, dataset=None):
|
| 550 |
+
if isinstance(line, int):
|
| 551 |
+
line = self.data.iloc[line]
|
| 552 |
+
|
| 553 |
+
tgt_path = self.dump_image(line, dataset)
|
| 554 |
+
system_prompt, prompt = '', ''
|
| 555 |
+
|
| 556 |
+
question = line['question']
|
| 557 |
+
|
| 558 |
+
if not self.use_cot(dataset):
|
| 559 |
+
if DATASET_TYPE(dataset) == 'MCQ':
|
| 560 |
+
options = {
|
| 561 |
+
cand: line[cand]
|
| 562 |
+
for cand in string.ascii_uppercase
|
| 563 |
+
if cand in line and not pd.isna(line[cand])
|
| 564 |
+
}
|
| 565 |
+
options_prompt = 'Options:\n'
|
| 566 |
+
for key, item in options.items():
|
| 567 |
+
options_prompt += f'{key}. {item}\n'
|
| 568 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
| 569 |
+
if hint is not None:
|
| 570 |
+
prompt += f'Hint: {hint}\n'
|
| 571 |
+
prompt += f'Question: {question}\n'
|
| 572 |
+
if len(options):
|
| 573 |
+
prompt += options_prompt
|
| 574 |
+
prompt += self.options_suffix_prompt
|
| 575 |
+
else:
|
| 576 |
+
system_prompt = self.wo_options_system_prompt
|
| 577 |
+
|
| 578 |
+
if 'MMMU' in dataset:
|
| 579 |
+
if len(system_prompt) > 0:
|
| 580 |
+
prompt = system_prompt + '\n' + prompt
|
| 581 |
+
system_prompt = ''
|
| 582 |
+
elif dataset is not None and listinstr(['HallusionBench'], dataset):
|
| 583 |
+
question += ' Yes or No?'
|
| 584 |
+
prompt = question
|
| 585 |
+
elif dataset is not None and listinstr(['OCRBench'], dataset):
|
| 586 |
+
system_prompt = self.vqa_prompt
|
| 587 |
+
prompt = question
|
| 588 |
+
elif DATASET_TYPE(dataset) == 'VQA':
|
| 589 |
+
if listinstr(['LLaVABench'], dataset):
|
| 590 |
+
system_prompt = ''
|
| 591 |
+
elif listinstr(['MMVet'], dataset):
|
| 592 |
+
system_prompt = self.detail_system_prompt
|
| 593 |
+
else:
|
| 594 |
+
system_prompt = self.vqa_prompt
|
| 595 |
+
prompt = question
|
| 596 |
+
else:
|
| 597 |
+
prompt = question
|
| 598 |
+
else:
|
| 599 |
+
has_options = True
|
| 600 |
+
if DATASET_TYPE(dataset) == 'MCQ':
|
| 601 |
+
options = {
|
| 602 |
+
cand: line[cand]
|
| 603 |
+
for cand in string.ascii_uppercase
|
| 604 |
+
if cand in line and not pd.isna(line[cand])
|
| 605 |
+
}
|
| 606 |
+
options_prompt = ''
|
| 607 |
+
for key, item in options.items():
|
| 608 |
+
options_prompt += f'{key}. {item}\n'
|
| 609 |
+
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
|
| 610 |
+
if hint is not None:
|
| 611 |
+
prompt += f'Hint: {hint}\n'
|
| 612 |
+
prompt += f'{question}\n'
|
| 613 |
+
|
| 614 |
+
if len(options):
|
| 615 |
+
prompt += options_prompt
|
| 616 |
+
else:
|
| 617 |
+
has_options = False
|
| 618 |
+
|
| 619 |
+
if 'MMMU' in dataset:
|
| 620 |
+
if len(system_prompt) > 0:
|
| 621 |
+
prompt = system_prompt + '\n' + prompt
|
| 622 |
+
system_prompt = ''
|
| 623 |
+
else:
|
| 624 |
+
prompt = question
|
| 625 |
+
|
| 626 |
+
if DATASET_TYPE(dataset) in ['MCQ', 'Y/N', 'VQA']:
|
| 627 |
+
if DATASET_TYPE(dataset) == 'MCQ':
|
| 628 |
+
if has_options:
|
| 629 |
+
prompt = self.multi_choice_cot_prompt + prompt
|
| 630 |
+
else:
|
| 631 |
+
prompt = self.short_ans_cot_prompt + prompt
|
| 632 |
+
elif DATASET_TYPE(dataset) == 'Y/N':
|
| 633 |
+
prompt = self.short_ans_cot_prompt + prompt
|
| 634 |
+
else:
|
| 635 |
+
prompt = self.short_ans_cot_prompt + prompt
|
| 636 |
+
|
| 637 |
+
msgs = []
|
| 638 |
+
if system_prompt:
|
| 639 |
+
msgs.append(dict(type='text', value=system_prompt))
|
| 640 |
+
if isinstance(tgt_path, list):
|
| 641 |
+
msgs.extend([dict(type='image', value=p) for p in tgt_path])
|
| 642 |
+
else:
|
| 643 |
+
msgs = [dict(type='image', value=tgt_path)]
|
| 644 |
+
msgs.append(dict(type='text', value=prompt))
|
| 645 |
+
|
| 646 |
+
return msgs
|
| 647 |
+
|
| 648 |
+
def extract_answer(self, res, dataset=None):
|
| 649 |
+
if dataset is None:
|
| 650 |
+
return res
|
| 651 |
+
if self.use_cot(dataset):
|
| 652 |
+
if DATASET_TYPE(dataset) == 'MCQ':
|
| 653 |
+
pattern = r'Answer:\s*([A-Ia-i])(?![A-Za-z])'
|
| 654 |
+
matches = re.findall(pattern, res, re.DOTALL)
|
| 655 |
+
if matches:
|
| 656 |
+
extracted_res = matches[-1].strip()
|
| 657 |
+
else:
|
| 658 |
+
extracted_res = res
|
| 659 |
+
return extracted_res
|
| 660 |
+
elif DATASET_TYPE(dataset) == 'VQA' and not listinstr(['OCRBench'], dataset):
|
| 661 |
+
pattern = r'Answer:\s*(.*)\s*$'
|
| 662 |
+
match = re.search(pattern, res, re.DOTALL)
|
| 663 |
+
if match:
|
| 664 |
+
extracted_res = match.group(1)
|
| 665 |
+
else:
|
| 666 |
+
extracted_res = res
|
| 667 |
+
return extracted_res
|
| 668 |
+
return res
|
| 669 |
+
|
| 670 |
+
def generate_inner(self, message, dataset=None):
|
| 671 |
+
if DATASET_MODALITY(dataset) == 'VIDEO':
|
| 672 |
+
max_slice_nums = 1
|
| 673 |
+
use_image_id = False
|
| 674 |
+
max_inp_length = 2048 * 10
|
| 675 |
+
else:
|
| 676 |
+
max_slice_nums = None
|
| 677 |
+
use_image_id = True
|
| 678 |
+
max_inp_length = 8192
|
| 679 |
+
|
| 680 |
+
max_new_tokens = 2048
|
| 681 |
+
default_kwargs = dict(
|
| 682 |
+
max_new_tokens=max_new_tokens,
|
| 683 |
+
sampling=False,
|
| 684 |
+
repetition_penalty=self.repetition_penalty,
|
| 685 |
+
num_beams=self.num_beams,
|
| 686 |
+
)
|
| 687 |
+
default_kwargs.update(self.kwargs)
|
| 688 |
+
|
| 689 |
+
content = []
|
| 690 |
+
|
| 691 |
+
for x in message:
|
| 692 |
+
if x['type'] == 'text':
|
| 693 |
+
content.append(x['value'])
|
| 694 |
+
elif x['type'] == 'image':
|
| 695 |
+
image = Image.open(x['value']).convert('RGB')
|
| 696 |
+
if not self.use_upsize(dataset):
|
| 697 |
+
content.append(image)
|
| 698 |
+
else:
|
| 699 |
+
img_width, img_height = image.width, image.height
|
| 700 |
+
if (img_width * img_height) >= (1344 * 1344):
|
| 701 |
+
content.append(image)
|
| 702 |
+
else:
|
| 703 |
+
ratio = math.sqrt((1344 * 1344) / (img_width * img_height))
|
| 704 |
+
max_img_width = int(img_width * ratio)
|
| 705 |
+
new_img_width = random.randint(img_width, max_img_width)
|
| 706 |
+
new_img_height = int(new_img_width / img_width * img_height)
|
| 707 |
+
resized_image = image.resize((new_img_width, new_img_height))
|
| 708 |
+
content.append(resized_image)
|
| 709 |
+
msgs = [{'role': 'user', 'content': content}]
|
| 710 |
+
|
| 711 |
+
res = self.model.chat(
|
| 712 |
+
image=None,
|
| 713 |
+
msgs=msgs,
|
| 714 |
+
context=None,
|
| 715 |
+
tokenizer=self.tokenizer,
|
| 716 |
+
max_inp_length=max_inp_length,
|
| 717 |
+
use_image_id=use_image_id,
|
| 718 |
+
max_slice_nums=max_slice_nums,
|
| 719 |
+
**default_kwargs
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
if isinstance(res, tuple) and len(res) > 0:
|
| 723 |
+
res = res[0]
|
| 724 |
+
|
| 725 |
+
res = self.extract_answer(res, dataset)
|
| 726 |
+
|
| 727 |
+
return res
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# vqa-eval
|
| 2 |
+
|
| 3 |
+
contains vqa_eval kit from the server.
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/datasets/__init__.py
ADDED
|
File without changes
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/datasets/vqa_dataset.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
|
| 6 |
+
def prompt_processor(prompt):
|
| 7 |
+
if prompt.startswith('OCR tokens: '):
|
| 8 |
+
pattern = r"Question: (.*?) Short answer:"
|
| 9 |
+
match = re.search(pattern, prompt, re.DOTALL)
|
| 10 |
+
question = match.group(1)
|
| 11 |
+
elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
|
| 12 |
+
if prompt.startswith('Reference OCR token:'):
|
| 13 |
+
question = prompt.split('\n')[1]
|
| 14 |
+
else:
|
| 15 |
+
question = prompt.split('\n')[0]
|
| 16 |
+
elif len(prompt.split('\n')) == 2:
|
| 17 |
+
question = prompt.split('\n')[0]
|
| 18 |
+
else:
|
| 19 |
+
assert False
|
| 20 |
+
|
| 21 |
+
return question.lower()
|
| 22 |
+
|
| 23 |
+
class textVQADataset(Dataset):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
image_dir="./downloads/TextVQA/train_images",
|
| 27 |
+
ann_path="./downloads/TextVQA/TextVQA_0.5.1_val.json",
|
| 28 |
+
):
|
| 29 |
+
self.data = json.load(open(ann_path, "r"))["data"]
|
| 30 |
+
self.image_dir = image_dir
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.data)
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, idx):
|
| 36 |
+
question = self.data[idx]['question']
|
| 37 |
+
answers = self.data[idx]['answers']
|
| 38 |
+
img_id = self.data[idx]['image_id']
|
| 39 |
+
qid = self.data[idx]['question_id']
|
| 40 |
+
img_path = os.path.join(self.image_dir, f"{img_id}.jpg")
|
| 41 |
+
|
| 42 |
+
item = {
|
| 43 |
+
"question_id": qid,
|
| 44 |
+
"image_path": img_path,
|
| 45 |
+
"question": question,
|
| 46 |
+
"gt_answers": answers
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
return item
|
| 50 |
+
|
| 51 |
+
class docVQADataset(Dataset):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
image_dir= "./downloads/DocVQA/spdocvqa_images",
|
| 55 |
+
ann_path= "./downloads/DocVQA/val_v1.0_withQT.json",
|
| 56 |
+
ocr_token_path=None
|
| 57 |
+
):
|
| 58 |
+
|
| 59 |
+
self.data = json.load(open(ann_path, "r"))["data"]
|
| 60 |
+
self.image_dir = image_dir
|
| 61 |
+
self.ann_path = ann_path
|
| 62 |
+
if ocr_token_path:
|
| 63 |
+
self.ocr_token_data = {item['image_id']: item for item in json.load(open(ocr_token_path, "r"))["data"]}
|
| 64 |
+
|
| 65 |
+
def __len__(self):
|
| 66 |
+
return len(self.data)
|
| 67 |
+
|
| 68 |
+
def __getitem__(self, idx):
|
| 69 |
+
question_id = self.data[idx]['questionId']
|
| 70 |
+
relative_img_path = self.data[idx]['image']
|
| 71 |
+
corrected_relative_img_path = relative_img_path.replace("documents", "images")
|
| 72 |
+
img_path = os.path.join(self.image_dir, corrected_relative_img_path)
|
| 73 |
+
question = self.data[idx]['question']
|
| 74 |
+
answers = self.data[idx]['answers']
|
| 75 |
+
|
| 76 |
+
question_type = self.data[idx]['question_types']
|
| 77 |
+
|
| 78 |
+
return {
|
| 79 |
+
"question_id": question_id,
|
| 80 |
+
"image_path": img_path,
|
| 81 |
+
"question": question,
|
| 82 |
+
"gt_answers": answers,
|
| 83 |
+
'question_type': question_type,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class docVQATESTDataset(Dataset):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
image_dir= "./downloads/DocVQA/spdocvqa_images",
|
| 91 |
+
ann_path= "./downloads/DocVQA/test_v1.0.json",
|
| 92 |
+
ocr_token_path=None
|
| 93 |
+
):
|
| 94 |
+
|
| 95 |
+
self.data = json.load(open(ann_path, "r"))["data"]
|
| 96 |
+
self.image_dir = image_dir
|
| 97 |
+
self.ann_path = ann_path
|
| 98 |
+
|
| 99 |
+
def __len__(self):
|
| 100 |
+
return len(self.data)
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, idx):
|
| 103 |
+
question_id = self.data[idx]['questionId']
|
| 104 |
+
relative_img_path = self.data[idx]['image']
|
| 105 |
+
corrected_relative_img_path = relative_img_path.replace("documents", "images")
|
| 106 |
+
img_path = os.path.join(self.image_dir, corrected_relative_img_path)
|
| 107 |
+
question = self.data[idx]['question']
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
return {
|
| 111 |
+
"question_id": question_id,
|
| 112 |
+
"image_path": img_path,
|
| 113 |
+
"question": question,
|
| 114 |
+
"gt_answers": "",
|
| 115 |
+
'question_type': "",
|
| 116 |
+
}
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/eval.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import datetime
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
script_dir = os.path.dirname(os.path.realpath(__file__))
|
| 8 |
+
|
| 9 |
+
sys.path.append(os.path.join(script_dir, '..'))
|
| 10 |
+
|
| 11 |
+
from datasets.vqa_dataset import docVQADataset, docVQATESTDataset, textVQADataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
print(torch.__version__)
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from eval_utils.getargs import parse_args
|
| 19 |
+
from eval_utils.vqa_evaluate import *
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_model(args):
|
| 23 |
+
if args.model_name == '':
|
| 24 |
+
raise Exception('Model name cannot be empty str!')
|
| 25 |
+
from models.MiniCPM.minicpmv import MiniCPM_V, MiniCPM_V_2_6, MiniCPM_o_2_6
|
| 26 |
+
model_path = args.model_path
|
| 27 |
+
ckpt = args.ckpt
|
| 28 |
+
|
| 29 |
+
if args.model_name == 'minicpmv':
|
| 30 |
+
model = MiniCPM_V(model_path=model_path, ckpt=ckpt, device=args.device)
|
| 31 |
+
elif args.model_name == 'minicpmv26':
|
| 32 |
+
model = MiniCPM_V_2_6(model_path=model_path, ckpt=ckpt, device=args.device)
|
| 33 |
+
elif args.model_name == 'minicpmo26':
|
| 34 |
+
model = MiniCPM_o_2_6(model_path=model_path, ckpt=ckpt, device=args.device)
|
| 35 |
+
else:
|
| 36 |
+
raise Exception(f"Unexpected Moedel Name {args.model_name}!")
|
| 37 |
+
|
| 38 |
+
return model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def main(args):
|
| 42 |
+
np.random.seed(0)
|
| 43 |
+
max_sample_num = None
|
| 44 |
+
|
| 45 |
+
torch.distributed.init_process_group(
|
| 46 |
+
backend='nccl',
|
| 47 |
+
world_size=int(os.getenv('WORLD_SIZE', '1')),
|
| 48 |
+
rank=int(os.getenv('RANK', '0')),
|
| 49 |
+
)
|
| 50 |
+
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
|
| 51 |
+
print(f'Init Rank-{torch.distributed.get_rank()}')
|
| 52 |
+
if torch.distributed.is_initialized():
|
| 53 |
+
args.device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
| 54 |
+
|
| 55 |
+
model = get_model(args)
|
| 56 |
+
|
| 57 |
+
result = {}
|
| 58 |
+
time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
| 59 |
+
|
| 60 |
+
if args.eval_textVQA or args.eval_all:
|
| 61 |
+
dataset = textVQADataset(args.textVQA_image_dir, args.textVQA_ann_path)
|
| 62 |
+
if max_sample_num is not None:
|
| 63 |
+
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
| 64 |
+
acc = evaluate_VQA(model, dataset, args.model_name, 'textVQA', time, \
|
| 65 |
+
batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
|
| 66 |
+
result['textVQA'] = acc
|
| 67 |
+
|
| 68 |
+
if args.eval_docVQA or args.eval_all:
|
| 69 |
+
dataset = docVQADataset(args.docVQA_image_dir, args.docVQA_ann_path)
|
| 70 |
+
if max_sample_num is not None:
|
| 71 |
+
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
| 72 |
+
acc = evaluate_VQA(model, dataset, args.model_name, 'docVQA', time, \
|
| 73 |
+
batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
|
| 74 |
+
result['docVQA'] = acc
|
| 75 |
+
|
| 76 |
+
if args.eval_docVQATest or args.eval_all:
|
| 77 |
+
dataset = docVQATESTDataset(args.docVQATest_image_dir, args.docVQATest_ann_path)
|
| 78 |
+
if max_sample_num is not None:
|
| 79 |
+
dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
|
| 80 |
+
acc = evaluate_VQA(model, dataset, args.model_name, 'docVQATest', time, \
|
| 81 |
+
batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
|
| 82 |
+
result['docVQATest'] = acc
|
| 83 |
+
|
| 84 |
+
if torch.distributed.is_initialized():
|
| 85 |
+
torch.distributed.barrier()
|
| 86 |
+
|
| 87 |
+
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
result_path = os.path.join(os.path.join(args.answer_path, args.model_name), 'result.json')
|
| 91 |
+
|
| 92 |
+
output_flag = False
|
| 93 |
+
for k, v in result.items():
|
| 94 |
+
if v > 0.0:
|
| 95 |
+
output_flag = True
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
if output_flag:
|
| 99 |
+
with open(result_path, "w") as f:
|
| 100 |
+
f.write(json.dumps(result, indent=4))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
args = parse_args()
|
| 105 |
+
|
| 106 |
+
main(args)
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/eval_utils/cal_metric.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import glob
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
def has_word(sentence, word):
|
| 6 |
+
pattern = r"\b" + re.escape(word) + r"\b"
|
| 7 |
+
match = re.search(pattern, sentence)
|
| 8 |
+
if match:
|
| 9 |
+
return True
|
| 10 |
+
else:
|
| 11 |
+
return False
|
| 12 |
+
def remove_special_chars(s):
|
| 13 |
+
pattern = r"[^a-zA-Z0-9\s]"
|
| 14 |
+
s = re.sub(pattern, "", s)
|
| 15 |
+
return s
|
| 16 |
+
|
| 17 |
+
for model in glob.glob('./answer_save/*'):
|
| 18 |
+
print(model, ':')
|
| 19 |
+
result_list = sorted(glob.glob(f'{model}/*.json'))
|
| 20 |
+
for task_result_path in result_list:
|
| 21 |
+
taskname = task_result_path.split('/')[-1]
|
| 22 |
+
taskname = taskname.split('.')[0]
|
| 23 |
+
if taskname not in ['IIIT5K', 'svt', 'IC13_857', 'IC15_1811', 'svtp', 'ct80',
|
| 24 |
+
'cocotext', 'ctw', 'totaltext', 'HOST']:
|
| 25 |
+
continue
|
| 26 |
+
|
| 27 |
+
correct = 0
|
| 28 |
+
num = 0
|
| 29 |
+
with open(task_result_path, 'r') as f:
|
| 30 |
+
dict = json.load(f)[:100]
|
| 31 |
+
for i in range(len(dict)):
|
| 32 |
+
gt_answers = dict[i]['gt_answers']
|
| 33 |
+
answer = dict[i]['answer']
|
| 34 |
+
gt_answers = remove_special_chars(gt_answers).lower()
|
| 35 |
+
answer = remove_special_chars(answer).lower()
|
| 36 |
+
if has_word(answer, gt_answers):
|
| 37 |
+
correct+=1
|
| 38 |
+
num+=1
|
| 39 |
+
print(f'{taskname:10s}:{float(correct)/num*100:.2f}')
|
| 40 |
+
print('=' * 32)
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/requirements.txt
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
aiohttp==3.8.4
|
| 3 |
+
aiosignal==1.3.1
|
| 4 |
+
async-timeout==4.0.2
|
| 5 |
+
attrs==22.2.0
|
| 6 |
+
bitsandbytes==0.37.0
|
| 7 |
+
cchardet==2.1.7
|
| 8 |
+
chardet==5.1.0
|
| 9 |
+
contourpy==1.0.7
|
| 10 |
+
cycler==0.11.0
|
| 11 |
+
filelock==3.9.0
|
| 12 |
+
fonttools==4.38.0
|
| 13 |
+
frozenlist==1.3.3
|
| 14 |
+
huggingface-hub==0.13.4
|
| 15 |
+
importlib-resources==5.12.0
|
| 16 |
+
kiwisolver==1.4.4
|
| 17 |
+
matplotlib==3.7.0
|
| 18 |
+
multidict==6.0.4
|
| 19 |
+
openai==0.27.0
|
| 20 |
+
packaging==23.0
|
| 21 |
+
psutil==5.9.4
|
| 22 |
+
pycocotools==2.0.6
|
| 23 |
+
pyparsing==3.0.9
|
| 24 |
+
python-dateutil==2.8.2
|
| 25 |
+
pyyaml==6.0
|
| 26 |
+
regex==2022.10.31
|
| 27 |
+
tokenizers==0.13.2
|
| 28 |
+
tqdm==4.64.1
|
| 29 |
+
transformers==4.44.2
|
| 30 |
+
timm==0.6.13
|
| 31 |
+
spacy==3.5.1
|
| 32 |
+
webdataset==0.2.48
|
| 33 |
+
scikit-learn==1.2.2
|
| 34 |
+
scipy==1.10.1
|
| 35 |
+
yarl==1.8.2
|
| 36 |
+
zipp==3.14.0
|
| 37 |
+
omegaconf==2.3.0
|
| 38 |
+
opencv-python==4.7.0.72
|
| 39 |
+
iopath==0.1.10
|
| 40 |
+
decord==0.6.0
|
| 41 |
+
tenacity==8.2.2
|
| 42 |
+
peft
|
| 43 |
+
pycocoevalcap
|
| 44 |
+
sentence-transformers
|
| 45 |
+
umap-learn
|
| 46 |
+
notebook
|
| 47 |
+
gradio==3.24.1
|
| 48 |
+
gradio-client==0.0.8
|
| 49 |
+
wandb
|
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/transform_docvqatest_for_submission.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
parser = argparse.ArgumentParser()
|
| 6 |
+
parser.add_argument("--input_file_path", type=str, default="", help="path to the originial output json.")
|
| 7 |
+
parser.add_argument("--output_file_path", type=str, default="", help="path to where you want to save the processed json.")
|
| 8 |
+
args = parser.parse_args()
|
| 9 |
+
|
| 10 |
+
with open(args.input_file_path , 'r') as f:
|
| 11 |
+
data = json.load(f)
|
| 12 |
+
|
| 13 |
+
transformed_data = [{"questionId": item["question_id"], "answer": item["answer"].replace("</s>", "")} for item in data]
|
| 14 |
+
|
| 15 |
+
with open(args.output_file_path, 'w') as f:
|
| 16 |
+
json.dump(transformed_data, f)
|