|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict |
|
|
|
|
|
import numpy as np |
|
|
import onnx |
|
|
from onnx import numpy_helper |
|
|
import onnxruntime as ort |
|
|
import torch |
|
|
from onnxruntime.quantization import QuantType, quantize_dynamic |
|
|
from transformers import AutoModelForCausalLM, AutoConfig |
|
|
|
|
|
|
|
|
def get_args(): |
|
|
parser = argparse.ArgumentParser( |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
) |
|
|
parser.add_argument( |
|
|
"--llm-config-path", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to LLM config directory (e.g., Qwen3-0.6B)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--opset-version", |
|
|
type=int, |
|
|
default=18, |
|
|
help="ONNX opset version", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-filename", |
|
|
type=str, |
|
|
default="embedding.onnx", |
|
|
help="Output ONNX filename", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seq-length", |
|
|
type=int, |
|
|
default=512, |
|
|
help="Dummy sequence length for export", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--verify", |
|
|
action="store_true", |
|
|
help="Verify the exported ONNX model", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def check_safetensors_file(llm_config_path: str) -> bool: |
|
|
config_path = Path(llm_config_path) |
|
|
model_files = list(config_path.glob("*.safetensors")) |
|
|
if model_files: |
|
|
print(f"Found safetensors file(s): {[str(f) for f in model_files]}") |
|
|
return True |
|
|
if (config_path / "model.safetensors.index.json").exists(): |
|
|
print(f"Found sharded safetensors model: {config_path / 'model.safetensors.index.json'}") |
|
|
return True |
|
|
if (config_path / "pytorch_model.bin").exists(): |
|
|
print(f"Found pytorch_model.bin (fallback)") |
|
|
return True |
|
|
print(f"Warning: No safetensors or pytorch_model.bin found in {config_path}") |
|
|
print(f"Available files: {list(config_path.glob('*'))}") |
|
|
return False |
|
|
|
|
|
|
|
|
def add_meta_data(filename: str, meta_data: Dict[str, Any]): |
|
|
model = onnx.load(filename) |
|
|
while len(model.metadata_props): |
|
|
model.metadata_props.pop() |
|
|
for key, value in meta_data.items(): |
|
|
meta = model.metadata_props.add() |
|
|
meta.key = key |
|
|
meta.value = str(value) |
|
|
onnx.save(model, filename) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def export_embedding_onnx( |
|
|
embedding_layer: torch.nn.Module, |
|
|
vocab_size: int, |
|
|
hidden_size: int, |
|
|
filename: str, |
|
|
seq_length: int, |
|
|
opset_version: int, |
|
|
): |
|
|
embedding_layer.eval() |
|
|
batch_size = 1 |
|
|
dummy_input_ids = torch.randint(0, vocab_size, (batch_size, seq_length), dtype=torch.int64) |
|
|
|
|
|
print(f"Exporting embedding layer to ONNX...") |
|
|
print(f" Input shape: {dummy_input_ids.shape}") |
|
|
print(f" Vocab size: {vocab_size}") |
|
|
print(f" Hidden size: {hidden_size}") |
|
|
print(f" Output shape: ({batch_size}, {seq_length}, {hidden_size})") |
|
|
|
|
|
os.environ["TORCH_ONNX_DISABLE_DYNAMO"] = "1" |
|
|
torch.onnx.export( |
|
|
embedding_layer, |
|
|
dummy_input_ids, |
|
|
filename, |
|
|
opset_version=opset_version, |
|
|
input_names=["input_ids"], |
|
|
output_names=["embeddings"], |
|
|
dynamic_axes={ |
|
|
"input_ids": {0: "batch_size", 1: "sequence_length"}, |
|
|
"embeddings": {0: "batch_size", 1: "sequence_length"}, |
|
|
}, |
|
|
do_constant_folding=True, |
|
|
verbose=False, |
|
|
export_params=True, |
|
|
) |
|
|
print(f"ONNX model saved to: {filename}") |
|
|
|
|
|
|
|
|
def verify_onnx_model( |
|
|
onnx_filename: str, |
|
|
llm_config_path: str, |
|
|
seq_length: int = 100, |
|
|
num_tests: int = 3, |
|
|
): |
|
|
print("Verifying ONNX embedding model...") |
|
|
print("Loading PyTorch model for verification...") |
|
|
pytorch_model = AutoModelForCausalLM.from_pretrained( |
|
|
llm_config_path, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float32, |
|
|
device_map=None, |
|
|
) |
|
|
pytorch_embedding = pytorch_model.get_input_embeddings() |
|
|
pytorch_embedding.eval() |
|
|
|
|
|
print("Loading ONNX model...") |
|
|
session_opts = ort.SessionOptions() |
|
|
session_opts.inter_op_num_threads = 1 |
|
|
session_opts.intra_op_num_threads = 1 |
|
|
|
|
|
try: |
|
|
ort_session = ort.InferenceSession( |
|
|
onnx_filename, |
|
|
sess_options=session_opts, |
|
|
providers=["CPUExecutionProvider"], |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error loading ONNX model: {e}") |
|
|
return False |
|
|
|
|
|
input_name = ort_session.get_inputs()[0].name |
|
|
output_name = ort_session.get_outputs()[0].name |
|
|
print(f"ONNX input name: {input_name}") |
|
|
print(f"ONNX output name: {output_name}") |
|
|
|
|
|
all_passed = True |
|
|
max_diff = 0.0 |
|
|
max_relative_diff = 0.0 |
|
|
|
|
|
for test_idx in range(num_tests): |
|
|
vocab_size = pytorch_embedding.num_embeddings |
|
|
input_ids = torch.randint(0, vocab_size, (1, seq_length), dtype=torch.int64) |
|
|
|
|
|
with torch.no_grad(): |
|
|
pytorch_output = pytorch_embedding(input_ids).numpy() |
|
|
|
|
|
onnx_output = ort_session.run( |
|
|
[output_name], |
|
|
{input_name: input_ids.numpy()}, |
|
|
)[0] |
|
|
|
|
|
diff = np.abs(pytorch_output - onnx_output) |
|
|
max_abs_diff = np.max(diff) |
|
|
mean_abs_diff = np.mean(diff) |
|
|
|
|
|
pytorch_abs = np.abs(pytorch_output) |
|
|
relative_diff = np.where( |
|
|
pytorch_abs > 1e-8, |
|
|
diff / (pytorch_abs + 1e-8), |
|
|
diff |
|
|
) |
|
|
max_rel_diff = np.max(relative_diff) |
|
|
mean_rel_diff = np.mean(relative_diff) |
|
|
|
|
|
max_diff = max(max_diff, max_abs_diff) |
|
|
max_relative_diff = max(max_relative_diff, max_rel_diff) |
|
|
|
|
|
tolerance = 1e-5 |
|
|
passed = max_abs_diff < tolerance |
|
|
|
|
|
status = "PASS" if passed else "FAIL" |
|
|
print(f"\nTest {test_idx + 1}/{num_tests}: {status}") |
|
|
print(f" Input shape: {input_ids.shape}") |
|
|
print(f" Output shape: {onnx_output.shape}") |
|
|
print(f" Max absolute difference: {max_abs_diff:.2e}") |
|
|
print(f" Mean absolute difference: {mean_abs_diff:.2e}") |
|
|
print(f" Max relative difference: {max_rel_diff:.2e}") |
|
|
print(f" Mean relative difference: {mean_rel_diff:.2e}") |
|
|
|
|
|
if not passed: |
|
|
all_passed = False |
|
|
print(f"Warning: Difference exceeds tolerance ({tolerance})") |
|
|
|
|
|
print("Verification Summary:") |
|
|
print(f"Overall status: {'PASSED' if all_passed else 'FAILED'}") |
|
|
print(f"Max absolute difference across all tests: {max_diff:.2e}") |
|
|
print(f"Max relative difference across all tests: {max_relative_diff:.2e}") |
|
|
|
|
|
if all_passed: |
|
|
print("\nAll tests passed! ONNX model matches PyTorch model.") |
|
|
else: |
|
|
print("\nSome tests failed. Please check the differences.") |
|
|
|
|
|
return all_passed |
|
|
|
|
|
|
|
|
def quantize_embedding_model(model, filename, filename_int8, original_size): |
|
|
op_types = set(node.op_type for node in model.graph.node) |
|
|
print(f"Operations in model: {op_types}") |
|
|
|
|
|
if "Gather" not in op_types: |
|
|
print("Quantizing to INT8...") |
|
|
quantize_dynamic( |
|
|
model_input=filename, |
|
|
model_output=filename_int8, |
|
|
op_types_to_quantize=["MatMul"], |
|
|
weight_type=QuantType.QUInt8, |
|
|
) |
|
|
quantized_size = os.path.getsize(filename_int8) / (1024 * 1024) |
|
|
size_reduction = original_size - quantized_size |
|
|
print(f"Quantized model saved to: {filename_int8}") |
|
|
print(f"Quantized model size: {quantized_size:.2f}MB (reduced by {size_reduction:.2f}MB, {size_reduction/original_size*100:.1f}%)") |
|
|
return True |
|
|
|
|
|
print("Embedding model uses Gather, quantizing embedding weights...") |
|
|
gather_node = None |
|
|
embedding_weight_name = None |
|
|
|
|
|
for node in model.graph.node: |
|
|
if node.op_type == "Gather": |
|
|
gather_node = node |
|
|
embedding_weight_name = node.input[0] |
|
|
break |
|
|
|
|
|
if gather_node is None: |
|
|
print("Warning: Gather node not found") |
|
|
shutil.copy(filename, filename_int8) |
|
|
return False |
|
|
|
|
|
for initializer in model.graph.initializer: |
|
|
if initializer.name == embedding_weight_name and len(initializer.dims) >= 2: |
|
|
weight_array = numpy_helper.to_array(initializer) |
|
|
original_size_mb = weight_array.nbytes / (1024 * 1024) |
|
|
|
|
|
weight_min = np.min(weight_array) |
|
|
weight_max = np.max(weight_array) |
|
|
|
|
|
if weight_max <= weight_min: |
|
|
continue |
|
|
|
|
|
scale = (weight_max - weight_min) / 255.0 |
|
|
zero_point = np.clip(np.round(-weight_min / scale), 0, 255).astype(np.uint8) |
|
|
quantized_weight = np.clip( |
|
|
np.round((weight_array - weight_min) / scale), 0, 255 |
|
|
).astype(np.uint8) |
|
|
|
|
|
quantized_size_mb = quantized_weight.nbytes / (1024 * 1024) |
|
|
saved_mb = original_size_mb - quantized_size_mb |
|
|
|
|
|
initializer.data_type = onnx.TensorProto.UINT8 |
|
|
initializer.raw_data = quantized_weight.tobytes() |
|
|
|
|
|
scale_name = initializer.name + "_scale" |
|
|
zp_name = initializer.name + "_zero_point" |
|
|
|
|
|
scale_tensor = onnx.helper.make_tensor(scale_name, onnx.TensorProto.FLOAT, [], [float(scale)]) |
|
|
zp_tensor = onnx.helper.make_tensor(zp_name, onnx.TensorProto.UINT8, [], [int(zero_point)]) |
|
|
|
|
|
model.graph.initializer.append(scale_tensor) |
|
|
model.graph.initializer.append(zp_tensor) |
|
|
|
|
|
gather_output_name = gather_node.output[0] |
|
|
intermediate_output = gather_output_name + "_uint8" |
|
|
|
|
|
gather_node.output[0] = intermediate_output |
|
|
|
|
|
dequant_node = onnx.helper.make_node( |
|
|
"DequantizeLinear", |
|
|
inputs=[intermediate_output, scale_name, zp_name], |
|
|
outputs=[gather_output_name], |
|
|
name="DequantizeLinear_0" |
|
|
) |
|
|
|
|
|
node_idx = list(model.graph.node).index(gather_node) |
|
|
model.graph.node.insert(node_idx + 1, dequant_node) |
|
|
|
|
|
print(f" Quantized {initializer.name}: {initializer.dims}, saved {saved_mb:.2f}MB") |
|
|
|
|
|
try: |
|
|
onnx.checker.check_model(model) |
|
|
onnx.save(model, filename_int8) |
|
|
quantized_size = os.path.getsize(filename_int8) / (1024 * 1024) |
|
|
size_reduction = original_size - quantized_size |
|
|
print(f"Quantized model saved to: {filename_int8}") |
|
|
print(f"Quantized model size: {quantized_size:.2f}MB (reduced by {size_reduction:.2f}MB, {size_reduction/original_size*100:.1f}%)") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"Error saving quantized model: {e}") |
|
|
print("Falling back to original model") |
|
|
shutil.copy(filename, filename_int8) |
|
|
return False |
|
|
|
|
|
print("Warning: No weights were quantized, copying original model") |
|
|
shutil.copy(filename, filename_int8) |
|
|
return False |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def main(): |
|
|
args = get_args() |
|
|
print(vars(args)) |
|
|
|
|
|
llm_config_path = Path(args.llm_config_path) |
|
|
if not llm_config_path.exists(): |
|
|
raise ValueError(f"LLM config path not found: {llm_config_path}") |
|
|
|
|
|
print(f"\nChecking for model files in {llm_config_path}...") |
|
|
has_model = check_safetensors_file(str(llm_config_path)) |
|
|
if not has_model: |
|
|
print("Warning: No model files found. Model loading may fail.") |
|
|
|
|
|
print(f"\nLoading LLM config from {llm_config_path}...") |
|
|
config = AutoConfig.from_pretrained(str(llm_config_path), trust_remote_code=True) |
|
|
vocab_size = config.vocab_size |
|
|
hidden_size = config.hidden_size |
|
|
print(f"Model config: vocab_size={vocab_size}, hidden_size={hidden_size}") |
|
|
|
|
|
print(f"\nLoading PyTorch model...") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
str(llm_config_path), |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float32, |
|
|
device_map=None, |
|
|
) |
|
|
embedding_layer = model.get_input_embeddings() |
|
|
embedding_layer.eval() |
|
|
embedding_layer.to("cpu") |
|
|
|
|
|
print(f"Embedding layer: {type(embedding_layer).__name__}") |
|
|
print(f" Weight shape: {embedding_layer.weight.shape}") |
|
|
print(f" Num embeddings: {embedding_layer.num_embeddings}") |
|
|
print(f" Embedding dim: {embedding_layer.embedding_dim}") |
|
|
|
|
|
seq_length = args.seq_length |
|
|
opset_version = args.opset_version |
|
|
filename = args.output_filename |
|
|
|
|
|
export_embedding_onnx( |
|
|
embedding_layer=embedding_layer, |
|
|
vocab_size=vocab_size, |
|
|
hidden_size=hidden_size, |
|
|
filename=filename, |
|
|
seq_length=seq_length, |
|
|
opset_version=opset_version, |
|
|
) |
|
|
|
|
|
model_author = "FunAudioLLM" |
|
|
comment = os.environ.get("comment", "FunAudioLLM/Fun-ASR-Nano-2512") |
|
|
url = "https://huggingface.co/FunAudioLLM/Fun-ASR-Nano-2512" |
|
|
meta_data = { |
|
|
"model_type": "embedding_layer", |
|
|
"version": "1", |
|
|
"model_author": model_author, |
|
|
"vocab_size": vocab_size, |
|
|
"hidden_size": hidden_size, |
|
|
"comment": comment, |
|
|
"url": url, |
|
|
} |
|
|
add_meta_data(filename, meta_data) |
|
|
print("Metadata added to ONNX model.") |
|
|
|
|
|
filename_int8 = filename.replace(".onnx", ".int8.onnx") |
|
|
model = onnx.load(filename) |
|
|
original_size = os.path.getsize(filename) / (1024 * 1024) |
|
|
print(f"Original model size: {original_size:.2f}MB") |
|
|
|
|
|
quantize_embedding_model(model, filename, filename_int8, original_size) |
|
|
add_meta_data(filename_int8, meta_data) |
|
|
|
|
|
if args.verify: |
|
|
verify_onnx_model( |
|
|
onnx_filename=filename, |
|
|
llm_config_path=str(llm_config_path), |
|
|
seq_length=min(seq_length, 200), |
|
|
num_tests=5, |
|
|
) |
|
|
else: |
|
|
print("\nNote: Use --verify flag to verify the exported model") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
torch.manual_seed(20251219) |
|
|
main() |
|
|
|
|
|
|