""" api_server.py ============= AI Firewall — FastAPI Security Gateway Exposes a REST API that acts as a security proxy between end-users and any AI/LLM backend. All input/output is validated by the firewall pipeline before being forwarded or returned. Endpoints --------- POST /secure-inference Full pipeline: check → model → output guardrail POST /check-prompt Input-only check (no model call) GET /health Liveness probe GET /metrics Basic request counters GET /docs Swagger UI (auto-generated) Run --- uvicorn ai_firewall.api_server:app --reload --port 8000 Environment variables (all optional) -------------------------------------- FIREWALL_BLOCK_THRESHOLD float default 0.70 FIREWALL_FLAG_THRESHOLD float default 0.40 FIREWALL_USE_EMBEDDINGS bool default false FIREWALL_LOG_DIR str default "." FIREWALL_MAX_LENGTH int default 4096 DEMO_ECHO_MODE bool default true (echo prompt as model output in /secure-inference) """ from __future__ import annotations import logging import os import time from contextlib import asynccontextmanager from typing import Any, Dict, Optional import uvicorn from fastapi import FastAPI, HTTPException, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, field_validator, ConfigDict from ai_firewall.guardrails import Guardrails, FirewallDecision from ai_firewall.risk_scoring import RequestStatus # --------------------------------------------------------------------------- # Logging setup # --------------------------------------------------------------------------- logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", ) logger = logging.getLogger("ai_firewall.api_server") # --------------------------------------------------------------------------- # Configuration from environment # --------------------------------------------------------------------------- BLOCK_THRESHOLD = float(os.getenv("FIREWALL_BLOCK_THRESHOLD", "0.70")) FLAG_THRESHOLD = float(os.getenv("FIREWALL_FLAG_THRESHOLD", "0.40")) USE_EMBEDDINGS = os.getenv("FIREWALL_USE_EMBEDDINGS", "false").lower() in ("1", "true", "yes") LOG_DIR = os.getenv("FIREWALL_LOG_DIR", ".") MAX_LENGTH = int(os.getenv("FIREWALL_MAX_LENGTH", "4096")) DEMO_ECHO_MODE = os.getenv("DEMO_ECHO_MODE", "true").lower() in ("1", "true", "yes") # --------------------------------------------------------------------------- # Shared state # --------------------------------------------------------------------------- _guardrails: Optional[Guardrails] = None _metrics: Dict[str, int] = { "total_requests": 0, "blocked": 0, "flagged": 0, "safe": 0, "output_blocked": 0, } # --------------------------------------------------------------------------- # Lifespan (startup / shutdown) # --------------------------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI): global _guardrails logger.info("Initialising AI Firewall pipeline…") _guardrails = Guardrails( block_threshold=BLOCK_THRESHOLD, flag_threshold=FLAG_THRESHOLD, use_embeddings=USE_EMBEDDINGS, log_dir=LOG_DIR, sanitizer_max_length=MAX_LENGTH, ) logger.info( "AI Firewall ready | block=%.2f flag=%.2f embeddings=%s", BLOCK_THRESHOLD, FLAG_THRESHOLD, USE_EMBEDDINGS, ) yield logger.info("AI Firewall shutting down.") # --------------------------------------------------------------------------- # FastAPI app # --------------------------------------------------------------------------- app = FastAPI( title="AI Firewall", description=( "Production-ready AI Security Firewall. " "Protects LLM systems from prompt injection, adversarial inputs, " "and data leakage." ), version="1.0.0", lifespan=lifespan, docs_url="/docs", redoc_url="/redoc", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # --------------------------------------------------------------------------- # Request / Response schemas # --------------------------------------------------------------------------- class InferenceRequest(BaseModel): model_config = ConfigDict(protected_namespaces=()) prompt: str = Field(..., min_length=1, max_length=32_000, description="The user prompt to secure.") model_endpoint: Optional[str] = Field(None, description="External model endpoint URL (future use).") metadata: Optional[Dict[str, Any]] = Field(None, description="Arbitrary caller metadata.") @field_validator("prompt") @classmethod def prompt_not_empty(cls, v: str) -> str: if not v.strip(): raise ValueError("Prompt must not be blank.") return v class CheckRequest(BaseModel): prompt: str = Field(..., min_length=1, max_length=32_000) class RiskReportSchema(BaseModel): status: str risk_score: float risk_level: str injection_score: float adversarial_score: float attack_type: Optional[str] = None attack_category: Optional[str] = None flags: list latency_ms: float class InferenceResponse(BaseModel): model_config = ConfigDict(protected_namespaces=()) status: str risk_score: float risk_level: str sanitized_prompt: str model_output: Optional[str] = None safe_output: Optional[str] = None attack_type: Optional[str] = None flags: list = [] total_latency_ms: float class CheckResponse(BaseModel): status: str risk_score: float risk_level: str attack_type: Optional[str] = None attack_category: Optional[str] = None flags: list sanitized_prompt: str injection_score: float adversarial_score: float latency_ms: float # --------------------------------------------------------------------------- # Middleware — request timing & metrics # --------------------------------------------------------------------------- @app.middleware("http") async def metrics_middleware(request: Request, call_next): _metrics["total_requests"] += 1 start = time.perf_counter() response = await call_next(request) elapsed = (time.perf_counter() - start) * 1000 response.headers["X-Process-Time-Ms"] = f"{elapsed:.2f}" return response # --------------------------------------------------------------------------- # Helper # --------------------------------------------------------------------------- def _demo_model(prompt: str) -> str: """Echo model used in DEMO_ECHO_MODE — returns the prompt as output.""" return f"[DEMO ECHO] {prompt}" def _decision_to_inference_response(decision: FirewallDecision) -> InferenceResponse: rr = decision.risk_report _update_metrics(rr.status.value, decision) return InferenceResponse( status=rr.status.value, risk_score=rr.risk_score, risk_level=rr.risk_level.value, sanitized_prompt=decision.sanitized_prompt, model_output=decision.model_output, safe_output=decision.safe_output, attack_type=rr.attack_type, flags=rr.flags, total_latency_ms=decision.total_latency_ms, ) def _update_metrics(status: str, decision: FirewallDecision) -> None: if status == "blocked": _metrics["blocked"] += 1 elif status == "flagged": _metrics["flagged"] += 1 else: _metrics["safe"] += 1 if decision.model_output is not None and decision.safe_output != decision.model_output: _metrics["output_blocked"] += 1 # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @app.get("/health", tags=["System"]) async def health(): """Liveness / readiness probe.""" return {"status": "ok", "service": "ai-firewall", "version": "1.0.0"} @app.get("/metrics", tags=["System"]) async def metrics(): """Basic request counters for monitoring.""" return _metrics @app.post( "/check-prompt", response_model=CheckResponse, tags=["Security"], summary="Check a prompt without calling an AI model", ) async def check_prompt(body: CheckRequest): """ Run the full input security pipeline (sanitization + injection detection + adversarial detection + risk scoring) without forwarding the prompt to any model. Returns a detailed risk report so you can decide whether to proceed. """ if _guardrails is None: raise HTTPException(status_code=503, detail="Firewall not initialised.") decision = _guardrails.check_input(body.prompt) rr = decision.risk_report _update_metrics(rr.status.value, decision) return CheckResponse( status=rr.status.value, risk_score=rr.risk_score, risk_level=rr.risk_level.value, attack_type=rr.attack_type, attack_category=rr.attack_category, flags=rr.flags, sanitized_prompt=decision.sanitized_prompt, injection_score=rr.injection_score, adversarial_score=rr.adversarial_score, latency_ms=decision.total_latency_ms, ) @app.post( "/secure-inference", response_model=InferenceResponse, tags=["Security"], summary="Secure end-to-end inference with input + output guardrails", ) async def secure_inference(body: InferenceRequest): """ Full security pipeline: 1. Sanitize input 2. Detect prompt injection 3. Detect adversarial inputs 4. Compute risk score → block if too risky 5. Forward to AI model (demo echo in DEMO_ECHO_MODE) 6. Validate model output 7. Return safe, redacted response **status** values: - `safe` → passed all checks - `flagged` → suspicious but allowed through - `blocked` → rejected; no model output returned """ if _guardrails is None: raise HTTPException(status_code=503, detail="Firewall not initialised.") model_fn = _demo_model # replace with real model integration decision = _guardrails.secure_call(body.prompt, model_fn) return _decision_to_inference_response(decision) # --------------------------------------------------------------------------- # Global exception handler # --------------------------------------------------------------------------- @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): logger.error("Unhandled exception: %s", exc, exc_info=True) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Internal server error. Check server logs."}, ) # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- if __name__ == "__main__": uvicorn.run( "ai_firewall.api_server:app", host="0.0.0.0", port=8000, reload=False, log_level="info", )