| | import os |
| | import json |
| | import shutil |
| |
|
| | from optimum.exporters.onnx import main_export |
| | import onnx |
| | from onnxconverter_common import float16 |
| | import onnxruntime as rt |
| | from onnxruntime.tools.onnx_model_utils import * |
| | from onnxruntime.quantization import quantize_dynamic, QuantType |
| | import huggingface_hub |
| |
|
| | def add_mean_pooling(input_model, output_model, op, IR, output_embeddings_number): |
| | model = onnx.load(input_model) |
| | model_ir8 = onnx.helper.make_model(model.graph, ir_version = IR, opset_imports = [op]) |
| | |
| | minus_one_axis = onnx.helper.make_tensor( |
| | name = "minus_one_axis", |
| | data_type = onnx.TensorProto.INT64, |
| | dims = [1], |
| | vals = [-1]) |
| | |
| | model_ir8.graph.initializer.append(minus_one_axis) |
| | |
| | mask_clip_lower_limit = onnx.helper.make_tensor( |
| | name = "mask_clip_lower_limit", |
| | data_type = onnx.TensorProto.FLOAT, |
| | dims = [1], |
| | vals = [1e-9]) |
| | |
| | model_ir8.graph.initializer.append(mask_clip_lower_limit) |
| | |
| | sum_one_axis = onnx.helper.make_tensor( |
| | name = "sum_one_axis", |
| | data_type = onnx.TensorProto.INT64, |
| | dims = [1], |
| | vals = [1]) |
| | |
| | model_ir8.graph.initializer.append(sum_one_axis) |
| | |
| | attention_mask_cast_op = onnx.helper.make_node( |
| | "Cast", |
| | inputs=["attention_mask"], |
| | outputs=["attention_mask_fp32"], |
| | to=onnx.TensorProto.FLOAT |
| | ) |
| | |
| | model_ir8.graph.node.append(attention_mask_cast_op) |
| | |
| | expand_dims_op = onnx.helper.make_node( |
| | "Unsqueeze", |
| | inputs=["attention_mask_fp32", "minus_one_axis"], |
| | outputs=["unsqueezed_attention_mask"], |
| | ) |
| | |
| | model_ir8.graph.node.append(expand_dims_op) |
| | |
| | shape_op = onnx.helper.make_node( |
| | "Shape", |
| | inputs = ["last_hidden_state"], |
| | outputs = ["last_hidden_state_shape"] |
| | ) |
| | |
| | model_ir8.graph.node.append(shape_op) |
| | |
| | broadcast_to_op = onnx.helper.make_node( |
| | "Expand", |
| | inputs=["unsqueezed_attention_mask", "last_hidden_state_shape"], |
| | outputs=["expanded_attention_mask"], |
| | ) |
| | |
| | model_ir8.graph.node.append(broadcast_to_op) |
| | |
| | multiply_op = onnx.helper.make_node( |
| | "Mul", |
| | inputs=["last_hidden_state", "expanded_attention_mask"], |
| | outputs=["last_hidden_state_x_expanded_attention_mask"], |
| | ) |
| | |
| | model_ir8.graph.node.append(multiply_op) |
| | |
| | sum_embeddings_op = onnx.helper.make_node( |
| | "ReduceSum", |
| | inputs=["last_hidden_state_x_expanded_attention_mask", "sum_one_axis"], |
| | outputs=["sum_last_hidden_state_x_expanded_attention_mask"], |
| | ) |
| | |
| | model_ir8.graph.node.append(sum_embeddings_op) |
| | |
| | sum_mask_op = onnx.helper.make_node( |
| | "ReduceSum", |
| | inputs=["expanded_attention_mask", "sum_one_axis"], |
| | outputs=["sum_expanded_attention_mask"], |
| | ) |
| | |
| | model_ir8.graph.node.append(sum_mask_op) |
| | |
| | clip_mask_op = onnx.helper.make_node( |
| | "Clip", |
| | inputs=["sum_expanded_attention_mask", "mask_clip_lower_limit"], |
| | outputs=["clipped_sum_expanded_attention_mask"], |
| | ) |
| | |
| | model_ir8.graph.node.append(clip_mask_op) |
| | |
| | pooled_embeddings_op = onnx.helper.make_node( |
| | "Div", |
| | inputs=["sum_last_hidden_state_x_expanded_attention_mask", "clipped_sum_expanded_attention_mask"], |
| | outputs=["pooled_embeddings"], |
| | |
| | ) |
| | |
| | model_ir8.graph.node.append(pooled_embeddings_op) |
| | |
| | squeeze_pooled_embeddings_op = onnx.helper.make_node( |
| | "Squeeze", |
| | inputs=["pooled_embeddings", "sum_one_axis"], |
| | outputs=["squeezed_pooled_embeddings"] |
| | |
| | ) |
| | |
| | model_ir8.graph.node.append(squeeze_pooled_embeddings_op) |
| | |
| | normalized_pooled_embeddings_op = onnx.helper.make_node( |
| | "Normalizer", |
| | domain="ai.onnx.ml", |
| | inputs=["squeezed_pooled_embeddings"], |
| | outputs=["sentence_embedding"], |
| | norm = "L2" |
| | ) |
| | |
| | |
| | model_ir8.graph.node.append(normalized_pooled_embeddings_op) |
| | |
| | sentence_embeddings_output = onnx.helper.make_tensor_value_info( |
| | "sentence_embedding", |
| | onnx.TensorProto.FLOAT, |
| | shape=["batch_size", output_embeddings_number] |
| | ) |
| | |
| | model_ir8.graph.output.append(sentence_embeddings_output) |
| | |
| | for node in model_ir8.graph.output: |
| | if node.name == "last_hidden_state": |
| | model_ir8.graph.output.remove(node) |
| | |
| | model_ir8 = onnx.helper.make_model(model_ir8.graph, ir_version = 8, opset_imports = [op]) |
| | |
| | onnx.save(model_ir8, output_model, save_as_external_data = False) |
| |
|
| | |
| |
|
| | with open('conversion_config.json') as json_file: |
| | conversion_config = json.load(json_file) |
| |
|
| |
|
| | model_id = conversion_config["model_id"] |
| | number_of_generated_embeddings = conversion_config["number_of_generated_embeddings"] |
| | precision_to_filename_map = conversion_config["precision_to_filename_map"] |
| | opset = conversion_config["opset"] |
| | IR = conversion_config["IR"] |
| |
|
| | |
| | op = onnx.OperatorSetIdProto() |
| | op.version = opset |
| | |
| | |
| | if not os.path.exists("onnx"): |
| | os.makedirs("onnx") |
| |
|
| | print("Exporting the main model version") |
| | try: |
| | main_export(model_name_or_path=model_id, output="./", opset=opset, trust_remote_code=True, task="feature-extraction", dtype="fp32") |
| | except: |
| | huggingface_hub.hf_hub_download(repo_id=model_id, filename="model.onnx", local_dir="./") |
| | |
| | |
| | if "fp32" in precision_to_filename_map: |
| | print("Exporting the fp32 onnx file...") |
| | |
| | shutil.copyfile('model.onnx', precision_to_filename_map["fp32"]) |
| | add_mean_pooling("model.onnx", precision_to_filename_map["fp32"], op, IR, number_of_generated_embeddings) |
| | |
| | print("Done\n\n") |
| |
|
| | if "int8" in precision_to_filename_map: |
| | print("Quantizing fp32 model to int8...") |
| | quantize_dynamic("model.onnx", precision_to_filename_map["int8"], weight_type=QuantType.QInt8) |
| | add_mean_pooling( precision_to_filename_map["int8"], precision_to_filename_map["int8"], op, IR, number_of_generated_embeddings) |
| | print("Done\n\n") |
| | |
| | if "uint8" in precision_to_filename_map: |
| | print("Quantizing fp32 model to uint8...") |
| | quantize_dynamic("model.onnx", precision_to_filename_map["uint8"], weight_type=QuantType.QUInt8) |
| | add_mean_pooling( precision_to_filename_map["uint8"], precision_to_filename_map["uint8"], op, IR, number_of_generated_embeddings) |
| | print("Done\n\n") |
| | |
| | os.remove("model.onnx") |
| |
|