Spaces:
Sleeping
Sleeping
File size: 2,075 Bytes
ceeab90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | import os
import torch
from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification
from optimum.onnxruntime import ORTModelForAudioClassification
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from optimum.onnxruntime import ORTQuantizer
import shutil
# --- CONFIG ---
MODEL_PATH = "models/wav2vec2-finetuned"
ONNX_PATH = "models/onnx"
QUANTIZED_PATH = "models/onnx_quantized"
def export_to_onnx():
print(f"Exporting PyTorch model to ONNX...")
# Load PyTorch model
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
# Export to ONNX using Optimum (handles complex configs automatically)
model = ORTModelForAudioClassification.from_pretrained(MODEL_PATH, export=True)
# Save ONNX model
if os.path.exists(ONNX_PATH):
shutil.rmtree(ONNX_PATH)
model.save_pretrained(ONNX_PATH)
feature_extractor.save_pretrained(ONNX_PATH)
print(f"ONNX model saved to: {ONNX_PATH}")
def quantize_onnx():
print(f"Quantizing ONNX model to INT8...")
# Load ONNX model for quantization
quantizer = ORTQuantizer.from_pretrained(ONNX_PATH, file_name="model.onnx")
# Define quantization config (INT8 dynamic quantization)
# Exclude Conv layers to prevent 'initializer' errors in Wav2Vec2
qconfig = AutoQuantizationConfig.arm64(
is_static=False,
per_channel=False,
operators_to_quantize=["MatMul", "Attention", "LSTM", "Gather", "Transpose", "EmbedLayerNormalization"]
)
# Apply quantization
if os.path.exists(QUANTIZED_PATH):
shutil.rmtree(QUANTIZED_PATH)
quantizer.quantize(save_dir=QUANTIZED_PATH, quantization_config=qconfig)
# Copy feature extractor config to quantized folder so it's self-contained
feature_extractor = AutoFeatureExtractor.from_pretrained(ONNX_PATH)
feature_extractor.save_pretrained(QUANTIZED_PATH)
print(f"INT8 Quantized model saved to: {QUANTIZED_PATH}")
if __name__ == "__main__":
export_to_onnx()
quantize_onnx()
|