Translation
LiteRT
Safetensors
English
Chinese
marian

Quantizing to int8 could compress its size while keeping the results almost unchanged.

#1
by greyovo - opened

Generated by Kimi 2.6 and tested. Clone this repository and create a new script contains code below:

import os
import sys

import torch
from transformers import MarianMTModel, AutoTokenizer
from pathlib import Path


class MarianEncoderDecoderWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(
        self,
        input_ids,
        attention_mask,
        decoder_input_ids,
    ):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            return_dict=False,
        )
        return outputs[0]


def export_to_onnx(model_dir, output_path, use_fp16=False, quantize=False):
    model_dir = Path(model_dir)
    output_path = Path(output_path)

    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

    print("Loading model...")
    model = MarianMTModel.from_pretrained(model_dir)
    model.eval()

    if use_fp16:
        print("Converting model to float16...")
        model = model.half()

    print("Wrapping model for ONNX export...")
    wrapped_model = MarianEncoderDecoderWrapper(model)

    dummy_input_ids = torch.ones(1, 128, dtype=torch.long)
    dummy_attention_mask = torch.ones(1, 128, dtype=torch.long)
    dummy_decoder_input_ids = torch.ones(1, 1, dtype=torch.long)

    print("Exporting to ONNX...")
    torch.onnx.export(
        wrapped_model,
        (dummy_input_ids, dummy_attention_mask, dummy_decoder_input_ids),
        str(output_path),
        input_names=["input_ids", "attention_mask", "decoder_input_ids"],
        output_names=["logits"],
        dynamic_axes={
            "input_ids": {0: "batch_size", 1: "sequence_length"},
            "attention_mask": {0: "batch_size", 1: "sequence_length"},
            "decoder_input_ids": {0: "batch_size", 1: "decoder_sequence_length"},
            "logits": {0: "batch_size", 1: "decoder_sequence_length"},
        },
        opset_version=14,
        do_constant_folding=True,
        dynamo=False,
    )

    print(f"ONNX model exported to: {output_path}")
    print(f"File size: {output_path.stat().st_size / (1024 * 1024):.2f} MB")

    if quantize:
        quantize_onnx(output_path)


def quantize_onnx(model_path):
    try:
        from onnxruntime.quantization import quantize_dynamic, QuantType

        model_path = Path(model_path)
        quantized_path = model_path.with_suffix(".quantized.onnx")

        print(f"Quantizing ONNX model to INT8...")
        quantize_dynamic(
            model_input=str(model_path),
            model_output=str(quantized_path),
            weight_type=QuantType.QInt8,
        )

        original_size = model_path.stat().st_size / (1024 * 1024)
        quantized_size = quantized_path.stat().st_size / (1024 * 1024)
        print(f"Quantized model saved to: {quantized_path}")
        print(f"Original size: {original_size:.2f} MB")
        print(f"Quantized size: {quantized_size:.2f} MB")
        print(f"Compression ratio: {original_size / quantized_size:.2f}x")

        return quantized_path
    except ImportError:
        print("onnxruntime not installed, skipping quantization")
        return None


if __name__ == "__main__":
    model_dir = Path(__file__).parent

    print("=" * 50)
    print("Exporting FP16 ONNX model...")
    print("=" * 50)
    fp16_output_path = model_dir / "model.onnx"
    export_to_onnx(model_dir, fp16_output_path, use_fp16=True, quantize=False)
    print()

    print("=" * 50)
    print("Exporting FP32 ONNX model for INT8 quantization...")
    print("=" * 50)
    fp32_output_path = model_dir / "model.fp32.onnx"
    export_to_onnx(model_dir, fp32_output_path, use_fp16=False, quantize=True)

Sign up or log in to comment