update model scripts
Browse files- README.md +10 -1
- inference.py +8 -4
README.md
CHANGED
|
@@ -1,3 +1,12 @@
|
|
| 1 |
# MMAlaya
|
| 2 |
-
MMAlaya
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
|
|
|
| 1 |
# MMAlaya
|
| 2 |
+
MMAlaya是基于大语言模型[Alaya](https://github.com/DataCanvasIO/Alaya)的多模态模型。
|
| 3 |
+
|
| 4 |
+
MMAlaya包含以下三个模块:
|
| 5 |
+
<br>1,大语言模型Alaya。
|
| 6 |
+
<br>2,图像文本特征编码器[blip2-opt-2.7b](https://huggingface.co/Salesforce/blip2-opt-2.7b)
|
| 7 |
+
<br>3,图像文本特征到大预言模型的线性投影器。
|
| 8 |
+
|
| 9 |
+
模型的训练主要基于[LLaVA]架构(https://github.com/haotian-liu/LLaVA)
|
| 10 |
+
|
| 11 |
+
2024.01.23 最终在[MMBench](https://mmbench.opencompass.org.cn)线上测试中文测试集分数为56.9,总排名为第20名,7B模型的第9名。英文测试集分数为59.8,总排名为第29名,7B模型的第12名。
|
| 12 |
|
inference.py
CHANGED
|
@@ -12,10 +12,10 @@ import argparse
|
|
| 12 |
|
| 13 |
def main(args):
|
| 14 |
disable_torch_init()
|
| 15 |
-
|
| 16 |
conv_mode = "mmalaya_llama"
|
| 17 |
model_path = args.model_path
|
| 18 |
-
|
|
|
|
| 19 |
model_path=model_path,
|
| 20 |
)
|
| 21 |
prompts = [
|
|
@@ -27,21 +27,25 @@ def main(args):
|
|
| 27 |
time1 = time.time()
|
| 28 |
|
| 29 |
for prompt in prompts:
|
|
|
|
| 30 |
conv = conv_templates[conv_mode].copy()
|
| 31 |
inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
|
| 32 |
conv.append_message(conv.roles[0], inp)
|
| 33 |
conv.append_message(conv.roles[1], None)
|
| 34 |
prompt = conv.get_prompt()
|
|
|
|
| 35 |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
|
|
|
| 36 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 37 |
if conv_mode == 'mmalaya_llama':
|
| 38 |
stop_str = conv.sep2
|
| 39 |
keywords = [stop_str]
|
| 40 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 41 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
|
| 42 |
-
|
| 43 |
image = Image.open('./data/chang_chen.jpg').convert("RGB")
|
| 44 |
image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda()
|
|
|
|
| 45 |
with torch.inference_mode():
|
| 46 |
generate_ids = model.generate(
|
| 47 |
inputs=input_ids,
|
|
@@ -55,7 +59,7 @@ def main(args):
|
|
| 55 |
use_cache=True,
|
| 56 |
stopping_criteria=[stopping_criteria],
|
| 57 |
)
|
| 58 |
-
#
|
| 59 |
input_token_len = input_ids.shape[1]
|
| 60 |
output = tokenizer.batch_decode(
|
| 61 |
generate_ids[:, input_token_len:],
|
|
|
|
| 12 |
|
| 13 |
def main(args):
|
| 14 |
disable_torch_init()
|
|
|
|
| 15 |
conv_mode = "mmalaya_llama"
|
| 16 |
model_path = args.model_path
|
| 17 |
+
# 加载model,tokenizer,image_processor
|
| 18 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(
|
| 19 |
model_path=model_path,
|
| 20 |
)
|
| 21 |
prompts = [
|
|
|
|
| 27 |
time1 = time.time()
|
| 28 |
|
| 29 |
for prompt in prompts:
|
| 30 |
+
# 加载对话模板
|
| 31 |
conv = conv_templates[conv_mode].copy()
|
| 32 |
inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
|
| 33 |
conv.append_message(conv.roles[0], inp)
|
| 34 |
conv.append_message(conv.roles[1], None)
|
| 35 |
prompt = conv.get_prompt()
|
| 36 |
+
# 对prompt进行分词
|
| 37 |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
|
| 38 |
+
# 加载generate stop条件
|
| 39 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 40 |
if conv_mode == 'mmalaya_llama':
|
| 41 |
stop_str = conv.sep2
|
| 42 |
keywords = [stop_str]
|
| 43 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 44 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0)
|
| 45 |
+
# 加载图像
|
| 46 |
image = Image.open('./data/chang_chen.jpg').convert("RGB")
|
| 47 |
image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda()
|
| 48 |
+
# 推理
|
| 49 |
with torch.inference_mode():
|
| 50 |
generate_ids = model.generate(
|
| 51 |
inputs=input_ids,
|
|
|
|
| 59 |
use_cache=True,
|
| 60 |
stopping_criteria=[stopping_criteria],
|
| 61 |
)
|
| 62 |
+
# 截断generate_ids中的input_ids,然后解码为文本
|
| 63 |
input_token_len = input_ids.shape[1]
|
| 64 |
output = tokenizer.batch_decode(
|
| 65 |
generate_ids[:, input_token_len:],
|