spam-detection-app / quantize_models.py
premmm's picture
Upload folder using huggingface_hub
9930208 verified
Raw
History Blame Contribute Delete
1.76 kB
from optimum.onnxruntime import ORTModelForSequenceClassification, ORTQuantizer
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from transformers import AutoTokenizer
import os
import logging
from pathlib import Path
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
MODELS = {
"MiniLM": "/home/office-7/Downloads/minilm_v2/models/minilm",
"XLM-Roberta": "/home/office-7/Downloads/xlm_roberta_v2/models/xlm_roberta"
}
OUTPUT_DIR = "/home/office-7/Desktop/spam-model/models_int8"
def quantize_model(name, model_path):
logger.info(f"--- Quantizing {name} ---")
output_path = Path(OUTPUT_DIR) / name.lower().replace("-", "_")
output_path.mkdir(parents=True, exist_ok=True)
# 1. Export to ONNX
logger.info(f"Exporting {name} to ONNX...")
model = ORTModelForSequenceClassification.from_pretrained(model_path, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Save the base ONNX model and tokenizer
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
# 2. Quantize
logger.info(f"Applying INT8 Dynamic Quantization to {name}...")
quantizer = ORTQuantizer.from_pretrained(model)
# ARM-64 or X86? Using dynamic quantization which is safe for both
dq_config = AutoQuantizationConfig.avx512_vnni(is_static=False, per_channel=False)
quantizer.quantize(
save_dir=output_path,
quantization_config=dq_config,
)
logger.info(f"Quantized {name} saved to {output_path}")
if __name__ == "__main__":
for name, path in MODELS.items():
try:
quantize_model(name, path)
except Exception as e:
logger.error(f"Failed to quantize {name}: {e}")