1f commited on
Commit
1ccf6d6
·
verified ·
1 Parent(s): fa29beb

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/utils/ccocr_evaluator/ocr_evaluator.py +106 -0
  2. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/__init__.py +4 -0
  3. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/file.py +344 -0
  4. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/log.py +47 -0
  5. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/misc.py +291 -0
  6. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/vlm.py +179 -0
  7. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/__init__.py +7 -0
  8. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/matching_util.py +69 -0
  9. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/mp_util.py +72 -0
  10. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/result_transfer.py +97 -0
  11. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/__init__.py +6 -0
  12. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/base.py +198 -0
  13. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/minicpm_v.py +727 -0
  14. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/README.md +3 -0
  15. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/datasets/__init__.py +0 -0
  16. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/datasets/vqa_dataset.py +116 -0
  17. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/eval.py +106 -0
  18. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/eval_utils/cal_metric.py +40 -0
  19. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/requirements.txt +49 -0
  20. r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/transform_docvqatest_for_submission.py +16 -0
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/dataset/utils/ccocr_evaluator/ocr_evaluator.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import re
5
+ from collections import Counter
6
+
7
+ # local import
8
+ from .common import BaseMetric
9
+
10
+
11
+ def token_normalize(token_text, is_lower=False, is_alphanum_only=False):
12
+ """
13
+ """
14
+ if is_lower:
15
+ token_text = token_text.lower()
16
+ if is_alphanum_only:
17
+ token_text = re.sub('[^A-Za-z0-9]+', '', token_text)
18
+ return token_text
19
+
20
+
21
+ def text_normalize_and_tokenize(text, is_keep_blank=True, is_lower=True, is_alphanum_only=False):
22
+ text = text.replace("\t", " ").replace("\n", " ").replace("###", "").replace("***", "")
23
+ text = re.sub(r'\s+', ' ', text)
24
+ if not is_keep_blank:
25
+ text = text.replace(" ", "")
26
+ text_tokens = text.split(" ") if is_keep_blank else list(text)
27
+ text_token_normalized = [token_normalize(t, is_lower, is_alphanum_only) for t in text_tokens]
28
+ text_token_normalized = [x for x in text_token_normalized if len(x) > 0]
29
+ return text_token_normalized
30
+
31
+
32
+ def evaluate_single_sample(gts, preds):
33
+ right_num = 0
34
+ gt_counter_info = dict(Counter(gts))
35
+ pdt_counter_info = dict(Counter(preds))
36
+ for gt_token, gt_count in gt_counter_info.items():
37
+ pred_count = pdt_counter_info.get(gt_token, 0)
38
+ right_num += min(gt_count, pred_count)
39
+ return right_num
40
+
41
+
42
+ def calculate_metrics(response_info, gt_info, is_verbose=False):
43
+ """
44
+ """
45
+ macro_recall_list, macro_precision_list, macro_f1_list = [], [], []
46
+ total_gt_num, total_pred_num, total_right_num = 0, 0, 0
47
+ for file_name, fullbox_gts in gt_info.items():
48
+ fullbox_preds = response_info.get(file_name, [])
49
+ right_num = evaluate_single_sample(fullbox_gts, fullbox_preds)
50
+ total_right_num += right_num
51
+ total_gt_num += len(fullbox_gts)
52
+ total_pred_num += len(fullbox_preds)
53
+
54
+ macro_recall = right_num / (len(fullbox_gts) + 1e-9)
55
+ macro_precision = right_num / (len(fullbox_preds) + 1e-9)
56
+ macro_f1 = 2 * macro_recall * macro_precision / (macro_recall + macro_precision + 1e-9)
57
+ macro_recall_list.append(macro_recall)
58
+ macro_precision_list.append(macro_precision)
59
+ macro_f1_list.append(macro_f1)
60
+
61
+ # marco
62
+ final_macro_recall = sum(macro_recall_list) / (len(macro_recall_list) + 1e-9)
63
+ final_macro_precision = sum(macro_precision_list) / (len(macro_precision_list) + 1e-9)
64
+ final_macro_f1 = sum(macro_f1_list) / (len(macro_f1_list) + 1e-9)
65
+
66
+ # micro
67
+ recall_acc = total_right_num / (total_gt_num + 1e-9)
68
+ preci_acc = total_right_num / (total_pred_num + 1e-9)
69
+ hmean = 2 * recall_acc * preci_acc / (recall_acc + preci_acc + 1e-9)
70
+ vbs_eval_result = {
71
+ 'macro_recall': final_macro_recall, 'macro_precision': final_macro_precision, 'macro_f1_score': final_macro_f1,
72
+ 'micro_recall': recall_acc, 'micro_precision': preci_acc, 'mirco_f1_score': hmean
73
+ }
74
+ eval_result = vbs_eval_result if is_verbose else {'macro_f1_score': final_macro_f1, 'mirco_f1_score': hmean}
75
+ return eval_result
76
+
77
+
78
+ class OcrEvaluator(BaseMetric):
79
+ def response_post_func(self, response_text, **kwargs):
80
+ return response_text
81
+
82
+ def evaluate(self, response_info, gt_info, **kwargs):
83
+ # hard code here
84
+ dataset_name = kwargs['dataset']
85
+ is_word_level, is_lower, is_alphanum_only = True, True, False
86
+ if dataset_name in ["Arabic", "Japanese", "Korean"] or "zh" in dataset_name:
87
+ is_word_level = False
88
+ if "multi_scene_ocr" in self.group_name and is_word_level:
89
+ is_alphanum_only = True
90
+ eval_config = {"word_level": is_word_level, "alphanum_only": is_alphanum_only, "lowercase": is_lower}
91
+
92
+ image_pdt_info, image_gt_info = {}, {}
93
+ for file_name, gt_src in gt_info.items():
94
+ pred_src = response_info.get(file_name, "")
95
+ pdt_token_list = text_normalize_and_tokenize(
96
+ str(pred_src).strip(), is_word_level, is_lower, is_alphanum_only)
97
+ gt_token_list = text_normalize_and_tokenize(
98
+ str(gt_src).strip(), is_word_level, is_lower, is_alphanum_only)
99
+ image_pdt_info[file_name] = pdt_token_list
100
+ image_gt_info[file_name] = gt_token_list
101
+ eval_result = calculate_metrics(image_pdt_info, image_gt_info, is_verbose=False)
102
+ return {"summary": eval_result, "metric_config": eval_config}
103
+
104
+
105
+ if __name__ == '__main__':
106
+ pass
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .file import *
2
+ from .vlm import *
3
+ from .misc import *
4
+ from .log import *
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/file.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pickle
3
+ import pandas as pd
4
+ import os
5
+ import csv
6
+ import hashlib
7
+ import os.path as osp
8
+ import time
9
+ import numpy as np
10
+ import validators
11
+ import mimetypes
12
+ import multiprocessing as mp
13
+ from .misc import toliststr
14
+ from .vlm import decode_base64_to_image_file
15
+
16
+
17
+ def decode_img_omni(tup):
18
+ root, im, p = tup
19
+ images = toliststr(im)
20
+ paths = toliststr(p)
21
+ if len(images) > 1 and len(paths) == 1:
22
+ paths = [osp.splitext(p)[0] + f'_{i}' + osp.splitext(p)[1] for i in range(len(images))]
23
+
24
+ assert len(images) == len(paths)
25
+ paths = [osp.join(root, p) for p in paths]
26
+ for p, im in zip(paths, images):
27
+ if osp.exists(p):
28
+ continue
29
+ if isinstance(im, str) and len(im) > 64:
30
+ decode_base64_to_image_file(im, p)
31
+ return paths
32
+
33
+
34
+ def localize_df(data, dname, nproc=32):
35
+ assert 'image' in data
36
+ indices = list(data['index'])
37
+ indices_str = [str(x) for x in indices]
38
+ images = list(data['image'])
39
+ image_map = {x: y for x, y in zip(indices_str, images)}
40
+
41
+ root = LMUDataRoot()
42
+ root = osp.join(root, 'images', dname)
43
+ os.makedirs(root, exist_ok=True)
44
+
45
+ if 'image_path' in data:
46
+ img_paths = list(data['image_path'])
47
+ else:
48
+ img_paths = []
49
+ for i in indices_str:
50
+ if len(image_map[i]) <= 64:
51
+ idx = image_map[i]
52
+ assert idx in image_map and len(image_map[idx]) > 64
53
+ img_paths.append(f'{idx}.jpg')
54
+ else:
55
+ img_paths.append(f'{i}.jpg')
56
+
57
+ tups = [(root, im, p) for p, im in zip(img_paths, images)]
58
+
59
+ pool = mp.Pool(32)
60
+ ret = pool.map(decode_img_omni, tups)
61
+ pool.close()
62
+ data.pop('image')
63
+ if 'image_path' not in data:
64
+ data['image_path'] = [x[0] if len(x) == 1 else x for x in ret]
65
+ return data
66
+
67
+
68
+ def LMUDataRoot():
69
+ if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']):
70
+ return os.environ['LMUData']
71
+ home = osp.expanduser('~')
72
+ root = osp.join(home, 'LMUData')
73
+ os.makedirs(root, exist_ok=True)
74
+ return root
75
+
76
+
77
+ def HFCacheRoot():
78
+ cache_list = ['HUGGINGFACE_HUB_CACHE', 'HF_HOME']
79
+ for cache_name in cache_list:
80
+ if cache_name in os.environ and osp.exists(os.environ[cache_name]):
81
+ if os.environ[cache_name].split('/')[-1] == 'hub':
82
+ return os.environ[cache_name]
83
+ else:
84
+ return osp.join(os.environ[cache_name], 'hub')
85
+ home = osp.expanduser('~')
86
+ root = osp.join(home, '.cache', 'huggingface', 'hub')
87
+ os.makedirs(root, exist_ok=True)
88
+ return root
89
+
90
+
91
+ def MMBenchOfficialServer(dataset_name):
92
+ root = LMUDataRoot()
93
+
94
+ if dataset_name in ['MMBench', 'MMBench_V11', 'MMBench_CN', 'MMBench_CN_V11']:
95
+ ans_file = f'{root}/{dataset_name}.tsv'
96
+ if osp.exists(ans_file):
97
+ data = load(ans_file)
98
+ if 'answer' in data and sum([pd.isna(x) for x in data['answer']]) == 0:
99
+ return True
100
+
101
+ if dataset_name in ['MMBench_TEST_EN', 'MMBench_TEST_CN', 'MMBench_TEST_EN_V11', 'MMBench_TEST_CN_V11']:
102
+ ans_file1 = f'{root}/{dataset_name}.tsv'
103
+ mapp = {
104
+ 'MMBench_TEST_EN': 'MMBench', 'MMBench_TEST_CN': 'MMBench_CN',
105
+ 'MMBench_TEST_EN_V11': 'MMBench_V11', 'MMBench_TEST_CN_V11': 'MMBench_CN_V11',
106
+ }
107
+ ans_file2 = f'{root}/{mapp[dataset_name]}.tsv'
108
+ for f in [ans_file1, ans_file2]:
109
+ if osp.exists(f):
110
+ data = load(f)
111
+ if 'answer' in data and sum([pd.isna(x) for x in data['answer']]) == 0:
112
+ return True
113
+ return False
114
+
115
+
116
+ class NumpyEncoder(json.JSONEncoder):
117
+ def default(self, obj):
118
+ if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
119
+ np.int16, np.int32, np.int64, np.uint8,
120
+ np.uint16, np.uint32, np.uint64)):
121
+ return int(obj)
122
+ elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
123
+ return float(obj)
124
+ elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
125
+ return {'real': obj.real, 'imag': obj.imag}
126
+ elif isinstance(obj, (np.ndarray,)):
127
+ return obj.tolist()
128
+ elif isinstance(obj, (np.bool_)):
129
+ return bool(obj)
130
+ elif isinstance(obj, (np.void)):
131
+ return None
132
+ return json.JSONEncoder.default(self, obj)
133
+
134
+
135
+ # LOAD & DUMP
136
+ def dump(data, f, **kwargs):
137
+ def dump_pkl(data, pth, **kwargs):
138
+ pickle.dump(data, open(pth, 'wb'))
139
+
140
+ def dump_json(data, pth, **kwargs):
141
+ json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False, cls=NumpyEncoder)
142
+
143
+ def dump_jsonl(data, f, **kwargs):
144
+ lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data]
145
+ with open(f, 'w', encoding='utf8') as fout:
146
+ fout.write('\n'.join(lines))
147
+
148
+ def dump_xlsx(data, f, **kwargs):
149
+ data.to_excel(f, index=False, engine='xlsxwriter')
150
+
151
+ def dump_csv(data, f, quoting=csv.QUOTE_ALL):
152
+ data.to_csv(f, index=False, encoding='utf-8', quoting=quoting)
153
+
154
+ def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
155
+ data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting)
156
+
157
+ handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv)
158
+ suffix = f.split('.')[-1]
159
+ return handlers[suffix](data, f, **kwargs)
160
+
161
+
162
+ def load(f, fmt=None):
163
+ def load_pkl(pth):
164
+ return pickle.load(open(pth, 'rb'))
165
+
166
+ def load_json(pth):
167
+ return json.load(open(pth, 'r', encoding='utf-8'))
168
+
169
+ def load_jsonl(f):
170
+ lines = open(f, encoding='utf-8').readlines()
171
+ lines = [x.strip() for x in lines]
172
+ if lines[-1] == '':
173
+ lines = lines[:-1]
174
+ data = [json.loads(x) for x in lines]
175
+ return data
176
+
177
+ def load_xlsx(f):
178
+ return pd.read_excel(f)
179
+
180
+ def load_csv(f):
181
+ return pd.read_csv(f)
182
+
183
+ def load_tsv(f):
184
+ return pd.read_csv(f, sep='\t')
185
+
186
+ handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
187
+ if fmt is not None:
188
+ return handlers[fmt](f)
189
+
190
+ suffix = f.split('.')[-1]
191
+ return handlers[suffix](f)
192
+
193
+
194
+ def download_file(url, filename=None):
195
+ import urllib.request
196
+ from tqdm import tqdm
197
+
198
+ class DownloadProgressBar(tqdm):
199
+ def update_to(self, b=1, bsize=1, tsize=None):
200
+ if tsize is not None:
201
+ self.total = tsize
202
+ self.update(b * bsize - self.n)
203
+
204
+ if filename is None:
205
+ filename = url.split('/')[-1]
206
+
207
+ try:
208
+ with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t:
209
+ urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
210
+ except Exception as e:
211
+ import logging
212
+ logging.warning(f'{type(e)}: {e}')
213
+ # Handle Failed Downloads from huggingface.co
214
+ if 'huggingface.co' in url:
215
+ url_new = url.replace('huggingface.co', 'hf-mirror.com')
216
+ try:
217
+ download_file(url_new, filename)
218
+ return filename
219
+ except Exception as e:
220
+ logging.warning(f'{type(e)}: {e}')
221
+ raise Exception(f'Failed to download {url}')
222
+ else:
223
+ raise Exception(f'Failed to download {url}')
224
+
225
+ return filename
226
+
227
+
228
+ def ls(dirname='.', match=[], mode='all', level=1):
229
+ if isinstance(level, str):
230
+ assert '+' in level
231
+ level = int(level[:-1])
232
+ res = []
233
+ for i in range(1, level + 1):
234
+ res.extend(ls(dirname, match=match, mode='file', level=i))
235
+ return res
236
+
237
+ if dirname == '.':
238
+ ans = os.listdir(dirname)
239
+ else:
240
+ ans = [osp.join(dirname, x) for x in os.listdir(dirname)]
241
+ assert mode in ['all', 'dir', 'file']
242
+ assert level >= 1 and isinstance(level, int)
243
+ if level == 1:
244
+ if isinstance(match, str):
245
+ match = [match]
246
+ for m in match:
247
+ if len(m) == 0:
248
+ continue
249
+ if m[0] != '!':
250
+ ans = [x for x in ans if m in x]
251
+ else:
252
+ ans = [x for x in ans if m[1:] not in x]
253
+ if mode == 'dir':
254
+ ans = [x for x in ans if osp.isdir(x)]
255
+ elif mode == 'file':
256
+ ans = [x for x in ans if not osp.isdir(x)]
257
+ return ans
258
+ else:
259
+ dirs = [x for x in ans if osp.isdir(x)]
260
+ res = []
261
+ for d in dirs:
262
+ res.extend(ls(d, match=match, mode=mode, level=level - 1))
263
+ return res
264
+
265
+
266
+ def mrlines(fname, sp='\n'):
267
+ f = open(fname).read().split(sp)
268
+ while f != [] and f[-1] == '':
269
+ f = f[:-1]
270
+ return f
271
+
272
+
273
+ def mwlines(lines, fname):
274
+ with open(fname, 'w') as fout:
275
+ fout.write('\n'.join(lines))
276
+
277
+
278
+ def md5(s):
279
+ hash = hashlib.new('md5')
280
+ if osp.exists(s):
281
+ with open(s, 'rb') as f:
282
+ for chunk in iter(lambda: f.read(2**20), b''):
283
+ hash.update(chunk)
284
+ else:
285
+ hash.update(s.encode('utf-8'))
286
+ return str(hash.hexdigest())
287
+
288
+
289
+ def last_modified(pth):
290
+ stamp = osp.getmtime(pth)
291
+ m_ti = time.ctime(stamp)
292
+ t_obj = time.strptime(m_ti)
293
+ t = time.strftime('%Y%m%d%H%M%S', t_obj)[2:]
294
+ return t
295
+
296
+
297
+ def parse_file(s):
298
+ if osp.exists(s) and s != '.':
299
+ assert osp.isfile(s)
300
+ suffix = osp.splitext(s)[1].lower()
301
+ mime = mimetypes.types_map.get(suffix, 'unknown')
302
+ return (mime, s)
303
+ elif s.startswith('data:image/'):
304
+ # To be compatible with OPENAI base64 format
305
+ content = s[11:]
306
+ mime = content.split(';')[0]
307
+ content = ';'.join(content.split(';')[1:])
308
+ dname = osp.join(LMUDataRoot(), 'files')
309
+ assert content.startswith('base64,')
310
+ b64 = content[7:]
311
+ os.makedirs(dname, exist_ok=True)
312
+ tgt = osp.join(dname, md5(b64) + '.png')
313
+ decode_base64_to_image_file(b64, tgt)
314
+ return parse_file(tgt)
315
+ elif validators.url(s):
316
+ suffix = osp.splitext(s)[1].lower()
317
+ if suffix in mimetypes.types_map:
318
+ mime = mimetypes.types_map[suffix]
319
+ dname = osp.join(LMUDataRoot(), 'files')
320
+ os.makedirs(dname, exist_ok=True)
321
+ tgt = osp.join(dname, md5(s) + suffix)
322
+ download_file(s, tgt)
323
+ return (mime, tgt)
324
+ else:
325
+ return ('url', s)
326
+ else:
327
+ return (None, s)
328
+
329
+
330
+ def file_size(f, unit='GB'):
331
+ stats = os.stat(f)
332
+ div_map = {
333
+ 'GB': 2 ** 30,
334
+ 'MB': 2 ** 20,
335
+ 'KB': 2 ** 10,
336
+ }
337
+ return stats.st_size / div_map[unit]
338
+
339
+
340
+ def parquet_to_tsv(file_path):
341
+ data = pd.read_parquet(file_path)
342
+ pth = '/'.join(file_path.split('/')[:-1])
343
+ data_name = file_path.split('/')[-1].split('.')[0]
344
+ data.to_csv(osp.join(pth, f'{data_name}.tsv'), sep='\t', index=False)
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/log.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logging.basicConfig(
3
+ format='[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s',
4
+ datefmt='%Y-%m-%d %H:%M:%S')
5
+
6
+ logger_initialized = {}
7
+
8
+
9
+ def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
10
+ logger = logging.getLogger(name)
11
+ if name in logger_initialized:
12
+ return logger
13
+
14
+ for logger_name in logger_initialized:
15
+ if name.startswith(logger_name):
16
+ return logger
17
+
18
+ stream_handler = logging.StreamHandler()
19
+ handlers = [stream_handler]
20
+
21
+ try:
22
+ import torch.distributed as dist
23
+ if dist.is_available() and dist.is_initialized():
24
+ rank = dist.get_rank()
25
+ else:
26
+ rank = 0
27
+ except ImportError:
28
+ rank = 0
29
+
30
+ if rank == 0 and log_file is not None:
31
+ file_handler = logging.FileHandler(log_file, file_mode)
32
+ handlers.append(file_handler)
33
+
34
+ formatter = logging.Formatter(
35
+ '[%(asctime)s] %(levelname)s - %(name)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s')
36
+ for handler in handlers:
37
+ handler.setFormatter(formatter)
38
+ handler.setLevel(log_level)
39
+ logger.addHandler(handler)
40
+
41
+ if rank == 0:
42
+ logger.setLevel(log_level)
43
+ else:
44
+ logger.setLevel(logging.ERROR)
45
+
46
+ logger_initialized[name] = True
47
+ return logger
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/misc.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F401, F403
2
+ import abc
3
+ import argparse
4
+ import csv
5
+ import multiprocessing as mp
6
+ import os
7
+ import os.path as osp
8
+ from pathlib import Path
9
+ import copy as cp
10
+ import random as rd
11
+ import requests
12
+ import shutil
13
+ import subprocess
14
+ import warnings
15
+ import pandas as pd
16
+ from collections import OrderedDict, defaultdict
17
+ from multiprocessing import Pool, current_process
18
+ from tqdm import tqdm
19
+ import datetime
20
+ import matplotlib.pyplot as plt
21
+ from tabulate import tabulate
22
+ from json import JSONDecoder
23
+ from huggingface_hub import scan_cache_dir
24
+ from huggingface_hub.utils._cache_manager import _scan_cached_repo
25
+ from sty import fg, bg, ef, rs
26
+
27
+
28
+ def modelscope_flag_set():
29
+ return os.environ.get('VLMEVALKIT_USE_MODELSCOPE', None) in ['1', 'True']
30
+
31
+
32
+ def process_punctuation(inText):
33
+ import re
34
+ outText = inText
35
+ punct = [
36
+ ';', r'/', '[', ']', '"', '{', '}', '(', ')', '=', '+', '\\', '_', '-',
37
+ '>', '<', '@', '`', ',', '?', '!'
38
+ ]
39
+ commaStrip = re.compile('(\d)(,)(\d)') # noqa: W605
40
+ periodStrip = re.compile('(?!<=\d)(\.)(?!\d)') # noqa: W605
41
+ for p in punct:
42
+ if (p + ' ' in inText or ' ' + p in inText) or (re.search(
43
+ commaStrip, inText) is not None):
44
+ outText = outText.replace(p, '')
45
+ else:
46
+ outText = outText.replace(p, ' ')
47
+ outText = periodStrip.sub('', outText, re.UNICODE)
48
+ return outText
49
+
50
+ def h2r(value):
51
+ if value[0] == '#':
52
+ value = value[1:]
53
+ assert len(value) == 6
54
+ return tuple(int(value[i:i + 2], 16) for i in range(0, 6, 2))
55
+
56
+ def r2h(rgb):
57
+ return '#%02x%02x%02x' % rgb
58
+
59
+ def colored(s, color):
60
+ if isinstance(color, str):
61
+ if hasattr(fg, color):
62
+ return getattr(fg, color) + s + fg.rs
63
+ color = h2r(color)
64
+ return fg(*color) + s + fg.rs
65
+
66
+ def istype(s, type):
67
+ if isinstance(s, type):
68
+ return True
69
+ try:
70
+ return isinstance(eval(s), type)
71
+ except Exception as _:
72
+ return False
73
+
74
+ def bincount(lst):
75
+ bins = defaultdict(lambda: 0)
76
+ for item in lst:
77
+ bins[item] += 1
78
+ return bins
79
+
80
+ def get_cache_path(repo_id, branch='main', repo_type='datasets'):
81
+ try:
82
+ if modelscope_flag_set():
83
+ from modelscope.hub.file_download import create_temporary_directory_and_cache
84
+ if repo_type == 'datasets':
85
+ repo_type = 'dataset'
86
+ _, cache = create_temporary_directory_and_cache(model_id=repo_id, repo_type=repo_type)
87
+ cache_path = cache.get_root_location()
88
+ return cache_path
89
+ else:
90
+ from .file import HFCacheRoot
91
+ cache_path = HFCacheRoot()
92
+ org, repo_name = repo_id.split('/')
93
+ repo_path = Path(osp.join(cache_path, f'{repo_type}--{org}--{repo_name}/'))
94
+ hf_cache_info = _scan_cached_repo(repo_path=repo_path)
95
+ revs = {r.refs: r for r in hf_cache_info.revisions}
96
+ if branch is not None:
97
+ revs = {refs: r for refs, r in revs.items() if branch in refs}
98
+ rev2keep = max(revs.values(), key=lambda r: r.last_modified)
99
+ return str(rev2keep.snapshot_path)
100
+ except Exception as e:
101
+ import logging
102
+ logging.warning(f'{type(e)}: {e}')
103
+ return None
104
+
105
+ def proxy_set(s):
106
+ import os
107
+ for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']:
108
+ os.environ[key] = s
109
+
110
+ def get_rank_and_world_size():
111
+ rank = int(os.environ.get('RANK', 0))
112
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
113
+ return rank, world_size
114
+
115
+ def splitlen(s, sym='/'):
116
+ return len(s.split(sym))
117
+
118
+ def listinstr(lst, s):
119
+ assert isinstance(lst, list)
120
+ for item in lst:
121
+ if item in s:
122
+ return True
123
+ return False
124
+
125
+ def d2df(D):
126
+ return pd.DataFrame({x: [D[x]] for x in D})
127
+
128
+ def cn_string(s):
129
+ import re
130
+ if re.search(u'[\u4e00-\u9fff]', s):
131
+ return True
132
+ return False
133
+
134
+ try:
135
+ import decord
136
+ except ImportError:
137
+ pass
138
+
139
+ def timestr(granularity='second'):
140
+ s = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
141
+ assert granularity in ['second', 'minute', 'hour', 'day']
142
+ if granularity == 'second':
143
+ return s
144
+ elif granularity == 'minute':
145
+ return s[:-2]
146
+ elif granularity == 'hour':
147
+ return s[:-4]
148
+ elif granularity == 'day':
149
+ return s[:-6]
150
+
151
+ def _minimal_ext_cmd(cmd, cwd=None):
152
+ env = {}
153
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
154
+ v = os.environ.get(k)
155
+ if v is not None:
156
+ env[k] = v
157
+ env['LANGUAGE'] = 'C'
158
+ env['LANG'] = 'C'
159
+ env['LC_ALL'] = 'C'
160
+ out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env, cwd=cwd).communicate()[0]
161
+ return out
162
+
163
+ def githash(fallback='unknown', digits=8):
164
+ if digits is not None and not isinstance(digits, int):
165
+ raise TypeError('digits must be None or an integer')
166
+ try:
167
+ import vlmeval
168
+ except ImportError as e:
169
+ import logging
170
+ logging.error(f'ImportError: {str(e)}')
171
+ return fallback
172
+ try:
173
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'], cwd=vlmeval.__path__[0])
174
+ sha = out.strip().decode('ascii')
175
+ if digits is not None:
176
+ sha = sha[:digits]
177
+ except OSError:
178
+ sha = fallback
179
+ return sha
180
+
181
+ def dict_merge(dct, merge_dct):
182
+ for k, _ in merge_dct.items():
183
+ if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa
184
+ dict_merge(dct[k], merge_dct[k])
185
+ else:
186
+ dct[k] = merge_dct[k]
187
+
188
+ def youtube_dl(idx):
189
+ cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4'
190
+ os.system(cmd)
191
+
192
+ def run_command(cmd):
193
+ if isinstance(cmd, str):
194
+ cmd = cmd.split()
195
+ return subprocess.check_output(cmd).decode()
196
+
197
+ def load_env():
198
+ import logging
199
+ logging.basicConfig(
200
+ format='[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s',
201
+ datefmt='%Y-%m-%d %H:%M:%S')
202
+
203
+ try:
204
+ import vlmeval
205
+ except ImportError:
206
+ logging.error('VLMEval is not installed. Failed to import environment variables from .env file. ')
207
+ return
208
+ pth = osp.realpath(vlmeval.__path__[0])
209
+ pth = osp.join(pth, '../.env')
210
+ pth = osp.realpath(pth)
211
+ if not osp.exists(pth):
212
+ logging.error(f'Did not detect the .env file at {pth}, failed to load. ')
213
+ return
214
+
215
+ from dotenv import dotenv_values
216
+ values = dotenv_values(pth)
217
+ for k, v in values.items():
218
+ if v is not None and len(v):
219
+ os.environ[k] = v
220
+ logging.info(f'API Keys successfully loaded from {pth}')
221
+
222
+ def pip_install_robust(package):
223
+ import sys
224
+ retry = 3
225
+ while retry > 0:
226
+ try:
227
+ package_base = package.split('=')[0]
228
+ module = __import__(package)
229
+ return True
230
+ except ImportError:
231
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
232
+ retry -= 1
233
+ return False
234
+
235
+
236
+ def version_cmp(v1, v2, op='eq'):
237
+ from packaging import version
238
+ import operator
239
+ op_func = getattr(operator, op)
240
+ return op_func(version.parse(v1), version.parse(v2))
241
+
242
+
243
+ def toliststr(s):
244
+ if isinstance(s, str) and (s[0] == '[') and (s[-1] == ']'):
245
+ return [str(x) for x in eval(s)]
246
+ elif isinstance(s, str):
247
+ return [s]
248
+ elif isinstance(s, list):
249
+ return [str(x) for x in s]
250
+ raise NotImplementedError
251
+
252
+
253
+ def extract_json_objects(text, decoder=JSONDecoder()):
254
+ pos = 0
255
+ while True:
256
+ match = text.find('{', pos)
257
+ if match == -1: break
258
+ try:
259
+ result, index = decoder.raw_decode(text[match:])
260
+ yield result
261
+ pos = match + index
262
+ except ValueError:
263
+ pos = match + 1
264
+
265
+
266
+ def get_gpu_memory():
267
+ import subprocess
268
+ try:
269
+ command = "nvidia-smi --query-gpu=memory.free --format=csv"
270
+ memory_free_info = subprocess.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
271
+ memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
272
+ return memory_free_values
273
+ except Exception as e:
274
+ print(f'{type(e)}: {str(e)}')
275
+ return []
276
+
277
+
278
+ def auto_split_flag():
279
+ flag = os.environ.get('AUTO_SPLIT', '0')
280
+ if flag == '1':
281
+ return True
282
+ _, world_size = get_rank_and_world_size()
283
+ try:
284
+ import torch
285
+ device_count = torch.cuda.device_count()
286
+ if device_count > world_size and device_count % world_size == 0:
287
+ return True
288
+ else:
289
+ return False
290
+ except:
291
+ return False
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/smp/vlm.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import pandas as pd
4
+ import numpy as np
5
+ import string
6
+ from uuid import uuid4
7
+ import os.path as osp
8
+ import base64
9
+ from PIL import Image
10
+ import sys
11
+
12
+ Image.MAX_IMAGE_PIXELS = 1e9
13
+
14
+
15
+ def rescale_img(img, tgt=None):
16
+ assert isinstance(tgt, tuple) and -1 in tgt
17
+ w, h = img.size
18
+ if tgt[0] != -1:
19
+ new_w, new_h = tgt[0], int(tgt[0] / w * h)
20
+ elif tgt[1] != -1:
21
+ new_w, new_h = int(tgt[1] / h * w), tgt[1]
22
+ img = img.resize((new_w, new_h))
23
+ return img
24
+
25
+
26
+ def concat_images_vlmeval(images, target_size=-1, mode='h', return_image=False):
27
+ from .file import md5
28
+
29
+ ims = [Image.open(im) for im in images]
30
+ if target_size != -1:
31
+ ims = [
32
+ rescale_img(im, (-1, target_size) if mode == 'h' else (target_size, -1))
33
+ for im in ims
34
+ ]
35
+
36
+ ws, hs = [x.width for x in ims], [x.height for x in ims]
37
+ if mode == 'h':
38
+ new_w, new_h = sum(ws), max(hs)
39
+ dst = Image.new('RGB', (new_w, new_h))
40
+ for i, im in enumerate(ims):
41
+ dst.paste(im, (sum(ws[:i]), 0))
42
+ elif mode == 'v':
43
+ new_w, new_h = max(ws), sum(hs)
44
+ dst = Image.new('RGB', (new_w, new_h))
45
+ for i, im in enumerate(ims):
46
+ dst.paste(im, (sum(ws[:i], 0)))
47
+ if return_image:
48
+ return dst
49
+ else:
50
+ _str = '\n'.join(images)
51
+ str_md5 = md5(_str)
52
+ tgt = osp.join('/tmp', str_md5 + '.jpg')
53
+ dst.save(tgt)
54
+ return tgt
55
+
56
+
57
+ def mmqa_display(question, target_size=512):
58
+ question = {k.lower(): v for k, v in question.items()}
59
+ keys = list(question.keys())
60
+ keys = [k for k in keys if k not in ['index', 'image']]
61
+
62
+ images = question['image']
63
+ if isinstance(images, str):
64
+ images = [images]
65
+
66
+ idx = question.pop('index', 'XXX')
67
+ print(f'INDEX: {idx}')
68
+
69
+ for im in images:
70
+ image = decode_base64_to_image(im, target_size=target_size)
71
+ display(image) # noqa: F821
72
+
73
+ for k in keys:
74
+ try:
75
+ if not pd.isna(question[k]):
76
+ print(f'{k.upper()}. {question[k]}')
77
+ except ValueError:
78
+ if False in pd.isna(question[k]):
79
+ print(f'{k.upper()}. {question[k]}')
80
+
81
+
82
+ def encode_image_to_base64(img, target_size=-1, fmt='JPEG'):
83
+ # if target_size == -1, will not do resizing
84
+ # else, will set the max_size ot (target_size, target_size)
85
+ if img.mode in ('RGBA', 'P'):
86
+ img = img.convert('RGB')
87
+ if target_size > 0:
88
+ img.thumbnail((target_size, target_size))
89
+ img_buffer = io.BytesIO()
90
+ img.save(img_buffer, format=fmt)
91
+ image_data = img_buffer.getvalue()
92
+ ret = base64.b64encode(image_data).decode('utf-8')
93
+ return ret
94
+
95
+
96
+ def encode_image_file_to_base64(image_path, target_size=-1):
97
+ image = Image.open(image_path)
98
+ return encode_image_to_base64(image, target_size=target_size)
99
+
100
+
101
+ def decode_base64_to_image(base64_string, target_size=-1):
102
+ image_data = base64.b64decode(base64_string)
103
+ image = Image.open(io.BytesIO(image_data))
104
+ if image.mode in ('RGBA', 'P'):
105
+ image = image.convert('RGB')
106
+ if target_size > 0:
107
+ image.thumbnail((target_size, target_size))
108
+ return image
109
+
110
+
111
+ def decode_base64_to_image_file(base64_string, image_path, target_size=-1):
112
+ image = decode_base64_to_image(base64_string, target_size=target_size)
113
+ image.save(image_path)
114
+
115
+
116
+ def build_option_str(option_dict):
117
+ s = 'There are several options: \n'
118
+ for c, content in option_dict.items():
119
+ if not pd.isna(content):
120
+ s += f'{c}. {content}\n'
121
+ return s
122
+
123
+
124
+ def isimg(s):
125
+ return osp.exists(s) or s.startswith('http')
126
+
127
+
128
+ def read_ok(img_path):
129
+ if not osp.exists(img_path):
130
+ return False
131
+ try:
132
+ im = Image.open(img_path)
133
+ assert im.size[0] > 0 and im.size[1] > 0
134
+ return True
135
+ except:
136
+ return False
137
+
138
+
139
+ def gpt_key_set():
140
+ openai_key = os.environ.get('OPENAI_API_KEY', None)
141
+ return isinstance(openai_key, str) and openai_key.startswith('sk-')
142
+
143
+
144
+ def apiok(wrapper):
145
+ s = wrapper.generate('Hello!')
146
+ return wrapper.fail_msg not in s
147
+
148
+
149
+ def circular_pred(df, extract_func=None):
150
+ if extract_func is None:
151
+ extract_func = lambda x: x # noqa: E731
152
+ df = df.sort_values('index')
153
+ from vlmeval.utils import can_infer_option
154
+
155
+ shift = int(1e6)
156
+
157
+ choices = [extract_func(x) for x in df['prediction']]
158
+ pred_map = {i: c for i, c in zip(df['index'], choices)}
159
+ flag_map = {i: True for i in pred_map if i < 1e6}
160
+ valid_map = {i: True for i in pred_map if i < 1e6}
161
+ for i in df['index']:
162
+ if i >= shift and pred_map[i] and pred_map[i - shift]:
163
+ if pred_map[i] not in list(
164
+ string.ascii_uppercase
165
+ ) or pred_map[ # noqa: W504
166
+ i - shift
167
+ ] not in list(
168
+ string.ascii_uppercase
169
+ ):
170
+
171
+ valid_map[i % shift] = False
172
+ continue
173
+ if (ord(pred_map[i]) - ord(pred_map[i - shift])) % 4 == 1:
174
+ continue
175
+ else:
176
+ flag_map[i % shift] = False
177
+ flag_map = {k: v for k, v in flag_map.items() if valid_map[k]}
178
+ flags = list(flag_map.values())
179
+ return np.mean(flags)
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .matching_util import can_infer, can_infer_option, can_infer_text
2
+ from .mp_util import track_progress_rich
3
+
4
+
5
+ __all__ = [
6
+ 'can_infer', 'can_infer_option', 'can_infer_text', 'track_progress_rich',
7
+ ]
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/matching_util.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+ import copy as cp
3
+ import os
4
+ from ..smp import *
5
+
6
+
7
+ def can_infer_option(answer, choices):
8
+ verbose = os.environ.get('VERBOSE', 0)
9
+ # Choices is a dictionary
10
+ if 'Failed to obtain answer via API' in answer:
11
+ return False
12
+
13
+ reject_to_answer = [
14
+ "Sorry, I can't help with images of people yet.",
15
+ "I can't process this file.",
16
+ "I'm sorry, but without the image provided",
17
+ 'Cannot determine the answer'
18
+ ]
19
+ for err in reject_to_answer:
20
+ if err in answer:
21
+ return 'Z'
22
+
23
+ def count_choice(splits, choices, prefix='', suffix=''):
24
+ cnt = 0
25
+ for c in choices:
26
+ if prefix + c + suffix in splits:
27
+ cnt += 1
28
+ return cnt
29
+
30
+ answer_mod = cp.copy(answer)
31
+ chars = '.()[],:;!*#{}'
32
+ for c in chars:
33
+ answer_mod = answer_mod.replace(c, ' ')
34
+
35
+ splits = [x.strip() for x in answer_mod.split()]
36
+ count = count_choice(splits, choices)
37
+
38
+ if count == 1:
39
+ for ch in choices:
40
+ if 'A' in splits and len(splits) > 3 and verbose:
41
+ logger = get_logger('Evaluation')
42
+ logger.info(f'A might be a quantifier in the string: {answer}.')
43
+ return False
44
+ if ch in splits:
45
+ return ch
46
+ elif count == 0 and count_choice(splits, {'Z', ''}) == 1:
47
+ return 'Z'
48
+ return False
49
+
50
+
51
+ def can_infer_text(answer, choices):
52
+ answer = answer.lower()
53
+ assert isinstance(choices, dict)
54
+ for k in choices:
55
+ assert k in string.ascii_uppercase
56
+ choices[k] = str(choices[k]).lower()
57
+ cands = []
58
+ for k in choices:
59
+ if choices[k] in answer:
60
+ cands.append(k)
61
+ if len(cands) == 1:
62
+ return cands[0]
63
+ return False
64
+
65
+
66
+ def can_infer(answer, choices):
67
+ answer = str(answer)
68
+ copt = can_infer_option(answer, choices)
69
+ return copt if copt else can_infer_text(answer, choices)
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/mp_util.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import Pool
2
+ import os
3
+ from typing import Callable, Iterable, Sized
4
+
5
+ from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
6
+ TaskProgressColumn, TextColumn, TimeRemainingColumn)
7
+ from rich.text import Text
8
+ import os.path as osp
9
+ import time
10
+ import portalocker
11
+ from ..smp import load, dump
12
+
13
+
14
+ def track_progress_rich(
15
+ func: Callable,
16
+ tasks: Iterable = tuple(),
17
+ nproc: int = 1,
18
+ save=None,
19
+ keys=None,
20
+ **kwargs) -> list:
21
+
22
+ from concurrent.futures import ThreadPoolExecutor
23
+ from tqdm import tqdm
24
+ if save is not None:
25
+ assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ''
26
+ if not osp.exists(save):
27
+ dump({}, save)
28
+ if keys is not None:
29
+ assert len(keys) == len(tasks)
30
+ if not callable(func):
31
+ raise TypeError('func must be a callable object')
32
+ if not isinstance(tasks, Iterable):
33
+ raise TypeError(
34
+ f'tasks must be an iterable object, but got {type(tasks)}')
35
+ assert nproc > 0, 'nproc must be a positive number'
36
+ res = load(save) if save is not None else {}
37
+ results = [None for _ in range(len(tasks))]
38
+
39
+ with ThreadPoolExecutor(max_workers=nproc) as executor:
40
+ futures = []
41
+
42
+ for inputs in tasks:
43
+ if not isinstance(inputs, (tuple, list, dict)):
44
+ inputs = (inputs, )
45
+ if isinstance(inputs, dict):
46
+ future = executor.submit(func, **inputs)
47
+ else:
48
+ future = executor.submit(func, *inputs)
49
+ futures.append(future)
50
+
51
+ unfinished = set(range(len(tasks)))
52
+ pbar = tqdm(total=len(unfinished))
53
+ while len(unfinished):
54
+ new_finished = set()
55
+ for idx in unfinished:
56
+ if futures[idx].done():
57
+ results[idx] = futures[idx].result()
58
+ new_finished.add(idx)
59
+ if keys is not None:
60
+ res[keys[idx]] = results[idx]
61
+ if len(new_finished):
62
+ if save is not None:
63
+ dump(res, save)
64
+ pbar.update(len(new_finished))
65
+ for k in new_finished:
66
+ unfinished.remove(k)
67
+ time.sleep(0.1)
68
+ pbar.close()
69
+
70
+ if save is not None:
71
+ dump(res, save)
72
+ return results
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/utils/result_transfer.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..smp import *
2
+ from ..dataset.utils.judge_util import build_judge
3
+ from ..dataset.utils.multiple_choice import extract_answer_from_item
4
+ from .matching_util import can_infer
5
+ from .mp_util import track_progress_rich
6
+
7
+
8
+ def MMMU_result_transfer(result_path):
9
+ res = {}
10
+ result_data = load(result_path)
11
+ mcq = result_data['A'].notna()
12
+ lt = len(result_data)
13
+ for i in range(lt):
14
+ line = result_data.iloc[i]
15
+ if mcq[i]:
16
+ options = {
17
+ cand: line[cand]
18
+ for cand in string.ascii_uppercase
19
+ if cand in line and not pd.isna(line[cand])
20
+ }
21
+ prediction = line['prediction']
22
+ infer_prediction = can_infer(prediction, options)
23
+ res[line['id']] = infer_prediction
24
+ else:
25
+ res[line['id']] = line['prediction']
26
+ result_json = result_path.replace('.xlsx', '.json')
27
+ dump(res, result_json)
28
+ return result_json
29
+
30
+
31
+ def MMTBench_result_transfer(eval_file, dataset='default', **judge_kwargs):
32
+ logger = get_logger('Evaluation')
33
+ nproc = judge_kwargs.pop('nproc', 4)
34
+
35
+ rd.seed(2680)
36
+ suffix = eval_file.split('.')[-1]
37
+ model = judge_kwargs['model']
38
+ assert model in ['chatgpt-0125', 'exact_matching', 'gpt-4-0125']
39
+ name_str_map = {
40
+ 'chatgpt-0125': 'openai',
41
+ 'gpt-4-0125': 'gpt4'
42
+ }
43
+ name_str = name_str_map[model] if model in name_str_map else model
44
+
45
+ if model == 'exact_matching':
46
+ model = None
47
+ elif gpt_key_set():
48
+ model = build_judge(**judge_kwargs)
49
+ if not model.working():
50
+ logger.error('The OPENAI API is not working properly, will use exact matching for evaluation')
51
+ model = None
52
+ else:
53
+ logger.error('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
54
+ model = None
55
+
56
+ logger.info(f'Evaluating {eval_file}')
57
+ result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_option.pkl')
58
+ result = {}
59
+ if osp.exists(result_file):
60
+ result = load(result_file)
61
+
62
+ data = load(eval_file)
63
+ assert 'index' in data, 'Essentail columns missing in the eval_file.'
64
+
65
+ data = data.sort_values(by='index')
66
+ data['prediction'] = [str(x) for x in data['prediction']]
67
+ for k in data.keys():
68
+ data[k.lower() if k not in list(string.ascii_uppercase) else k] = data.pop(k)
69
+
70
+ idx2lines = {data.iloc[i]['index']: data.iloc[i] for i in range(len(data))}
71
+ idx2lines = {k: v for k, v in idx2lines.items() if k not in result}
72
+
73
+ indices = list(idx2lines.keys())
74
+ lines = [idx2lines[i] for i in indices]
75
+ tups = [(model, line) for line in lines]
76
+ res = track_progress_rich(
77
+ extract_answer_from_item,
78
+ tups,
79
+ nproc=nproc,
80
+ chunksize=nproc,
81
+ save=result_file,
82
+ keys=indices)
83
+
84
+ for i, r in zip(indices, res):
85
+ if i in result:
86
+ assert result[i]['opt'] == r['opt'] and result[i]['log'] == r['log']
87
+ else:
88
+ result[i] = r
89
+
90
+ indices = list(data['index'])
91
+ data['opt'] = [result[i]['opt'] for i in data['index']]
92
+ data['log'] = [result[i]['log'] for i in data['index']]
93
+
94
+ # load split
95
+ output_path = eval_file.replace(f'.{suffix}', f'_{name_str}_submission.tsv')
96
+ dump(data, eval_file.replace(f'.{suffix}', f'_{name_str}_submission.tsv'))
97
+ return output_path
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ torch.set_grad_enabled(False)
4
+ torch.manual_seed(1234)
5
+ from .base import BaseModel
6
+ from .minicpm_v import MiniCPM_V, MiniCPM_Llama3_V, MiniCPM_V_2_6, MiniCPM_o_2_6
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/base.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..smp import *
2
+ from ..dataset import img_root_map, DATASET_TYPE
3
+ from abc import abstractmethod
4
+
5
+
6
+ class BaseModel:
7
+
8
+ INTERLEAVE = False
9
+ allowed_types = ['text', 'image', 'video']
10
+
11
+ def __init__(self):
12
+ self.dump_image_func = None
13
+
14
+ def use_custom_prompt(self, dataset):
15
+ """Whether to use custom prompt for the given dataset.
16
+
17
+ Args:
18
+ dataset (str): The name of the dataset.
19
+
20
+ Returns:
21
+ bool: Whether to use custom prompt. If True, will call `build_prompt` of the VLM to build the prompt.
22
+ Default to False.
23
+ """
24
+ return False
25
+
26
+ @abstractmethod
27
+ def build_prompt(self, line, dataset):
28
+ """Build custom prompts for a specific dataset. Called only if `use_custom_prompt` returns True.
29
+
30
+ Args:
31
+ line (line of pd.DataFrame): The raw input line.
32
+ dataset (str): The name of the dataset.
33
+
34
+ Returns:
35
+ str: The built message.
36
+ """
37
+ raise NotImplementedError
38
+
39
+ def set_dump_image(self, dump_image_func):
40
+ self.dump_image_func = dump_image_func
41
+
42
+ def dump_image(self, line, dataset):
43
+ return self.dump_image_func(line)
44
+
45
+ @abstractmethod
46
+ def generate_inner(self, message, dataset=None):
47
+ raise NotImplementedError
48
+
49
+ def check_content(self, msgs):
50
+ """Check the content type of the input. Four types are allowed: str, dict, liststr, listdict.
51
+ """
52
+ if isinstance(msgs, str):
53
+ return 'str'
54
+ if isinstance(msgs, dict):
55
+ return 'dict'
56
+ if isinstance(msgs, list):
57
+ types = [self.check_content(m) for m in msgs]
58
+ if all(t == 'str' for t in types):
59
+ return 'liststr'
60
+ if all(t == 'dict' for t in types):
61
+ return 'listdict'
62
+ return 'unknown'
63
+
64
+ def preproc_content(self, inputs):
65
+ """Convert the raw input messages to a list of dicts.
66
+
67
+ Args:
68
+ inputs: raw input messages.
69
+
70
+ Returns:
71
+ list(dict): The preprocessed input messages. Will return None if failed to preprocess the input.
72
+ """
73
+ if self.check_content(inputs) == 'str':
74
+ return [dict(type='text', value=inputs)]
75
+ elif self.check_content(inputs) == 'dict':
76
+ assert 'type' in inputs and 'value' in inputs
77
+ return [inputs]
78
+ elif self.check_content(inputs) == 'liststr':
79
+ res = []
80
+ for s in inputs:
81
+ mime, pth = parse_file(s)
82
+ if mime is None or mime == 'unknown':
83
+ res.append(dict(type='text', value=s))
84
+ else:
85
+ res.append(dict(type=mime.split('/')[0], value=pth))
86
+ return res
87
+ elif self.check_content(inputs) == 'listdict':
88
+ for item in inputs:
89
+ assert 'type' in item and 'value' in item
90
+ mime, s = parse_file(item['value'])
91
+ if mime is None:
92
+ assert item['type'] == 'text'
93
+ else:
94
+ assert mime.split('/')[0] == item['type']
95
+ item['value'] = s
96
+ return inputs
97
+ else:
98
+ return None
99
+
100
+ def generate(self, message, dataset=None):
101
+ """Generate the output message.
102
+
103
+ Args:
104
+ message (list[dict]): The input message.
105
+ dataset (str, optional): The name of the dataset. Defaults to None.
106
+
107
+ Returns:
108
+ str: The generated message.
109
+ """
110
+ assert self.check_content(message) in ['str', 'dict', 'liststr', 'listdict'], f'Invalid input type: {message}'
111
+ message = self.preproc_content(message)
112
+ assert message is not None and self.check_content(message) == 'listdict'
113
+ for item in message:
114
+ assert item['type'] in self.allowed_types, f'Invalid input type: {item["type"]}'
115
+ return self.generate_inner(message, dataset)
116
+
117
+ def chat(self, messages, dataset=None):
118
+ """The main function for multi-turn chatting. Will call `chat_inner` with the preprocessed input messages."""
119
+ assert hasattr(self, 'chat_inner'), 'The API model should has the `chat_inner` method. '
120
+ for msg in messages:
121
+ assert isinstance(msg, dict) and 'role' in msg and 'content' in msg, msg
122
+ assert self.check_content(msg['content']) in ['str', 'dict', 'liststr', 'listdict'], msg
123
+ msg['content'] = self.preproc_content(msg['content'])
124
+
125
+ while len(messages):
126
+ try:
127
+ return self.chat_inner(messages, dataset=dataset)
128
+ except Exception as e:
129
+ logging.info(f'{type(e)}: {e}')
130
+ messages = messages[1:]
131
+ while len(messages) and messages[0]['role'] != 'user':
132
+ messages = messages[1:]
133
+ continue
134
+ return 'Chat Mode: Failed with all possible conversation turns.'
135
+
136
+ def message_to_promptimg(self, message, dataset=None):
137
+ assert not self.INTERLEAVE
138
+ model_name = self.__class__.__name__
139
+ warnings.warn(
140
+ f'Model {model_name} does not support interleaved input. '
141
+ 'Will use the first image and aggregated texts as prompt. ')
142
+ num_images = len([x for x in message if x['type'] == 'image'])
143
+ if num_images == 0:
144
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
145
+ image = None
146
+ else:
147
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
148
+ images = [x['value'] for x in message if x['type'] == 'image']
149
+ if 'BLINK' == dataset:
150
+ image = concat_images_vlmeval(images, target_size=512)
151
+ else:
152
+ image = images[0]
153
+ return prompt, image
154
+
155
+ def message_to_promptvideo(self, message):
156
+ if self.VIDEO_LLM:
157
+ num_videos = len([x for x in message if x['type'] == 'video'])
158
+ if num_videos == 0:
159
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
160
+ video = None
161
+ else:
162
+ prompt = '\n'.join([x['value'] for x in message if x['type'] == 'text'])
163
+ video = [x['value'] for x in message if x['type'] == 'video'][0]
164
+ return prompt, video
165
+ else:
166
+ logging.critical('Model does not support video input.')
167
+ raise NotImplementedError
168
+
169
+ def message_to_promptvideo_withrole(self, message, dataset=None):
170
+ if self.VIDEO_LLM:
171
+ system, user, assistant, video_list = '', '', '', []
172
+ for msg in message:
173
+ if msg['type'] == 'text':
174
+ if 'role' in msg and msg['role'] == 'system':
175
+ system += msg['value']
176
+ elif 'role' in msg and msg['role'] == 'assistant':
177
+ assistant += msg['value']
178
+ else:
179
+ user += msg['value']
180
+ elif msg['type'] == 'video':
181
+ video_list.append(msg['value'])
182
+ question = {
183
+ 'system': system,
184
+ 'user': user,
185
+ 'assistant': assistant
186
+ }
187
+ if assistant == '':
188
+ if listinstr(['MCQ'], DATASET_TYPE(dataset)):
189
+ question['assistant'] = 'Best Option: ('
190
+ else:
191
+ del question['assistant']
192
+ if len(video_list) > 1:
193
+ print('VLMEvalKit only support single video as input, take first video as input')
194
+ video = video_list[0]
195
+ return question, video
196
+ else:
197
+ logging.critical('Model does not support video input.')
198
+ raise NotImplementedError
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vlmevalkit/vlmeval/vlm/minicpm_v.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from PIL import Image
6
+ from transformers import AutoModel, AutoTokenizer
7
+
8
+ from .base import BaseModel
9
+ from ..smp import *
10
+ from ..dataset import DATASET_TYPE, DATASET_MODALITY
11
+
12
+ import re
13
+
14
+
15
+ class MiniCPM_V(BaseModel):
16
+
17
+ INSTALL_REQ = False
18
+ INTERLEAVE = False
19
+
20
+ def __init__(self, model_path='openbmb/MiniCPM-V', **kwargs):
21
+ assert model_path is not None
22
+ self.model_path = model_path
23
+ print(f'load from {self.model_path}')
24
+ self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
25
+ self.model = self.model.to(dtype=torch.bfloat16)
26
+ self.model.eval().cuda()
27
+ self.kwargs = kwargs
28
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
29
+ torch.cuda.empty_cache()
30
+ self.num_beams = 3
31
+
32
+ def use_custom_prompt(self, dataset):
33
+ assert dataset is not None
34
+ if listinstr(['MMDU', 'MME-RealWorld', 'MME-RealWorld-CN'], dataset):
35
+ # For Multi-Turn we don't have custom prompt
36
+ return False
37
+ return False
38
+
39
+ def build_prompt(self, line, dataset=None):
40
+ assert dataset is None or isinstance(dataset, str)
41
+ assert self.use_custom_prompt(dataset)
42
+ tgt_path = self.dump_image(line, dataset)
43
+
44
+ question = line['question']
45
+ options = {
46
+ cand: line[cand]
47
+ for cand in string.ascii_uppercase
48
+ if cand in line and not pd.isna(line[cand])
49
+ }
50
+ options_prompt = 'Options:\n'
51
+ for key, item in options.items():
52
+ options_prompt += f'{key}. {item}\n'
53
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
54
+ prompt = ''
55
+ if hint is not None:
56
+ prompt += f'Hint: {hint}\n'
57
+ prompt += f'{question}\n'
58
+ if len(options):
59
+ prompt += options_prompt
60
+ prompt = 'Study the image carefully and pick the option associated with the correct answer. \
61
+ Focus solely on selecting the option and avoid including any other content.\n' + prompt
62
+ message = [dict(type='text', value=prompt)]
63
+ message.extend([dict(type='image', value=p) for p in tgt_path])
64
+
65
+ return message
66
+
67
+ def generate_inner(self, message, dataset=None):
68
+ prompt, image_path = self.message_to_promptimg(message, dataset=dataset)
69
+ image = Image.open(image_path).convert('RGB')
70
+ msgs = [{'role': 'user', 'content': prompt}]
71
+ if DATASET_TYPE(dataset) == 'MCQ':
72
+ max_new_tokens = 20
73
+ elif DATASET_TYPE(dataset) == 'Y/N':
74
+ max_new_tokens = 100
75
+ else:
76
+ max_new_tokens = 1024
77
+
78
+ default_kwargs = dict(
79
+ max_new_tokens=max_new_tokens,
80
+ sampling=False,
81
+ num_beams=self.num_beams
82
+ )
83
+ default_kwargs.update(self.kwargs)
84
+ res, _, _ = self.model.chat(
85
+ image=image,
86
+ msgs=msgs,
87
+ context=None,
88
+ tokenizer=self.tokenizer,
89
+ **default_kwargs
90
+ )
91
+ return res
92
+
93
+
94
+ class MiniCPM_Llama3_V(BaseModel):
95
+
96
+ INSTALL_REQ = False
97
+ INTERLEAVE = True
98
+
99
+ def __init__(self, model_path='openbmb/MiniCPM-Llama3-V-2_5', **kwargs):
100
+ assert model_path is not None
101
+ self.model_path = model_path
102
+ print(f'load from {self.model_path}')
103
+ self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
104
+ self.model = self.model.to(dtype=torch.float16)
105
+ self.model.eval().cuda()
106
+ self.kwargs = kwargs
107
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
108
+ torch.cuda.empty_cache()
109
+ self.num_beams = 3
110
+ self.options_system_prompt = ('Carefully read the following question and select the letter corresponding '
111
+ 'to the correct answer. Highlight the applicable choices without giving '
112
+ 'explanations.')
113
+ self.wo_options_system_prompt = 'Carefully read the following question Answer the question directly.'
114
+ self.detail_system_prompt = 'Answer this question in detail.'
115
+ self.vqa_prompt = 'Answer the question using a single word or phrase.'
116
+
117
+ def use_custom_prompt(self, dataset):
118
+ if listinstr(['MCQ', 'VQA'], DATASET_TYPE(dataset)):
119
+ return True
120
+ elif dataset is not None and listinstr(['HallusionBench'], dataset):
121
+ return True
122
+ return False
123
+
124
+ def build_prompt(self, line, dataset=None):
125
+ if isinstance(line, int):
126
+ line = self.data.iloc[line]
127
+
128
+ tgt_path = self.dump_image(line, dataset)
129
+ system_prompt = ''
130
+
131
+ question = line['question']
132
+ if DATASET_TYPE(dataset) == 'MCQ':
133
+ options = {
134
+ cand: line[cand]
135
+ for cand in string.ascii_uppercase
136
+ if cand in line and not pd.isna(line[cand])
137
+ }
138
+ options_prompt = 'Options:\n'
139
+ for key, item in options.items():
140
+ options_prompt += f'{key}. {item}\n'
141
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
142
+ prompt = ''
143
+ if hint is not None:
144
+ prompt += f'Hint: {hint}\n'
145
+ prompt += f'Question: {question}\n'
146
+ if len(options):
147
+ prompt += options_prompt
148
+ system_prompt = self.options_system_prompt + '\nPlease just indicate your choice.'
149
+ else:
150
+ system_prompt = self.wo_options_system_prompt
151
+ if 'MMMU' in dataset: # Corner Case
152
+ prompt = system_prompt + '\n' + prompt
153
+ system_prompt = ''
154
+ elif dataset is not None and listinstr(['HallusionBench'], dataset):
155
+ question = line['question'] + ' Yes or No?'
156
+ prompt = question
157
+ elif dataset is not None and listinstr(['MME'], dataset):
158
+ question = line['question'] + ' Yes or No?'
159
+ prompt = question
160
+ elif dataset is not None and listinstr(['OCRBench'], dataset):
161
+ system_prompt = self.vqa_prompt
162
+ question = line['question']
163
+ prompt = question
164
+ elif DATASET_TYPE(dataset) == 'VQA':
165
+ if listinstr(['LLaVABench', 'MMLongBench_DOC'], dataset):
166
+ system_prompt = ''
167
+ prompt = question
168
+ elif listinstr(['MMVet'], dataset):
169
+ system_prompt = self.detail_system_prompt
170
+ prompt = question
171
+ else:
172
+ system_prompt = self.vqa_prompt
173
+ prompt = question
174
+
175
+ msgs = []
176
+ if system_prompt:
177
+ msgs.append(dict(type='text', value=system_prompt))
178
+ if isinstance(tgt_path, list):
179
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
180
+ else:
181
+ msgs = [dict(type='image', value=tgt_path)]
182
+ msgs.append(dict(type='text', value=prompt))
183
+ return msgs
184
+
185
+ def generate_inner(self, message, dataset=None):
186
+ if DATASET_TYPE(dataset) == 'MCQ':
187
+ max_new_tokens = 200
188
+ elif DATASET_TYPE(dataset) == 'Y/N':
189
+ max_new_tokens = 3
190
+ else:
191
+ max_new_tokens = 1024
192
+
193
+ default_kwargs = dict(
194
+ max_new_tokens=max_new_tokens,
195
+ sampling=False,
196
+ num_beams=self.num_beams,
197
+ )
198
+ default_kwargs.update(self.kwargs)
199
+
200
+ content = []
201
+ for x in message:
202
+ if x['type'] == 'text':
203
+ content.append(x['value'])
204
+ elif x['type'] == 'image':
205
+ image = Image.open(x['value']).convert('RGB')
206
+ content.append(image)
207
+ msgs = [{'role': 'user', 'content': content}]
208
+
209
+ res = self.model.chat(
210
+ msgs=msgs,
211
+ context=None,
212
+ image=None,
213
+ tokenizer=self.tokenizer,
214
+ **default_kwargs
215
+ )
216
+
217
+ if isinstance(res, tuple) and len(res) > 0:
218
+ res = res[0]
219
+ return res
220
+
221
+ def chat_inner(self, message, dataset=None):
222
+ max_new_tokens = 1024
223
+
224
+ default_kwargs = dict(
225
+ max_new_tokens=max_new_tokens,
226
+ sampling=False,
227
+ num_beams=self.num_beams,
228
+ )
229
+ default_kwargs.update(self.kwargs)
230
+
231
+ msgs = []
232
+ for msg in message:
233
+ content = []
234
+ if len(msg['content']) == 1 and msg['content'][0]['type'] == 'text':
235
+ msg_new = {'role': msg['role'], 'content': msg['content'][0]['value']}
236
+ msgs.append(msg_new)
237
+ continue
238
+
239
+ for x in msg['content']:
240
+ if x['type'] == 'text':
241
+ content.append(x['value'])
242
+ elif x['type'] == 'image':
243
+ image = Image.open(x['value']).convert('RGB')
244
+ content.append(image)
245
+ msg_new = {'role': msg['role'], 'content': content}
246
+ msgs.append(msg_new)
247
+
248
+ res = self.model.chat(
249
+ msgs=msgs,
250
+ context=None,
251
+ image=None,
252
+ tokenizer=self.tokenizer,
253
+ **default_kwargs)
254
+
255
+ if isinstance(res, tuple) and len(res) > 0:
256
+ res = res[0]
257
+ return res
258
+
259
+
260
+ class MiniCPM_V_2_6(BaseModel):
261
+ INSTALL_REQ = False
262
+ INTERLEAVE = True
263
+
264
+ def __init__(self, model_path='openbmb/MiniCPM-V-2_6', **kwargs):
265
+ random.seed(0)
266
+ np.random.seed(0)
267
+ torch.manual_seed(0)
268
+ torch.cuda.manual_seed_all(0)
269
+
270
+ assert model_path is not None
271
+ self.model_path = model_path
272
+ print(f'load from path {self.model_path}')
273
+ self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True)
274
+ self.model = self.model.to(dtype=torch.bfloat16)
275
+ self.model.eval().cuda()
276
+
277
+ self.kwargs = kwargs
278
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
279
+ torch.cuda.empty_cache()
280
+ self.num_beams = 3
281
+
282
+ self.options_suffix_prompt = '''\nAnswer with the option's letter from the given choices directly.'''
283
+ self.wo_options_system_prompt = 'Carefully read the following question Answer the question directly.'
284
+ self.detail_system_prompt = 'Answer this question in detail.'
285
+ self.vqa_prompt = 'Answer the question using a single word or phrase.'
286
+
287
+ self.multi_choice_cot_prompt = ('''Carefully read the following multichoice question, solve it step '''
288
+ '''by step and finally pick the option associated with the correct '''
289
+ '''answer in the format of "Answer: selected option\n\n''')
290
+ self.short_ans_cot_prompt = ('''Read the following question carefully, solve it step by step, and '''
291
+ '''then output the final answer in the format of "Answer: single number '''
292
+ '''or single word or phrase".\n\n''')
293
+
294
+ def use_custom_prompt(self, dataset=None):
295
+ if dataset is None:
296
+ return False
297
+ if DATASET_TYPE(dataset) in ['MCQ', 'VQA', 'Y/N']:
298
+ return True
299
+ return False
300
+
301
+ def use_cot(self, dataset=None):
302
+ if dataset is None:
303
+ return False
304
+ if listinstr(['MMMU', 'HallusionBench', 'OCRBench', 'ChartQA'], dataset):
305
+ return True
306
+ elif listinstr(['MathVista', 'MMVet', 'MMBench', 'MMStar', 'AI2D', 'RealWorldQA',
307
+ 'POPE', 'ScienceQA', 'TextVQA', 'DocVQA'], dataset):
308
+ return False
309
+ else:
310
+ return False
311
+
312
+ def use_upsize(self, dataset=None):
313
+ if dataset is None:
314
+ return False
315
+ if listinstr(['MMVet', 'MMBench', 'MMStar', 'AI2D', 'OCRBench'], dataset):
316
+ return True
317
+ else:
318
+ return False
319
+
320
+ def build_prompt(self, line, dataset=None):
321
+ if isinstance(line, int):
322
+ line = self.data.iloc[line]
323
+
324
+ tgt_path = self.dump_image(line, dataset)
325
+ system_prompt, prompt = '', ''
326
+
327
+ question = line['question']
328
+
329
+ if not self.use_cot(dataset):
330
+ if DATASET_TYPE(dataset) == 'MCQ':
331
+ options = {
332
+ cand: line[cand]
333
+ for cand in string.ascii_uppercase
334
+ if cand in line and not pd.isna(line[cand])
335
+ }
336
+ options_prompt = 'Options:\n'
337
+ for key, item in options.items():
338
+ options_prompt += f'{key}. {item}\n'
339
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
340
+ if hint is not None:
341
+ prompt += f'Hint: {hint}\n'
342
+ prompt += f'Question: {question}\n'
343
+ if len(options):
344
+ prompt += options_prompt
345
+ prompt += self.options_suffix_prompt
346
+ else:
347
+ system_prompt = self.wo_options_system_prompt
348
+
349
+ if 'MMMU' in dataset:
350
+ if len(system_prompt) > 0:
351
+ prompt = system_prompt + '\n' + prompt
352
+ system_prompt = ''
353
+ elif dataset is not None and listinstr(['HallusionBench'], dataset):
354
+ question += ' Yes or No?'
355
+ prompt = question
356
+ elif dataset is not None and listinstr(['OCRBench'], dataset):
357
+ system_prompt = self.vqa_prompt
358
+ prompt = question
359
+ elif DATASET_TYPE(dataset) == 'VQA':
360
+ if listinstr(['LLaVABench'], dataset):
361
+ system_prompt = ''
362
+ elif listinstr(['MMVet'], dataset):
363
+ system_prompt = self.detail_system_prompt
364
+ else:
365
+ system_prompt = self.vqa_prompt
366
+ prompt = question
367
+ else:
368
+ prompt = question
369
+ else:
370
+ has_options = True
371
+ if DATASET_TYPE(dataset) == 'MCQ':
372
+ options = {
373
+ cand: line[cand]
374
+ for cand in string.ascii_uppercase
375
+ if cand in line and not pd.isna(line[cand])
376
+ }
377
+ options_prompt = ''
378
+ for key, item in options.items():
379
+ options_prompt += f'{key}. {item}\n'
380
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
381
+ if hint is not None:
382
+ prompt += f'Hint: {hint}\n'
383
+ prompt += f'{question}\n'
384
+
385
+ if len(options):
386
+ prompt += options_prompt
387
+ else:
388
+ has_options = False
389
+
390
+ if 'MMMU' in dataset:
391
+ if len(system_prompt) > 0:
392
+ prompt = system_prompt + '\n' + prompt
393
+ system_prompt = ''
394
+ else:
395
+ prompt = question
396
+
397
+ if DATASET_TYPE(dataset) in ['MCQ', 'Y/N', 'VQA']:
398
+ if DATASET_TYPE(dataset) == 'MCQ':
399
+ if has_options:
400
+ prompt = self.multi_choice_cot_prompt + prompt
401
+ else:
402
+ prompt = self.short_ans_cot_prompt + prompt
403
+ elif DATASET_TYPE(dataset) == 'Y/N':
404
+ prompt = self.short_ans_cot_prompt + prompt
405
+ else:
406
+ prompt = self.short_ans_cot_prompt + prompt
407
+
408
+ msgs = []
409
+ if system_prompt:
410
+ msgs.append(dict(type='text', value=system_prompt))
411
+ if isinstance(tgt_path, list):
412
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
413
+ else:
414
+ msgs = [dict(type='image', value=tgt_path)]
415
+ msgs.append(dict(type='text', value=prompt))
416
+
417
+ return msgs
418
+
419
+ def generate_inner(self, message, dataset=None):
420
+ if DATASET_MODALITY(dataset) == 'VIDEO':
421
+ max_slice_nums = 1
422
+ use_image_id = False
423
+ max_inp_length = 2048 * 10
424
+ else:
425
+ max_slice_nums = None
426
+ use_image_id = True
427
+ max_inp_length = 8192
428
+
429
+ max_new_tokens = 2048
430
+ default_kwargs = dict(
431
+ max_new_tokens=max_new_tokens,
432
+ sampling=False,
433
+ num_beams=self.num_beams,
434
+ )
435
+ default_kwargs.update(self.kwargs)
436
+
437
+ content = []
438
+
439
+ for x in message:
440
+ if x['type'] == 'text':
441
+ content.append(x['value'])
442
+ elif x['type'] == 'image':
443
+ image = Image.open(x['value']).convert('RGB')
444
+ if not self.use_upsize(dataset):
445
+ content.append(image)
446
+ else:
447
+ img_width, img_height = image.width, image.height
448
+ if (img_width * img_height) >= (1344 * 1344):
449
+ content.append(image)
450
+ else:
451
+ ratio = math.sqrt((1344 * 1344) / (img_width * img_height))
452
+ max_img_width = int(img_width * ratio)
453
+ new_img_width = random.randint(img_width, max_img_width)
454
+ new_img_height = int(new_img_width / img_width * img_height)
455
+ resized_image = image.resize((new_img_width, new_img_height))
456
+ content.append(resized_image)
457
+ msgs = [{'role': 'user', 'content': content}]
458
+
459
+ res = self.model.chat(
460
+ image=None,
461
+ msgs=msgs,
462
+ context=None,
463
+ tokenizer=self.tokenizer,
464
+ max_inp_length=max_inp_length,
465
+ use_image_id=use_image_id,
466
+ max_slice_nums=max_slice_nums,
467
+ **default_kwargs
468
+ )
469
+
470
+ if isinstance(res, tuple) and len(res) > 0:
471
+ res = res[0]
472
+
473
+ return res
474
+
475
+
476
+ class MiniCPM_o_2_6(BaseModel):
477
+ INSTALL_REQ = False
478
+ INTERLEAVE = True
479
+
480
+ def __init__(self, model_path='openbmb/MiniCPM-o-2_6', **kwargs):
481
+ random.seed(0)
482
+ np.random.seed(0)
483
+ torch.manual_seed(0)
484
+ torch.cuda.manual_seed_all(0)
485
+
486
+ assert model_path is not None
487
+ self.model_path = model_path
488
+ print(f'load from path {self.model_path}')
489
+ self.model = AutoModel.from_pretrained(
490
+ self.model_path,
491
+ trust_remote_code=True,
492
+ attn_implementation='sdpa',
493
+ torch_dtype=torch.bfloat16,
494
+ init_vision=True,
495
+ init_audio=False,
496
+ init_tts=False
497
+ )
498
+
499
+ self.model.eval().cuda()
500
+
501
+ self.kwargs = kwargs
502
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
503
+ torch.cuda.empty_cache()
504
+
505
+ num_beams = int(os.getenv("NUM_BEAMS", "3"))
506
+ self.num_beams = 3 if self.model_path == 'openbmb/MiniCPM-o-2_6' else num_beams
507
+
508
+ repetition_penalty = float(os.getenv("PENALTY", "1.2"))
509
+ self.repetition_penalty = repetition_penalty
510
+
511
+ self.options_suffix_prompt = '''\nAnswer with the option's letter from the given choices directly.'''
512
+ self.wo_options_system_prompt = 'Carefully read the following question Answer the question directly.'
513
+ self.detail_system_prompt = 'Answer this question in detail.'
514
+ self.vqa_prompt = 'Answer the question using a single word or phrase.'
515
+
516
+ self.multi_choice_cot_prompt = ('''Carefully read the following multichoice question, solve it step '''
517
+ '''by step and finally pick the option associated with the correct '''
518
+ '''answer in the format of "Answer: selected option\n\n''')
519
+ self.short_ans_cot_prompt = ('''Read the following question carefully, solve it step by step, and '''
520
+ '''then output the final answer in the format of "Answer: single number '''
521
+ '''or single word or phrase".\n\n''')
522
+
523
+ def use_custom_prompt(self, dataset=None):
524
+ if dataset is None:
525
+ return False
526
+ if listinstr(['MCQ', 'VQA', 'Y/N'], DATASET_TYPE(dataset)):
527
+ return True
528
+ return False
529
+
530
+ def use_cot(self, dataset=None):
531
+ if dataset is None:
532
+ return False
533
+ if listinstr(['MMMU', 'MathVista', 'OCRBench', 'ChartQA', 'MathVision', 'MathVerse_MINI_Vision_Only'], dataset):
534
+ return True
535
+ elif listinstr(['MMVet', 'MMBench', 'MMStar', 'HallusionBench', 'AI2D', 'RealWorldQA',
536
+ 'POPE', 'ScienceQA', 'TextVQA', 'DocVQA'], dataset):
537
+ return False
538
+ else:
539
+ return False
540
+
541
+ def use_upsize(self, dataset=None):
542
+ if dataset is None:
543
+ return False
544
+ if listinstr(['MathVista', 'MMBench_TEST_CN', 'MMStar', 'AI2D', 'OCRBench', 'DynaMath'], dataset):
545
+ return True
546
+ else:
547
+ return False
548
+
549
+ def build_prompt(self, line, dataset=None):
550
+ if isinstance(line, int):
551
+ line = self.data.iloc[line]
552
+
553
+ tgt_path = self.dump_image(line, dataset)
554
+ system_prompt, prompt = '', ''
555
+
556
+ question = line['question']
557
+
558
+ if not self.use_cot(dataset):
559
+ if DATASET_TYPE(dataset) == 'MCQ':
560
+ options = {
561
+ cand: line[cand]
562
+ for cand in string.ascii_uppercase
563
+ if cand in line and not pd.isna(line[cand])
564
+ }
565
+ options_prompt = 'Options:\n'
566
+ for key, item in options.items():
567
+ options_prompt += f'{key}. {item}\n'
568
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
569
+ if hint is not None:
570
+ prompt += f'Hint: {hint}\n'
571
+ prompt += f'Question: {question}\n'
572
+ if len(options):
573
+ prompt += options_prompt
574
+ prompt += self.options_suffix_prompt
575
+ else:
576
+ system_prompt = self.wo_options_system_prompt
577
+
578
+ if 'MMMU' in dataset:
579
+ if len(system_prompt) > 0:
580
+ prompt = system_prompt + '\n' + prompt
581
+ system_prompt = ''
582
+ elif dataset is not None and listinstr(['HallusionBench'], dataset):
583
+ question += ' Yes or No?'
584
+ prompt = question
585
+ elif dataset is not None and listinstr(['OCRBench'], dataset):
586
+ system_prompt = self.vqa_prompt
587
+ prompt = question
588
+ elif DATASET_TYPE(dataset) == 'VQA':
589
+ if listinstr(['LLaVABench'], dataset):
590
+ system_prompt = ''
591
+ elif listinstr(['MMVet'], dataset):
592
+ system_prompt = self.detail_system_prompt
593
+ else:
594
+ system_prompt = self.vqa_prompt
595
+ prompt = question
596
+ else:
597
+ prompt = question
598
+ else:
599
+ has_options = True
600
+ if DATASET_TYPE(dataset) == 'MCQ':
601
+ options = {
602
+ cand: line[cand]
603
+ for cand in string.ascii_uppercase
604
+ if cand in line and not pd.isna(line[cand])
605
+ }
606
+ options_prompt = ''
607
+ for key, item in options.items():
608
+ options_prompt += f'{key}. {item}\n'
609
+ hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
610
+ if hint is not None:
611
+ prompt += f'Hint: {hint}\n'
612
+ prompt += f'{question}\n'
613
+
614
+ if len(options):
615
+ prompt += options_prompt
616
+ else:
617
+ has_options = False
618
+
619
+ if 'MMMU' in dataset:
620
+ if len(system_prompt) > 0:
621
+ prompt = system_prompt + '\n' + prompt
622
+ system_prompt = ''
623
+ else:
624
+ prompt = question
625
+
626
+ if DATASET_TYPE(dataset) in ['MCQ', 'Y/N', 'VQA']:
627
+ if DATASET_TYPE(dataset) == 'MCQ':
628
+ if has_options:
629
+ prompt = self.multi_choice_cot_prompt + prompt
630
+ else:
631
+ prompt = self.short_ans_cot_prompt + prompt
632
+ elif DATASET_TYPE(dataset) == 'Y/N':
633
+ prompt = self.short_ans_cot_prompt + prompt
634
+ else:
635
+ prompt = self.short_ans_cot_prompt + prompt
636
+
637
+ msgs = []
638
+ if system_prompt:
639
+ msgs.append(dict(type='text', value=system_prompt))
640
+ if isinstance(tgt_path, list):
641
+ msgs.extend([dict(type='image', value=p) for p in tgt_path])
642
+ else:
643
+ msgs = [dict(type='image', value=tgt_path)]
644
+ msgs.append(dict(type='text', value=prompt))
645
+
646
+ return msgs
647
+
648
+ def extract_answer(self, res, dataset=None):
649
+ if dataset is None:
650
+ return res
651
+ if self.use_cot(dataset):
652
+ if DATASET_TYPE(dataset) == 'MCQ':
653
+ pattern = r'Answer:\s*([A-Ia-i])(?![A-Za-z])'
654
+ matches = re.findall(pattern, res, re.DOTALL)
655
+ if matches:
656
+ extracted_res = matches[-1].strip()
657
+ else:
658
+ extracted_res = res
659
+ return extracted_res
660
+ elif DATASET_TYPE(dataset) == 'VQA' and not listinstr(['OCRBench'], dataset):
661
+ pattern = r'Answer:\s*(.*)\s*$'
662
+ match = re.search(pattern, res, re.DOTALL)
663
+ if match:
664
+ extracted_res = match.group(1)
665
+ else:
666
+ extracted_res = res
667
+ return extracted_res
668
+ return res
669
+
670
+ def generate_inner(self, message, dataset=None):
671
+ if DATASET_MODALITY(dataset) == 'VIDEO':
672
+ max_slice_nums = 1
673
+ use_image_id = False
674
+ max_inp_length = 2048 * 10
675
+ else:
676
+ max_slice_nums = None
677
+ use_image_id = True
678
+ max_inp_length = 8192
679
+
680
+ max_new_tokens = 2048
681
+ default_kwargs = dict(
682
+ max_new_tokens=max_new_tokens,
683
+ sampling=False,
684
+ repetition_penalty=self.repetition_penalty,
685
+ num_beams=self.num_beams,
686
+ )
687
+ default_kwargs.update(self.kwargs)
688
+
689
+ content = []
690
+
691
+ for x in message:
692
+ if x['type'] == 'text':
693
+ content.append(x['value'])
694
+ elif x['type'] == 'image':
695
+ image = Image.open(x['value']).convert('RGB')
696
+ if not self.use_upsize(dataset):
697
+ content.append(image)
698
+ else:
699
+ img_width, img_height = image.width, image.height
700
+ if (img_width * img_height) >= (1344 * 1344):
701
+ content.append(image)
702
+ else:
703
+ ratio = math.sqrt((1344 * 1344) / (img_width * img_height))
704
+ max_img_width = int(img_width * ratio)
705
+ new_img_width = random.randint(img_width, max_img_width)
706
+ new_img_height = int(new_img_width / img_width * img_height)
707
+ resized_image = image.resize((new_img_width, new_img_height))
708
+ content.append(resized_image)
709
+ msgs = [{'role': 'user', 'content': content}]
710
+
711
+ res = self.model.chat(
712
+ image=None,
713
+ msgs=msgs,
714
+ context=None,
715
+ tokenizer=self.tokenizer,
716
+ max_inp_length=max_inp_length,
717
+ use_image_id=use_image_id,
718
+ max_slice_nums=max_slice_nums,
719
+ **default_kwargs
720
+ )
721
+
722
+ if isinstance(res, tuple) and len(res) > 0:
723
+ res = res[0]
724
+
725
+ res = self.extract_answer(res, dataset)
726
+
727
+ return res
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # vqa-eval
2
+
3
+ contains vqa_eval kit from the server.
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/datasets/__init__.py ADDED
File without changes
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/datasets/vqa_dataset.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from torch.utils.data import Dataset
5
+
6
+ def prompt_processor(prompt):
7
+ if prompt.startswith('OCR tokens: '):
8
+ pattern = r"Question: (.*?) Short answer:"
9
+ match = re.search(pattern, prompt, re.DOTALL)
10
+ question = match.group(1)
11
+ elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
12
+ if prompt.startswith('Reference OCR token:'):
13
+ question = prompt.split('\n')[1]
14
+ else:
15
+ question = prompt.split('\n')[0]
16
+ elif len(prompt.split('\n')) == 2:
17
+ question = prompt.split('\n')[0]
18
+ else:
19
+ assert False
20
+
21
+ return question.lower()
22
+
23
+ class textVQADataset(Dataset):
24
+ def __init__(
25
+ self,
26
+ image_dir="./downloads/TextVQA/train_images",
27
+ ann_path="./downloads/TextVQA/TextVQA_0.5.1_val.json",
28
+ ):
29
+ self.data = json.load(open(ann_path, "r"))["data"]
30
+ self.image_dir = image_dir
31
+
32
+ def __len__(self):
33
+ return len(self.data)
34
+
35
+ def __getitem__(self, idx):
36
+ question = self.data[idx]['question']
37
+ answers = self.data[idx]['answers']
38
+ img_id = self.data[idx]['image_id']
39
+ qid = self.data[idx]['question_id']
40
+ img_path = os.path.join(self.image_dir, f"{img_id}.jpg")
41
+
42
+ item = {
43
+ "question_id": qid,
44
+ "image_path": img_path,
45
+ "question": question,
46
+ "gt_answers": answers
47
+ }
48
+
49
+ return item
50
+
51
+ class docVQADataset(Dataset):
52
+ def __init__(
53
+ self,
54
+ image_dir= "./downloads/DocVQA/spdocvqa_images",
55
+ ann_path= "./downloads/DocVQA/val_v1.0_withQT.json",
56
+ ocr_token_path=None
57
+ ):
58
+
59
+ self.data = json.load(open(ann_path, "r"))["data"]
60
+ self.image_dir = image_dir
61
+ self.ann_path = ann_path
62
+ if ocr_token_path:
63
+ self.ocr_token_data = {item['image_id']: item for item in json.load(open(ocr_token_path, "r"))["data"]}
64
+
65
+ def __len__(self):
66
+ return len(self.data)
67
+
68
+ def __getitem__(self, idx):
69
+ question_id = self.data[idx]['questionId']
70
+ relative_img_path = self.data[idx]['image']
71
+ corrected_relative_img_path = relative_img_path.replace("documents", "images")
72
+ img_path = os.path.join(self.image_dir, corrected_relative_img_path)
73
+ question = self.data[idx]['question']
74
+ answers = self.data[idx]['answers']
75
+
76
+ question_type = self.data[idx]['question_types']
77
+
78
+ return {
79
+ "question_id": question_id,
80
+ "image_path": img_path,
81
+ "question": question,
82
+ "gt_answers": answers,
83
+ 'question_type': question_type,
84
+ }
85
+
86
+
87
+ class docVQATESTDataset(Dataset):
88
+ def __init__(
89
+ self,
90
+ image_dir= "./downloads/DocVQA/spdocvqa_images",
91
+ ann_path= "./downloads/DocVQA/test_v1.0.json",
92
+ ocr_token_path=None
93
+ ):
94
+
95
+ self.data = json.load(open(ann_path, "r"))["data"]
96
+ self.image_dir = image_dir
97
+ self.ann_path = ann_path
98
+
99
+ def __len__(self):
100
+ return len(self.data)
101
+
102
+ def __getitem__(self, idx):
103
+ question_id = self.data[idx]['questionId']
104
+ relative_img_path = self.data[idx]['image']
105
+ corrected_relative_img_path = relative_img_path.replace("documents", "images")
106
+ img_path = os.path.join(self.image_dir, corrected_relative_img_path)
107
+ question = self.data[idx]['question']
108
+
109
+
110
+ return {
111
+ "question_id": question_id,
112
+ "image_path": img_path,
113
+ "question": question,
114
+ "gt_answers": "",
115
+ 'question_type': "",
116
+ }
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/eval.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import datetime
3
+ import json
4
+ import os
5
+ import torch
6
+
7
+ script_dir = os.path.dirname(os.path.realpath(__file__))
8
+
9
+ sys.path.append(os.path.join(script_dir, '..'))
10
+
11
+ from datasets.vqa_dataset import docVQADataset, docVQATESTDataset, textVQADataset
12
+
13
+
14
+ print(torch.__version__)
15
+
16
+ import numpy as np
17
+
18
+ from eval_utils.getargs import parse_args
19
+ from eval_utils.vqa_evaluate import *
20
+
21
+
22
+ def get_model(args):
23
+ if args.model_name == '':
24
+ raise Exception('Model name cannot be empty str!')
25
+ from models.MiniCPM.minicpmv import MiniCPM_V, MiniCPM_V_2_6, MiniCPM_o_2_6
26
+ model_path = args.model_path
27
+ ckpt = args.ckpt
28
+
29
+ if args.model_name == 'minicpmv':
30
+ model = MiniCPM_V(model_path=model_path, ckpt=ckpt, device=args.device)
31
+ elif args.model_name == 'minicpmv26':
32
+ model = MiniCPM_V_2_6(model_path=model_path, ckpt=ckpt, device=args.device)
33
+ elif args.model_name == 'minicpmo26':
34
+ model = MiniCPM_o_2_6(model_path=model_path, ckpt=ckpt, device=args.device)
35
+ else:
36
+ raise Exception(f"Unexpected Moedel Name {args.model_name}!")
37
+
38
+ return model
39
+
40
+
41
+ def main(args):
42
+ np.random.seed(0)
43
+ max_sample_num = None
44
+
45
+ torch.distributed.init_process_group(
46
+ backend='nccl',
47
+ world_size=int(os.getenv('WORLD_SIZE', '1')),
48
+ rank=int(os.getenv('RANK', '0')),
49
+ )
50
+ torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
51
+ print(f'Init Rank-{torch.distributed.get_rank()}')
52
+ if torch.distributed.is_initialized():
53
+ args.device = torch.device(f"cuda:{torch.cuda.current_device()}")
54
+
55
+ model = get_model(args)
56
+
57
+ result = {}
58
+ time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
59
+
60
+ if args.eval_textVQA or args.eval_all:
61
+ dataset = textVQADataset(args.textVQA_image_dir, args.textVQA_ann_path)
62
+ if max_sample_num is not None:
63
+ dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
64
+ acc = evaluate_VQA(model, dataset, args.model_name, 'textVQA', time, \
65
+ batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
66
+ result['textVQA'] = acc
67
+
68
+ if args.eval_docVQA or args.eval_all:
69
+ dataset = docVQADataset(args.docVQA_image_dir, args.docVQA_ann_path)
70
+ if max_sample_num is not None:
71
+ dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
72
+ acc = evaluate_VQA(model, dataset, args.model_name, 'docVQA', time, \
73
+ batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
74
+ result['docVQA'] = acc
75
+
76
+ if args.eval_docVQATest or args.eval_all:
77
+ dataset = docVQATESTDataset(args.docVQATest_image_dir, args.docVQATest_ann_path)
78
+ if max_sample_num is not None:
79
+ dataset = torch.utils.data.Subset(dataset, range(max_sample_num))
80
+ acc = evaluate_VQA(model, dataset, args.model_name, 'docVQATest', time, \
81
+ batch_size=args.batchsize, generate_method=args.generate_method, answer_path=args.answer_path)
82
+ result['docVQATest'] = acc
83
+
84
+ if torch.distributed.is_initialized():
85
+ torch.distributed.barrier()
86
+
87
+ if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
88
+ return None
89
+
90
+ result_path = os.path.join(os.path.join(args.answer_path, args.model_name), 'result.json')
91
+
92
+ output_flag = False
93
+ for k, v in result.items():
94
+ if v > 0.0:
95
+ output_flag = True
96
+ break
97
+
98
+ if output_flag:
99
+ with open(result_path, "w") as f:
100
+ f.write(json.dumps(result, indent=4))
101
+
102
+
103
+ if __name__ == "__main__":
104
+ args = parse_args()
105
+
106
+ main(args)
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/eval_utils/cal_metric.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import glob
3
+ import re
4
+
5
+ def has_word(sentence, word):
6
+ pattern = r"\b" + re.escape(word) + r"\b"
7
+ match = re.search(pattern, sentence)
8
+ if match:
9
+ return True
10
+ else:
11
+ return False
12
+ def remove_special_chars(s):
13
+ pattern = r"[^a-zA-Z0-9\s]"
14
+ s = re.sub(pattern, "", s)
15
+ return s
16
+
17
+ for model in glob.glob('./answer_save/*'):
18
+ print(model, ':')
19
+ result_list = sorted(glob.glob(f'{model}/*.json'))
20
+ for task_result_path in result_list:
21
+ taskname = task_result_path.split('/')[-1]
22
+ taskname = taskname.split('.')[0]
23
+ if taskname not in ['IIIT5K', 'svt', 'IC13_857', 'IC15_1811', 'svtp', 'ct80',
24
+ 'cocotext', 'ctw', 'totaltext', 'HOST']:
25
+ continue
26
+
27
+ correct = 0
28
+ num = 0
29
+ with open(task_result_path, 'r') as f:
30
+ dict = json.load(f)[:100]
31
+ for i in range(len(dict)):
32
+ gt_answers = dict[i]['gt_answers']
33
+ answer = dict[i]['answer']
34
+ gt_answers = remove_special_chars(gt_answers).lower()
35
+ answer = remove_special_chars(answer).lower()
36
+ if has_word(answer, gt_answers):
37
+ correct+=1
38
+ num+=1
39
+ print(f'{taskname:10s}:{float(correct)/num*100:.2f}')
40
+ print('=' * 32)
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/requirements.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ async-timeout==4.0.2
5
+ attrs==22.2.0
6
+ bitsandbytes==0.37.0
7
+ cchardet==2.1.7
8
+ chardet==5.1.0
9
+ contourpy==1.0.7
10
+ cycler==0.11.0
11
+ filelock==3.9.0
12
+ fonttools==4.38.0
13
+ frozenlist==1.3.3
14
+ huggingface-hub==0.13.4
15
+ importlib-resources==5.12.0
16
+ kiwisolver==1.4.4
17
+ matplotlib==3.7.0
18
+ multidict==6.0.4
19
+ openai==0.27.0
20
+ packaging==23.0
21
+ psutil==5.9.4
22
+ pycocotools==2.0.6
23
+ pyparsing==3.0.9
24
+ python-dateutil==2.8.2
25
+ pyyaml==6.0
26
+ regex==2022.10.31
27
+ tokenizers==0.13.2
28
+ tqdm==4.64.1
29
+ transformers==4.44.2
30
+ timm==0.6.13
31
+ spacy==3.5.1
32
+ webdataset==0.2.48
33
+ scikit-learn==1.2.2
34
+ scipy==1.10.1
35
+ yarl==1.8.2
36
+ zipp==3.14.0
37
+ omegaconf==2.3.0
38
+ opencv-python==4.7.0.72
39
+ iopath==0.1.10
40
+ decord==0.6.0
41
+ tenacity==8.2.2
42
+ peft
43
+ pycocoevalcap
44
+ sentence-transformers
45
+ umap-learn
46
+ notebook
47
+ gradio==3.24.1
48
+ gradio-client==0.0.8
49
+ wandb
r1-a/response_generation/minicpm/MiniCPM-o/eval_mm/vqaeval/transform_docvqatest_for_submission.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ if __name__ == "__main__":
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument("--input_file_path", type=str, default="", help="path to the originial output json.")
7
+ parser.add_argument("--output_file_path", type=str, default="", help="path to where you want to save the processed json.")
8
+ args = parser.parse_args()
9
+
10
+ with open(args.input_file_path , 'r') as f:
11
+ data = json.load(f)
12
+
13
+ transformed_data = [{"questionId": item["question_id"], "answer": item["answer"].replace("</s>", "")} for item in data]
14
+
15
+ with open(args.output_file_path, 'w') as f:
16
+ json.dump(transformed_data, f)