FunASR-nano-onnx / scripts /export_embedding_onnx.py
csukuangfj's picture
Copy files from https://github.com/Wasser1462/FunASR-nano-onnx
b176a89
#!/usr/bin/env python3
import argparse
import os
import shutil
from pathlib import Path
from typing import Any, Dict
import numpy as np
import onnx
from onnx import numpy_helper
import onnxruntime as ort
import torch
from onnxruntime.quantization import QuantType, quantize_dynamic
from transformers import AutoModelForCausalLM, AutoConfig
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--llm-config-path",
type=str,
required=True,
help="Path to LLM config directory (e.g., Qwen3-0.6B)",
)
parser.add_argument(
"--opset-version",
type=int,
default=18,
help="ONNX opset version",
)
parser.add_argument(
"--output-filename",
type=str,
default="embedding.onnx",
help="Output ONNX filename",
)
parser.add_argument(
"--seq-length",
type=int,
default=512,
help="Dummy sequence length for export",
)
parser.add_argument(
"--verify",
action="store_true",
help="Verify the exported ONNX model",
)
return parser.parse_args()
def check_safetensors_file(llm_config_path: str) -> bool:
config_path = Path(llm_config_path)
model_files = list(config_path.glob("*.safetensors"))
if model_files:
print(f"Found safetensors file(s): {[str(f) for f in model_files]}")
return True
if (config_path / "model.safetensors.index.json").exists():
print(f"Found sharded safetensors model: {config_path / 'model.safetensors.index.json'}")
return True
if (config_path / "pytorch_model.bin").exists():
print(f"Found pytorch_model.bin (fallback)")
return True
print(f"Warning: No safetensors or pytorch_model.bin found in {config_path}")
print(f"Available files: {list(config_path.glob('*'))}")
return False
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
@torch.no_grad()
def export_embedding_onnx(
embedding_layer: torch.nn.Module,
vocab_size: int,
hidden_size: int,
filename: str,
seq_length: int,
opset_version: int,
):
embedding_layer.eval()
batch_size = 1
dummy_input_ids = torch.randint(0, vocab_size, (batch_size, seq_length), dtype=torch.int64)
print(f"Exporting embedding layer to ONNX...")
print(f" Input shape: {dummy_input_ids.shape}")
print(f" Vocab size: {vocab_size}")
print(f" Hidden size: {hidden_size}")
print(f" Output shape: ({batch_size}, {seq_length}, {hidden_size})")
os.environ["TORCH_ONNX_DISABLE_DYNAMO"] = "1"
torch.onnx.export(
embedding_layer,
dummy_input_ids,
filename,
opset_version=opset_version,
input_names=["input_ids"],
output_names=["embeddings"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"embeddings": {0: "batch_size", 1: "sequence_length"},
},
do_constant_folding=True,
verbose=False,
export_params=True,
)
print(f"ONNX model saved to: {filename}")
def verify_onnx_model(
onnx_filename: str,
llm_config_path: str,
seq_length: int = 100,
num_tests: int = 3,
):
print("Verifying ONNX embedding model...")
print("Loading PyTorch model for verification...")
pytorch_model = AutoModelForCausalLM.from_pretrained(
llm_config_path,
trust_remote_code=True,
torch_dtype=torch.float32,
device_map=None,
)
pytorch_embedding = pytorch_model.get_input_embeddings()
pytorch_embedding.eval()
print("Loading ONNX model...")
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
try:
ort_session = ort.InferenceSession(
onnx_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
except Exception as e:
print(f"Error loading ONNX model: {e}")
return False
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
print(f"ONNX input name: {input_name}")
print(f"ONNX output name: {output_name}")
all_passed = True
max_diff = 0.0
max_relative_diff = 0.0
for test_idx in range(num_tests):
vocab_size = pytorch_embedding.num_embeddings
input_ids = torch.randint(0, vocab_size, (1, seq_length), dtype=torch.int64)
with torch.no_grad():
pytorch_output = pytorch_embedding(input_ids).numpy()
onnx_output = ort_session.run(
[output_name],
{input_name: input_ids.numpy()},
)[0]
diff = np.abs(pytorch_output - onnx_output)
max_abs_diff = np.max(diff)
mean_abs_diff = np.mean(diff)
pytorch_abs = np.abs(pytorch_output)
relative_diff = np.where(
pytorch_abs > 1e-8,
diff / (pytorch_abs + 1e-8),
diff
)
max_rel_diff = np.max(relative_diff)
mean_rel_diff = np.mean(relative_diff)
max_diff = max(max_diff, max_abs_diff)
max_relative_diff = max(max_relative_diff, max_rel_diff)
tolerance = 1e-5
passed = max_abs_diff < tolerance
status = "PASS" if passed else "FAIL"
print(f"\nTest {test_idx + 1}/{num_tests}: {status}")
print(f" Input shape: {input_ids.shape}")
print(f" Output shape: {onnx_output.shape}")
print(f" Max absolute difference: {max_abs_diff:.2e}")
print(f" Mean absolute difference: {mean_abs_diff:.2e}")
print(f" Max relative difference: {max_rel_diff:.2e}")
print(f" Mean relative difference: {mean_rel_diff:.2e}")
if not passed:
all_passed = False
print(f"Warning: Difference exceeds tolerance ({tolerance})")
print("Verification Summary:")
print(f"Overall status: {'PASSED' if all_passed else 'FAILED'}")
print(f"Max absolute difference across all tests: {max_diff:.2e}")
print(f"Max relative difference across all tests: {max_relative_diff:.2e}")
if all_passed:
print("\nAll tests passed! ONNX model matches PyTorch model.")
else:
print("\nSome tests failed. Please check the differences.")
return all_passed
def quantize_embedding_model(model, filename, filename_int8, original_size):
op_types = set(node.op_type for node in model.graph.node)
print(f"Operations in model: {op_types}")
if "Gather" not in op_types:
print("Quantizing to INT8...")
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QUInt8,
)
quantized_size = os.path.getsize(filename_int8) / (1024 * 1024)
size_reduction = original_size - quantized_size
print(f"Quantized model saved to: {filename_int8}")
print(f"Quantized model size: {quantized_size:.2f}MB (reduced by {size_reduction:.2f}MB, {size_reduction/original_size*100:.1f}%)")
return True
print("Embedding model uses Gather, quantizing embedding weights...")
gather_node = None
embedding_weight_name = None
for node in model.graph.node:
if node.op_type == "Gather":
gather_node = node
embedding_weight_name = node.input[0]
break
if gather_node is None:
print("Warning: Gather node not found")
shutil.copy(filename, filename_int8)
return False
for initializer in model.graph.initializer:
if initializer.name == embedding_weight_name and len(initializer.dims) >= 2:
weight_array = numpy_helper.to_array(initializer)
original_size_mb = weight_array.nbytes / (1024 * 1024)
weight_min = np.min(weight_array)
weight_max = np.max(weight_array)
if weight_max <= weight_min:
continue
scale = (weight_max - weight_min) / 255.0
zero_point = np.clip(np.round(-weight_min / scale), 0, 255).astype(np.uint8)
quantized_weight = np.clip(
np.round((weight_array - weight_min) / scale), 0, 255
).astype(np.uint8)
quantized_size_mb = quantized_weight.nbytes / (1024 * 1024)
saved_mb = original_size_mb - quantized_size_mb
initializer.data_type = onnx.TensorProto.UINT8
initializer.raw_data = quantized_weight.tobytes()
scale_name = initializer.name + "_scale"
zp_name = initializer.name + "_zero_point"
scale_tensor = onnx.helper.make_tensor(scale_name, onnx.TensorProto.FLOAT, [], [float(scale)])
zp_tensor = onnx.helper.make_tensor(zp_name, onnx.TensorProto.UINT8, [], [int(zero_point)])
model.graph.initializer.append(scale_tensor)
model.graph.initializer.append(zp_tensor)
gather_output_name = gather_node.output[0]
intermediate_output = gather_output_name + "_uint8"
gather_node.output[0] = intermediate_output
dequant_node = onnx.helper.make_node(
"DequantizeLinear",
inputs=[intermediate_output, scale_name, zp_name],
outputs=[gather_output_name],
name="DequantizeLinear_0"
)
node_idx = list(model.graph.node).index(gather_node)
model.graph.node.insert(node_idx + 1, dequant_node)
print(f" Quantized {initializer.name}: {initializer.dims}, saved {saved_mb:.2f}MB")
try:
onnx.checker.check_model(model)
onnx.save(model, filename_int8)
quantized_size = os.path.getsize(filename_int8) / (1024 * 1024)
size_reduction = original_size - quantized_size
print(f"Quantized model saved to: {filename_int8}")
print(f"Quantized model size: {quantized_size:.2f}MB (reduced by {size_reduction:.2f}MB, {size_reduction/original_size*100:.1f}%)")
return True
except Exception as e:
print(f"Error saving quantized model: {e}")
print("Falling back to original model")
shutil.copy(filename, filename_int8)
return False
print("Warning: No weights were quantized, copying original model")
shutil.copy(filename, filename_int8)
return False
@torch.no_grad()
def main():
args = get_args()
print(vars(args))
llm_config_path = Path(args.llm_config_path)
if not llm_config_path.exists():
raise ValueError(f"LLM config path not found: {llm_config_path}")
print(f"\nChecking for model files in {llm_config_path}...")
has_model = check_safetensors_file(str(llm_config_path))
if not has_model:
print("Warning: No model files found. Model loading may fail.")
print(f"\nLoading LLM config from {llm_config_path}...")
config = AutoConfig.from_pretrained(str(llm_config_path), trust_remote_code=True)
vocab_size = config.vocab_size
hidden_size = config.hidden_size
print(f"Model config: vocab_size={vocab_size}, hidden_size={hidden_size}")
print(f"\nLoading PyTorch model...")
model = AutoModelForCausalLM.from_pretrained(
str(llm_config_path),
trust_remote_code=True,
torch_dtype=torch.float32,
device_map=None,
)
embedding_layer = model.get_input_embeddings()
embedding_layer.eval()
embedding_layer.to("cpu")
print(f"Embedding layer: {type(embedding_layer).__name__}")
print(f" Weight shape: {embedding_layer.weight.shape}")
print(f" Num embeddings: {embedding_layer.num_embeddings}")
print(f" Embedding dim: {embedding_layer.embedding_dim}")
seq_length = args.seq_length
opset_version = args.opset_version
filename = args.output_filename
export_embedding_onnx(
embedding_layer=embedding_layer,
vocab_size=vocab_size,
hidden_size=hidden_size,
filename=filename,
seq_length=seq_length,
opset_version=opset_version,
)
model_author = "FunAudioLLM"
comment = os.environ.get("comment", "FunAudioLLM/Fun-ASR-Nano-2512")
url = "https://huggingface.co/FunAudioLLM/Fun-ASR-Nano-2512"
meta_data = {
"model_type": "embedding_layer",
"version": "1",
"model_author": model_author,
"vocab_size": vocab_size,
"hidden_size": hidden_size,
"comment": comment,
"url": url,
}
add_meta_data(filename, meta_data)
print("Metadata added to ONNX model.")
filename_int8 = filename.replace(".onnx", ".int8.onnx")
model = onnx.load(filename)
original_size = os.path.getsize(filename) / (1024 * 1024)
print(f"Original model size: {original_size:.2f}MB")
quantize_embedding_model(model, filename, filename_int8, original_size)
add_meta_data(filename_int8, meta_data)
if args.verify:
verify_onnx_model(
onnx_filename=filename,
llm_config_path=str(llm_config_path),
seq_length=min(seq_length, 200),
num_tests=5,
)
else:
print("\nNote: Use --verify flag to verify the exported model")
if __name__ == "__main__":
torch.manual_seed(20251219)
main()