File size: 7,747 Bytes
8d28ba8
 
 
 
2b134bc
 
8d28ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b134bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d28ba8
 
2b134bc
 
 
 
 
 
 
 
 
 
 
 
 
 
8d28ba8
2b134bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d28ba8
 
2b134bc
 
 
8d28ba8
2b134bc
8d28ba8
2b134bc
 
8d28ba8
2b134bc
8d28ba8
2b134bc
 
8d28ba8
2b134bc
 
 
 
8d28ba8
 
 
 
 
 
 
 
2b134bc
8d28ba8
2b134bc
8d28ba8
 
 
 
 
 
 
 
 
 
2b134bc
8d28ba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b134bc
 
8d28ba8
2b134bc
8d28ba8
 
 
2b134bc
8d28ba8
 
2b134bc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#!/usr/bin/env python
# coding: utf-8

import os
import re
from typing import Optional, Set
from rknn.api import RKNN
from math import exp
from sys import exit
import argparse
import onnxscript
from onnxscript.rewriter import pattern
import onnx.numpy_helper as onh
import numpy as np
import onnx
import onnxruntime as ort
from rknn.utils import onnx_edit

os.chdir(os.path.dirname(os.path.abspath(__file__)))

speech_length = 171

def _remove_file(path: str, *, keep: Optional[Set[str]] = None) -> None:
    if not path:
        return
    keep_paths: Set[str] = {os.path.abspath(item) for item in keep} if keep else set()
    normalized = os.path.abspath(path)
    if keep_paths and normalized in keep_paths:
        return
    if not os.path.exists(normalized):
        return
    try:
        os.remove(normalized)
        print(f'cleaned temp model: {normalized}')
    except OSError as err:
        print(f'warning: failed to remove {normalized}: {err}')

def _with_suffix(path: str, suffix: str) -> str:
    stem, ext = os.path.splitext(path)
    return f"{stem}{suffix}{ext}"

def _sanitize_name(name: str) -> str:
    return re.sub(r'[^0-9A-Za-z_]', '_', name)

def _insert_div_node(model: onnx.ModelProto, tensor_name: str, divisor: float = 16.0) -> bool:
    graph = model.graph

    for node in graph.node:
        if node.op_type == 'Div' and tensor_name in node.output:
            return False

    producer_index = None
    output_index = None
    for idx, node in enumerate(graph.node):
        for out_idx, output in enumerate(node.output):
            if output == tensor_name:
                producer_index = idx
                output_index = out_idx
                producer_node = node
                break
        if producer_index is not None:
            break

    if producer_index is None:
        raise RuntimeError(f"Producer node for tensor {tensor_name} not found.")

    pre_div_output = f"{tensor_name}_pre_div"
    producer_node.output[output_index] = pre_div_output

    sanitized = _sanitize_name(tensor_name)
    const_output = f"{sanitized}_div_const"
    const_node_name = f"{sanitized}_DivConst"
    div_node_name = f"{sanitized}_Div"

    const_tensor = onnx.helper.make_tensor(
        name=f"{const_node_name}_value",
        data_type=onnx.TensorProto.FLOAT,
        dims=[],
        vals=[divisor],
    )

    const_node = onnx.helper.make_node(
        'Constant',
        inputs=[],
        outputs=[const_output],
        value=const_tensor,
        name=const_node_name,
    )

    div_node = onnx.helper.make_node(
        'Div',
        inputs=[pre_div_output, const_output],
        outputs=[tensor_name],
        name=div_node_name,
    )

    graph.node.insert(producer_index + 1, const_node)
    graph.node.insert(producer_index + 2, div_node)
    return True

def _scale_initializer(model: onnx.ModelProto, initializer_name: str, divisor: float = 16.0) -> bool:
    for idx, initializer in enumerate(model.graph.initializer):
        if initializer.name == initializer_name:
            data = onh.to_array(initializer).astype(np.float32, copy=False)
            scaled = data / divisor
            model.graph.initializer[idx].CopyFrom(onh.from_array(scaled, name=initializer_name))
            return True
    return False

