File size: 5,861 Bytes
6128fc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
import os
import json
import copy
import argparse

import torch

from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import get_model_name_from_path


def export(args):
    # Load model
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path,
                                                                           args.model_base,
                                                                           model_name,
                                                                           device="cpu")

    # Save extra metadata that is not saved during LLaVA training
    # required by HF for auto-loading model and for mlx-vlm preprocessing

    # Save image processing config
    setattr(image_processor, "processor_class", "LlavaProcessor")
    output_path = os.path.join(model_path, "preprocessor_config.json")
    image_processor.to_json_file(output_path)

    # Create processor config
    processor_config = dict()
    processor_config["image_token"] = "<image>"
    processor_config["num_additional_image_tokens"] = 0
    processor_config["processor_class"] = "LlavaProcessor"
    processor_config["patch_size"] = 64
    output_path = os.path.join(model_path, "processor_config.json")
    json.dump(processor_config, open(output_path, "w"), indent=2)

    # Modify tokenizer to include <image> special token.
    tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json")
    tokenizer_config = json.load(open(tokenizer_config_path, 'r'))
    token_ids = list()
    image_token_is_present = False
    for k, v in tokenizer_config['added_tokens_decoder'].items():
        token_ids.append(int(k))
        if v["content"] == "<image>":
            image_token_is_present = True
            token_ids.pop()

    # Append only if <image> token is not present
    if not image_token_is_present:
        tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy(
            tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}'])
        tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = "<image>"
        json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2)

    # Modify config to contain token id for <image>
    config_path = os.path.join(model_path, "config.json")
    model_config = json.load(open(config_path, 'r'))
    model_config["image_token_index"] = max(token_ids) + 1
    json.dump(model_config, open(config_path, 'w'), indent=2)

    # Export the vision encoder to ONNX
    image_res = image_processor.to_dict()['size']['shortest_edge']
    dummy_vision_input = torch.rand(1, 3, image_res, image_res).float() # Dummy input tensor

    vision_model = model.get_vision_tower()
    # Ensure model is on CPU, in float precision, and in evaluation mode for ONNX export
    vision_model = vision_model.cpu().float().eval()

    onnx_vision_model_path = os.path.join(model_path, "fastvithd.onnx")
    
    print(f"Exporting vision model to {onnx_vision_model_path}...")
    torch.onnx.export(
        vision_model,
        dummy_vision_input, # Pass the dummy input tensor
        onnx_vision_model_path,
        input_names=['pixel_values'],             # ONNX图中输入节点的名称
        output_names=['last_hidden_state'],    # ONNX图中输出节点的名称
        # dynamic_axes={
        #     'pixel_values': {0: 'batch_size'},     # 输入'pixel_values'的第0维是动态的batch_size
        #     'last_hidden_state': {0: 'batch_size'}  # 输出'last_hidden_state'的第0维是动态的batch_size
        # },
        opset_version=17,                   # ONNX opset 版本
        export_params=True,                 # 在模型文件中存储训练好的参数权重
        do_constant_folding=True            # 执行常量折叠优化
    )
    print(f"Vision model ONNX export complete: {onnx_vision_model_path}")

    # Generate dummy input for mm_projector by passing dummy_vision_input through vision_model
    # This ensures the mm_projector receives input with the correct shape and characteristics
    with torch.no_grad():
        dummy_mm_projector_input = vision_model(dummy_vision_input)
    
    # Ensure the input is on CPU and in float32 precision for the projector
    dummy_mm_projector_input = dummy_mm_projector_input.cpu().float()

    # Export the mm_projector to ONNX
    # model.get_model() gives the underlying base model (e.g., LlavaLlamaModel)
    # which contains the mm_projector attribute.
    mm_projector = model.get_model().mm_projector
    mm_projector = mm_projector.cpu().float().eval()

    onnx_mm_projector_path = os.path.join(model_path, "mm_projector.onnx")

    print(f"Exporting mm_projector to {onnx_mm_projector_path}...")
    torch.onnx.export(
        mm_projector,
        dummy_mm_projector_input,
        onnx_mm_projector_path,
        input_names=['last_hidden_state'],
        output_names=['projected_image_features'],
        opset_version=17,
        export_params=True,
        do_constant_folding=True
    )
    print(f"mm_projector ONNX export complete: {onnx_mm_projector_path}")

    # Removed CoreML specific code and intermediate .pt file handling
    # No need for os.remove(pt_name) as pt_name is no longer created


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, required=True)
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--conv-mode", type=str, default="qwen_2")

    args = parser.parse_args()

    export(args)