#!/usr/bin/env python3 """ FastAPI server for Web Attack Detection using ONNX Runtime. Supports both CPU and GPU inference. Usage: python server_onnx.py --host 0.0.0.0 --port 8000 --device gpu python server_onnx.py --host 0.0.0.0 --port 8000 --device cpu python server_onnx.py --quantized # Use quantized model (smaller, faster) """ import os import sys import json import time import argparse import numpy as np from typing import List, Optional from contextlib import asynccontextmanager import onnxruntime as ort from transformers import RobertaTokenizer from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field # Configuration ONNX_MODEL_PATH = "/c1/new-models/model.onnx" ONNX_QUANTIZED_PATH = "/c1/new-models/model_quantized.onnx" TOKENIZER_PATH = "/c1/huggingface/codebert-base" MAX_LENGTH = 256 class PredictRequest(BaseModel): """Single prediction request.""" payload: str = Field(..., description="The payload/request to classify") class BatchPredictRequest(BaseModel): """Batch prediction request.""" payloads: List[str] = Field(..., description="List of payloads to classify") class PredictResponse(BaseModel): """Prediction response.""" payload: str prediction: str # "malicious" or "benign" confidence: float probabilities: dict inference_time_ms: float class BatchPredictResponse(BaseModel): """Batch prediction response.""" predictions: List[PredictResponse] total_inference_time_ms: float avg_inference_time_ms: float class HealthResponse(BaseModel): """Health check response.""" status: str model_loaded: bool device: str provider: str model_path: str version: str # Global variables tokenizer = None ort_session = None device_type = "cpu" model_path = ONNX_MODEL_PATH def load_model(use_gpu: bool = True, use_quantized: bool = False): """Load ONNX model and tokenizer.""" global tokenizer, ort_session, device_type, model_path print("Loading model...") # Load tokenizer print(f" Loading tokenizer from: {TOKENIZER_PATH}") tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER_PATH) # Select model model_path = ONNX_QUANTIZED_PATH if use_quantized else ONNX_MODEL_PATH if not os.path.exists(model_path): model_path = ONNX_MODEL_PATH print(f" Loading ONNX model from: {model_path}") # Configure providers providers = [] if use_gpu: if 'CUDAExecutionProvider' in ort.get_available_providers(): providers.append('CUDAExecutionProvider') device_type = "gpu" else: print(" Warning: CUDA not available, falling back to CPU") providers.append('CPUExecutionProvider') if device_type != "gpu": device_type = "cpu" # Create session sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL ort_session = ort.InferenceSession( model_path, sess_options=sess_options, providers=providers ) actual_provider = ort_session.get_providers()[0] print(f" Model loaded successfully!") print(f" Provider: {actual_provider}") print(f" Device: {device_type}") return ort_session def predict_single(payload: str) -> dict: """Make prediction for a single payload.""" global tokenizer, ort_session start_time = time.time() # Tokenize inputs = tokenizer( payload, max_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors='np' ) # Run inference outputs = ort_session.run( None, { 'input_ids': inputs['input_ids'].astype(np.int64), 'attention_mask': inputs['attention_mask'].astype(np.int64) } ) # Process results probs = outputs[0][0] pred_idx = int(np.argmax(probs)) confidence = float(probs[pred_idx]) prediction = "malicious" if pred_idx == 1 else "benign" inference_time = (time.time() - start_time) * 1000 return { "payload": payload[:100] + "..." if len(payload) > 100 else payload, "prediction": prediction, "confidence": round(confidence, 4), "probabilities": { "benign": round(float(probs[0]), 4), "malicious": round(float(probs[1]), 4) }, "inference_time_ms": round(inference_time, 2) } def predict_batch(payloads: List[str]) -> dict: """Make predictions for a batch of payloads.""" global tokenizer, ort_session start_time = time.time() # Tokenize batch inputs = tokenizer( payloads, max_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors='np' ) # Run inference outputs = ort_session.run( None, { 'input_ids': inputs['input_ids'].astype(np.int64), 'attention_mask': inputs['attention_mask'].astype(np.int64) } ) total_time = (time.time() - start_time) * 1000 # Process results predictions = [] probs_batch = outputs[0] for i, (payload, probs) in enumerate(zip(payloads, probs_batch)): pred_idx = int(np.argmax(probs)) confidence = float(probs[pred_idx]) prediction = "malicious" if pred_idx == 1 else "benign" predictions.append({ "payload": payload[:100] + "..." if len(payload) > 100 else payload, "prediction": prediction, "confidence": round(confidence, 4), "probabilities": { "benign": round(float(probs[0]), 4), "malicious": round(float(probs[1]), 4) }, "inference_time_ms": round(total_time / len(payloads), 2) }) return { "predictions": predictions, "total_inference_time_ms": round(total_time, 2), "avg_inference_time_ms": round(total_time / len(payloads), 2) } # Startup/shutdown events @asynccontextmanager async def lifespan(app: FastAPI): # Load model on startup use_gpu = getattr(app.state, 'use_gpu', True) use_quantized = getattr(app.state, 'use_quantized', False) load_model(use_gpu=use_gpu, use_quantized=use_quantized) yield # Cleanup on shutdown print("Shutting down...") # Create FastAPI app app = FastAPI( title="Web Attack Detection API", description="CodeBERT-based web attack detection using ONNX Runtime. Supports CPU and GPU inference.", version="2.0.0", lifespan=lifespan ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/", response_model=dict) async def root(): """API root endpoint.""" return { "name": "Web Attack Detection API", "version": "2.0.0", "model": "CodeBERT + ONNX Runtime", "endpoints": { "/predict": "POST - Single payload prediction", "/batch_predict": "POST - Batch payload prediction", "/health": "GET - Health check" } } @app.get("/health", response_model=HealthResponse) async def health(): """Health check endpoint.""" return { "status": "healthy" if ort_session is not None else "unhealthy", "model_loaded": ort_session is not None, "device": device_type, "provider": ort_session.get_providers()[0] if ort_session else "none", "model_path": model_path, "version": "2.0.0" } @app.post("/predict", response_model=PredictResponse) async def predict(request: PredictRequest): """ Predict if a single payload is malicious or benign. - **payload**: The HTTP request/payload string to analyze """ if not ort_session: raise HTTPException(status_code=503, detail="Model not loaded") try: result = predict_single(request.payload) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/batch_predict", response_model=BatchPredictResponse) async def batch_predict(request: BatchPredictRequest): """ Predict if multiple payloads are malicious or benign. - **payloads**: List of HTTP request/payload strings to analyze """ if not ort_session: raise HTTPException(status_code=503, detail="Model not loaded") if len(request.payloads) == 0: raise HTTPException(status_code=400, detail="Empty payload list") if len(request.payloads) > 100: raise HTTPException(status_code=400, detail="Maximum batch size is 100") try: result = predict_batch(request.payloads) return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) def main(): """Main entry point.""" parser = argparse.ArgumentParser(description="Web Attack Detection API Server") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") parser.add_argument("--port", type=int, default=8000, help="Port to bind to") parser.add_argument("--device", type=str, default="gpu", choices=["cpu", "gpu"], help="Device to use for inference") parser.add_argument("--quantized", action="store_true", help="Use quantized model (smaller, potentially faster)") parser.add_argument("--workers", type=int, default=1, help="Number of workers") args = parser.parse_args() # Store config in app state app.state.use_gpu = (args.device == "gpu") app.state.use_quantized = args.quantized print("=" * 60) print("Web Attack Detection API Server") print("=" * 60) print(f"Host: {args.host}") print(f"Port: {args.port}") print(f"Device: {args.device}") print(f"Quantized: {args.quantized}") print("=" * 60) import uvicorn uvicorn.run( app, host=args.host, port=args.port, workers=args.workers, log_level="info" ) if __name__ == "__main__": main()