Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification | |
| from optimum.onnxruntime import ORTModelForAudioClassification | |
| from optimum.onnxruntime.configuration import AutoQuantizationConfig | |
| from optimum.onnxruntime import ORTQuantizer | |
| import shutil | |
| # --- CONFIG --- | |
| MODEL_PATH = "models/wav2vec2-finetuned" | |
| ONNX_PATH = "models/onnx" | |
| QUANTIZED_PATH = "models/onnx_quantized" | |
| def export_to_onnx(): | |
| print(f"Exporting PyTorch model to ONNX...") | |
| # Load PyTorch model | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH) | |
| # Export to ONNX using Optimum (handles complex configs automatically) | |
| model = ORTModelForAudioClassification.from_pretrained(MODEL_PATH, export=True) | |
| # Save ONNX model | |
| if os.path.exists(ONNX_PATH): | |
| shutil.rmtree(ONNX_PATH) | |
| model.save_pretrained(ONNX_PATH) | |
| feature_extractor.save_pretrained(ONNX_PATH) | |
| print(f"ONNX model saved to: {ONNX_PATH}") | |
| def quantize_onnx(): | |
| print(f"Quantizing ONNX model to INT8...") | |
| # Load ONNX model for quantization | |
| quantizer = ORTQuantizer.from_pretrained(ONNX_PATH, file_name="model.onnx") | |
| # Define quantization config (INT8 dynamic quantization) | |
| # Exclude Conv layers to prevent 'initializer' errors in Wav2Vec2 | |
| qconfig = AutoQuantizationConfig.arm64( | |
| is_static=False, | |
| per_channel=False, | |
| operators_to_quantize=["MatMul", "Attention", "LSTM", "Gather", "Transpose", "EmbedLayerNormalization"] | |
| ) | |
| # Apply quantization | |
| if os.path.exists(QUANTIZED_PATH): | |
| shutil.rmtree(QUANTIZED_PATH) | |
| quantizer.quantize(save_dir=QUANTIZED_PATH, quantization_config=qconfig) | |
| # Copy feature extractor config to quantized folder so it's self-contained | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(ONNX_PATH) | |
| feature_extractor.save_pretrained(QUANTIZED_PATH) | |
| print(f"INT8 Quantized model saved to: {QUANTIZED_PATH}") | |
| if __name__ == "__main__": | |
| export_to_onnx() | |
| quantize_onnx() | |