| import os |
| import gc |
| import sys |
| import time |
| import logging |
| import traceback |
| import torch |
| import warnings |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from onnxruntime.quantization import quantize_dynamic, QuantType |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| handlers=[logging.StreamHandler(sys.stdout)] |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| warnings.filterwarnings("ignore", category=UserWarning, message=".*The shape of the input dimension.*") |
| warnings.filterwarnings("ignore", category=UserWarning, message=".*Converting a tensor to a Python.*") |
|
|
| |
| RELIABLE_MODELS = [ |
| { |
| "id": "facebook/opt-350m", |
| "description": "Well-balanced model (350M) for RAG and chatbots" |
| }, |
| { |
| "id": "gpt2", |
| "description": "Very reliable model (124M) with excellent ONNX compatibility" |
| }, |
| { |
| "id": "distilgpt2", |
| "description": "Lightweight (82M) model with good performance" |
| } |
| ] |
|
|
| class ModelWrapper(torch.nn.Module): |
| """ |
| Wrapper to handle ONNX export compatibility issues. |
| This wrapper specifically: |
| 1. Bypasses cache handling |
| 2. Simplifies the forward pass to avoid dynamic operations |
| """ |
| def __init__(self, model): |
| super().__init__() |
| self.model = model |
| |
| def forward(self, input_ids): |
| |
| with torch.no_grad(): |
| return self.model(input_ids=input_ids, use_cache=False, return_dict=False)[0] |
|
|
| def convert_model(model_id, output_dir, quantize=True): |
| """Convert a model to ONNX format with maximum compatibility.""" |
| start_time = time.time() |
| |
| logger.info(f"\n{'=' * 60}") |
| logger.info(f"Converting {model_id} to ONNX") |
| logger.info(f"{'=' * 60}") |
| |
| |
| model_name = model_id.split("/")[-1] |
| model_dir = os.path.join(output_dir, model_name) |
| os.makedirs(model_dir, exist_ok=True) |
| |
| try: |
| |
| logger.info("Step 1/5: Loading tokenizer...") |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| |
| |
| if tokenizer.pad_token is None and hasattr(tokenizer, 'eos_token'): |
| logger.info("Adding pad_token = eos_token") |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| tokenizer.save_pretrained(model_dir) |
| logger.info(f"✓ Tokenizer saved to {model_dir}") |
| |
| |
| logger.info("Step 2/5: Loading model with memory optimizations...") |
| |
| |
| gc.collect() |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True |
| ) |
| |
| |
| model.config.save_pretrained(model_dir) |
| logger.info(f"✓ Model config saved to {model_dir}") |
| |
| |
| logger.info("Step 3/5: Preparing for export...") |
| |
| |
| wrapped_model = ModelWrapper(model) |
| wrapped_model.eval() |
| |
| |
| gc.collect() |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| |
| logger.info("Step 4/5: Exporting to ONNX format...") |
| onnx_path = os.path.join(model_dir, "model.onnx") |
| |
| |
| batch_size = 1 |
| seq_length = 8 |
| dummy_input = torch.ones(batch_size, seq_length, dtype=torch.long) |
| |
| |
| torch.onnx.export( |
| wrapped_model, |
| dummy_input, |
| onnx_path, |
| export_params=True, |
| opset_version=14, |
| do_constant_folding=True, |
| input_names=['input_ids'], |
| output_names=['logits'], |
| dynamic_axes={ |
| 'input_ids': {0: 'batch_size', 1: 'sequence'}, |
| 'logits': {0: 'batch_size', 1: 'sequence'} |
| } |
| ) |
| |
| |
| del model |
| del wrapped_model |
| gc.collect() |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| |
| if os.path.exists(onnx_path): |
| size_mb = os.path.getsize(onnx_path) / (1024 * 1024) |
| logger.info(f"✓ ONNX model saved to {onnx_path}") |
| logger.info(f"✓ Original size: {size_mb:.2f} MB") |
| |
| |
| if quantize: |
| logger.info("Step 5/5: Applying int8 quantization...") |
| quant_path = onnx_path.replace(".onnx", "_quantized.onnx") |
| |
| try: |
| quantize_dynamic( |
| model_input=onnx_path, |
| model_output=quant_path, |
| per_channel=False, |
| reduce_range=False, |
| weight_type=QuantType.QInt8 |
| ) |
| |
| if os.path.exists(quant_path): |
| quant_size = os.path.getsize(quant_path) / (1024 * 1024) |
| logger.info(f"✓ Quantized size: {quant_size:.2f} MB") |
| logger.info(f"✓ Size reduction: {(1 - quant_size/size_mb) * 100:.1f}%") |
| |
| |
| os.replace(quant_path, onnx_path) |
| logger.info("✓ Replaced original with quantized version") |
| else: |
| logger.warning("âš Quantized file not created, using original") |
| except Exception as e: |
| logger.error(f"âš Quantization error: {str(e)}") |
| logger.info("âš Using original model without quantization") |
| else: |
| logger.info("Step 5/5: Skipping quantization (not requested)") |
| |
| |
| end_time = time.time() |
| duration = end_time - start_time |
| logger.info(f"✓ Conversion completed in {duration:.2f} seconds") |
| |
| return { |
| "success": True, |
| "model_id": model_id, |
| "size_mb": os.path.getsize(onnx_path) / (1024 * 1024), |
| "duration_seconds": duration, |
| "output_dir": model_dir |
| } |
| else: |
| logger.error(f"× ONNX file not created at {onnx_path}") |
| return { |
| "success": False, |
| "model_id": model_id, |
| "error": "ONNX file not created" |
| } |
| |
| except Exception as e: |
| logger.error(f"× Error converting model: {str(e)}") |
| logger.error(traceback.format_exc()) |
| |
| return { |
| "success": False, |
| "model_id": model_id, |
| "error": str(e) |
| } |
|
|
| def main(): |
| """Convert all reliable models.""" |
| |
| logger.info("\nGUARANTEED ONNX CONVERTER") |
| logger.info("======================") |
| logger.info("Using reliable models with proven ONNX compatibility") |
| |
| |
| output_dir = "./onnx_models" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| if len(sys.argv) > 1: |
| model_id = sys.argv[1] |
| logger.info(f"Converting single model: {model_id}") |
| convert_model(model_id, output_dir) |
| return |
| |
| |
| results = [] |
| for model_info in RELIABLE_MODELS: |
| model_id = model_info["id"] |
| logger.info(f"Processing model: {model_id}") |
| logger.info(f"Description: {model_info['description']}") |
| |
| result = convert_model(model_id, output_dir) |
| results.append(result) |
| |
| |
| logger.info("\n" + "=" * 60) |
| logger.info("CONVERSION SUMMARY") |
| logger.info("=" * 60) |
| |
| success_count = 0 |
| for result in results: |
| if result.get("success", False): |
| success_count += 1 |
| size_info = f" - Size: {result.get('size_mb', 0):.2f} MB" |
| time_info = f" - Time: {result.get('duration_seconds', 0):.2f}s" |
| logger.info(f"✓ SUCCESS: {result['model_id']}{size_info}{time_info}") |
| else: |
| logger.info(f"× FAILED: {result['model_id']} - Error: {result.get('error', 'Unknown error')}") |
| |
| logger.info(f"\nSuccessfully converted {success_count}/{len(RELIABLE_MODELS)} models") |
| logger.info(f"Models saved to: {os.path.abspath(output_dir)}") |
| |
| if success_count > 0: |
| logger.info("\nThe models are ready for RAG and chatbot applications!") |
|
|
| if __name__ == "__main__": |
| main() |