Quantizing to int8 could compress its size while keeping the results almost unchanged.
#1
by greyovo - opened
Generated by Kimi 2.6 and tested. Clone this repository and create a new script contains code below:
import os
import sys
import torch
from transformers import MarianMTModel, AutoTokenizer
from pathlib import Path
class MarianEncoderDecoderWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
input_ids,
attention_mask,
decoder_input_ids,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
return_dict=False,
)
return outputs[0]
def export_to_onnx(model_dir, output_path, use_fp16=False, quantize=False):
model_dir = Path(model_dir)
output_path = Path(output_path)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
print("Loading model...")
model = MarianMTModel.from_pretrained(model_dir)
model.eval()
if use_fp16:
print("Converting model to float16...")
model = model.half()
print("Wrapping model for ONNX export...")
wrapped_model = MarianEncoderDecoderWrapper(model)
dummy_input_ids = torch.ones(1, 128, dtype=torch.long)
dummy_attention_mask = torch.ones(1, 128, dtype=torch.long)
dummy_decoder_input_ids = torch.ones(1, 1, dtype=torch.long)
print("Exporting to ONNX...")
torch.onnx.export(
wrapped_model,
(dummy_input_ids, dummy_attention_mask, dummy_decoder_input_ids),
str(output_path),
input_names=["input_ids", "attention_mask", "decoder_input_ids"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"decoder_input_ids": {0: "batch_size", 1: "decoder_sequence_length"},
"logits": {0: "batch_size", 1: "decoder_sequence_length"},
},
opset_version=14,
do_constant_folding=True,
dynamo=False,
)
print(f"ONNX model exported to: {output_path}")
print(f"File size: {output_path.stat().st_size / (1024 * 1024):.2f} MB")
if quantize:
quantize_onnx(output_path)
def quantize_onnx(model_path):
try:
from onnxruntime.quantization import quantize_dynamic, QuantType
model_path = Path(model_path)
quantized_path = model_path.with_suffix(".quantized.onnx")
print(f"Quantizing ONNX model to INT8...")
quantize_dynamic(
model_input=str(model_path),
model_output=str(quantized_path),
weight_type=QuantType.QInt8,
)
original_size = model_path.stat().st_size / (1024 * 1024)
quantized_size = quantized_path.stat().st_size / (1024 * 1024)
print(f"Quantized model saved to: {quantized_path}")
print(f"Original size: {original_size:.2f} MB")
print(f"Quantized size: {quantized_size:.2f} MB")
print(f"Compression ratio: {original_size / quantized_size:.2f}x")
return quantized_path
except ImportError:
print("onnxruntime not installed, skipping quantization")
return None
if __name__ == "__main__":
model_dir = Path(__file__).parent
print("=" * 50)
print("Exporting FP16 ONNX model...")
print("=" * 50)
fp16_output_path = model_dir / "model.onnx"
export_to_onnx(model_dir, fp16_output_path, use_fp16=True, quantize=False)
print()
print("=" * 50)
print("Exporting FP32 ONNX model for INT8 quantization...")
print("=" * 50)
fp32_output_path = model_dir / "model.fp32.onnx"
export_to_onnx(model_dir, fp32_output_path, use_fp16=False, quantize=True)