redauzhang
upload model fit for web attack payload classfication/ and model based on codebert-base/ dataset used opensource
62c3b33
| #!/usr/bin/env python3 | |
| """ | |
| Export trained CodeBERT model to ONNX format with optional quantization. | |
| Supports both CPU and GPU inference. | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| from transformers import RobertaTokenizer, RobertaModel | |
| import json | |
| # Paths | |
| MODEL_PATH = "/c1/new-models/best_model.pt" | |
| CODEBERT_PATH = "/c1/huggingface/codebert-base" | |
| OUTPUT_DIR = "/c1/new-models" | |
| ONNX_PATH = os.path.join(OUTPUT_DIR, "model.onnx") | |
| ONNX_QUANTIZED_PATH = os.path.join(OUTPUT_DIR, "model_quantized.onnx") | |
| class CodeBERTClassifier(nn.Module): | |
| """CodeBERT-based classifier for web attack detection - matches training script.""" | |
| def __init__(self, model_path, num_labels=2, dropout=0.1): | |
| super(CodeBERTClassifier, self).__init__() | |
| self.codebert = RobertaModel.from_pretrained(model_path) | |
| self.dropout = nn.Dropout(dropout) | |
| self.classifier = nn.Linear(self.codebert.config.hidden_size, num_labels) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.codebert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.pooler_output | |
| pooled_output = self.dropout(pooled_output) | |
| logits = self.classifier(pooled_output) | |
| return logits | |
| class ONNXCodeBERTClassifier(nn.Module): | |
| """Wrapper for ONNX export with softmax output.""" | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| self.model.dropout.p = 0 # Disable dropout for inference | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.model.codebert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.pooler_output | |
| logits = self.model.classifier(pooled_output) | |
| probabilities = torch.softmax(logits, dim=-1) | |
| return probabilities | |
| def export_to_onnx(): | |
| """Export model to ONNX format.""" | |
| print("=" * 80) | |
| print("ONNX Model Export") | |
| print("=" * 80) | |
| # Device - use CPU for export to avoid CUDA issues | |
| device = torch.device("cpu") | |
| print(f"Export Device: {device}") | |
| # Load tokenizer | |
| print("\n1. Loading tokenizer...") | |
| tokenizer = RobertaTokenizer.from_pretrained(CODEBERT_PATH) | |
| print(f" Tokenizer loaded: {type(tokenizer).__name__}") | |
| # Create model with same architecture as training | |
| print("\n2. Loading model...") | |
| model = CodeBERTClassifier(CODEBERT_PATH) | |
| # Load trained weights | |
| checkpoint = torch.load(MODEL_PATH, map_location=device) | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| model.to(device) | |
| print(f" Model loaded from: {MODEL_PATH}") | |
| # Wrap for ONNX export | |
| onnx_model = ONNXCodeBERTClassifier(model) | |
| onnx_model.eval() | |
| onnx_model.to(device) | |
| # Create dummy input | |
| print("\n3. Creating dummy input...") | |
| max_length = 256 | |
| dummy_text = "SELECT * FROM users WHERE id=1" | |
| inputs = tokenizer( | |
| dummy_text, | |
| max_length=max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| dummy_input_ids = inputs['input_ids'].to(device) | |
| dummy_attention_mask = inputs['attention_mask'].to(device) | |
| print(f" Input shape: {dummy_input_ids.shape}") | |
| # Test forward pass first | |
| print("\n4. Testing forward pass...") | |
| with torch.no_grad(): | |
| test_output = onnx_model(dummy_input_ids, dummy_attention_mask) | |
| print(f" Output shape: {test_output.shape}") | |
| print(f" Output sample: {test_output[0].numpy()}") | |
| # Export to ONNX | |
| print("\n5. Exporting to ONNX...") | |
| torch.onnx.export( | |
| onnx_model, | |
| (dummy_input_ids, dummy_attention_mask), | |
| ONNX_PATH, | |
| export_params=True, | |
| opset_version=14, | |
| do_constant_folding=True, | |
| input_names=['input_ids', 'attention_mask'], | |
| output_names=['probabilities'], | |
| dynamic_axes={ | |
| 'input_ids': {0: 'batch_size', 1: 'sequence_length'}, | |
| 'attention_mask': {0: 'batch_size', 1: 'sequence_length'}, | |
| 'probabilities': {0: 'batch_size'} | |
| } | |
| ) | |
| onnx_size = os.path.getsize(ONNX_PATH) / (1024 * 1024) | |
| print(f" ONNX model saved: {ONNX_PATH}") | |
| print(f" Size: {onnx_size:.2f} MB") | |
| # Quantize model | |
| print("\n6. Quantizing model (dynamic quantization)...") | |
| try: | |
| from onnxruntime.quantization import quantize_dynamic, QuantType | |
| quantize_dynamic( | |
| model_input=ONNX_PATH, | |
| model_output=ONNX_QUANTIZED_PATH, | |
| weight_type=QuantType.QUInt8, | |
| optimize_model=True | |
| ) | |
| quantized_size = os.path.getsize(ONNX_QUANTIZED_PATH) / (1024 * 1024) | |
| print(f" Quantized model saved: {ONNX_QUANTIZED_PATH}") | |
| print(f" Size: {quantized_size:.2f} MB") | |
| print(f" Compression ratio: {onnx_size / quantized_size:.2f}x") | |
| except Exception as e: | |
| print(f" Warning: Quantization failed: {e}") | |
| print(" Using non-quantized model.") | |
| import shutil | |
| shutil.copy(ONNX_PATH, ONNX_QUANTIZED_PATH) | |
| # Verify ONNX model | |
| print("\n7. Verifying ONNX model...") | |
| try: | |
| import onnx | |
| onnx_check = onnx.load(ONNX_PATH) | |
| onnx.checker.check_model(onnx_check) | |
| print(" ONNX model verification: PASSED") | |
| except Exception as e: | |
| print(f" Warning: ONNX verification failed: {e}") | |
| # Test inference with ONNX Runtime | |
| print("\n8. Testing ONNX Runtime inference...") | |
| try: | |
| import onnxruntime as ort | |
| import numpy as np | |
| # Try GPU first, fallback to CPU | |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
| available_providers = ort.get_available_providers() | |
| use_providers = [p for p in providers if p in available_providers] | |
| session = ort.InferenceSession(ONNX_PATH, providers=use_providers) | |
| actual_provider = session.get_providers()[0] | |
| print(f" Using provider: {actual_provider}") | |
| # Test inference | |
| test_texts = [ | |
| "SELECT * FROM users WHERE id=1 OR 1=1", # SQL injection | |
| "GET /index.html HTTP/1.1", # Normal request | |
| "<script>alert('xss')</script>", # XSS | |
| "Mozilla/5.0 (Windows NT 10.0; Win64)", # Normal UA | |
| ] | |
| print("\n Test predictions:") | |
| for text in test_texts: | |
| inputs = tokenizer( | |
| text, | |
| max_length=max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='np' | |
| ) | |
| outputs = session.run( | |
| None, | |
| { | |
| 'input_ids': inputs['input_ids'].astype(np.int64), | |
| 'attention_mask': inputs['attention_mask'].astype(np.int64) | |
| } | |
| ) | |
| probs = outputs[0][0] | |
| pred = np.argmax(probs) | |
| label = "Malicious" if pred == 1 else "Benign" | |
| conf = probs[pred] * 100 | |
| print(f" - '{text[:40]:<40}' => {label:<10} ({conf:.1f}%)") | |
| except Exception as e: | |
| print(f" Warning: ONNX Runtime test failed: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Save export config | |
| print("\n9. Saving export configuration...") | |
| export_config = { | |
| "model_path": ONNX_PATH, | |
| "quantized_model_path": ONNX_QUANTIZED_PATH, | |
| "max_length": max_length, | |
| "tokenizer_path": CODEBERT_PATH, | |
| "labels": {"0": "benign", "1": "malicious"}, | |
| "input_names": ["input_ids", "attention_mask"], | |
| "output_names": ["probabilities"] | |
| } | |
| config_path = os.path.join(OUTPUT_DIR, "onnx_config.json") | |
| with open(config_path, 'w') as f: | |
| json.dump(export_config, f, indent=2) | |
| print(f" Config saved: {config_path}") | |
| print("\n" + "=" * 80) | |
| print("Export completed!") | |
| print("=" * 80) | |
| print(f"ONNX Model: {ONNX_PATH} ({onnx_size:.2f} MB)") | |
| if os.path.exists(ONNX_QUANTIZED_PATH): | |
| qsize = os.path.getsize(ONNX_QUANTIZED_PATH) / (1024 * 1024) | |
| print(f"Quantized Model: {ONNX_QUANTIZED_PATH} ({qsize:.2f} MB)") | |
| print("=" * 80) | |
| if __name__ == "__main__": | |
| export_to_onnx() | |