File size: 10,292 Bytes
62c3b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
#!/usr/bin/env python3
"""
FastAPI server for Web Attack Detection using ONNX Runtime.
Supports both CPU and GPU inference.

Usage:
    python server_onnx.py --host 0.0.0.0 --port 8000 --device gpu
    python server_onnx.py --host 0.0.0.0 --port 8000 --device cpu
    python server_onnx.py --quantized  # Use quantized model (smaller, faster)
"""

import os
import sys
import json
import time
import argparse
import numpy as np
from typing import List, Optional
from contextlib import asynccontextmanager

import onnxruntime as ort
from transformers import RobertaTokenizer
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field

# Configuration
ONNX_MODEL_PATH = "/c1/new-models/model.onnx"
ONNX_QUANTIZED_PATH = "/c1/new-models/model_quantized.onnx"
TOKENIZER_PATH = "/c1/huggingface/codebert-base"
MAX_LENGTH = 256


class PredictRequest(BaseModel):
    """Single prediction request."""
    payload: str = Field(..., description="The payload/request to classify")


class BatchPredictRequest(BaseModel):
    """Batch prediction request."""
    payloads: List[str] = Field(..., description="List of payloads to classify")


class PredictResponse(BaseModel):
    """Prediction response."""
    payload: str
    prediction: str  # "malicious" or "benign"
    confidence: float
    probabilities: dict
    inference_time_ms: float


class BatchPredictResponse(BaseModel):
    """Batch prediction response."""
    predictions: List[PredictResponse]
    total_inference_time_ms: float
    avg_inference_time_ms: float


class HealthResponse(BaseModel):
    """Health check response."""
    status: str
    model_loaded: bool
    device: str
    provider: str
    model_path: str
    version: str


# Global variables
tokenizer = None
ort_session = None
device_type = "cpu"
model_path = ONNX_MODEL_PATH


def load_model(use_gpu: bool = True, use_quantized: bool = False):
    """Load ONNX model and tokenizer."""
    global tokenizer, ort_session, device_type, model_path
    
    print("Loading model...")
    
    # Load tokenizer
    print(f"  Loading tokenizer from: {TOKENIZER_PATH}")
    tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER_PATH)
    
    # Select model
    model_path = ONNX_QUANTIZED_PATH if use_quantized else ONNX_MODEL_PATH
    if not os.path.exists(model_path):
        model_path = ONNX_MODEL_PATH
    
    print(f"  Loading ONNX model from: {model_path}")
    
    # Configure providers
    providers = []
    if use_gpu:
        if 'CUDAExecutionProvider' in ort.get_available_providers():
            providers.append('CUDAExecutionProvider')
            device_type = "gpu"
        else:
            print("  Warning: CUDA not available, falling back to CPU")
    
    providers.append('CPUExecutionProvider')
    if device_type != "gpu":
        device_type = "cpu"
    
    # Create session
    sess_options = ort.SessionOptions()
    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    
    ort_session = ort.InferenceSession(
        model_path,
        sess_options=sess_options,
        providers=providers
    )
    
    actual_provider = ort_session.get_providers()[0]
    print(f"  Model loaded successfully!")
    print(f"  Provider: {actual_provider}")
    print(f"  Device: {device_type}")
    
    return ort_session


def predict_single(payload: str) -> dict:
    """Make prediction for a single payload."""
    global tokenizer, ort_session
    
    start_time = time.time()
    
    # Tokenize
    inputs = tokenizer(
        payload,
        max_length=MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='np'
    )
    
    # Run inference
    outputs = ort_session.run(
        None,
        {
            'input_ids': inputs['input_ids'].astype(np.int64),
            'attention_mask': inputs['attention_mask'].astype(np.int64)
        }
    )
    
    # Process results
    probs = outputs[0][0]
    pred_idx = int(np.argmax(probs))
    confidence = float(probs[pred_idx])
    prediction = "malicious" if pred_idx == 1 else "benign"
    
    inference_time = (time.time() - start_time) * 1000
    
    return {
        "payload": payload[:100] + "..." if len(payload) > 100 else payload,
        "prediction": prediction,
        "confidence": round(confidence, 4),
        "probabilities": {
            "benign": round(float(probs[0]), 4),
            "malicious": round(float(probs[1]), 4)
        },
        "inference_time_ms": round(inference_time, 2)
    }


