add politics
Browse files- LLaVA-MOSS2/add_extra_data.py +57 -0
- LLaVA-MOSS2/llava/serve/submit.py +124 -97
- LLaVA-MOSS2/read_political.py +40 -0
- LLaVA-MOSS2/scripts/finetune.sh +2 -2
- LLaVA-MOSS2/test.py +45 -73
- LLaVA-MOSS2/vote.py +35 -0
LLaVA-MOSS2/add_extra_data.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from llava.train.train import train
|
| 5 |
+
|
| 6 |
+
with open('./playground/data/llava_v1_5_mix665k.json', 'r', encoding='utf-8') as file:
|
| 7 |
+
data = json.load(file)
|
| 8 |
+
len = len(data)
|
| 9 |
+
|
| 10 |
+
paths = ["/root/.cache/huggingface/hub/datasets--dmayhem93--agieval-gaokao-biology/snapshots/91e58112d4022523e02d07cfbc96a950eac9219f/data/test-00000-of-00001-de5aacbbef2a047d.parquet",
|
| 11 |
+
"/root/.cache/huggingface/hub/datasets--dmayhem93--agieval-gaokao-geography/snapshots/bea8c9da6c3ecf4c07a423b36914f1daa1ba6a1e/data/test-00000-of-00001-82c3eb504d984b0c.parquet",
|
| 12 |
+
"/root/.cache/huggingface/hub/datasets--dmayhem93--agieval-gaokao-chemistry/snapshots/2fb33cf46ce4aeea9409ea3600a3b1d7e5216536/data/test-00000-of-00001-79e3d766a5e30db5.parquet",
|
| 13 |
+
"/root/.cache/huggingface/hub/datasets--dmayhem93--agieval-gaokao-chinese/snapshots/d47e4d2c79b7280a7fb9990a11b036dfb8cdd89b/data/test-00000-of-00001-cb21ebb290e0161f.parquet",
|
| 14 |
+
"/root/.cache/huggingface/hub/datasets--dmayhem93--agieval-gaokao-english/snapshots/691d13566972917f1cdc82f4fa1bad1a5b197cab/data/test-00000-of-00001-8025cecb3b3c0c99.parquet",
|
| 15 |
+
"/root/.cache/huggingface/hub/datasets--dmayhem93--agieval-gaokao-history/snapshots/41252f835bf3198590df5f4d488d64f78b6fd595/data/test-00000-of-00001-92728bc55381f2f3.parquet",
|
| 16 |
+
"/root/.cache/huggingface/hub/datasets--dmayhem93--agieval-gaokao-mathqa/snapshots/54ede00f50d90dab8295c0163e16912ee52f8068/data/test-00000-of-00001-31399d80475862e0.parquet",
|
| 17 |
+
"/root/.cache/huggingface/hub/datasets--dmayhem93--agieval-gaokao-physics/snapshots/3f82847f19ead1a682f0b27cc5c829ac964586bb/data/test-00000-of-00001-d34ffce230cd958a.parquet"]
|
| 18 |
+
datas = []
|
| 19 |
+
for path in paths:
|
| 20 |
+
df = pd.read_parquet(path)
|
| 21 |
+
|
| 22 |
+
dict_list = []
|
| 23 |
+
for index, row in df.iterrows():
|
| 24 |
+
dict_item = {}
|
| 25 |
+
dict_item['id'] = str(len)
|
| 26 |
+
len+=1
|
| 27 |
+
|
| 28 |
+
dict_item['image'] = ""
|
| 29 |
+
|
| 30 |
+
conversion = []
|
| 31 |
+
human = {}
|
| 32 |
+
human['from'] = 'human'
|
| 33 |
+
human['value'] = row['query']
|
| 34 |
+
gpt = {}
|
| 35 |
+
gpt['from'] = 'gpt'
|
| 36 |
+
result = "答案是:"
|
| 37 |
+
for option in row['gold']:
|
| 38 |
+
result += chr(ord('A') + option)
|
| 39 |
+
gpt['value'] = result
|
| 40 |
+
conversion.append(human)
|
| 41 |
+
conversion.append(gpt)
|
| 42 |
+
dict_item['conversations'] = conversion
|
| 43 |
+
|
| 44 |
+
print(dict_item)
|
| 45 |
+
|
| 46 |
+
dict_list.append(dict_item)
|
| 47 |
+
|
| 48 |
+
datas = datas + dict_list
|
| 49 |
+
data = data + datas
|
| 50 |
+
with open('data_with_extra_data.json', 'w', encoding='utf-8') as file:
|
| 51 |
+
# 使用json.dump()函数将字典写入文件
|
| 52 |
+
json.dump(data, file, ensure_ascii=False, indent=4)
|
| 53 |
+
|
| 54 |
+
new_data = data[::2]
|
| 55 |
+
with open('data_with_extra_data_half.json', 'w', encoding='utf-8') as file:
|
| 56 |
+
# 使用json.dump()函数将字典写入文件
|
| 57 |
+
json.dump(new_data, file, ensure_ascii=False, indent=4)
|
LLaVA-MOSS2/llava/serve/submit.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import argparse
|
| 3 |
import torch
|
|
@@ -5,6 +6,8 @@ import json
|
|
| 5 |
import re
|
| 6 |
|
| 7 |
import sys
|
|
|
|
|
|
|
| 8 |
sys.path.append('/root/workspace/my-llava-moss2/LLaVA-MOSS2')
|
| 9 |
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
| 10 |
from llava.conversation import conv_templates, SeparatorStyle
|
|
@@ -19,6 +22,18 @@ from PIL import Image
|
|
| 19 |
from io import BytesIO
|
| 20 |
from transformers import TextStreamer
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def load_image(image_file):
|
| 24 |
if image_file.startswith('http://') or image_file.startswith('https://'):
|
|
@@ -55,24 +70,6 @@ def get_prompt(key, question, len_of_pictures, image_token):
|
|
| 55 |
D.4
|
| 56 |
|
| 57 |
### 回答:
|
| 58 |
-
根据欧几里得算法,逐步解析计算两个数6和7的最大公约数(gcd)的步骤如下:
|
| 59 |
-
|
| 60 |
-
1. 判断6和7是否相等:不相等。
|
| 61 |
-
2. 判断6和7大小关系,7 > 6,所以用更大的数7减去较小的数6得到结果1。
|
| 62 |
-
3. 现在计算6和1的最大公约数。
|
| 63 |
-
4. 6 > 1,根据算法用更大的数6减去较小的数1得到结果5。
|
| 64 |
-
5. 再计算5和1的最大公约数。
|
| 65 |
-
6. 5 > 1,用5减去1得到结果4。
|
| 66 |
-
7. 再计算4和1的最大公约数。
|
| 67 |
-
8. 4 > 1,用4减去1得到结果3。
|
| 68 |
-
9. 再计算3和1的最大公约数。
|
| 69 |
-
10. 3 > 1,用3减去1得到结果2。
|
| 70 |
-
11. 再计算2和1的最大公约数。
|
| 71 |
-
12. 2 > 1,用2减去1得到结果1。
|
| 72 |
-
13. 最后计算1和1的最大公约数,两数相等,gcd即为这两个数,也就是1。
|
| 73 |
-
|
| 74 |
-
因此,6和7的最大公约数是1。
|
| 75 |
-
|
| 76 |
答案是:A.
|
| 77 |
|
| 78 |
题目如下:
|
|
@@ -136,90 +133,120 @@ def main(args):
|
|
| 136 |
test_data_path = './playground/test'
|
| 137 |
questions_path = 'playground/test/questions.json'
|
| 138 |
with open(questions_path, 'r', encoding='utf-8') as file:
|
| 139 |
-
|
| 140 |
|
| 141 |
answer_dic = {'A':0, 'B':0, 'C':0, 'D':0}
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
image_tensor =
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
#
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
else:
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
# 使用json.dump()函数将字典写入文件
|
| 222 |
-
json.dump(
|
| 223 |
|
| 224 |
|
| 225 |
# while True:
|
|
|
|
| 1 |
+
import copy
|
| 2 |
import os
|
| 3 |
import argparse
|
| 4 |
import torch
|
|
|
|
| 6 |
import re
|
| 7 |
|
| 8 |
import sys
|
| 9 |
+
|
| 10 |
+
import tqdm
|
| 11 |
sys.path.append('/root/workspace/my-llava-moss2/LLaVA-MOSS2')
|
| 12 |
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
| 13 |
from llava.conversation import conv_templates, SeparatorStyle
|
|
|
|
| 22 |
from io import BytesIO
|
| 23 |
from transformers import TextStreamer
|
| 24 |
|
| 25 |
+
class SilentStreamer(TextStreamer):
|
| 26 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 27 |
+
pass
|
| 28 |
+
# if self.batch_stream:
|
| 29 |
+
# if not self.text_cache:
|
| 30 |
+
# self.text_cache = text
|
| 31 |
+
# elif text:
|
| 32 |
+
# self.text_cache = [i + j for i, j in zip(self.text_cache, text)]
|
| 33 |
+
# # print(f'\r{self.text_cache}', flush=True, end="" if not stream_end else None)
|
| 34 |
+
# else:
|
| 35 |
+
# # print(text, flush=True, end="" if not stream_end else None)
|
| 36 |
+
# pass
|
| 37 |
|
| 38 |
def load_image(image_file):
|
| 39 |
if image_file.startswith('http://') or image_file.startswith('https://'):
|
|
|
|
| 70 |
D.4
|
| 71 |
|
| 72 |
### 回答:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
答案是:A.
|
| 74 |
|
| 75 |
题目如下:
|
|
|
|
| 133 |
test_data_path = './playground/test'
|
| 134 |
questions_path = 'playground/test/questions.json'
|
| 135 |
with open(questions_path, 'r', encoding='utf-8') as file:
|
| 136 |
+
questions_origin = json.load(file)
|
| 137 |
|
| 138 |
answer_dic = {'A':0, 'B':0, 'C':0, 'D':0}
|
| 139 |
|
| 140 |
+
answers = []
|
| 141 |
+
|
| 142 |
+
for i in tqdm.tqdm(range(0, 5), desc="Voting Processing"):
|
| 143 |
+
questions = copy.deepcopy(questions_origin)
|
| 144 |
+
for subject in questions:
|
| 145 |
+
example = subject['example']
|
| 146 |
+
for question_itme in tqdm.tqdm(example, desc = f'output_{i}.json ' + subject['keyword'] + ' Processing'):
|
| 147 |
+
picture = question_itme['picture']
|
| 148 |
+
question = question_itme['question']
|
| 149 |
+
# print("question " + str(question_itme['index']) + ":\n")
|
| 150 |
+
|
| 151 |
+
images = [load_image(os.path.join(test_data_path, picture_path)) for picture_path in picture]
|
| 152 |
+
images_size = [image.size for image in images]
|
| 153 |
+
image_tensor = process_images(images, image_processor, model.config)
|
| 154 |
+
|
| 155 |
+
# image = load_image(args.image_file)
|
| 156 |
+
# image_size = image.size
|
| 157 |
+
# # Similar operation in model_worker.py
|
| 158 |
+
# image_tensor = process_images([image], image_processor, model.config)
|
| 159 |
+
if type(image_tensor) is list:
|
| 160 |
+
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
|
| 161 |
+
else:
|
| 162 |
+
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
| 163 |
+
|
| 164 |
+
if len(images) != 0:
|
| 165 |
+
# first message
|
| 166 |
+
|
| 167 |
+
if model.config.mm_use_im_start_end:
|
| 168 |
+
# if len(images) == 4:
|
| 169 |
+
# inp = question +'\nA.' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\nB.' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\nC.' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\nD.' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
| 170 |
+
# elif len(images) == 5:
|
| 171 |
+
# inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + question +'\nA.' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\nB.' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\nC.' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\nD.' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
|
| 172 |
+
# else:
|
| 173 |
+
# inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + question
|
| 174 |
+
inp = get_prompt(subject['keyword'], question, len(picture), DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN)
|
| 175 |
+
else:
|
| 176 |
+
# if len(images) == 4:
|
| 177 |
+
# inp = question +'\nA.' + DEFAULT_IMAGE_TOKEN + '\nB.' + DEFAULT_IMAGE_TOKEN + '\nC.' + DEFAULT_IMAGE_TOKEN + '\nD.' + DEFAULT_IMAGE_TOKEN
|
| 178 |
+
# elif len(images) == 5:
|
| 179 |
+
# inp = DEFAULT_IMAGE_TOKEN + question +'\nA.' + DEFAULT_IMAGE_TOKEN + '\nB.' + DEFAULT_IMAGE_TOKEN + '\nC.' + DEFAULT_IMAGE_TOKEN + '\nD.' + DEFAULT_IMAGE_TOKEN
|
| 180 |
+
# else:
|
| 181 |
+
# inp = DEFAULT_IMAGE_TOKEN + '\n' + question
|
| 182 |
+
inp = get_prompt(subject['keyword'], question, len(picture), DEFAULT_IMAGE_TOKEN)
|
| 183 |
+
|
| 184 |
+
images = None
|
| 185 |
+
|
| 186 |
+
conv = conv_templates[args.conv_mode].copy()
|
| 187 |
+
if "mpt" in model_name.lower():
|
| 188 |
+
roles = ('user', 'assistant')
|
| 189 |
else:
|
| 190 |
+
roles = conv.roles
|
| 191 |
+
conv.append_message(conv.roles[0], inp)
|
| 192 |
+
conv.append_message(conv.roles[1], None)
|
| 193 |
+
prompt = conv.get_prompt()
|
| 194 |
+
|
| 195 |
+
input_ids = tokenizer_image_token(inp, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
| 196 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 197 |
+
keywords = [stop_str]
|
| 198 |
+
streamer = SilentStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 199 |
+
|
| 200 |
+
with torch.inference_mode():
|
| 201 |
+
output_ids = model.generate(
|
| 202 |
+
input_ids,
|
| 203 |
+
images=image_tensor,
|
| 204 |
+
image_sizes=images_size,
|
| 205 |
+
do_sample=True if args.temperature > 0 else False,
|
| 206 |
+
temperature=args.temperature,
|
| 207 |
+
max_new_tokens=args.max_new_tokens,
|
| 208 |
+
streamer=streamer,
|
| 209 |
+
use_cache=True)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
outputs = tokenizer.decode(output_ids[0]).strip()
|
| 213 |
+
outputs = re.sub(r'\([^()]*\)', '', outputs)
|
| 214 |
+
outputs = re.sub(r'<s>|</s>', '', outputs)
|
| 215 |
+
outputs = extract(outputs, answer_dic)
|
| 216 |
+
conv.messages[-1][-1] = outputs
|
| 217 |
+
question_itme['model_answer'] = [outputs]
|
| 218 |
+
question_itme.pop('picture')
|
| 219 |
+
question_itme.pop('question')
|
| 220 |
+
|
| 221 |
+
# print(subject['keyword'] + "finished")
|
| 222 |
+
|
| 223 |
+
answers.append(questions)
|
| 224 |
+
|
| 225 |
+
final_ans = answers[0]
|
| 226 |
+
for ans in answers:
|
| 227 |
+
for i, sub in enumerate(ans):
|
| 228 |
+
example = sub['example']
|
| 229 |
+
for j, item in enumerate(example):
|
| 230 |
+
item_ans = item['model_answer']
|
| 231 |
+
index = ord(item_ans[0]) - 65
|
| 232 |
+
if 'count' not in final_ans:
|
| 233 |
+
final_ans[i]['example'][j]['count'] = [0] * 4
|
| 234 |
+
final_ans[i]['example'][j]['count'][index] += 1
|
| 235 |
+
|
| 236 |
+
for sub in final_ans:
|
| 237 |
+
example = sub['example']
|
| 238 |
+
for item in example:
|
| 239 |
+
max = 0
|
| 240 |
+
for i in range(1, 4):
|
| 241 |
+
if item['count'][i] > item['count'][max]:
|
| 242 |
+
max = i
|
| 243 |
+
item['model_answer'] = str(chr(max + 65))
|
| 244 |
+
item.pop('count')
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
with open('final_answer.json', 'w', encoding='utf-8') as file:
|
| 248 |
# 使用json.dump()函数将字典写入文件
|
| 249 |
+
json.dump(final_ans, file, ensure_ascii=False, indent=4)
|
| 250 |
|
| 251 |
|
| 252 |
# while True:
|
LLaVA-MOSS2/read_political.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import jsonlines
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
with open('data_with_extra_data.json', 'r', encoding='utf-8') as file:
|
| 5 |
+
data = json.load(file)
|
| 6 |
+
len = len(data)
|
| 7 |
+
|
| 8 |
+
path = '/root/.cache/huggingface/hub/datasets--RUCAIBox--gaokao-bench/snapshots/49877cf53b6db9c24d7d285161fc12bba2f85d29/test/2010-2022_Political_Science_MCQs.jsonl'
|
| 9 |
+
with open(path, 'r+', encoding='utf-8') as file:
|
| 10 |
+
dict_list = []
|
| 11 |
+
for line in jsonlines.Reader(file):
|
| 12 |
+
dict_item = {}
|
| 13 |
+
dict_item['id'] = str(len)
|
| 14 |
+
len += 1
|
| 15 |
+
|
| 16 |
+
dict_item['image'] = ""
|
| 17 |
+
|
| 18 |
+
conversion = []
|
| 19 |
+
human = {}
|
| 20 |
+
human['from'] = 'human'
|
| 21 |
+
human['value'] = line['question']
|
| 22 |
+
gpt = {}
|
| 23 |
+
gpt['from'] = 'gpt'
|
| 24 |
+
result = line['analysis']
|
| 25 |
+
result += "答案是:" + ''.join(line['answer'])
|
| 26 |
+
gpt['value'] = result
|
| 27 |
+
conversion.append(human)
|
| 28 |
+
conversion.append(gpt)
|
| 29 |
+
dict_item['conversations'] = conversion
|
| 30 |
+
|
| 31 |
+
print(dict_item)
|
| 32 |
+
|
| 33 |
+
dict_list.append(dict_item)
|
| 34 |
+
|
| 35 |
+
final_data = dict_list + data
|
| 36 |
+
|
| 37 |
+
with open('political_data_with_extra_data.json', 'w', encoding='utf-8') as file:
|
| 38 |
+
# 使用json.dump()函数将字典写入文件
|
| 39 |
+
json.dump(data, file, ensure_ascii=False, indent=4)
|
| 40 |
+
|
LLaVA-MOSS2/scripts/finetune.sh
CHANGED
|
@@ -15,7 +15,7 @@ deepspeed llava/train/train_mem.py \
|
|
| 15 |
--deepspeed ./scripts/zero2.json \
|
| 16 |
--model_name_or_path /root/.cache/huggingface/hub/models--fnlp--moss2-2_5b-chat/snapshots/3eda5a066c519990bf5f9ba056f5f8ef81531c83 \
|
| 17 |
--version $PROMPT_VERSION \
|
| 18 |
-
--data_path ./
|
| 19 |
--image_folder ./playground/data \
|
| 20 |
--vision_tower openai/clip-vit-large-patch14 \
|
| 21 |
--pretrain_mm_mlp_adapter ./checkpoints/llava-moss2-2_5b-chat-pretrain/mm_projector.bin \
|
|
@@ -23,7 +23,7 @@ deepspeed llava/train/train_mem.py \
|
|
| 23 |
--mm_use_im_start_end False \
|
| 24 |
--mm_use_im_patch_token False \
|
| 25 |
--bf16 True \
|
| 26 |
-
--max_steps
|
| 27 |
--per_device_train_batch_size 2 \
|
| 28 |
--per_device_eval_batch_size 2 \
|
| 29 |
--gradient_accumulation_steps 2 \
|
|
|
|
| 15 |
--deepspeed ./scripts/zero2.json \
|
| 16 |
--model_name_or_path /root/.cache/huggingface/hub/models--fnlp--moss2-2_5b-chat/snapshots/3eda5a066c519990bf5f9ba056f5f8ef81531c83 \
|
| 17 |
--version $PROMPT_VERSION \
|
| 18 |
+
--data_path ./data_with_extra_data_half.json\
|
| 19 |
--image_folder ./playground/data \
|
| 20 |
--vision_tower openai/clip-vit-large-patch14 \
|
| 21 |
--pretrain_mm_mlp_adapter ./checkpoints/llava-moss2-2_5b-chat-pretrain/mm_projector.bin \
|
|
|
|
| 23 |
--mm_use_im_start_end False \
|
| 24 |
--mm_use_im_patch_token False \
|
| 25 |
--bf16 True \
|
| 26 |
+
--max_steps 40000 \
|
| 27 |
--per_device_train_batch_size 2 \
|
| 28 |
--per_device_eval_batch_size 2 \
|
| 29 |
--gradient_accumulation_steps 2 \
|
LLaVA-MOSS2/test.py
CHANGED
|
@@ -1,80 +1,52 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
# pattern = re.compile(r'\s([A-D]\.\s.*[^\n])')
|
| 4 |
-
# # 使用findall查找所有匹配的选项
|
| 5 |
-
# options = pattern.findall(quesiton)
|
| 6 |
-
# print(options)
|
| 7 |
-
|
| 8 |
-
# options = '\n'.join(f"{'ABCDEFG'[i]}. {'Image'}" for i in range(0, 4))
|
| 9 |
-
# print(options)
|
| 10 |
-
|
| 11 |
-
# def get_prompt(key, question, len_of_pictures, image_token):
|
| 12 |
-
|
| 13 |
-
# pattern = re.compile(r'\s([A-D]\.\s.*[^\n])')
|
| 14 |
-
# # 使用findall查找所有匹配的选项
|
| 15 |
-
# options = pattern.findall(question)
|
| 16 |
-
# if len(options) == 4 or len(options) == 5:
|
| 17 |
-
# options = '\n'.join(f"{'ABCDEFG'[i]}. {image_token}" for i in range(0, 4))
|
| 18 |
-
# else:
|
| 19 |
-
# options = '\n'.join(options)
|
| 20 |
-
# question = question.split('A.')[0]
|
| 21 |
-
# if len_of_pictures == 5 or len(options) == 1:
|
| 22 |
-
# question = image_token + question
|
| 23 |
-
# prompt = f"""你是一个{key}专家,擅长解决{key}问题。以下是一个{key}的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。请逐步分析问题并在最后一行输出答案,最后一行的格式为"答案是:A"。题目如下:
|
| 24 |
-
|
| 25 |
-
# ### 问题:
|
| 26 |
-
# {question}
|
| 27 |
-
|
| 28 |
-
# ### 选项:
|
| 29 |
-
# {options}
|
| 30 |
-
# """
|
| 31 |
-
# return prompt
|
| 32 |
-
|
| 33 |
-
# def extract(input_text):
|
| 34 |
-
# ans_pattern = re.compile(r"答案是:(.)", re.S)
|
| 35 |
-
|
| 36 |
-
# problems = ans_pattern.findall(input_text)
|
| 37 |
-
# # print(problems)
|
| 38 |
-
# if(problems == ''):
|
| 39 |
-
# return 'A'
|
| 40 |
-
# return problems[0]
|
| 41 |
import json
|
| 42 |
-
import re
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
# print(problems)
|
| 49 |
-
answer = ''
|
| 50 |
-
if(len(problems) == 0 or problems == ''):
|
| 51 |
-
for char in input_text:
|
| 52 |
-
if char.isupper():
|
| 53 |
-
answer = "" + char
|
| 54 |
-
else:
|
| 55 |
-
answer = problems[0]
|
| 56 |
|
| 57 |
-
|
| 58 |
-
answer = 'A'
|
| 59 |
-
for option in ['B', 'C', 'D']:
|
| 60 |
-
if answer_dic[option] < answer_dic[answer]:
|
| 61 |
-
answer = option
|
| 62 |
-
answer_dic[answer] += 1
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
with open('
|
| 78 |
# 使用json.dump()函数将字典写入文��
|
| 79 |
-
json.dump(
|
| 80 |
-
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import json
|
|
|
|
| 4 |
|
| 5 |
+
with open('./political_data_with_extra_data.json', 'r', encoding='utf-8') as file:
|
| 6 |
+
data = json.load(file)
|
| 7 |
+
len = len(data)
|
| 8 |
|
| 9 |
+
final_folder = 'playground/data/cmmlu'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
files = os.listdir(final_folder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
selected_files = ['combined_anatomy.csv','combined_ancient_chinese.csv','combined_arts.csv','combined_chinese_civil_service_exam.csv','combined_chinese_foreign_policy.csv',
|
| 14 |
+
'combined_chinese_history.csv','combined_college_education.csv', 'combined_college_engineering_hydrology.csv', 'combined_college_mathematics.csv', 'combined_college_medicine.csv',
|
| 15 |
+
'combined_conceptual_physics.csv','combined_electrical_engineering.csv','combined_elementary_mathematics.csv','combined_food_science.csv',
|
| 16 |
+
'combined_genetics.csv', 'combined_high_school_biology.csv', 'combined_high_school_chemistry.csv','combined_high_school_geography.csv','combined_high_school_mathematics.csv',
|
| 17 |
+
'combined_high_school_physics.csv','combined_high_school_politics.csv','combined_legal_and_moral_basis.csv','combined_management.csv','combined_marxist_theory.csv',
|
| 18 |
+
'combined_modern_chinese.csv','combined_philosophy.csv','combined_virology.csv','combined_world_history.csv']
|
| 19 |
+
cmmlu_list = []
|
| 20 |
+
for file_name in selected_files:
|
| 21 |
+
path = os.path.join(final_folder, file_name)
|
| 22 |
+
df = pd.read_csv(path)
|
| 23 |
+
|
| 24 |
+
for index, row in df.iterrows():
|
| 25 |
+
dict_item = {}
|
| 26 |
+
dict_item['id'] = str(len)
|
| 27 |
+
len+=1
|
| 28 |
+
|
| 29 |
+
dict_item['image'] = ""
|
| 30 |
+
|
| 31 |
+
conversion = []
|
| 32 |
+
human = {}
|
| 33 |
+
human['from'] = 'human'
|
| 34 |
+
question = row['Question'] + '\nA.' + row['A'] + '\nB.' + row['B'] + '\nC.' + row['C'] + '\nD' + row['D'] + '\n'
|
| 35 |
+
human['value'] = question
|
| 36 |
+
gpt = {}
|
| 37 |
+
gpt['from'] = 'gpt'
|
| 38 |
+
result = "答案是:" + row['Answer']
|
| 39 |
+
gpt['value'] = result
|
| 40 |
+
conversion.append(human)
|
| 41 |
+
conversion.append(gpt)
|
| 42 |
+
dict_item['conversations'] = conversion
|
| 43 |
+
|
| 44 |
+
print(dict_item)
|
| 45 |
+
|
| 46 |
+
cmmlu_list.append(dict_item)
|
| 47 |
+
|
| 48 |
+
data = cmmlu_list + data
|
| 49 |
|
| 50 |
+
with open('cmmlu_political_data_gaokao.json', 'w', encoding='utf-8') as file:
|
| 51 |
# 使用json.dump()函数将字典写入文��
|
| 52 |
+
json.dump(data, file, ensure_ascii=False, indent=4)
|
|
|
LLaVA-MOSS2/vote.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
answers = []
|
| 4 |
+
for i in range(1, 5):
|
| 5 |
+
filename = f'output_{i}.json'
|
| 6 |
+
with open(filename, 'r', encoding='utf-8') as file:
|
| 7 |
+
questions = json.load(file)
|
| 8 |
+
answers.append(questions)
|
| 9 |
+
|
| 10 |
+
final_ans = answers[0]
|
| 11 |
+
for ans in answers:
|
| 12 |
+
for i, sub in enumerate(ans):
|
| 13 |
+
example = sub['example']
|
| 14 |
+
for j, item in enumerate(example):
|
| 15 |
+
item_ans = item['model_answer']
|
| 16 |
+
index = ord(item_ans[0]) - 65
|
| 17 |
+
if 'count' not in final_ans:
|
| 18 |
+
final_ans[i]['example'][j]['count'] = [0] * 4
|
| 19 |
+
final_ans[i]['example'][j]['count'][index] += 1
|
| 20 |
+
|
| 21 |
+
for sub in final_ans:
|
| 22 |
+
example = sub['example']
|
| 23 |
+
for item in example:
|
| 24 |
+
max = 0
|
| 25 |
+
for i in range(1, 4):
|
| 26 |
+
if item['count'][i] > item['count'][max]:
|
| 27 |
+
max = i
|
| 28 |
+
item['model_answer'] = str(chr(max + 65))
|
| 29 |
+
item.pop('count')
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
with open('final_answer.json', 'w', encoding='utf-8') as file:
|
| 33 |
+
# 使用json.dump()函数将字典写入文件
|
| 34 |
+
json.dump(final_ans, file, ensure_ascii=False, indent=4)
|
| 35 |
+
|