def convert_encoder(model_path: str):
    rknn = RKNN(verbose=True)

    ONNX_MODEL = os.path.abspath(model_path)
    if not os.path.isfile(ONNX_MODEL):
        print(f'Model file not found: {model_path}')
        exit(1)
    if not ONNX_MODEL.lower().endswith('.onnx'):
        print(f'Model file must be an ONNX file: {model_path}')
        exit(1)

    RKNN_MODEL = os.path.splitext(ONNX_MODEL)[0] + ".rknn"
    DATASET = "dataset.txt"
    QUANTIZE = False
    original_model = ONNX_MODEL
    preserve_files: Set[str] = {original_model}


    print('--> Patching model to avoid overflow issue')
    base_model = onnx.load(ONNX_MODEL)
    modified = False
    for layer_idx in range(48, 49): 
        for target in [
            f'/encoders.{layer_idx}/feed_forward/activation/Relu_output_0',
            f'/encoders.{layer_idx}/norm2/Cast_output_0',
        ]:
            modified |= _insert_div_node(base_model, target, divisor=2.0)
    bias_scaled = False
    if modified:
        for layer_idx in range(48, 49): 
            bias_scaled |= _scale_initializer(base_model, f'model.encoders.{layer_idx}.feed_forward.w_2.bias', divisor=2.0)
    div_model_path = _with_suffix(ONNX_MODEL, "_div")
    onnx.save(base_model, div_model_path)
    if os.path.exists(div_model_path):
        previous_model = ONNX_MODEL
        ONNX_MODEL = div_model_path
        _remove_file(previous_model, keep=preserve_files)
    if modified:
        if bias_scaled:
            print('done (created div-adjusted model and scaled bias)')
        else:
            print('done (created div-adjusted model; bias initializer not found)')
    else:
        print('done (div nodes already present)')
        
    #开局先给我来个大惊喜,rknn做第一步常量折叠的时候就会在这个子图里报错,所以要单独拿出来先跑一遍
    #然后把这个子图的输出结果保存下来喂给rknn
    extract_model_path = os.path.join(os.getcwd(), "extract_model.onnx")
    onnx.utils.extract_model(ONNX_MODEL, extract_model_path, ['speech_lengths'], ['/make_pad_mask/Cast_2_output_0'])
    sess = ort.InferenceSession(extract_model_path, providers=['CPUExecutionProvider'])
    extract_result = sess.run(None, {"speech_lengths": np.array([speech_length], dtype=np.int64)})[0]
    _remove_file(extract_model_path)

    # 删掉模型最后的多余transpose, 速度从365ms提升到259ms
    edited_model_path = _with_suffix(ONNX_MODEL, "_edited")
    ret = onnx_edit(model = ONNX_MODEL,
        export_path = edited_model_path,
        # # 1, len, 25055 -> 1, 25055, 1, len   # 这个是坏的, 我真服了,
        outputs_transform = {'encoder_out': 'a,b,c->a,c,1,b'},
        # outputs_transform = {'encoder_out': 'a,b,c->a,c,b'},
    )
    if os.path.exists(edited_model_path):
        previous_model = ONNX_MODEL
        ONNX_MODEL = edited_model_path
        _remove_file(previous_model, keep=preserve_files)

    # pre-process config
    print('--> Config model')
    rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3)
    print('done')

    # Load ONNX model
    print("--> Loading model")
    current_model_path = ONNX_MODEL
    ret = rknn.load_onnx(
        model=current_model_path,
        inputs=["speech", "/make_pad_mask/Cast_2_output_0"],
        input_size_list=[[1, speech_length, 560], [extract_result.shape[0], extract_result.shape[1]]],
        input_initial_val=[None, extract_result],
        # outputs=["output"]
    )

    if ret != 0:
        print('Load model failed!')
        exit(ret)
    print('done')
    _remove_file(current_model_path, keep=preserve_files)

    # Build model
    print('--> Building model')
    ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
    if ret != 0:
        print('Build model failed!')
        exit(ret)
    print('done')

    # export
    print('--> Export RKNN model')
    ret = rknn.export_rknn(RKNN_MODEL)
    if ret != 0:
        print('Export RKNN model failed!')
        exit(ret)
    print('done')
    # 精度分析(可选)
    # rknn.accuracy_analysis(inputs=["input_content.npy"], target="rk3588", device_id=None)

# usage: python convert_rknn.py path/to/model.onnx

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("model_path", type=str, help="path to source ONNX model")
    args = parser.parse_args()

    convert_encoder(args.model_path)