def predict_batch(payloads: List[str]) -> dict:
    """Make predictions for a batch of payloads."""
    global tokenizer, ort_session
    
    start_time = time.time()
    
    # Tokenize batch
    inputs = tokenizer(
        payloads,
        max_length=MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='np'
    )
    
    # Run inference
    outputs = ort_session.run(
        None,
        {
            'input_ids': inputs['input_ids'].astype(np.int64),
            'attention_mask': inputs['attention_mask'].astype(np.int64)
        }
    )
    
    total_time = (time.time() - start_time) * 1000
    
    # Process results
    predictions = []
    probs_batch = outputs[0]
    
    for i, (payload, probs) in enumerate(zip(payloads, probs_batch)):
        pred_idx = int(np.argmax(probs))
        confidence = float(probs[pred_idx])
        prediction = "malicious" if pred_idx == 1 else "benign"
        
        predictions.append({
            "payload": payload[:100] + "..." if len(payload) > 100 else payload,
            "prediction": prediction,
            "confidence": round(confidence, 4),
            "probabilities": {
                "benign": round(float(probs[0]), 4),
                "malicious": round(float(probs[1]), 4)
            },
            "inference_time_ms": round(total_time / len(payloads), 2)
        })
    
    return {
        "predictions": predictions,
        "total_inference_time_ms": round(total_time, 2),
        "avg_inference_time_ms": round(total_time / len(payloads), 2)
    }


# Startup/shutdown events
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Load model on startup
    use_gpu = getattr(app.state, 'use_gpu', True)
    use_quantized = getattr(app.state, 'use_quantized', False)
    load_model(use_gpu=use_gpu, use_quantized=use_quantized)
    yield
    # Cleanup on shutdown
    print("Shutting down...")


# Create FastAPI app
app = FastAPI(
    title="Web Attack Detection API",
    description="CodeBERT-based web attack detection using ONNX Runtime. Supports CPU and GPU inference.",
    version="2.0.0",
    lifespan=lifespan
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/", response_model=dict)
async def root():
    """API root endpoint."""
    return {
        "name": "Web Attack Detection API",
        "version": "2.0.0",
        "model": "CodeBERT + ONNX Runtime",
        "endpoints": {
            "/predict": "POST - Single payload prediction",
            "/batch_predict": "POST - Batch payload prediction",
            "/health": "GET - Health check"
        }
    }


@app.get("/health", response_model=HealthResponse)
async def health():
    """Health check endpoint."""
    return {
        "status": "healthy" if ort_session is not None else "unhealthy",
        "model_loaded": ort_session is not None,
        "device": device_type,
        "provider": ort_session.get_providers()[0] if ort_session else "none",
        "model_path": model_path,
        "version": "2.0.0"
    }


@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
    """
    Predict if a single payload is malicious or benign.
    
    - **payload**: The HTTP request/payload string to analyze
    """
    if not ort_session:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    try:
        result = predict_single(request.payload)
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/batch_predict", response_model=BatchPredictResponse)
async def batch_predict(request: BatchPredictRequest):
    """
    Predict if multiple payloads are malicious or benign.
    
    - **payloads**: List of HTTP request/payload strings to analyze
    """
    if not ort_session:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    if len(request.payloads) == 0:
        raise HTTPException(status_code=400, detail="Empty payload list")
    
    if len(request.payloads) > 100:
        raise HTTPException(status_code=400, detail="Maximum batch size is 100")
    
    try:
        result = predict_batch(request.payloads)
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(description="Web Attack Detection API Server")
    parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
    parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
    parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu"],
                        help="Device to use for inference")
    parser.add_argument("--quantized", action="store_true", 
                        help="Use quantized model (smaller, potentially faster)")
    parser.add_argument("--workers", type=int, default=1, help="Number of workers")
    
    args = parser.parse_args()
    
    # Store config in app state
    app.state.use_gpu = (args.device == "gpu")
    app.state.use_quantized = args.quantized
    
    print("=" * 60)
    print("Web Attack Detection API Server")
    print("=" * 60)
    print(f"Host: {args.host}")
    print(f"Port: {args.port}")
    print(f"Device: {args.device}")
    print(f"Quantized: {args.quantized}")
    print("=" * 60)
    
    import uvicorn
    uvicorn.run(
        app,
        host=args.host,
        port=args.port,
        workers=args.workers,
        log_level="info"
    )


if __name__ == "__main__":
    main()