File size: 6,287 Bytes
db19442 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
import os
import logging
from typing import Optional, Dict, Any
class ONNXModelConverter:
def __init__(self, model_name: str, output_dir: str):
self.model_name = model_name
self.output_dir = output_dir
self.setup_logging()
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Load model and tokenizer
self.logger.info(f"Loading model {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
# Load model with specific dtype
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float32
)
self.model.eval()
def setup_logging(self):
"""Set up logging configuration"""
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def prepare_dummy_inputs(self):
"""Prepare dummy inputs for ONNX export"""
# Create a simple input for testing
dummy_input = self.tokenizer(
"Hello, how are you?",
return_tensors="pt",
padding=True,
truncation=True,
max_length=128
)
return {
'input_ids': dummy_input['input_ids'],
'attention_mask': dummy_input['attention_mask']
}
def export_to_onnx(self):
"""Export model to ONNX format"""
output_path = os.path.join(self.output_dir, "model.onnx")
# Get dummy inputs
inputs = self.prepare_dummy_inputs()
# Define dynamic axes for variable length inputs
dynamic_axes = {
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
'logits': {0: 'batch_size', 1: 'sequence_length'}
}
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
return outputs.logits
# Wrap the model
wrapped_model = ModelWrapper(self.model)
try:
# Export to ONNX
torch.onnx.export(
wrapped_model,
(inputs['input_ids'], inputs['attention_mask']),
output_path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=['input_ids', 'attention_mask'],
output_names=['logits'],
dynamic_axes=dynamic_axes,
verbose=False
)
self.logger.info(f"Model exported to {output_path}")
return output_path
except Exception as e:
self.logger.error(f"ONNX export failed: {str(e)}")
raise
def verify_model(self, model_path: str):
"""Verify the exported ONNX model"""
try:
onnx_model = onnx.load(model_path)
onnx.checker.check_model(onnx_model)
self.logger.info("ONNX model verification successful")
return True
except Exception as e:
self.logger.error(f"Model verification failed: {str(e)}")
return False
def quantize_model(self, model_path: str):
"""Quantize the ONNX model"""
weight_types = {'int4':QuantType.QInt4, 'int8':QuantType.QInt8, 'uint4':QuantType.QUInt4, 'uint8':QuantType.QUInt8, 'uint16':QuantType.QUInt16, 'int16':QuantType.QInt16}
all_quantized_paths = []
for weight_type in weight_types.keys():
quantized_path = os.path.join(self.output_dir, "model_" + weight_type + ".onnx")
try:
quantize_dynamic(
model_path,
quantized_path,
weight_type=weight_types[weight_type]
)
self.logger.info(f"Model quantized and saved to {quantized_path}")
all_quantized_paths.append(quantized_path)
except Exception as e:
self.logger.error(f"Quantization failed: {str(e)}")
raise
return all_quantized_paths
def convert(self):
"""Complete conversion process"""
try:
# Export to ONNX
onnx_path = self.export_to_onnx()
# Verify the exported model
if self.verify_model(onnx_path):
# Quantize if verification successful
quantized_path = self.quantize_model(onnx_path)
# Save the tokenizer
tokenizer_path = os.path.join(self.output_dir, "tokenizer")
self.tokenizer.save_pretrained(tokenizer_path)
self.logger.info(f"Tokenizer saved to {tokenizer_path}")
return {
'onnx_model': onnx_path,
'quantized_model': quantized_path,
'tokenizer': tokenizer_path
}
else:
raise Exception("Model verification failed")
except Exception as e:
self.logger.error(f"Conversion process failed: {str(e)}")
raise
if __name__ == "__main__":
MODEL_NAME = "SmallDoge/Doge-60M-Instruct"
OUTPUT_DIR = "onnx"
try:
converter = ONNXModelConverter(MODEL_NAME, OUTPUT_DIR)
results = converter.convert()
print("\nConversion completed successfully!")
print(f"ONNX model path: {results['onnx_model']}")
print(f"Quantized model path: {results['quantized_model']}")
print(f"Tokenizer path: {results['tokenizer']}")
except Exception as e:
print(f"Conversion failed: {str(e)}")
|