Ddos_Preventor / app.py
AliMusaRizvi's picture
Update app.py
91074f0 verified
import logging
import os
import sys
from contextlib import asynccontextmanager
from typing import Any
import joblib
import numpy as np
import pandas as pd
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from huggingface_hub import hf_hub_download
from pydantic import BaseModel, Field, field_validator
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger("ddos_detector")
REPO_ID = "AliMusaRizvi/Ddos_detector"
MODEL_CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache")
API_VERSION = "1.0.0"
# Soft-voting weights mirror the training ensemble: [rf=1, xgb=2, lgb=2]
_RAW_WEIGHTS = np.array([1.0, 2.0, 2.0], dtype=np.float64)
ENSEMBLE_WEIGHTS = _RAW_WEIGHTS / _RAW_WEIGHTS.sum()
ARTIFACT_FILENAMES: dict[str, str] = {
"scaler": "modelsfeature_scaler.pkl",
"xgb": "modelsxgboost_model.pkl",
"lgb": "modelslightgbm_model.pkl",
"rf": "modelsrandom_forest_model.pkl",
}
_registry: dict[str, Any] = {}
def _download_and_load(filename: str) -> Any:
"""
Download a single artifact from the Hub and deserialise it.
scikit-learn internally uses joblib for serialisation, so even files
named *.pkl must be loaded with joblib.load rather than pickle.load.
Passing a joblib-serialised file to pickle.load produces the opaque
STACK_GLOBAL requires str error.
"""
local_path = hf_hub_download(
repo_id=REPO_ID,
filename=filename,
cache_dir=MODEL_CACHE_DIR,
)
return joblib.load(local_path)
def _extract_feature_names(scaler: Any) -> list[str]:
"""Return the ordered feature names the scaler was fitted on."""
if hasattr(scaler, "feature_names_in_"):
return scaler.feature_names_in_.tolist()
# Fallback: positional names when the scaler was fitted on a plain array
return [f"feature_{i}" for i in range(scaler.n_features_in_)]
@asynccontextmanager
async def lifespan(application: FastAPI):
"""Load all model artifacts at startup; release them on shutdown."""
logger.info("Loading models from HuggingFace Hub (repo: %s) …", REPO_ID)
for key, filename in ARTIFACT_FILENAMES.items():
logger.info(" Fetching %s", filename)
_registry[key] = _download_and_load(filename)
# XGBoost stores the training device in its booster config.
# HF free-tier has no GPU, so we redirect inference to CPU.
if hasattr(_registry["xgb"], "set_params"):
_registry["xgb"].set_params(device="cpu")
_registry["feature_names"] = _extract_feature_names(_registry["scaler"])
logger.info(
"Startup complete — %d features expected.",
len(_registry["feature_names"]),
)
yield
_registry.clear()
logger.info("Shutdown — model registry cleared.")
limiter = Limiter(
key_func=get_remote_address,
default_limits=["300/hour"],
)
app = FastAPI(
title="DDoS Detector API",
description=(
"Classifies network flows as **Benign** or **Attack** using a "
"weighted soft-voting ensemble of XGBoost, LightGBM, and Random "
"Forest trained on the CIC-IDS2018 dataset.\n\n"
"Supply feature values in the order returned by `GET /features`."
),
version=API_VERSION,
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"],
max_age=600,
)
@app.middleware("http")
async def security_headers(request: Request, call_next):
"""Attach defensive HTTP headers to every response."""
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "no-referrer"
response.headers["Content-Security-Policy"] = "default-src 'none'"
response.headers["Cache-Control"] = "no-store"
return response
class PredictionRequest(BaseModel):
features: list[float] = Field(
...,
description=(
"Ordered list of numeric network-flow feature values. "
"Length must match the count returned by GET /features."
),
examples=[[0.1, 0.4, 1200.0, 0.0]],
)
@field_validator("features")
@classmethod
def validate_features(cls, values: list[float]) -> list[float]:
if len(values) == 0:
raise ValueError("Feature list must not be empty.")
for i, v in enumerate(values):
if not np.isfinite(v):
raise ValueError(
f"Non-finite value at index {i}: {v!r}. "
"All features must be real numbers."
)
return values
class BatchPredictionRequest(BaseModel):
records: list[PredictionRequest] = Field(
...,
min_length=1,
max_length=100,
description="One to 100 flow records for batch classification.",
)
class ModelAgreement(BaseModel):
random_forest: int = Field(..., description="0 = Benign, 1 = Attack")
xgboost: int = Field(..., description="0 = Benign, 1 = Attack")
lightgbm: int = Field(..., description="0 = Benign, 1 = Attack")
class PredictionResponse(BaseModel):
prediction: int = Field(..., description="0 = Benign, 1 = Attack")
label: str = Field(..., description="Human-readable label.")
confidence: float = Field(..., description="Ensemble confidence for the predicted class (0–1).")
attack_probability: float = Field(..., description="Weighted ensemble probability that the flow is an attack.")
model_agreement: ModelAgreement = Field(..., description="Individual prediction from each base learner.")
class BatchPredictionResponse(BaseModel):
results: list[PredictionResponse]
total: int = Field(..., description="Total number of flows in the batch.")
attack_count: int = Field(..., description="Flows classified as Attack.")
benign_count: int = Field(..., description="Flows classified as Benign.")
class HealthResponse(BaseModel):
status: str = Field(..., description="'healthy' or 'degraded'")
models_loaded: bool
feature_count: int
class FeatureListResponse(BaseModel):
feature_count: int
features: list[str]
def _classify(feature_values: list[float]) -> PredictionResponse:
"""
Scale features and run the weighted soft-voting ensemble.
Raises HTTPException on dimension mismatch.
"""
feature_names: list[str] = _registry["feature_names"]
n_expected = len(feature_names)
if len(feature_values) != n_expected:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=(
f"Feature count mismatch: expected {n_expected}, "
f"received {len(feature_values)}. "
"Call GET /features to retrieve the ordered feature list."
),
)
input_df = pd.DataFrame([feature_values], columns=feature_names)
scaled = _registry["scaler"].transform(input_df)
scaled_df = pd.DataFrame(scaled, columns=feature_names)
rf_p = float(_registry["rf"].predict_proba(scaled_df)[0, 1])
xgb_p = float(_registry["xgb"].predict_proba(scaled_df)[0, 1])
lgb_p = float(_registry["lgb"].predict_proba(scaled_df)[0, 1])
# Weighted probability aggregation (replicates soft VotingClassifier)
attack_prob = float(np.dot(ENSEMBLE_WEIGHTS, [rf_p, xgb_p, lgb_p]))
prediction = int(attack_prob >= 0.5)
confidence = attack_prob if prediction == 1 else (1.0 - attack_prob)
return PredictionResponse(
prediction=prediction,
label="Attack" if prediction == 1 else "Benign",
confidence=round(confidence, 6),
attack_probability=round(attack_prob, 6),
model_agreement=ModelAgreement(
random_forest=int(rf_p >= 0.5),
xgboost= int(xgb_p >= 0.5),
lightgbm= int(lgb_p >= 0.5),
),
)
def _require_models() -> None:
"""Raise 503 if the model registry is not yet populated."""
if not _registry:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Models are still loading. Retry in a moment.",
)
@app.get("/", include_in_schema=False)
async def root():
return {
"service": "DDoS Detector API",
"version": API_VERSION,
"documentation": "/docs",
}
@app.get(
"/health",
response_model=HealthResponse,
tags=["System"],
summary="Liveness and readiness check",
)
async def health_check():
"""Returns 200 when all models are loaded and ready to serve requests."""
loaded = all(k in _registry for k in ("scaler", "xgb", "lgb", "rf"))
if not loaded:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="One or more models failed to load.",
)
return HealthResponse(
status="healthy",
models_loaded=True,
feature_count=len(_registry.get("feature_names", [])),
)
@app.get(
"/features",
response_model=FeatureListResponse,
tags=["System"],
summary="List expected input features",
)
async def list_features():
"""
Returns the ordered list of feature names that the `/predict` endpoint expects.
Supply values in this exact order.
"""
_require_models()
names = _registry["feature_names"]
return FeatureListResponse(feature_count=len(names), features=names)
@app.post(
"/predict",
response_model=PredictionResponse,
tags=["Inference"],
summary="Classify a single network flow",
status_code=status.HTTP_200_OK,
)
@limiter.limit("30/minute")
async def predict(request: Request, body: PredictionRequest):
"""
Classifies one network flow record.
- Feature values must be ordered as returned by `GET /features`.
- Rate limited to **30 requests per minute** per IP address.
"""
_require_models()
return _classify(body.features)
@app.post(
"/predict/batch",
response_model=BatchPredictionResponse,
tags=["Inference"],
summary="Classify a batch of network flows",
status_code=status.HTTP_200_OK,
)
@limiter.limit("10/minute")
async def predict_batch(request: Request, body: BatchPredictionRequest):
"""
Classifies up to **100** network flow records in a single request.
- Rate limited to **10 requests per minute** per IP address.
- Each record follows the same feature ordering as `/predict`.
"""
_require_models()
results = [_classify(record.features) for record in body.records]
attack_count = sum(r.prediction for r in results)
return BatchPredictionResponse(
results=results,
total=len(results),
attack_count=attack_count,
benign_count=len(results) - attack_count,
)