import logging import os import sys from contextlib import asynccontextmanager from typing import Any import joblib import numpy as np import pandas as pd from fastapi import FastAPI, HTTPException, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from huggingface_hub import hf_hub_download from pydantic import BaseModel, Field, field_validator from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s — %(message)s", handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger("ddos_detector") REPO_ID = "AliMusaRizvi/Ddos_detector" MODEL_CACHE_DIR = os.getenv("HF_HOME", "/tmp/hf_cache") API_VERSION = "1.0.0" # Soft-voting weights mirror the training ensemble: [rf=1, xgb=2, lgb=2] _RAW_WEIGHTS = np.array([1.0, 2.0, 2.0], dtype=np.float64) ENSEMBLE_WEIGHTS = _RAW_WEIGHTS / _RAW_WEIGHTS.sum() ARTIFACT_FILENAMES: dict[str, str] = { "scaler": "modelsfeature_scaler.pkl", "xgb": "modelsxgboost_model.pkl", "lgb": "modelslightgbm_model.pkl", "rf": "modelsrandom_forest_model.pkl", } _registry: dict[str, Any] = {} def _download_and_load(filename: str) -> Any: """ Download a single artifact from the Hub and deserialise it. scikit-learn internally uses joblib for serialisation, so even files named *.pkl must be loaded with joblib.load rather than pickle.load. Passing a joblib-serialised file to pickle.load produces the opaque STACK_GLOBAL requires str error. """ local_path = hf_hub_download( repo_id=REPO_ID, filename=filename, cache_dir=MODEL_CACHE_DIR, ) return joblib.load(local_path) def _extract_feature_names(scaler: Any) -> list[str]: """Return the ordered feature names the scaler was fitted on.""" if hasattr(scaler, "feature_names_in_"): return scaler.feature_names_in_.tolist() # Fallback: positional names when the scaler was fitted on a plain array return [f"feature_{i}" for i in range(scaler.n_features_in_)] @asynccontextmanager async def lifespan(application: FastAPI): """Load all model artifacts at startup; release them on shutdown.""" logger.info("Loading models from HuggingFace Hub (repo: %s) …", REPO_ID) for key, filename in ARTIFACT_FILENAMES.items(): logger.info(" Fetching %s", filename) _registry[key] = _download_and_load(filename) # XGBoost stores the training device in its booster config. # HF free-tier has no GPU, so we redirect inference to CPU. if hasattr(_registry["xgb"], "set_params"): _registry["xgb"].set_params(device="cpu") _registry["feature_names"] = _extract_feature_names(_registry["scaler"]) logger.info( "Startup complete — %d features expected.", len(_registry["feature_names"]), ) yield _registry.clear() logger.info("Shutdown — model registry cleared.") limiter = Limiter( key_func=get_remote_address, default_limits=["300/hour"], ) app = FastAPI( title="DDoS Detector API", description=( "Classifies network flows as **Benign** or **Attack** using a " "weighted soft-voting ensemble of XGBoost, LightGBM, and Random " "Forest trained on the CIC-IDS2018 dataset.\n\n" "Supply feature values in the order returned by `GET /features`." ), version=API_VERSION, lifespan=lifespan, docs_url="/docs", redoc_url="/redoc", ) app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["GET", "POST"], allow_headers=["Content-Type"], max_age=600, ) @app.middleware("http") async def security_headers(request: Request, call_next): """Attach defensive HTTP headers to every response.""" response = await call_next(request) response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Referrer-Policy"] = "no-referrer" response.headers["Content-Security-Policy"] = "default-src 'none'" response.headers["Cache-Control"] = "no-store" return response class PredictionRequest(BaseModel): features: list[float] = Field( ..., description=( "Ordered list of numeric network-flow feature values. " "Length must match the count returned by GET /features." ), examples=[[0.1, 0.4, 1200.0, 0.0]], ) @field_validator("features") @classmethod def validate_features(cls, values: list[float]) -> list[float]: if len(values) == 0: raise ValueError("Feature list must not be empty.") for i, v in enumerate(values): if not np.isfinite(v): raise ValueError( f"Non-finite value at index {i}: {v!r}. " "All features must be real numbers." ) return values class BatchPredictionRequest(BaseModel): records: list[PredictionRequest] = Field( ..., min_length=1, max_length=100, description="One to 100 flow records for batch classification.", ) class ModelAgreement(BaseModel): random_forest: int = Field(..., description="0 = Benign, 1 = Attack") xgboost: int = Field(..., description="0 = Benign, 1 = Attack") lightgbm: int = Field(..., description="0 = Benign, 1 = Attack") class PredictionResponse(BaseModel): prediction: int = Field(..., description="0 = Benign, 1 = Attack") label: str = Field(..., description="Human-readable label.") confidence: float = Field(..., description="Ensemble confidence for the predicted class (0–1).") attack_probability: float = Field(..., description="Weighted ensemble probability that the flow is an attack.") model_agreement: ModelAgreement = Field(..., description="Individual prediction from each base learner.") class BatchPredictionResponse(BaseModel): results: list[PredictionResponse] total: int = Field(..., description="Total number of flows in the batch.") attack_count: int = Field(..., description="Flows classified as Attack.") benign_count: int = Field(..., description="Flows classified as Benign.") class HealthResponse(BaseModel): status: str = Field(..., description="'healthy' or 'degraded'") models_loaded: bool feature_count: int class FeatureListResponse(BaseModel): feature_count: int features: list[str] def _classify(feature_values: list[float]) -> PredictionResponse: """ Scale features and run the weighted soft-voting ensemble. Raises HTTPException on dimension mismatch. """ feature_names: list[str] = _registry["feature_names"] n_expected = len(feature_names) if len(feature_values) != n_expected: raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=( f"Feature count mismatch: expected {n_expected}, " f"received {len(feature_values)}. " "Call GET /features to retrieve the ordered feature list." ), ) input_df = pd.DataFrame([feature_values], columns=feature_names) scaled = _registry["scaler"].transform(input_df) scaled_df = pd.DataFrame(scaled, columns=feature_names) rf_p = float(_registry["rf"].predict_proba(scaled_df)[0, 1]) xgb_p = float(_registry["xgb"].predict_proba(scaled_df)[0, 1]) lgb_p = float(_registry["lgb"].predict_proba(scaled_df)[0, 1]) # Weighted probability aggregation (replicates soft VotingClassifier) attack_prob = float(np.dot(ENSEMBLE_WEIGHTS, [rf_p, xgb_p, lgb_p])) prediction = int(attack_prob >= 0.5) confidence = attack_prob if prediction == 1 else (1.0 - attack_prob) return PredictionResponse( prediction=prediction, label="Attack" if prediction == 1 else "Benign", confidence=round(confidence, 6), attack_probability=round(attack_prob, 6), model_agreement=ModelAgreement( random_forest=int(rf_p >= 0.5), xgboost= int(xgb_p >= 0.5), lightgbm= int(lgb_p >= 0.5), ), ) def _require_models() -> None: """Raise 503 if the model registry is not yet populated.""" if not _registry: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Models are still loading. Retry in a moment.", ) @app.get("/", include_in_schema=False) async def root(): return { "service": "DDoS Detector API", "version": API_VERSION, "documentation": "/docs", } @app.get( "/health", response_model=HealthResponse, tags=["System"], summary="Liveness and readiness check", ) async def health_check(): """Returns 200 when all models are loaded and ready to serve requests.""" loaded = all(k in _registry for k in ("scaler", "xgb", "lgb", "rf")) if not loaded: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="One or more models failed to load.", ) return HealthResponse( status="healthy", models_loaded=True, feature_count=len(_registry.get("feature_names", [])), ) @app.get( "/features", response_model=FeatureListResponse, tags=["System"], summary="List expected input features", ) async def list_features(): """ Returns the ordered list of feature names that the `/predict` endpoint expects. Supply values in this exact order. """ _require_models() names = _registry["feature_names"] return FeatureListResponse(feature_count=len(names), features=names) @app.post( "/predict", response_model=PredictionResponse, tags=["Inference"], summary="Classify a single network flow", status_code=status.HTTP_200_OK, ) @limiter.limit("30/minute") async def predict(request: Request, body: PredictionRequest): """ Classifies one network flow record. - Feature values must be ordered as returned by `GET /features`. - Rate limited to **30 requests per minute** per IP address. """ _require_models() return _classify(body.features) @app.post( "/predict/batch", response_model=BatchPredictionResponse, tags=["Inference"], summary="Classify a batch of network flows", status_code=status.HTTP_200_OK, ) @limiter.limit("10/minute") async def predict_batch(request: Request, body: BatchPredictionRequest): """ Classifies up to **100** network flow records in a single request. - Rate limited to **10 requests per minute** per IP address. - Each record follows the same feature ordering as `/predict`. """ _require_models() results = [_classify(record.features) for record in body.records] attack_count = sum(r.prediction for r in results) return BatchPredictionResponse( results=results, total=len(results), attack_count=attack_count, benign_count=len(results) - attack_count, )