Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import onnx | |
| import logging | |
| from scipy.stats import spearmanr | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from transformers import BitsAndBytesConfig | |
| from onnxconverter_common import float16 | |
| from onnxruntime.quantization import quantize_dynamic, QuantType | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class ModelQuantizer: | |
| """Handles model quantization and comparison operations""" | |
| def quantize_model(model_class, model_name, quantization_type): | |
| """Quantizes a model based on specified quantization type""" | |
| try: | |
| if quantization_type == "4-bit": | |
| quantization_config = BitsAndBytesConfig(load_in_4bit=True) | |
| model = model_class.from_pretrained(model_name, quantization_config=quantization_config) | |
| elif quantization_type == "8-bit": | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
| model = model_class.from_pretrained(model_name, quantization_config=quantization_config) | |
| elif quantization_type == "16-bit-float": | |
| model = model_class.from_pretrained(model_name) | |
| model = model.to(torch.float16) | |
| else: | |
| raise ValueError(f"Unsupported quantization type: {quantization_type}") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Quantization failed: {str(e)}") | |
| raise | |
| def get_model_size(model): | |
| """Calculate model size in MB""" | |
| try: | |
| torch.save(model.state_dict(), "temp.pth") | |
| size = os.path.getsize("temp.pth") / (1024 * 1024) | |
| os.remove("temp.pth") | |
| return size | |
| except Exception as e: | |
| logger.error(f"Failed to get model size: {str(e)}") | |
| raise | |
| def compare_model_outputs(original_outputs, quantized_outputs): | |
| """Compare outputs between original and quantized models""" | |
| try: | |
| if original_outputs is None or quantized_outputs is None: | |
| return None | |
| if hasattr(original_outputs, 'logits') and hasattr(quantized_outputs, 'logits'): | |
| original_logits = original_outputs.logits.detach().cpu().numpy() | |
| quantized_logits = quantized_outputs.logits.detach().cpu().numpy() | |
| metrics = { | |
| 'mse': ((original_logits - quantized_logits) ** 2).mean(), | |
| 'spearman_corr': spearmanr(original_logits.flatten(), quantized_logits.flatten())[0], | |
| 'cosine_sim': cosine_similarity(original_logits.reshape(1, -1), quantized_logits.reshape(1, -1))[0][0] | |
| } | |
| return metrics | |
| return None | |
| except Exception as e: | |
| logger.error(f"Output comparison failed: {str(e)}") | |
| raise | |
| def quantize_onnx_model(model_dir, quantization_type): | |
| """ | |
| Quantize ONNX model in the specified directory. | |
| """ | |
| logger.info(f"Quantizing ONNX model in: {model_dir}") | |
| for filename in os.listdir(model_dir): | |
| if filename.endswith('.onnx'): | |
| input_model_path = os.path.join(model_dir, filename) | |
| output_model_path = os.path.join(model_dir, f"quantized_{filename}") | |
| try: | |
| model = onnx.load(input_model_path) | |
| if quantization_type == "16-bit-float": | |
| model_fp16 = float16.convert_float_to_float16(model) | |
| onnx.save(model_fp16, output_model_path) | |
| elif quantization_type in ["8-bit", "16-bit-int"]: | |
| quant_type_mapping = { | |
| "8-bit": QuantType.QInt8, | |
| "16-bit-int": QuantType.QInt16, | |
| } | |
| quantize_dynamic( | |
| model_input=input_model_path, | |
| model_output=output_model_path, | |
| weight_type=quant_type_mapping[quantization_type] | |
| ) | |
| else: | |
| logger.error(f"Unsupported quantization type: {quantization_type}") | |
| continue | |
| os.remove(input_model_path) | |
| os.rename(output_model_path, input_model_path) | |
| logger.info(f"Quantized ONNX model saved to: {input_model_path}") | |
| except Exception as e: | |
| logger.error(f"Error during ONNX quantization: {str(e)}") | |
| if os.path.exists(output_model_path): | |
| os.remove(output_model_path) |