redauzhang commited on
Commit
62c3b33
·
1 Parent(s): 955416a

upload model fit for web attack payload classfication/ and model based on codebert-base/ dataset used opensource

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.onnx filter=lfs diff=lfs merge=lfs -text
37
+ model_quantized.onnx filter=lfs diff=lfs merge=lfs -text
38
+ *.ft filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -5,4 +5,361 @@ language:
5
  base_model:
6
  - microsoft/codebert-base
7
  pipeline_tag: text-classification
8
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  base_model:
6
  - microsoft/codebert-base
7
  pipeline_tag: text-classification
8
+ ---
9
+ # Web Attack Detection Model
10
+
11
+ A CodeBERT-based deep learning model for detecting malicious web requests and payloads. This model can identify SQL injection, XSS, path traversal, command injection, and other common web attack patterns.
12
+
13
+ ## Model Description
14
+
15
+ This model is fine-tuned from [microsoft/codebert-base](https://huggingface.co/microsoft/codebert-base) for binary classification of web requests as either **benign** or **malicious**.
16
+
17
+ ### Model Architecture
18
+
19
+ - **Base Model**: CodeBERT (RoBERTa-base architecture)
20
+ - **Task**: Binary Text Classification
21
+ - **Parameters**: 124.6M
22
+ - **Max Sequence Length**: 256 tokens
23
+
24
+ ### Performance Metrics
25
+
26
+ | Metric | Training Set | Test Set (125K) | 2000-Sample Test |
27
+ |--------|-------------|-----------------|------------------|
28
+ | **Accuracy** | 99.30% | 99.38% | 99.60% |
29
+ | **Precision** | - | 99.47% | 99.80% |
30
+ | **Recall** | - | 99.21% | 99.40% |
31
+ | **F1 Score** | - | 99.34% | 99.60% |
32
+
33
+ ### Confusion Matrix (Test Set)
34
+
35
+ | | Predicted Benign | Predicted Malicious |
36
+ |--|------------------|---------------------|
37
+ | **Actual Benign** | 65,914 | 312 |
38
+ | **Actual Malicious** | 464 | 58,491 |
39
+
40
+ ## Training Details
41
+
42
+ ### Dataset
43
+
44
+ - **Total Samples**: 625,904
45
+ - **Training Samples**: 500,722 (80%)
46
+ - **Test Samples**: 125,181 (20%)
47
+ - **Class Distribution**: Balanced (47% malicious, 53% benign)
48
+ - **Sampling Strategy**: Balanced sampling with WeightedRandomSampler
49
+
50
+ ### Training Configuration
51
+
52
+ | Parameter | Value |
53
+ |-----------|-------|
54
+ | Epochs | 3 |
55
+ | Batch Size | 8 |
56
+ | Gradient Accumulation Steps | 4 |
57
+ | Effective Batch Size | 32 |
58
+ | Learning Rate | 2e-5 |
59
+ | Warmup Steps | 500 |
60
+ | Weight Decay | 0.01 |
61
+ | Max Sequence Length | 256 |
62
+ | Optimizer | AdamW |
63
+
64
+ ### Training Progress
65
+
66
+ | Epoch | Train Loss | Train Acc | Test Loss | Test Acc | F1 Score |
67
+ |-------|------------|-----------|-----------|----------|----------|
68
+ | 1 | 0.0289 | 98.84% | 0.0192 | 99.09% | 0.9904 |
69
+ | 2 | 0.0201 | 99.24% | 0.0169 | 99.08% | 0.9903 |
70
+ | 3 | 0.0175 | 99.30% | 0.0274 | 99.38% | 0.9934 |
71
+
72
+ ### Hardware
73
+
74
+ - **GPU**: NVIDIA Tesla T4 (16GB)
75
+ - **Training Time**: ~24 hours
76
+
77
+ ## Model Files
78
+
79
+ | File | Size | Description |
80
+ |------|------|-------------|
81
+ | `best_model.pt` | 1.4 GB | PyTorch checkpoint (full precision) |
82
+ | `model.onnx` | 476 MB | ONNX model (full precision) |
83
+ | `model_quantized.onnx` | 120 MB | ONNX model (INT8 quantized) |
84
+
85
+ ## Usage
86
+
87
+ ### Quick Start with ONNX Runtime
88
+
89
+ ```python
90
+ import numpy as np
91
+ import onnxruntime as ort
92
+ from transformers import RobertaTokenizer
93
+
94
+ # Load tokenizer and model
95
+ tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
96
+ session = ort.InferenceSession("model_quantized.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
97
+
98
+ # Predict
99
+ def predict(payload: str) -> dict:
100
+ inputs = tokenizer(
101
+ payload,
102
+ max_length=256,
103
+ padding='max_length',
104
+ truncation=True,
105
+ return_tensors='np'
106
+ )
107
+
108
+ outputs = session.run(
109
+ None,
110
+ {
111
+ 'input_ids': inputs['input_ids'].astype(np.int64),
112
+ 'attention_mask': inputs['attention_mask'].astype(np.int64)
113
+ }
114
+ )
115
+
116
+ probs = outputs[0][0]
117
+ pred_idx = np.argmax(probs)
118
+
119
+ return {
120
+ "prediction": "malicious" if pred_idx == 1 else "benign",
121
+ "confidence": float(probs[pred_idx]),
122
+ "probabilities": {
123
+ "benign": float(probs[0]),
124
+ "malicious": float(probs[1])
125
+ }
126
+ }
127
+
128
+ # Example usage
129
+ result = predict("SELECT * FROM users WHERE id=1 OR 1=1--")
130
+ print(result)
131
+ # {'prediction': 'malicious', 'confidence': 0.9355, 'probabilities': {'benign': 0.0645, 'malicious': 0.9355}}
132
+ ```
133
+
134
+ ### Using PyTorch
135
+
136
+ ```python
137
+ import torch
138
+ import torch.nn as nn
139
+ from transformers import RobertaTokenizer, RobertaModel
140
+
141
+ class CodeBERTClassifier(nn.Module):
142
+ def __init__(self, model_path="microsoft/codebert-base", num_labels=2, dropout=0.1):
143
+ super().__init__()
144
+ self.codebert = RobertaModel.from_pretrained(model_path)
145
+ self.dropout = nn.Dropout(dropout)
146
+ self.classifier = nn.Linear(self.codebert.config.hidden_size, num_labels)
147
+
148
+ def forward(self, input_ids, attention_mask):
149
+ outputs = self.codebert(input_ids=input_ids, attention_mask=attention_mask)
150
+ pooled_output = outputs.pooler_output
151
+ pooled_output = self.dropout(pooled_output)
152
+ logits = self.classifier(pooled_output)
153
+ return logits
154
+
155
+ # Load model
156
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
157
+ model = CodeBERTClassifier()
158
+ model.load_state_dict(torch.load("best_model.pt", map_location=device))
159
+ model.eval()
160
+ model.to(device)
161
+
162
+ # Load tokenizer
163
+ tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
164
+
165
+ # Predict
166
+ def predict(payload: str) -> dict:
167
+ inputs = tokenizer(
168
+ payload,
169
+ max_length=256,
170
+ padding='max_length',
171
+ truncation=True,
172
+ return_tensors='pt'
173
+ ).to(device)
174
+
175
+ with torch.no_grad():
176
+ logits = model(inputs['input_ids'], inputs['attention_mask'])
177
+ probs = torch.softmax(logits, dim=-1)[0]
178
+
179
+ pred_idx = torch.argmax(probs).item()
180
+
181
+ return {
182
+ "prediction": "malicious" if pred_idx == 1 else "benign",
183
+ "confidence": probs[pred_idx].item()
184
+ }
185
+
186
+ # Example
187
+ result = predict("<script>alert('xss')</script>")
188
+ print(result)
189
+ # {'prediction': 'malicious', 'confidence': 0.9998}
190
+ ```
191
+
192
+ ## FastAPI Server
193
+
194
+ ### Installation
195
+
196
+ ```bash
197
+ pip install onnxruntime-gpu transformers fastapi uvicorn pydantic numpy
198
+ ```
199
+
200
+ ### Start Server
201
+
202
+ ```bash
203
+ # GPU mode (recommended)
204
+ python server_onnx.py --device gpu --quantized --port 8000
205
+
206
+ # CPU mode
207
+ python server_onnx.py --device cpu --quantized --port 8000
208
+ ```
209
+
210
+ ### API Endpoints
211
+
212
+ #### Health Check
213
+ ```bash
214
+ curl http://localhost:8000/health
215
+ ```
216
+
217
+ #### Single Prediction
218
+ ```bash
219
+ curl -X POST http://localhost:8000/predict \
220
+ -H "Content-Type: application/json" \
221
+ -d '{"payload": "SELECT * FROM users WHERE id=1 OR 1=1--"}'
222
+ ```
223
+
224
+ Response:
225
+ ```json
226
+ {
227
+ "payload": "SELECT * FROM users WHERE id=1 OR 1=1--",
228
+ "prediction": "malicious",
229
+ "confidence": 0.9355,
230
+ "probabilities": {"benign": 0.0645, "malicious": 0.9355},
231
+ "inference_time_ms": 15.23
232
+ }
233
+ ```
234
+
235
+ #### Batch Prediction
236
+ ```bash
237
+ curl -X POST http://localhost:8000/batch_predict \
238
+ -H "Content-Type: application/json" \
239
+ -d '{"payloads": ["<script>alert(1)</script>", "GET /api/users HTTP/1.1"]}'
240
+ ```
241
+
242
+ ## Docker Deployment
243
+
244
+ ### GPU Version
245
+
246
+ ```dockerfile
247
+ # Dockerfile
248
+ FROM nvidia/cuda:11.8-cudnn8-runtime-ubuntu22.04
249
+
250
+ RUN apt-get update && apt-get install -y python3 python3-pip
251
+ RUN pip3 install onnxruntime-gpu transformers fastapi uvicorn pydantic numpy
252
+
253
+ WORKDIR /app
254
+ COPY model_quantized.onnx ./models/
255
+ COPY server_onnx.py .
256
+
257
+ EXPOSE 8000
258
+ CMD ["python3", "server_onnx.py", "--device", "gpu", "--quantized"]
259
+ ```
260
+
261
+ ### CPU Version
262
+
263
+ ```dockerfile
264
+ # Dockerfile.cpu
265
+ FROM python:3.10-slim
266
+
267
+ RUN pip install onnxruntime transformers fastapi uvicorn pydantic numpy
268
+
269
+ WORKDIR /app
270
+ COPY model_quantized.onnx ./models/
271
+ COPY server_onnx.py .
272
+
273
+ EXPOSE 8000
274
+ CMD ["python", "server_onnx.py", "--device", "cpu", "--quantized"]
275
+ ```
276
+
277
+ ### Docker Compose
278
+
279
+ ```yaml
280
+ version: '3.8'
281
+ services:
282
+ web-attack-detector:
283
+ build: .
284
+ ports:
285
+ - "8000:8000"
286
+ deploy:
287
+ resources:
288
+ reservations:
289
+ devices:
290
+ - driver: nvidia
291
+ count: 1
292
+ capabilities: [gpu]
293
+ ```
294
+
295
+ ## Attack Types Detected
296
+
297
+ This model can detect various web attack patterns including:
298
+
299
+ | Attack Type | Example |
300
+ |-------------|---------|
301
+ | **SQL Injection** | `' OR '1'='1' --` |
302
+ | **Cross-Site Scripting (XSS)** | `<script>alert(document.cookie)</script>` |
303
+ | **Path Traversal** | `../../etc/passwd` |
304
+ | **Command Injection** | `; cat /etc/passwd` |
305
+ | **LDAP Injection** | `*)(uid=*))(|(uid=*` |
306
+ | **XML Injection** | `<?xml version="1.0"?><!DOCTYPE foo>` |
307
+ | **Server-Side Template Injection** | `{{7*7}}` |
308
+
309
+ ## Limitations
310
+
311
+ - The model is trained on specific attack patterns and may not detect novel or obfuscated attacks
312
+ - Maximum input length is 256 tokens; longer payloads will be truncated
313
+ - The model may have false positives on legitimate requests that resemble attack patterns
314
+ - Performance may vary on different types of web applications
315
+
316
+ ## Ethical Considerations
317
+
318
+ This model is intended for **defensive security purposes only**, including:
319
+ - Web Application Firewalls (WAF)
320
+ - Intrusion Detection Systems (IDS)
321
+ - Security monitoring and alerting
322
+ - Penetration testing and security assessments
323
+
324
+ **Do not use this model for malicious purposes.**
325
+
326
+ ## License
327
+
328
+ This model is released under the MIT License.
329
+
330
+ ## Citation
331
+
332
+ If you use this model in your research or application, please cite:
333
+
334
+ ```bibtex
335
+ @misc{web-attack-detection-codebert,
336
+ author = {Your Name},
337
+ title = {Web Attack Detection Model based on CodeBERT},
338
+ year = {2024},
339
+ publisher = {Hugging Face},
340
+ howpublished = {\url{https://huggingface.co/your-username/web-attack-detection}},
341
+ note = {Fine-tuned CodeBERT model for detecting malicious web requests}
342
+ }
343
+
344
+ @article{feng2020codebert,
345
+ title = {CodeBERT: A Pre-Trained Model for Programming and Natural Languages},
346
+ author = {Feng, Zhangyin and Guo, Daya and Tang, Duyu and Duan, Nan and Feng, Xiaocheng and Gong, Ming and Shou, Linjun and Qin, Bing and Liu, Ting and Jiang, Daxin and Zhou, Ming},
347
+ journal = {Findings of the Association for Computational Linguistics: EMNLP 2020},
348
+ year = {2020},
349
+ pages = {1536--1547},
350
+ doi = {10.18653/v1/2020.findings-emnlp.139}
351
+ }
352
+
353
+ @article{liu2019roberta,
354
+ title = {RoBERTa: A Robustly Optimized BERT Pretraining Approach},
355
+ author = {Liu, Yinhan and Ott, Myle and Goyal, Naman and Du, Jingfei and Joshi, Mandar and Chen, Danqi and Levy, Omer and Lewis, Mike and Zettlemoyer, Luke and Stoyanov, Veselin},
356
+ journal = {arXiv preprint arXiv:1907.11692},
357
+ year = {2019}
358
+ }
359
+ ```
360
+
361
+ ## Acknowledgments
362
+
363
+ - [Microsoft CodeBERT](https://github.com/microsoft/CodeBERT) for the pre-trained model
364
+ - [Hugging Face Transformers](https://huggingface.co/transformers/) for the model framework
365
+ - [ONNX Runtime](https://onnxruntime.ai/) for efficient inference
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c9bce29b63361bf2b8c0554f4a51b7161c4cccad0ac784f06d4e0a435116f3d
3
+ size 1495970402
export_onnx_quantized.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Export trained CodeBERT model to ONNX format with optional quantization.
4
+ Supports both CPU and GPU inference.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import RobertaTokenizer, RobertaModel
12
+ import json
13
+
14
+ # Paths
15
+ MODEL_PATH = "/c1/new-models/best_model.pt"
16
+ CODEBERT_PATH = "/c1/huggingface/codebert-base"
17
+ OUTPUT_DIR = "/c1/new-models"
18
+ ONNX_PATH = os.path.join(OUTPUT_DIR, "model.onnx")
19
+ ONNX_QUANTIZED_PATH = os.path.join(OUTPUT_DIR, "model_quantized.onnx")
20
+
21
+
22
+ class CodeBERTClassifier(nn.Module):
23
+ """CodeBERT-based classifier for web attack detection - matches training script."""
24
+
25
+ def __init__(self, model_path, num_labels=2, dropout=0.1):
26
+ super(CodeBERTClassifier, self).__init__()
27
+ self.codebert = RobertaModel.from_pretrained(model_path)
28
+ self.dropout = nn.Dropout(dropout)
29
+ self.classifier = nn.Linear(self.codebert.config.hidden_size, num_labels)
30
+
31
+ def forward(self, input_ids, attention_mask):
32
+ outputs = self.codebert(input_ids=input_ids, attention_mask=attention_mask)
33
+ pooled_output = outputs.pooler_output
34
+ pooled_output = self.dropout(pooled_output)
35
+ logits = self.classifier(pooled_output)
36
+ return logits
37
+
38
+
39
+ class ONNXCodeBERTClassifier(nn.Module):
40
+ """Wrapper for ONNX export with softmax output."""
41
+
42
+ def __init__(self, model):
43
+ super().__init__()
44
+ self.model = model
45
+ self.model.dropout.p = 0 # Disable dropout for inference
46
+
47
+ def forward(self, input_ids, attention_mask):
48
+ outputs = self.model.codebert(input_ids=input_ids, attention_mask=attention_mask)
49
+ pooled_output = outputs.pooler_output
50
+ logits = self.model.classifier(pooled_output)
51
+ probabilities = torch.softmax(logits, dim=-1)
52
+ return probabilities
53
+
54
+
55
+ def export_to_onnx():
56
+ """Export model to ONNX format."""
57
+ print("=" * 80)
58
+ print("ONNX Model Export")
59
+ print("=" * 80)
60
+
61
+ # Device - use CPU for export to avoid CUDA issues
62
+ device = torch.device("cpu")
63
+ print(f"Export Device: {device}")
64
+
65
+ # Load tokenizer
66
+ print("\n1. Loading tokenizer...")
67
+ tokenizer = RobertaTokenizer.from_pretrained(CODEBERT_PATH)
68
+ print(f" Tokenizer loaded: {type(tokenizer).__name__}")
69
+
70
+ # Create model with same architecture as training
71
+ print("\n2. Loading model...")
72
+ model = CodeBERTClassifier(CODEBERT_PATH)
73
+
74
+ # Load trained weights
75
+ checkpoint = torch.load(MODEL_PATH, map_location=device)
76
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
77
+ model.load_state_dict(checkpoint['model_state_dict'])
78
+ else:
79
+ model.load_state_dict(checkpoint)
80
+
81
+ model.eval()
82
+ model.to(device)
83
+ print(f" Model loaded from: {MODEL_PATH}")
84
+
85
+ # Wrap for ONNX export
86
+ onnx_model = ONNXCodeBERTClassifier(model)
87
+ onnx_model.eval()
88
+ onnx_model.to(device)
89
+
90
+ # Create dummy input
91
+ print("\n3. Creating dummy input...")
92
+ max_length = 256
93
+ dummy_text = "SELECT * FROM users WHERE id=1"
94
+ inputs = tokenizer(
95
+ dummy_text,
96
+ max_length=max_length,
97
+ padding='max_length',
98
+ truncation=True,
99
+ return_tensors='pt'
100
+ )
101
+
102
+ dummy_input_ids = inputs['input_ids'].to(device)
103
+ dummy_attention_mask = inputs['attention_mask'].to(device)
104
+ print(f" Input shape: {dummy_input_ids.shape}")
105
+
106
+ # Test forward pass first
107
+ print("\n4. Testing forward pass...")
108
+ with torch.no_grad():
109
+ test_output = onnx_model(dummy_input_ids, dummy_attention_mask)
110
+ print(f" Output shape: {test_output.shape}")
111
+ print(f" Output sample: {test_output[0].numpy()}")
112
+
113
+ # Export to ONNX
114
+ print("\n5. Exporting to ONNX...")
115
+ torch.onnx.export(
116
+ onnx_model,
117
+ (dummy_input_ids, dummy_attention_mask),
118
+ ONNX_PATH,
119
+ export_params=True,
120
+ opset_version=14,
121
+ do_constant_folding=True,
122
+ input_names=['input_ids', 'attention_mask'],
123
+ output_names=['probabilities'],
124
+ dynamic_axes={
125
+ 'input_ids': {0: 'batch_size', 1: 'sequence_length'},
126
+ 'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
127
+ 'probabilities': {0: 'batch_size'}
128
+ }
129
+ )
130
+
131
+ onnx_size = os.path.getsize(ONNX_PATH) / (1024 * 1024)
132
+ print(f" ONNX model saved: {ONNX_PATH}")
133
+ print(f" Size: {onnx_size:.2f} MB")
134
+
135
+ # Quantize model
136
+ print("\n6. Quantizing model (dynamic quantization)...")
137
+ try:
138
+ from onnxruntime.quantization import quantize_dynamic, QuantType
139
+
140
+ quantize_dynamic(
141
+ model_input=ONNX_PATH,
142
+ model_output=ONNX_QUANTIZED_PATH,
143
+ weight_type=QuantType.QUInt8,
144
+ optimize_model=True
145
+ )
146
+
147
+ quantized_size = os.path.getsize(ONNX_QUANTIZED_PATH) / (1024 * 1024)
148
+ print(f" Quantized model saved: {ONNX_QUANTIZED_PATH}")
149
+ print(f" Size: {quantized_size:.2f} MB")
150
+ print(f" Compression ratio: {onnx_size / quantized_size:.2f}x")
151
+
152
+ except Exception as e:
153
+ print(f" Warning: Quantization failed: {e}")
154
+ print(" Using non-quantized model.")
155
+ import shutil
156
+ shutil.copy(ONNX_PATH, ONNX_QUANTIZED_PATH)
157
+
158
+ # Verify ONNX model
159
+ print("\n7. Verifying ONNX model...")
160
+ try:
161
+ import onnx
162
+ onnx_check = onnx.load(ONNX_PATH)
163
+ onnx.checker.check_model(onnx_check)
164
+ print(" ONNX model verification: PASSED")
165
+ except Exception as e:
166
+ print(f" Warning: ONNX verification failed: {e}")
167
+
168
+ # Test inference with ONNX Runtime
169
+ print("\n8. Testing ONNX Runtime inference...")
170
+ try:
171
+ import onnxruntime as ort
172
+ import numpy as np
173
+
174
+ # Try GPU first, fallback to CPU
175
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
176
+ available_providers = ort.get_available_providers()
177
+ use_providers = [p for p in providers if p in available_providers]
178
+
179
+ session = ort.InferenceSession(ONNX_PATH, providers=use_providers)
180
+ actual_provider = session.get_providers()[0]
181
+ print(f" Using provider: {actual_provider}")
182
+
183
+ # Test inference
184
+ test_texts = [
185
+ "SELECT * FROM users WHERE id=1 OR 1=1", # SQL injection
186
+ "GET /index.html HTTP/1.1", # Normal request
187
+ "<script>alert('xss')</script>", # XSS
188
+ "Mozilla/5.0 (Windows NT 10.0; Win64)", # Normal UA
189
+ ]
190
+
191
+ print("\n Test predictions:")
192
+ for text in test_texts:
193
+ inputs = tokenizer(
194
+ text,
195
+ max_length=max_length,
196
+ padding='max_length',
197
+ truncation=True,
198
+ return_tensors='np'
199
+ )
200
+
201
+ outputs = session.run(
202
+ None,
203
+ {
204
+ 'input_ids': inputs['input_ids'].astype(np.int64),
205
+ 'attention_mask': inputs['attention_mask'].astype(np.int64)
206
+ }
207
+ )
208
+
209
+ probs = outputs[0][0]
210
+ pred = np.argmax(probs)
211
+ label = "Malicious" if pred == 1 else "Benign"
212
+ conf = probs[pred] * 100
213
+ print(f" - '{text[:40]:<40}' => {label:<10} ({conf:.1f}%)")
214
+
215
+ except Exception as e:
216
+ print(f" Warning: ONNX Runtime test failed: {e}")
217
+ import traceback
218
+ traceback.print_exc()
219
+
220
+ # Save export config
221
+ print("\n9. Saving export configuration...")
222
+ export_config = {
223
+ "model_path": ONNX_PATH,
224
+ "quantized_model_path": ONNX_QUANTIZED_PATH,
225
+ "max_length": max_length,
226
+ "tokenizer_path": CODEBERT_PATH,
227
+ "labels": {"0": "benign", "1": "malicious"},
228
+ "input_names": ["input_ids", "attention_mask"],
229
+ "output_names": ["probabilities"]
230
+ }
231
+
232
+ config_path = os.path.join(OUTPUT_DIR, "onnx_config.json")
233
+ with open(config_path, 'w') as f:
234
+ json.dump(export_config, f, indent=2)
235
+ print(f" Config saved: {config_path}")
236
+
237
+ print("\n" + "=" * 80)
238
+ print("Export completed!")
239
+ print("=" * 80)
240
+ print(f"ONNX Model: {ONNX_PATH} ({onnx_size:.2f} MB)")
241
+ if os.path.exists(ONNX_QUANTIZED_PATH):
242
+ qsize = os.path.getsize(ONNX_QUANTIZED_PATH) / (1024 * 1024)
243
+ print(f"Quantized Model: {ONNX_QUANTIZED_PATH} ({qsize:.2f} MB)")
244
+ print("=" * 80)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ export_to_onnx()
model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cbb6ba2f597dd1ec8cfd8baf6ed4df932c5fda6a20680dd093dc35c6480f712
3
+ size 498886238
model_quantized.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ac1755d9c22be619022d59ec863c361f44646c133d7986b22e16a616fadfaba
3
+ size 125354544
onnx_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_path": "/c1/new-models/model.onnx",
3
+ "quantized_model_path": "/c1/new-models/model_quantized.onnx",
4
+ "max_length": 256,
5
+ "tokenizer_path": "/c1/huggingface/codebert-base",
6
+ "labels": {
7
+ "0": "benign",
8
+ "1": "malicious"
9
+ },
10
+ "input_names": [
11
+ "input_ids",
12
+ "attention_mask"
13
+ ],
14
+ "output_names": [
15
+ "probabilities"
16
+ ]
17
+ }
requirements_onnx.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ONNX Runtime Inference Requirements
2
+ # For GPU: onnxruntime-gpu
3
+ # For CPU: onnxruntime
4
+
5
+ # Core
6
+ onnxruntime-gpu==1.16.3
7
+ transformers==4.35.0
8
+ tokenizers>=0.14.0
9
+
10
+ # Web framework
11
+ fastapi==0.104.1
12
+ uvicorn[standard]==0.24.0
13
+ pydantic>=2.0.0
14
+
15
+ # Utils
16
+ numpy>=1.24.0
17
+ requests>=2.31.0
requirements_onnx_cpu.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ONNX Runtime Inference Requirements (CPU only)
2
+
3
+ # Core - CPU version
4
+ onnxruntime==1.16.3
5
+ transformers==4.35.0
6
+ tokenizers>=0.14.0
7
+
8
+ # Web framework
9
+ fastapi==0.104.1
10
+ uvicorn[standard]==0.24.0
11
+ pydantic>=2.0.0
12
+
13
+ # Utils
14
+ numpy>=1.24.0
15
+ requests>=2.31.0
server_onnx.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ FastAPI server for Web Attack Detection using ONNX Runtime.
4
+ Supports both CPU and GPU inference.
5
+
6
+ Usage:
7
+ python server_onnx.py --host 0.0.0.0 --port 8000 --device gpu
8
+ python server_onnx.py --host 0.0.0.0 --port 8000 --device cpu
9
+ python server_onnx.py --quantized # Use quantized model (smaller, faster)
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import json
15
+ import time
16
+ import argparse
17
+ import numpy as np
18
+ from typing import List, Optional
19
+ from contextlib import asynccontextmanager
20
+
21
+ import onnxruntime as ort
22
+ from transformers import RobertaTokenizer
23
+ from fastapi import FastAPI, HTTPException
24
+ from fastapi.middleware.cors import CORSMiddleware
25
+ from pydantic import BaseModel, Field
26
+
27
+ # Configuration
28
+ ONNX_MODEL_PATH = "/c1/new-models/model.onnx"
29
+ ONNX_QUANTIZED_PATH = "/c1/new-models/model_quantized.onnx"
30
+ TOKENIZER_PATH = "/c1/huggingface/codebert-base"
31
+ MAX_LENGTH = 256
32
+
33
+
34
+ class PredictRequest(BaseModel):
35
+ """Single prediction request."""
36
+ payload: str = Field(..., description="The payload/request to classify")
37
+
38
+
39
+ class BatchPredictRequest(BaseModel):
40
+ """Batch prediction request."""
41
+ payloads: List[str] = Field(..., description="List of payloads to classify")
42
+
43
+
44
+ class PredictResponse(BaseModel):
45
+ """Prediction response."""
46
+ payload: str
47
+ prediction: str # "malicious" or "benign"
48
+ confidence: float
49
+ probabilities: dict
50
+ inference_time_ms: float
51
+
52
+
53
+ class BatchPredictResponse(BaseModel):
54
+ """Batch prediction response."""
55
+ predictions: List[PredictResponse]
56
+ total_inference_time_ms: float
57
+ avg_inference_time_ms: float
58
+
59
+
60
+ class HealthResponse(BaseModel):
61
+ """Health check response."""
62
+ status: str
63
+ model_loaded: bool
64
+ device: str
65
+ provider: str
66
+ model_path: str
67
+ version: str
68
+
69
+
70
+ # Global variables
71
+ tokenizer = None
72
+ ort_session = None
73
+ device_type = "cpu"
74
+ model_path = ONNX_MODEL_PATH
75
+
76
+
77
+ def load_model(use_gpu: bool = True, use_quantized: bool = False):
78
+ """Load ONNX model and tokenizer."""
79
+ global tokenizer, ort_session, device_type, model_path
80
+
81
+ print("Loading model...")
82
+
83
+ # Load tokenizer
84
+ print(f" Loading tokenizer from: {TOKENIZER_PATH}")
85
+ tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER_PATH)
86
+
87
+ # Select model
88
+ model_path = ONNX_QUANTIZED_PATH if use_quantized else ONNX_MODEL_PATH
89
+ if not os.path.exists(model_path):
90
+ model_path = ONNX_MODEL_PATH
91
+
92
+ print(f" Loading ONNX model from: {model_path}")
93
+
94
+ # Configure providers
95
+ providers = []
96
+ if use_gpu:
97
+ if 'CUDAExecutionProvider' in ort.get_available_providers():
98
+ providers.append('CUDAExecutionProvider')
99
+ device_type = "gpu"
100
+ else:
101
+ print(" Warning: CUDA not available, falling back to CPU")
102
+
103
+ providers.append('CPUExecutionProvider')
104
+ if device_type != "gpu":
105
+ device_type = "cpu"
106
+
107
+ # Create session
108
+ sess_options = ort.SessionOptions()
109
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
110
+
111
+ ort_session = ort.InferenceSession(
112
+ model_path,
113
+ sess_options=sess_options,
114
+ providers=providers
115
+ )
116
+
117
+ actual_provider = ort_session.get_providers()[0]
118
+ print(f" Model loaded successfully!")
119
+ print(f" Provider: {actual_provider}")
120
+ print(f" Device: {device_type}")
121
+
122
+ return ort_session
123
+
124
+
125
+ def predict_single(payload: str) -> dict:
126
+ """Make prediction for a single payload."""
127
+ global tokenizer, ort_session
128
+
129
+ start_time = time.time()
130
+
131
+ # Tokenize
132
+ inputs = tokenizer(
133
+ payload,
134
+ max_length=MAX_LENGTH,
135
+ padding='max_length',
136
+ truncation=True,
137
+ return_tensors='np'
138
+ )
139
+
140
+ # Run inference
141
+ outputs = ort_session.run(
142
+ None,
143
+ {
144
+ 'input_ids': inputs['input_ids'].astype(np.int64),
145
+ 'attention_mask': inputs['attention_mask'].astype(np.int64)
146
+ }
147
+ )
148
+
149
+ # Process results
150
+ probs = outputs[0][0]
151
+ pred_idx = int(np.argmax(probs))
152
+ confidence = float(probs[pred_idx])
153
+ prediction = "malicious" if pred_idx == 1 else "benign"
154
+
155
+ inference_time = (time.time() - start_time) * 1000
156
+
157
+ return {
158
+ "payload": payload[:100] + "..." if len(payload) > 100 else payload,
159
+ "prediction": prediction,
160
+ "confidence": round(confidence, 4),
161
+ "probabilities": {
162
+ "benign": round(float(probs[0]), 4),
163
+ "malicious": round(float(probs[1]), 4)
164
+ },
165
+ "inference_time_ms": round(inference_time, 2)
166
+ }
167
+
168
+
169
+ def predict_batch(payloads: List[str]) -> dict:
170
+ """Make predictions for a batch of payloads."""
171
+ global tokenizer, ort_session
172
+
173
+ start_time = time.time()
174
+
175
+ # Tokenize batch
176
+ inputs = tokenizer(
177
+ payloads,
178
+ max_length=MAX_LENGTH,
179
+ padding='max_length',
180
+ truncation=True,
181
+ return_tensors='np'
182
+ )
183
+
184
+ # Run inference
185
+ outputs = ort_session.run(
186
+ None,
187
+ {
188
+ 'input_ids': inputs['input_ids'].astype(np.int64),
189
+ 'attention_mask': inputs['attention_mask'].astype(np.int64)
190
+ }
191
+ )
192
+
193
+ total_time = (time.time() - start_time) * 1000
194
+
195
+ # Process results
196
+ predictions = []
197
+ probs_batch = outputs[0]
198
+
199
+ for i, (payload, probs) in enumerate(zip(payloads, probs_batch)):
200
+ pred_idx = int(np.argmax(probs))
201
+ confidence = float(probs[pred_idx])
202
+ prediction = "malicious" if pred_idx == 1 else "benign"
203
+
204
+ predictions.append({
205
+ "payload": payload[:100] + "..." if len(payload) > 100 else payload,
206
+ "prediction": prediction,
207
+ "confidence": round(confidence, 4),
208
+ "probabilities": {
209
+ "benign": round(float(probs[0]), 4),
210
+ "malicious": round(float(probs[1]), 4)
211
+ },
212
+ "inference_time_ms": round(total_time / len(payloads), 2)
213
+ })
214
+
215
+ return {
216
+ "predictions": predictions,
217
+ "total_inference_time_ms": round(total_time, 2),
218
+ "avg_inference_time_ms": round(total_time / len(payloads), 2)
219
+ }
220
+
221
+
222
+ # Startup/shutdown events
223
+ @asynccontextmanager
224
+ async def lifespan(app: FastAPI):
225
+ # Load model on startup
226
+ use_gpu = getattr(app.state, 'use_gpu', True)
227
+ use_quantized = getattr(app.state, 'use_quantized', False)
228
+ load_model(use_gpu=use_gpu, use_quantized=use_quantized)
229
+ yield
230
+ # Cleanup on shutdown
231
+ print("Shutting down...")
232
+
233
+
234
+ # Create FastAPI app
235
+ app = FastAPI(
236
+ title="Web Attack Detection API",
237
+ description="CodeBERT-based web attack detection using ONNX Runtime. Supports CPU and GPU inference.",
238
+ version="2.0.0",
239
+ lifespan=lifespan
240
+ )
241
+
242
+ # Add CORS middleware
243
+ app.add_middleware(
244
+ CORSMiddleware,
245
+ allow_origins=["*"],
246
+ allow_credentials=True,
247
+ allow_methods=["*"],
248
+ allow_headers=["*"],
249
+ )
250
+
251
+
252
+ @app.get("/", response_model=dict)
253
+ async def root():
254
+ """API root endpoint."""
255
+ return {
256
+ "name": "Web Attack Detection API",
257
+ "version": "2.0.0",
258
+ "model": "CodeBERT + ONNX Runtime",
259
+ "endpoints": {
260
+ "/predict": "POST - Single payload prediction",
261
+ "/batch_predict": "POST - Batch payload prediction",
262
+ "/health": "GET - Health check"
263
+ }
264
+ }
265
+
266
+
267
+ @app.get("/health", response_model=HealthResponse)
268
+ async def health():
269
+ """Health check endpoint."""
270
+ return {
271
+ "status": "healthy" if ort_session is not None else "unhealthy",
272
+ "model_loaded": ort_session is not None,
273
+ "device": device_type,
274
+ "provider": ort_session.get_providers()[0] if ort_session else "none",
275
+ "model_path": model_path,
276
+ "version": "2.0.0"
277
+ }
278
+
279
+
280
+ @app.post("/predict", response_model=PredictResponse)
281
+ async def predict(request: PredictRequest):
282
+ """
283
+ Predict if a single payload is malicious or benign.
284
+
285
+ - **payload**: The HTTP request/payload string to analyze
286
+ """
287
+ if not ort_session:
288
+ raise HTTPException(status_code=503, detail="Model not loaded")
289
+
290
+ try:
291
+ result = predict_single(request.payload)
292
+ return result
293
+ except Exception as e:
294
+ raise HTTPException(status_code=500, detail=str(e))
295
+
296
+
297
+ @app.post("/batch_predict", response_model=BatchPredictResponse)
298
+ async def batch_predict(request: BatchPredictRequest):
299
+ """
300
+ Predict if multiple payloads are malicious or benign.
301
+
302
+ - **payloads**: List of HTTP request/payload strings to analyze
303
+ """
304
+ if not ort_session:
305
+ raise HTTPException(status_code=503, detail="Model not loaded")
306
+
307
+ if len(request.payloads) == 0:
308
+ raise HTTPException(status_code=400, detail="Empty payload list")
309
+
310
+ if len(request.payloads) > 100:
311
+ raise HTTPException(status_code=400, detail="Maximum batch size is 100")
312
+
313
+ try:
314
+ result = predict_batch(request.payloads)
315
+ return result
316
+ except Exception as e:
317
+ raise HTTPException(status_code=500, detail=str(e))
318
+
319
+
320
+ def main():
321
+ """Main entry point."""
322
+ parser = argparse.ArgumentParser(description="Web Attack Detection API Server")
323
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
324
+ parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
325
+ parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu"],
326
+ help="Device to use for inference")
327
+ parser.add_argument("--quantized", action="store_true",
328
+ help="Use quantized model (smaller, potentially faster)")
329
+ parser.add_argument("--workers", type=int, default=1, help="Number of workers")
330
+
331
+ args = parser.parse_args()
332
+
333
+ # Store config in app state
334
+ app.state.use_gpu = (args.device == "gpu")
335
+ app.state.use_quantized = args.quantized
336
+
337
+ print("=" * 60)
338
+ print("Web Attack Detection API Server")
339
+ print("=" * 60)
340
+ print(f"Host: {args.host}")
341
+ print(f"Port: {args.port}")
342
+ print(f"Device: {args.device}")
343
+ print(f"Quantized: {args.quantized}")
344
+ print("=" * 60)
345
+
346
+ import uvicorn
347
+ uvicorn.run(
348
+ app,
349
+ host=args.host,
350
+ port=args.port,
351
+ workers=args.workers,
352
+ log_level="info"
353
+ )
354
+
355
+
356
+ if __name__ == "__main__":
357
+ main()
test_onnx_accuracy.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test ONNX model accuracy with 2000 samples from the dataset.
4
+ """
5
+
6
+ import os
7
+ import pandas as pd
8
+ import numpy as np
9
+ import requests
10
+ import time
11
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
12
+
13
+ # Configuration
14
+ API_URL = "http://localhost:8001"
15
+ DATASET_PATH = "/c1/web-attack-detection/dataset.csv"
16
+ NUM_SAMPLES = 2000 # 1000 malicious + 1000 benign
17
+ BATCH_SIZE = 50
18
+
19
+
20
+ def test_accuracy():
21
+ print("=" * 80)
22
+ print("ONNX Model Accuracy Test")
23
+ print("=" * 80)
24
+
25
+ # Check API health
26
+ print("\n1. Checking API health...")
27
+ try:
28
+ resp = requests.get(f"{API_URL}/health", timeout=10)
29
+ health = resp.json()
30
+ print(f" Status: {health['status']}")
31
+ print(f" Device: {health['device']}")
32
+ print(f" Provider: {health['provider']}")
33
+ print(f" Model: {health['model_path']}")
34
+ except Exception as e:
35
+ print(f" Error: {e}")
36
+ print(" Please ensure the server is running!")
37
+ return
38
+
39
+ # Load dataset
40
+ print("\n2. Loading dataset...")
41
+ df = pd.read_csv(DATASET_PATH)
42
+ df = df.dropna(subset=['Sentence', 'Label'])
43
+ df['Sentence'] = df['Sentence'].astype(str)
44
+ df['Label'] = df['Label'].astype(int)
45
+ print(f" Total samples: {len(df)}")
46
+
47
+ # Sample data
48
+ print("\n3. Sampling test data...")
49
+ samples_per_class = NUM_SAMPLES // 2
50
+
51
+ benign_samples = df[df['Label'] == 0].sample(n=min(samples_per_class, len(df[df['Label'] == 0])), random_state=42)
52
+ malicious_samples = df[df['Label'] == 1].sample(n=min(samples_per_class, len(df[df['Label'] == 1])), random_state=42)
53
+
54
+ test_df = pd.concat([benign_samples, malicious_samples]).sample(frac=1, random_state=42).reset_index(drop=True)
55
+ print(f" Test samples: {len(test_df)}")
56
+ print(f" Benign: {len(test_df[test_df['Label'] == 0])}")
57
+ print(f" Malicious: {len(test_df[test_df['Label'] == 1])}")
58
+
59
+ # Run predictions
60
+ print("\n4. Running predictions...")
61
+ predictions = []
62
+ true_labels = []
63
+ total_time = 0
64
+
65
+ for i in range(0, len(test_df), BATCH_SIZE):
66
+ batch = test_df.iloc[i:i+BATCH_SIZE]
67
+ payloads = batch['Sentence'].tolist()
68
+ labels = batch['Label'].tolist()
69
+
70
+ try:
71
+ start = time.time()
72
+ resp = requests.post(
73
+ f"{API_URL}/batch_predict",
74
+ json={"payloads": payloads},
75
+ timeout=60
76
+ )
77
+ elapsed = time.time() - start
78
+ total_time += elapsed
79
+
80
+ result = resp.json()
81
+ batch_preds = [1 if p['prediction'] == 'malicious' else 0 for p in result['predictions']]
82
+ predictions.extend(batch_preds)
83
+ true_labels.extend(labels)
84
+
85
+ # Progress
86
+ progress = min(i + BATCH_SIZE, len(test_df))
87
+ print(f" Processed: {progress}/{len(test_df)} ({100*progress/len(test_df):.1f}%)", end='\r')
88
+
89
+ except Exception as e:
90
+ print(f"\n Error at batch {i}: {e}")
91
+ continue
92
+
93
+ print(f"\n Total inference time: {total_time:.2f}s")
94
+ print(f" Avg time per sample: {1000*total_time/len(predictions):.2f}ms")
95
+
96
+ # Calculate metrics
97
+ print("\n5. Calculating metrics...")
98
+ accuracy = accuracy_score(true_labels, predictions)
99
+ precision = precision_score(true_labels, predictions)
100
+ recall = recall_score(true_labels, predictions)
101
+ f1 = f1_score(true_labels, predictions)
102
+ cm = confusion_matrix(true_labels, predictions)
103
+
104
+ print("\n" + "=" * 80)
105
+ print("RESULTS")
106
+ print("=" * 80)
107
+ print(f"\nSamples tested: {len(predictions)}")
108
+ print(f"\nMetrics:")
109
+ print(f" Accuracy: {accuracy*100:.2f}%")
110
+ print(f" Precision: {precision*100:.2f}%")
111
+ print(f" Recall: {recall*100:.2f}%")
112
+ print(f" F1 Score: {f1*100:.2f}%")
113
+
114
+ print(f"\nConfusion Matrix:")
115
+ print(f" Predicted")
116
+ print(f" Benign Malicious")
117
+ print(f" Actual Benign {cm[0][0]:5d} {cm[0][1]:5d}")
118
+ print(f" Actual Malicious {cm[1][0]:5d} {cm[1][1]:5d}")
119
+
120
+ print(f"\nDetailed Report:")
121
+ print(classification_report(true_labels, predictions, target_names=['Benign', 'Malicious']))
122
+
123
+ print("=" * 80)
124
+
125
+ # Return results
126
+ return {
127
+ 'accuracy': accuracy,
128
+ 'precision': precision,
129
+ 'recall': recall,
130
+ 'f1': f1,
131
+ 'samples': len(predictions),
132
+ 'inference_time_s': total_time
133
+ }
134
+
135
+
136
+ if __name__ == "__main__":
137
+ test_accuracy()
tokenizer_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "model_name": "/c1/huggingface/codebert-base",
3
+ "max_length": 256
4
+ }
train_new_model.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Train CodeBERT-based model for web attack detection
4
+ Dataset: /c1/web-attack-detection/dataset.csv
5
+ Output: /c1/new-models/
6
+ """
7
+
8
+ import os
9
+ import pandas as pd
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.utils.data import Dataset, DataLoader
14
+ from transformers import RobertaTokenizer, RobertaModel, AdamW, get_linear_schedule_with_warmup
15
+ from sklearn.model_selection import train_test_split
16
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix
17
+ from tqdm import tqdm
18
+ import json
19
+ import random
20
+ from collections import Counter
21
+
22
+ # Set random seeds for reproducibility
23
+ def set_seed(seed=42):
24
+ random.seed(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed_all(seed)
28
+
29
+ set_seed(42)
30
+
31
+ # Configuration
32
+ class Config:
33
+ # Paths
34
+ data_path = "/c1/web-attack-detection/dataset.csv"
35
+ model_base_path = "/c1/huggingface/codebert-base"
36
+ output_dir = "/c1/new-models"
37
+
38
+ # Training parameters
39
+ max_length = 256 # Reduced from 512
40
+ batch_size = 8 # Reduced from 32
41
+ gradient_accumulation_steps = 4 # Effective batch size = 8 * 4 = 32
42
+ epochs = 3
43
+ learning_rate = 2e-5
44
+ warmup_steps = 500
45
+ weight_decay = 0.01
46
+
47
+ # Data split
48
+ train_size = 0.8
49
+ test_size = 0.2
50
+
51
+ # Sampling strategy
52
+ use_sampling = True # Enable sampling
53
+ sampling_strategy = "balanced" # Options: "balanced", "oversample", "undersample", "none"
54
+
55
+ # GPU
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+
58
+ # Early stopping
59
+ early_stopping_patience = 2
60
+
61
+ config = Config()
62
+
63
+ print("="*80)
64
+ print("Web Attack Detection Model Training")
65
+ print("="*80)
66
+ print(f"Device: {config.device}")
67
+ print(f"Data path: {config.data_path}")
68
+ print(f"Model base: {config.model_base_path}")
69
+ print(f"Output dir: {config.output_dir}")
70
+ print(f"Sampling strategy: {config.sampling_strategy}")
71
+ print("="*80)
72
+
73
+ # Create output directory
74
+ os.makedirs(config.output_dir, exist_ok=True)
75
+
76
+ # Load data
77
+ print("\n1. Loading dataset...")
78
+ df = pd.read_csv(config.data_path)
79
+ print(f"Total samples: {len(df)}")
80
+ print(f"\nLabel distribution:")
81
+ print(df['Label'].value_counts())
82
+ print(f"\nLabel proportions:")
83
+ print(df['Label'].value_counts(normalize=True))
84
+
85
+ # Clean data
86
+ print("\n2. Cleaning data...")
87
+ df = df.dropna(subset=['Sentence', 'Label'])
88
+ df['Sentence'] = df['Sentence'].astype(str)
89
+ df['Label'] = df['Label'].astype(int)
90
+ print(f"Samples after cleaning: {len(df)}")
91
+
92
+ # Split data
93
+ print("\n3. Splitting data (80% train, 20% test)...")
94
+ train_df, test_df = train_test_split(
95
+ df,
96
+ test_size=config.test_size,
97
+ random_state=42,
98
+ stratify=df['Label']
99
+ )
100
+
101
+ print(f"Train samples: {len(train_df)}")
102
+ print(f"Test samples: {len(test_df)}")
103
+ print(f"\nTrain label distribution:")
104
+ print(train_df['Label'].value_counts())
105
+ print(f"\nTest label distribution:")
106
+ print(test_df['Label'].value_counts())
107
+
108
+ # Apply sampling strategy
109
+ def apply_sampling(df, strategy="balanced"):
110
+ """Apply sampling strategy to balance dataset"""
111
+ if strategy == "none":
112
+ return df
113
+
114
+ label_counts = df['Label'].value_counts()
115
+ print(f"\nOriginal distribution: {dict(label_counts)}")
116
+
117
+ if strategy == "balanced":
118
+ # Balanced: make both classes equal to average
119
+ target_count = int(label_counts.mean())
120
+ print(f"Target count per class: {target_count}")
121
+
122
+ elif strategy == "oversample":
123
+ # Oversample minority to match majority
124
+ target_count = label_counts.max()
125
+ print(f"Target count per class (oversample): {target_count}")
126
+
127
+ elif strategy == "undersample":
128
+ # Undersample majority to match minority
129
+ target_count = label_counts.min()
130
+ print(f"Target count per class (undersample): {target_count}")
131
+
132
+ balanced_dfs = []
133
+ for label in [0, 1]:
134
+ label_df = df[df['Label'] == label]
135
+ current_count = len(label_df)
136
+
137
+ if current_count < target_count:
138
+ # Oversample
139
+ sampled = label_df.sample(n=target_count, replace=True, random_state=42)
140
+ elif current_count > target_count:
141
+ # Undersample
142
+ sampled = label_df.sample(n=target_count, replace=False, random_state=42)
143
+ else:
144
+ sampled = label_df
145
+
146
+ balanced_dfs.append(sampled)
147
+
148
+ balanced_df = pd.concat(balanced_dfs, ignore_index=True)
149
+ balanced_df = balanced_df.sample(frac=1, random_state=42).reset_index(drop=True) # Shuffle
150
+
151
+ print(f"After sampling: {dict(balanced_df['Label'].value_counts())}")
152
+ return balanced_df
153
+
154
+ if config.use_sampling:
155
+ print(f"\n4. Applying sampling strategy: {config.sampling_strategy}...")
156
+ train_df = apply_sampling(train_df, config.sampling_strategy)
157
+ print(f"Final train samples: {len(train_df)}")
158
+ else:
159
+ print("\n4. Skipping sampling (using original distribution)...")
160
+
161
+ # Load tokenizer
162
+ print("\n5. Loading CodeBERT tokenizer...")
163
+ tokenizer = RobertaTokenizer.from_pretrained(config.model_base_path)
164
+ print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")
165
+
166
+ # Dataset class
167
+ class WebAttackDataset(Dataset):
168
+ def __init__(self, dataframe, tokenizer, max_length):
169
+ self.data = dataframe.reset_index(drop=True)
170
+ self.tokenizer = tokenizer
171
+ self.max_length = max_length
172
+
173
+ def __len__(self):
174
+ return len(self.data)
175
+
176
+ def __getitem__(self, idx):
177
+ text = str(self.data.loc[idx, 'Sentence'])
178
+ label = int(self.data.loc[idx, 'Label'])
179
+
180
+ encoding = self.tokenizer(
181
+ text,
182
+ add_special_tokens=True,
183
+ max_length=self.max_length,
184
+ padding='max_length',
185
+ truncation=True,
186
+ return_tensors='pt'
187
+ )
188
+
189
+ return {
190
+ 'input_ids': encoding['input_ids'].flatten(),
191
+ 'attention_mask': encoding['attention_mask'].flatten(),
192
+ 'label': torch.tensor(label, dtype=torch.long)
193
+ }
194
+
195
+ # Create datasets
196
+ print("\n6. Creating datasets...")
197
+ train_dataset = WebAttackDataset(train_df, tokenizer, config.max_length)
198
+ test_dataset = WebAttackDataset(test_df, tokenizer, config.max_length)
199
+
200
+ train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
201
+ test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
202
+
203
+ print(f"Train batches: {len(train_loader)}")
204
+ print(f"Test batches: {len(test_loader)}")
205
+
206
+ # Model class
207
+ class CodeBERTClassifier(nn.Module):
208
+ def __init__(self, model_path, num_labels=2, dropout=0.1):
209
+ super(CodeBERTClassifier, self).__init__()
210
+ self.codebert = RobertaModel.from_pretrained(model_path)
211
+ self.dropout = nn.Dropout(dropout)
212
+ self.classifier = nn.Linear(self.codebert.config.hidden_size, num_labels)
213
+
214
+ def forward(self, input_ids, attention_mask):
215
+ outputs = self.codebert(input_ids=input_ids, attention_mask=attention_mask)
216
+ pooled_output = outputs.pooler_output
217
+ pooled_output = self.dropout(pooled_output)
218
+ logits = self.classifier(pooled_output)
219
+ return logits
220
+
221
+ # Load model
222
+ print("\n7. Loading CodeBERT model...")
223
+ model = CodeBERTClassifier(config.model_base_path)
224
+ model.to(config.device)
225
+ print(f"Model loaded and moved to {config.device}")
226
+
227
+ # Count parameters
228
+ total_params = sum(p.numel() for p in model.parameters())
229
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
230
+ print(f"Total parameters: {total_params:,}")
231
+ print(f"Trainable parameters: {trainable_params:,}")
232
+
233
+ # Optimizer and scheduler
234
+ optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
235
+ total_steps = len(train_loader) * config.epochs
236
+ scheduler = get_linear_schedule_with_warmup(
237
+ optimizer,
238
+ num_warmup_steps=config.warmup_steps,
239
+ num_training_steps=total_steps
240
+ )
241
+
242
+ criterion = nn.CrossEntropyLoss()
243
+
244
+ # Training function
245
+ def train_epoch(model, dataloader, optimizer, scheduler, criterion, device, gradient_accumulation_steps=4):
246
+ model.train()
247
+ total_loss = 0
248
+ predictions = []
249
+ true_labels = []
250
+
251
+ optimizer.zero_grad()
252
+
253
+ progress_bar = tqdm(dataloader, desc="Training")
254
+ for idx, batch in enumerate(progress_bar):
255
+ input_ids = batch['input_ids'].to(device)
256
+ attention_mask = batch['attention_mask'].to(device)
257
+ labels = batch['label'].to(device)
258
+
259
+ logits = model(input_ids, attention_mask)
260
+ loss = criterion(logits, labels)
261
+ loss = loss / gradient_accumulation_steps # Normalize loss
262
+
263
+ loss.backward()
264
+
265
+ if (idx + 1) % gradient_accumulation_steps == 0:
266
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
267
+ optimizer.step()
268
+ scheduler.step()
269
+ optimizer.zero_grad()
270
+
271
+ total_loss += loss.item() * gradient_accumulation_steps
272
+
273
+ preds = torch.argmax(logits, dim=1)
274
+ predictions.extend(preds.cpu().numpy())
275
+ true_labels.extend(labels.cpu().numpy())
276
+
277
+ progress_bar.set_postfix({'loss': loss.item() * gradient_accumulation_steps})
278
+
279
+ avg_loss = total_loss / len(dataloader)
280
+ accuracy = accuracy_score(true_labels, predictions)
281
+
282
+ return avg_loss, accuracy
283
+
284
+ # Evaluation function
285
+ def evaluate(model, dataloader, criterion, device):
286
+ model.eval()
287
+ total_loss = 0
288
+ predictions = []
289
+ true_labels = []
290
+
291
+ with torch.no_grad():
292
+ for batch in tqdm(dataloader, desc="Evaluating"):
293
+ input_ids = batch['input_ids'].to(device)
294
+ attention_mask = batch['attention_mask'].to(device)
295
+ labels = batch['label'].to(device)
296
+
297
+ logits = model(input_ids, attention_mask)
298
+ loss = criterion(logits, labels)
299
+
300
+ total_loss += loss.item()
301
+
302
+ preds = torch.argmax(logits, dim=1)
303
+ predictions.extend(preds.cpu().numpy())
304
+ true_labels.extend(labels.cpu().numpy())
305
+
306
+ avg_loss = total_loss / len(dataloader)
307
+ accuracy = accuracy_score(true_labels, predictions)
308
+ precision, recall, f1, _ = precision_recall_fscore_support(
309
+ true_labels, predictions, average='binary'
310
+ )
311
+
312
+ return avg_loss, accuracy, precision, recall, f1, predictions, true_labels
313
+
314
+ # Training loop
315
+ print("\n8. Starting training...")
316
+ print("="*80)
317
+
318
+ best_accuracy = 0
319
+ best_f1 = 0
320
+ patience_counter = 0
321
+ training_history = []
322
+
323
+ for epoch in range(config.epochs):
324
+ print(f"\nEpoch {epoch + 1}/{config.epochs}")
325
+ print("-" * 80)
326
+
327
+ # Train
328
+ train_loss, train_acc = train_epoch(
329
+ model, train_loader, optimizer, scheduler, criterion, config.device, config.gradient_accumulation_steps
330
+ )
331
+
332
+ # Evaluate
333
+ test_loss, test_acc, test_precision, test_recall, test_f1, predictions, true_labels = evaluate(
334
+ model, test_loader, criterion, config.device
335
+ )
336
+
337
+ # Log results
338
+ print(f"\nTrain Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
339
+ print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
340
+ print(f"Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1: {test_f1:.4f}")
341
+
342
+ # Save history
343
+ history = {
344
+ 'epoch': epoch + 1,
345
+ 'train_loss': train_loss,
346
+ 'train_acc': train_acc,
347
+ 'test_loss': test_loss,
348
+ 'test_acc': test_acc,
349
+ 'precision': test_precision,
350
+ 'recall': test_recall,
351
+ 'f1': test_f1
352
+ }
353
+ training_history.append(history)
354
+
355
+ # Save best model
356
+ if test_f1 > best_f1:
357
+ best_f1 = test_f1
358
+ best_accuracy = test_acc
359
+ patience_counter = 0
360
+
361
+ # Save PyTorch model
362
+ model_save_path = os.path.join(config.output_dir, 'best_model.pt')
363
+ torch.save({
364
+ 'epoch': epoch + 1,
365
+ 'model_state_dict': model.state_dict(),
366
+ 'optimizer_state_dict': optimizer.state_dict(),
367
+ 'test_acc': test_acc,
368
+ 'test_f1': test_f1,
369
+ 'config': vars(config)
370
+ }, model_save_path)
371
+ print(f"\n✓ Best model saved! (F1: {test_f1:.4f})")
372
+ else:
373
+ patience_counter += 1
374
+ print(f"\nNo improvement. Patience: {patience_counter}/{config.early_stopping_patience}")
375
+
376
+ # Early stopping
377
+ if patience_counter >= config.early_stopping_patience:
378
+ print(f"\nEarly stopping triggered after {epoch + 1} epochs")
379
+ break
380
+
381
+ print("\n" + "="*80)
382
+ print("Training completed!")
383
+ print("="*80)
384
+
385
+ # Final evaluation
386
+ print("\n9. Final evaluation on test set...")
387
+ test_loss, test_acc, test_precision, test_recall, test_f1, predictions, true_labels = evaluate(
388
+ model, test_loader, criterion, config.device
389
+ )
390
+
391
+ print(f"\nFinal Test Results:")
392
+ print(f"Accuracy: {test_acc:.4f}")
393
+ print(f"Precision: {test_precision:.4f}")
394
+ print(f"Recall: {test_recall:.4f}")
395
+ print(f"F1 Score: {test_f1:.4f}")
396
+
397
+ # Classification report
398
+ print("\nClassification Report:")
399
+ print(classification_report(true_labels, predictions, target_names=['Benign', 'Malicious']))
400
+
401
+ # Confusion matrix
402
+ cm = confusion_matrix(true_labels, predictions)
403
+ print("\nConfusion Matrix:")
404
+ print(cm)
405
+ print(f"True Negatives: {cm[0][0]}")
406
+ print(f"False Positives: {cm[0][1]}")
407
+ print(f"False Negatives: {cm[1][0]}")
408
+ print(f"True Positives: {cm[1][1]}")
409
+
410
+ # Save results
411
+ results = {
412
+ 'final_metrics': {
413
+ 'accuracy': float(test_acc),
414
+ 'precision': float(test_precision),
415
+ 'recall': float(test_recall),
416
+ 'f1_score': float(test_f1)
417
+ },
418
+ 'confusion_matrix': cm.tolist(),
419
+ 'training_history': training_history,
420
+ 'config': {
421
+ 'epochs': config.epochs,
422
+ 'batch_size': config.batch_size,
423
+ 'learning_rate': config.learning_rate,
424
+ 'max_length': config.max_length,
425
+ 'sampling_strategy': config.sampling_strategy,
426
+ 'train_samples': len(train_df),
427
+ 'test_samples': len(test_df)
428
+ }
429
+ }
430
+
431
+ results_path = os.path.join(config.output_dir, 'training_results.json')
432
+ with open(results_path, 'w') as f:
433
+ json.dump(results, f, indent=2)
434
+ print(f"\nResults saved to: {results_path}")
435
+
436
+ # Save tokenizer config
437
+ tokenizer_config = {
438
+ 'model_name': config.model_base_path,
439
+ 'max_length': config.max_length
440
+ }
441
+ tokenizer_config_path = os.path.join(config.output_dir, 'tokenizer_config.json')
442
+ with open(tokenizer_config_path, 'w') as f:
443
+ json.dump(tokenizer_config, f, indent=2)
444
+ print(f"Tokenizer config saved to: {tokenizer_config_path}")
445
+
446
+ print("\n" + "="*80)
447
+ print("Training script completed successfully!")
448
+ print(f"Best F1 Score: {best_f1:.4f}")
449
+ print(f"Best Accuracy: {best_accuracy:.4f}")
450
+ print(f"Model saved to: {config.output_dir}")
451
+ print("="*80)
training_results.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "final_metrics": {
3
+ "accuracy": 0.993800976186482,
4
+ "precision": 0.9946941482577419,
5
+ "recall": 0.992129590365533,
6
+ "f1_score": 0.9934102141680395
7
+ },
8
+ "confusion_matrix": [
9
+ [
10
+ 65914,
11
+ 312
12
+ ],
13
+ [
14
+ 464,
15
+ 58491
16
+ ]
17
+ ],
18
+ "training_history": [
19
+ {
20
+ "epoch": 1,
21
+ "train_loss": 0.028937281491438017,
22
+ "train_acc": 0.9884227175957917,
23
+ "test_loss": 0.01923593317120965,
24
+ "test_acc": 0.9908692213674599,
25
+ "precision": 0.9826998864471311,
26
+ "recall": 0.9981850563989484,
27
+ "f1": 0.9903819453209806
28
+ },
29
+ {
30
+ "epoch": 2,
31
+ "train_loss": 0.020080875358571924,
32
+ "train_acc": 0.992353042207053,
33
+ "test_loss": 0.01691129035232433,
34
+ "test_acc": 0.9908053139054649,
35
+ "precision": 0.9840881682969316,
36
+ "recall": 0.9965906199643796,
37
+ "f1": 0.9902999351081672
38
+ },
39
+ {
40
+ "epoch": 3,
41
+ "train_loss": 0.01749216435263267,
42
+ "train_acc": 0.9929661568694804,
43
+ "test_loss": 0.02742091244277926,
44
+ "test_acc": 0.993800976186482,
45
+ "precision": 0.9946941482577419,
46
+ "recall": 0.992129590365533,
47
+ "f1": 0.9934102141680395
48
+ }
49
+ ],
50
+ "config": {
51
+ "epochs": 3,
52
+ "batch_size": 8,
53
+ "learning_rate": 2e-05,
54
+ "max_length": 256,
55
+ "sampling_strategy": "balanced",
56
+ "train_samples": 500722,
57
+ "test_samples": 125181
58
+ }
59
+ }