redauzhang
upload model fit for web attack payload classfication/ and model based on codebert-base/ dataset used opensource
62c3b33
| #!/usr/bin/env python3 | |
| """ | |
| Test ONNX model accuracy with 2000 samples from the dataset. | |
| """ | |
| import os | |
| import pandas as pd | |
| import numpy as np | |
| import requests | |
| import time | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report | |
| # Configuration | |
| API_URL = "http://localhost:8001" | |
| DATASET_PATH = "/c1/web-attack-detection/dataset.csv" | |
| NUM_SAMPLES = 2000 # 1000 malicious + 1000 benign | |
| BATCH_SIZE = 50 | |
| def test_accuracy(): | |
| print("=" * 80) | |
| print("ONNX Model Accuracy Test") | |
| print("=" * 80) | |
| # Check API health | |
| print("\n1. Checking API health...") | |
| try: | |
| resp = requests.get(f"{API_URL}/health", timeout=10) | |
| health = resp.json() | |
| print(f" Status: {health['status']}") | |
| print(f" Device: {health['device']}") | |
| print(f" Provider: {health['provider']}") | |
| print(f" Model: {health['model_path']}") | |
| except Exception as e: | |
| print(f" Error: {e}") | |
| print(" Please ensure the server is running!") | |
| return | |
| # Load dataset | |
| print("\n2. Loading dataset...") | |
| df = pd.read_csv(DATASET_PATH) | |
| df = df.dropna(subset=['Sentence', 'Label']) | |
| df['Sentence'] = df['Sentence'].astype(str) | |
| df['Label'] = df['Label'].astype(int) | |
| print(f" Total samples: {len(df)}") | |
| # Sample data | |
| print("\n3. Sampling test data...") | |
| samples_per_class = NUM_SAMPLES // 2 | |
| benign_samples = df[df['Label'] == 0].sample(n=min(samples_per_class, len(df[df['Label'] == 0])), random_state=42) | |
| malicious_samples = df[df['Label'] == 1].sample(n=min(samples_per_class, len(df[df['Label'] == 1])), random_state=42) | |
| test_df = pd.concat([benign_samples, malicious_samples]).sample(frac=1, random_state=42).reset_index(drop=True) | |
| print(f" Test samples: {len(test_df)}") | |
| print(f" Benign: {len(test_df[test_df['Label'] == 0])}") | |
| print(f" Malicious: {len(test_df[test_df['Label'] == 1])}") | |
| # Run predictions | |
| print("\n4. Running predictions...") | |
| predictions = [] | |
| true_labels = [] | |
| total_time = 0 | |
| for i in range(0, len(test_df), BATCH_SIZE): | |
| batch = test_df.iloc[i:i+BATCH_SIZE] | |
| payloads = batch['Sentence'].tolist() | |
| labels = batch['Label'].tolist() | |
| try: | |
| start = time.time() | |
| resp = requests.post( | |
| f"{API_URL}/batch_predict", | |
| json={"payloads": payloads}, | |
| timeout=60 | |
| ) | |
| elapsed = time.time() - start | |
| total_time += elapsed | |
| result = resp.json() | |
| batch_preds = [1 if p['prediction'] == 'malicious' else 0 for p in result['predictions']] | |
| predictions.extend(batch_preds) | |
| true_labels.extend(labels) | |
| # Progress | |
| progress = min(i + BATCH_SIZE, len(test_df)) | |
| print(f" Processed: {progress}/{len(test_df)} ({100*progress/len(test_df):.1f}%)", end='\r') | |
| except Exception as e: | |
| print(f"\n Error at batch {i}: {e}") | |
| continue | |
| print(f"\n Total inference time: {total_time:.2f}s") | |
| print(f" Avg time per sample: {1000*total_time/len(predictions):.2f}ms") | |
| # Calculate metrics | |
| print("\n5. Calculating metrics...") | |
| accuracy = accuracy_score(true_labels, predictions) | |
| precision = precision_score(true_labels, predictions) | |
| recall = recall_score(true_labels, predictions) | |
| f1 = f1_score(true_labels, predictions) | |
| cm = confusion_matrix(true_labels, predictions) | |
| print("\n" + "=" * 80) | |
| print("RESULTS") | |
| print("=" * 80) | |
| print(f"\nSamples tested: {len(predictions)}") | |
| print(f"\nMetrics:") | |
| print(f" Accuracy: {accuracy*100:.2f}%") | |
| print(f" Precision: {precision*100:.2f}%") | |
| print(f" Recall: {recall*100:.2f}%") | |
| print(f" F1 Score: {f1*100:.2f}%") | |
| print(f"\nConfusion Matrix:") | |
| print(f" Predicted") | |
| print(f" Benign Malicious") | |
| print(f" Actual Benign {cm[0][0]:5d} {cm[0][1]:5d}") | |
| print(f" Actual Malicious {cm[1][0]:5d} {cm[1][1]:5d}") | |
| print(f"\nDetailed Report:") | |
| print(classification_report(true_labels, predictions, target_names=['Benign', 'Malicious'])) | |
| print("=" * 80) | |
| # Return results | |
| return { | |
| 'accuracy': accuracy, | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'f1': f1, | |
| 'samples': len(predictions), | |
| 'inference_time_s': total_time | |
| } | |
| if __name__ == "__main__": | |
| test_accuracy() | |