| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import onnx |
| | from onnxruntime.quantization import quantize_dynamic, QuantType |
| | import os |
| | import logging |
| | from typing import Optional, Dict, Any |
| |
|
| | class ONNXModelConverter: |
| | def __init__(self, model_name: str, output_dir: str): |
| | self.model_name = model_name |
| | self.output_dir = output_dir |
| | self.setup_logging() |
| |
|
| | |
| | os.makedirs(output_dir, exist_ok=True) |
| |
|
| | |
| | self.logger.info(f"Loading model {model_name}...") |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | model_name, |
| | trust_remote_code=True |
| | ) |
| |
|
| | |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | trust_remote_code=True, |
| | torch_dtype=torch.float32 |
| | ) |
| | self.model.eval() |
| |
|
| | def setup_logging(self): |
| | """Set up logging configuration""" |
| | self.logger = logging.getLogger(__name__) |
| | self.logger.setLevel(logging.INFO) |
| | handler = logging.StreamHandler() |
| | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
| | handler.setFormatter(formatter) |
| | self.logger.addHandler(handler) |
| |
|
| | def prepare_dummy_inputs(self): |
| | """Prepare dummy inputs for ONNX export""" |
| | |
| | dummy_input = self.tokenizer( |
| | "Hello, how are you?", |
| | return_tensors="pt", |
| | padding=True, |
| | truncation=True, |
| | max_length=128 |
| | ) |
| |
|
| | return { |
| | 'input_ids': dummy_input['input_ids'], |
| | 'attention_mask': dummy_input['attention_mask'] |
| | } |
| |
|
| | def export_to_onnx(self): |
| | """Export model to ONNX format""" |
| | output_path = os.path.join(self.output_dir, "model.onnx") |
| |
|
| | |
| | inputs = self.prepare_dummy_inputs() |
| |
|
| | |
| | dynamic_axes = { |
| | 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, |
| | 'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, |
| | 'logits': {0: 'batch_size', 1: 'sequence_length'} |
| | } |
| |
|
| | class ModelWrapper(torch.nn.Module): |
| | def __init__(self, model): |
| | super().__init__() |
| | self.model = model |
| |
|
| | def forward(self, input_ids, attention_mask): |
| | outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
| | return outputs.logits |
| |
|
| | |
| | wrapped_model = ModelWrapper(self.model) |
| |
|
| | try: |
| | |
| | torch.onnx.export( |
| | wrapped_model, |
| | (inputs['input_ids'], inputs['attention_mask']), |
| | output_path, |
| | export_params=True, |
| | opset_version=14, |
| | do_constant_folding=True, |
| | input_names=['input_ids', 'attention_mask'], |
| | output_names=['logits'], |
| | dynamic_axes=dynamic_axes, |
| | verbose=False |
| | ) |
| |
|
| | self.logger.info(f"Model exported to {output_path}") |
| | return output_path |
| |
|
| | except Exception as e: |
| | self.logger.error(f"ONNX export failed: {str(e)}") |
| | raise |
| |
|
| | def verify_model(self, model_path: str): |
| | """Verify the exported ONNX model""" |
| | try: |
| | onnx_model = onnx.load(model_path) |
| | onnx.checker.check_model(onnx_model) |
| | self.logger.info("ONNX model verification successful") |
| | return True |
| | except Exception as e: |
| | self.logger.error(f"Model verification failed: {str(e)}") |
| | return False |
| |
|
| | def quantize_model(self, model_path: str): |
| | """Quantize the ONNX model""" |
| | weight_types = {'int4':QuantType.QInt4, 'int8':QuantType.QInt8, 'uint4':QuantType.QUInt4, 'uint8':QuantType.QUInt8, 'uint16':QuantType.QUInt16, 'int16':QuantType.QInt16} |
| | all_quantized_paths = [] |
| | for weight_type in weight_types.keys(): |
| | quantized_path = os.path.join(self.output_dir, "model_" + weight_type + ".onnx") |
| |
|
| | try: |
| | quantize_dynamic( |
| | model_path, |
| | quantized_path, |
| | weight_type=weight_types[weight_type] |
| | ) |
| | self.logger.info(f"Model quantized and saved to {quantized_path}") |
| | all_quantized_paths.append(quantized_path) |
| | except Exception as e: |
| | self.logger.error(f"Quantization failed: {str(e)}") |
| | raise |
| |
|
| | return all_quantized_paths |
| |
|
| | def convert(self): |
| | """Complete conversion process""" |
| | try: |
| | |
| | onnx_path = self.export_to_onnx() |
| |
|
| | |
| | if self.verify_model(onnx_path): |
| | |
| | quantized_path = self.quantize_model(onnx_path) |
| |
|
| | |
| | tokenizer_path = os.path.join(self.output_dir, "tokenizer") |
| | self.tokenizer.save_pretrained(tokenizer_path) |
| | self.logger.info(f"Tokenizer saved to {tokenizer_path}") |
| |
|
| | return { |
| | 'onnx_model': onnx_path, |
| | 'quantized_model': quantized_path, |
| | 'tokenizer': tokenizer_path |
| | } |
| | else: |
| | raise Exception("Model verification failed") |
| |
|
| | except Exception as e: |
| | self.logger.error(f"Conversion process failed: {str(e)}") |
| | raise |
| |
|
| | if __name__ == "__main__": |
| | MODEL_NAME = "SmallDoge/Doge-60M-Instruct" |
| | OUTPUT_DIR = "onnx" |
| |
|
| | try: |
| | converter = ONNXModelConverter(MODEL_NAME, OUTPUT_DIR) |
| | results = converter.convert() |
| |
|
| | print("\nConversion completed successfully!") |
| | print(f"ONNX model path: {results['onnx_model']}") |
| | print(f"Quantized model path: {results['quantized_model']}") |
| | print(f"Tokenizer path: {results['tokenizer']}") |
| |
|
| | except Exception as e: |
| | print(f"Conversion failed: {str(e)}") |
| |
|