| import os |
| import torch |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| from onnxruntime.quantization import quantize_dynamic, quantize_static, QuantType |
| from onnxruntime.quantization.calibrate import CalibrationDataReader |
| import onnx |
| import time |
| import numpy as np |
|
|
| def ensure_directory(path): |
| """Create directory if it doesn't exist""" |
| abs_path = os.path.abspath(path) |
| if not os.path.exists(abs_path): |
| os.makedirs(abs_path) |
| print(f"Created directory: {abs_path}") |
| return abs_path |
|
|
| def verify_file_exists(file_path, timeout=5): |
| """Verify that a file exists and is not empty""" |
| start_time = time.time() |
| while time.time() - start_time < timeout: |
| if os.path.exists(file_path) and os.path.getsize(file_path) > 0: |
| return True |
| time.sleep(0.1) |
| return False |
|
|
| def export_to_onnx(model, tokenizer, save_path): |
| """Export model to ONNX format""" |
| try: |
| |
| dummy_input = tokenizer("This is a sample input", return_tensors="pt") |
|
|
| |
| torch.onnx.export( |
| model, |
| (dummy_input["input_ids"], dummy_input["attention_mask"]), |
| save_path, |
| opset_version=14, |
| input_names=["input_ids", "attention_mask"], |
| output_names=["output"], |
| dynamic_axes={ |
| "input_ids": {0: "batch_size"}, |
| "attention_mask": {0: "batch_size"}, |
| "output": {0: "batch_size"} |
| } |
| ) |
|
|
| |
| if verify_file_exists(save_path): |
| print(f"Successfully exported ONNX model to {save_path}") |
| return True |
| else: |
| print(f"Failed to verify ONNX model at {save_path}") |
| return False |
| except Exception as e: |
| print(f"Error exporting to ONNX: {str(e)}") |
| return False |
|
|
| def create_calibration_dataset(tokenizer, max_length=512): |
| """Generate calibration dataset for static quantization with padding""" |
| samples = [ |
| "This is an English sentence.", |
| "Dies ist ein deutscher Satz.", |
| "C'est une phrase française.", |
| "Esta es una frase en español.", |
| "这是一个中文句子。", |
| "これは日本語の文章です。" |
| ] |
|
|
| |
| encoded_samples = [] |
| for text in samples: |
| encoded = tokenizer( |
| text, |
| padding='max_length', |
| max_length=max_length, |
| truncation=True, |
| return_tensors="pt" |
| ) |
| encoded_samples.append({ |
| 'input_ids': encoded['input_ids'], |
| 'attention_mask': encoded['attention_mask'] |
| }) |
|
|
| return encoded_samples |
|
|
| class CalibrationLoader(CalibrationDataReader): |
| def __init__(self, calibration_data): |
| self.calibration_data = calibration_data |
| self.current_index = 0 |
|
|
| def get_next(self): |
| if self.current_index >= len(self.calibration_data): |
| return None |
|
|
| current_data = self.calibration_data[self.current_index] |
| self.current_index += 1 |
|
|
| |
| return { |
| 'input_ids': current_data['input_ids'].numpy(), |
| 'attention_mask': current_data['attention_mask'].numpy() |
| } |
|
|
| def rewind(self): |
| self.current_index = 0 |
|
|
| def export_to_onnx(model, tokenizer, save_path, max_length=512): |
| """Export model to ONNX format with fixed dimensions""" |
| try: |
| |
| dummy_input = tokenizer( |
| "This is a sample input", |
| padding='max_length', |
| max_length=max_length, |
| truncation=True, |
| return_tensors="pt" |
| ) |
|
|
| |
| torch.onnx.export( |
| model, |
| (dummy_input["input_ids"], dummy_input["attention_mask"]), |
| save_path, |
| opset_version=14, |
| input_names=["input_ids", "attention_mask"], |
| output_names=["output"], |
| dynamic_axes={ |
| "input_ids": {0: "batch_size"}, |
| "attention_mask": {0: "batch_size"} |
| } |
| ) |
|
|
| if verify_file_exists(save_path): |
| print(f"Successfully exported ONNX model to {save_path}") |
| return True |
| else: |
| print(f"Failed to verify ONNX model at {save_path}") |
| return False |
| except Exception as e: |
| print(f"Error exporting to ONNX: {str(e)}") |
| return False |
|
|
| def quantize_model(base_onnx_path, onnx_dir, config_name, calibration_dataset=None): |
| """ |
| Quantize ONNX model using either dynamic or static quantization. |
| |
| Args: |
| base_onnx_path (str): Path to the base ONNX model |
| onnx_dir (str): Directory to save quantized models |
| config_name (str): Type of quantization ('dynamic' or 'static') |
| calibration_dataset (list, optional): Dataset for static quantization calibration |
| """ |
| try: |
| quantized_model_path = os.path.join(onnx_dir, f"model_{config_name}_quantized.onnx") |
|
|
| if config_name == "dynamic": |
| print(f"\nPerforming dynamic quantization...") |
| quantize_dynamic( |
| model_input=base_onnx_path, |
| model_output=quantized_model_path, |
| weight_type=QuantType.QUInt8 |
| ) |
|
|
| elif config_name == "static" and calibration_dataset is not None: |
| print(f"\nPerforming static quantization...") |
| calibration_loader = CalibrationLoader(calibration_dataset) |
| quantize_static( |
| model_input=base_onnx_path, |
| model_output=quantized_model_path, |
| calibration_data_reader=calibration_loader, |
| quant_format=QuantType.QUInt8 |
| ) |
|
|
| else: |
| print(f"Invalid quantization configuration: {config_name}") |
| return False |
|
|
| |
| if verify_file_exists(quantized_model_path): |
| print(f"Successfully created {config_name} quantized model at {quantized_model_path}") |
|
|
| |
| base_size = os.path.getsize(base_onnx_path) / (1024 * 1024) |
| quantized_size = os.path.getsize(quantized_model_path) / (1024 * 1024) |
|
|
| print(f"Original model size: {base_size:.2f} MB") |
| print(f"Quantized model size: {quantized_size:.2f} MB") |
| print(f"Size reduction: {((base_size - quantized_size) / base_size * 100):.2f}%") |
|
|
| return True |
| else: |
| print(f"Failed to verify quantized model at {quantized_model_path}") |
| return False |
|
|
| except Exception as e: |
| print(f"Error during {config_name} quantization: {str(e)}") |
| return False |
|
|
|
|
| def main(): |
| |
| current_dir = os.path.abspath(os.getcwd()) |
| onnx_dir = ensure_directory(os.path.join(current_dir, "onnx")) |
| base_onnx_path = os.path.join(onnx_dir, "model.onnx") |
|
|
| print(f"Working directory: {current_dir}") |
| print(f"ONNX directory: {onnx_dir}") |
| print(f"Base ONNX model path: {base_onnx_path}") |
|
|
| |
| print("\nLoading model and tokenizer...") |
| model_name = "alexneakameni/language_detection" |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
| |
| max_length = tokenizer.model_max_length |
|
|
| |
| if not export_to_onnx(model, tokenizer, base_onnx_path, max_length): |
| print("Failed to export base ONNX model. Exiting.") |
| return |
|
|
| |
| try: |
| print(f"Verifying ONNX model at: {base_onnx_path}") |
| onnx_model = onnx.load(base_onnx_path) |
| print("Successfully verified ONNX model") |
| except Exception as e: |
| print(f"Error verifying ONNX model: {str(e)}") |
| return |
|
|
| |
| calibration_dataset = create_calibration_dataset(tokenizer, max_length) |
|
|
| |
| print("\nCreating quantized versions...") |
|
|
| |
| quantize_model( |
| base_onnx_path=base_onnx_path, |
| onnx_dir=onnx_dir, |
| config_name="dynamic" |
| ) |
|
|
| |
| quantize_model( |
| base_onnx_path=base_onnx_path, |
| onnx_dir=onnx_dir, |
| config_name="static", |
| calibration_dataset=calibration_dataset |
| ) |
|
|
| if __name__ == "__main__": |
| main() |
|
|