{ "cells": [ { "cell_type": "markdown", "id": "d42d2f51", "metadata": {}, "source": [ "# ONNX Quantization\n", "\n", "Quantization ONNX model ke INT8 dengan preprocessing." ] }, { "cell_type": "code", "execution_count": null, "id": "411d7be9", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import onnx\n", "from pathlib import Path\n", "from onnxruntime.quantization import quantize_static, QuantType, CalibrationDataReader" ] }, { "cell_type": "code", "execution_count": null, "id": "b87dc761", "metadata": {}, "outputs": [], "source": [ "def get_size_mb(path):\n", " return os.path.getsize(path) / (1024**2)" ] }, { "cell_type": "code", "execution_count": null, "id": "a3af9a45", "metadata": {}, "outputs": [], "source": [ "ONNX_MODEL = \"/content/drive/MyDrive/models/best_checkpoint.onnx\"\n", "ONNX_PREPROCESSED = str(Path(ONNX_MODEL).with_suffix('')) + \"_preprocessed.onnx\"\n", "ONNX_INT8 = str(Path(ONNX_MODEL).with_suffix('')) + \"_int8.onnx\"\n", "\n", "if not os.path.exists(ONNX_MODEL):\n", " raise FileNotFoundError(f\"Model not found: {ONNX_MODEL}\")\n", "\n", "print(f\"Input: {Path(ONNX_MODEL).name} ({get_size_mb(ONNX_MODEL):.2f} MB)\")" ] }, { "cell_type": "markdown", "id": "bde9da41", "metadata": {}, "source": [ "## 1. Preprocessing" ] }, { "cell_type": "code", "execution_count": null, "id": "f89cfbb8", "metadata": {}, "outputs": [], "source": [ "!python -m onnxruntime.quantization.preprocess --input {ONNX_MODEL} --output {ONNX_PREPROCESSED}\n", "\n", "print(f\"Preprocessed: {get_size_mb(ONNX_PREPROCESSED):.2f} MB\")" ] }, { "cell_type": "markdown", "id": "433bb1f9", "metadata": {}, "source": [ "## 2. Calibration Data Reader" ] }, { "cell_type": "code", "execution_count": null, "id": "ec3cb386", "metadata": {}, "outputs": [], "source": [ "class CalibrationReader(CalibrationDataReader):\n", " def __init__(self, model_path, num_samples=10):\n", " model = onnx.load(model_path)\n", " input_shape = tuple([d.dim_value for d in model.graph.input[0].type.tensor_type.shape.dim])\n", " self.data = [np.random.randn(*input_shape).astype(np.float32) for _ in range(num_samples)]\n", " self.input_name = model.graph.input[0].name\n", " self.enum_index = 0\n", " \n", " def get_next(self):\n", " if self.enum_index >= len(self.data):\n", " return None\n", " input_dict = {self.input_name: self.data[self.enum_index]}\n", " self.enum_index += 1\n", " return input_dict" ] }, { "cell_type": "markdown", "id": "d1080068", "metadata": {}, "source": [ "## 3. INT8 Quantization" ] }, { "cell_type": "code", "execution_count": null, "id": "66501947", "metadata": {}, "outputs": [], "source": [ "calibration_reader = CalibrationReader(ONNX_PREPROCESSED)\n", "\n", "quantize_static(\n", " model_input=ONNX_PREPROCESSED,\n", " model_output=ONNX_INT8,\n", " calibration_data_reader=calibration_reader,\n", " weight_type=QuantType.QUInt8,\n", " optimize_model=False\n", ")\n", "\n", "print(f\"INT8: {get_size_mb(ONNX_INT8):.2f} MB ({(1 - get_size_mb(ONNX_INT8)/get_size_mb(ONNX_MODEL)) * 100:.1f}% reduction)\")" ] }, { "cell_type": "markdown", "id": "0a20622e", "metadata": {}, "source": [ "## 4. Summary" ] }, { "cell_type": "code", "execution_count": null, "id": "05d66ea4", "metadata": {}, "outputs": [], "source": [ "print(\"=\" * 50)\n", "print(f\"{'Model':<15} {'Size (MB)':<15} {'Reduksi':<10}\")\n", "print(\"-\" * 50)\n", "print(f\"{'Original':<15} {get_size_mb(ONNX_MODEL):.2f}{'':<8} baseline\")\n", "print(f\"{'Preprocessed':<15} {get_size_mb(ONNX_PREPROCESSED):.2f}{'':<8} {(1-get_size_mb(ONNX_PREPROCESSED)/get_size_mb(ONNX_MODEL))*100:.1f}%\")\n", "print(f\"{'INT8':<15} {get_size_mb(ONNX_INT8):.2f}{'':<8} {(1-get_size_mb(ONNX_INT8)/get_size_mb(ONNX_MODEL))*100:.1f}%\")\n", "print(\"=\" * 50)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 5 }