File size: 5,861 Bytes
6128fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
import os
import json
import copy
import argparse
import torch
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import get_model_name_from_path
def export(args):
# Load model
disable_torch_init()
model_path = os.path.expanduser(args.model_path)
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path,
args.model_base,
model_name,
device="cpu")
# Save extra metadata that is not saved during LLaVA training
# required by HF for auto-loading model and for mlx-vlm preprocessing
# Save image processing config
setattr(image_processor, "processor_class", "LlavaProcessor")
output_path = os.path.join(model_path, "preprocessor_config.json")
image_processor.to_json_file(output_path)
# Create processor config
processor_config = dict()
processor_config["image_token"] = "<image>"
processor_config["num_additional_image_tokens"] = 0
processor_config["processor_class"] = "LlavaProcessor"
processor_config["patch_size"] = 64
output_path = os.path.join(model_path, "processor_config.json")
json.dump(processor_config, open(output_path, "w"), indent=2)
# Modify tokenizer to include <image> special token.
tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json")
tokenizer_config = json.load(open(tokenizer_config_path, 'r'))
token_ids = list()
image_token_is_present = False
for k, v in tokenizer_config['added_tokens_decoder'].items():
token_ids.append(int(k))
if v["content"] == "<image>":
image_token_is_present = True
token_ids.pop()
# Append only if <image> token is not present
if not image_token_is_present:
tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy(
tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}'])
tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = "<image>"
json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2)
# Modify config to contain token id for <image>
config_path = os.path.join(model_path, "config.json")
model_config = json.load(open(config_path, 'r'))
model_config["image_token_index"] = max(token_ids) + 1
json.dump(model_config, open(config_path, 'w'), indent=2)
# Export the vision encoder to ONNX
image_res = image_processor.to_dict()['size']['shortest_edge']
dummy_vision_input = torch.rand(1, 3, image_res, image_res).float() # Dummy input tensor
vision_model = model.get_vision_tower()
# Ensure model is on CPU, in float precision, and in evaluation mode for ONNX export
vision_model = vision_model.cpu().float().eval()
onnx_vision_model_path = os.path.join(model_path, "fastvithd.onnx")
print(f"Exporting vision model to {onnx_vision_model_path}...")
torch.onnx.export(
vision_model,
dummy_vision_input, # Pass the dummy input tensor
onnx_vision_model_path,
input_names=['pixel_values'], # ONNX图中输入节点的名称
output_names=['last_hidden_state'], # ONNX图中输出节点的名称
# dynamic_axes={
# 'pixel_values': {0: 'batch_size'}, # 输入'pixel_values'的第0维是动态的batch_size
# 'last_hidden_state': {0: 'batch_size'} # 输出'last_hidden_state'的第0维是动态的batch_size
# },
opset_version=17, # ONNX opset 版本
export_params=True, # 在模型文件中存储训练好的参数权重
do_constant_folding=True # 执行常量折叠优化
)
print(f"Vision model ONNX export complete: {onnx_vision_model_path}")
# Generate dummy input for mm_projector by passing dummy_vision_input through vision_model
# This ensures the mm_projector receives input with the correct shape and characteristics
with torch.no_grad():
dummy_mm_projector_input = vision_model(dummy_vision_input)
# Ensure the input is on CPU and in float32 precision for the projector
dummy_mm_projector_input = dummy_mm_projector_input.cpu().float()
# Export the mm_projector to ONNX
# model.get_model() gives the underlying base model (e.g., LlavaLlamaModel)
# which contains the mm_projector attribute.
mm_projector = model.get_model().mm_projector
mm_projector = mm_projector.cpu().float().eval()
onnx_mm_projector_path = os.path.join(model_path, "mm_projector.onnx")
print(f"Exporting mm_projector to {onnx_mm_projector_path}...")
torch.onnx.export(
mm_projector,
dummy_mm_projector_input,
onnx_mm_projector_path,
input_names=['last_hidden_state'],
output_names=['projected_image_features'],
opset_version=17,
export_params=True,
do_constant_folding=True
)
print(f"mm_projector ONNX export complete: {onnx_mm_projector_path}")
# Removed CoreML specific code and intermediate .pt file handling
# No need for os.remove(pt_name) as pt_name is no longer created
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, required=True)
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--conv-mode", type=str, default="qwen_2")
args = parser.parse_args()
export(args)
|