cloud450's picture
Upload 48 files
4afcb3a verified
"""
api_server.py
=============
AI Firewall — FastAPI Security Gateway
Exposes a REST API that acts as a security proxy between end-users
and any AI/LLM backend. All input/output is validated by the firewall
pipeline before being forwarded or returned.
Endpoints
---------
POST /secure-inference Full pipeline: check → model → output guardrail
POST /check-prompt Input-only check (no model call)
GET /health Liveness probe
GET /metrics Basic request counters
GET /docs Swagger UI (auto-generated)
Run
---
uvicorn ai_firewall.api_server:app --reload --port 8000
Environment variables (all optional)
--------------------------------------
FIREWALL_BLOCK_THRESHOLD float default 0.70
FIREWALL_FLAG_THRESHOLD float default 0.40
FIREWALL_USE_EMBEDDINGS bool default false
FIREWALL_LOG_DIR str default "."
FIREWALL_MAX_LENGTH int default 4096
DEMO_ECHO_MODE bool default true (echo prompt as model output in /secure-inference)
"""
from __future__ import annotations
import logging
import os
import time
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional
import uvicorn
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, field_validator, ConfigDict
from ai_firewall.guardrails import Guardrails, FirewallDecision
from ai_firewall.risk_scoring import RequestStatus
# ---------------------------------------------------------------------------
# Logging setup
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
)
logger = logging.getLogger("ai_firewall.api_server")
# ---------------------------------------------------------------------------
# Configuration from environment
# ---------------------------------------------------------------------------
BLOCK_THRESHOLD = float(os.getenv("FIREWALL_BLOCK_THRESHOLD", "0.70"))
FLAG_THRESHOLD = float(os.getenv("FIREWALL_FLAG_THRESHOLD", "0.40"))
USE_EMBEDDINGS = os.getenv("FIREWALL_USE_EMBEDDINGS", "false").lower() in ("1", "true", "yes")
LOG_DIR = os.getenv("FIREWALL_LOG_DIR", ".")
MAX_LENGTH = int(os.getenv("FIREWALL_MAX_LENGTH", "4096"))
DEMO_ECHO_MODE = os.getenv("DEMO_ECHO_MODE", "true").lower() in ("1", "true", "yes")
# ---------------------------------------------------------------------------
# Shared state
# ---------------------------------------------------------------------------
_guardrails: Optional[Guardrails] = None
_metrics: Dict[str, int] = {
"total_requests": 0,
"blocked": 0,
"flagged": 0,
"safe": 0,
"output_blocked": 0,
}
# ---------------------------------------------------------------------------
# Lifespan (startup / shutdown)
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
global _guardrails
logger.info("Initialising AI Firewall pipeline…")
_guardrails = Guardrails(
block_threshold=BLOCK_THRESHOLD,
flag_threshold=FLAG_THRESHOLD,
use_embeddings=USE_EMBEDDINGS,
log_dir=LOG_DIR,
sanitizer_max_length=MAX_LENGTH,
)
logger.info(
"AI Firewall ready | block=%.2f flag=%.2f embeddings=%s",
BLOCK_THRESHOLD, FLAG_THRESHOLD, USE_EMBEDDINGS,
)
yield
logger.info("AI Firewall shutting down.")
# ---------------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------------
app = FastAPI(
title="AI Firewall",
description=(
"Production-ready AI Security Firewall. "
"Protects LLM systems from prompt injection, adversarial inputs, "
"and data leakage."
),
version="1.0.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Request / Response schemas
# ---------------------------------------------------------------------------
class InferenceRequest(BaseModel):
model_config = ConfigDict(protected_namespaces=())
prompt: str = Field(..., min_length=1, max_length=32_000, description="The user prompt to secure.")
model_endpoint: Optional[str] = Field(None, description="External model endpoint URL (future use).")
metadata: Optional[Dict[str, Any]] = Field(None, description="Arbitrary caller metadata.")
@field_validator("prompt")
@classmethod
def prompt_not_empty(cls, v: str) -> str:
if not v.strip():
raise ValueError("Prompt must not be blank.")
return v
class CheckRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=32_000)
class RiskReportSchema(BaseModel):
status: str
risk_score: float
risk_level: str
injection_score: float
adversarial_score: float
attack_type: Optional[str] = None
attack_category: Optional[str] = None
flags: list
latency_ms: float
class InferenceResponse(BaseModel):
model_config = ConfigDict(protected_namespaces=())
status: str
risk_score: float
risk_level: str
sanitized_prompt: str
model_output: Optional[str] = None
safe_output: Optional[str] = None
attack_type: Optional[str] = None
flags: list = []
total_latency_ms: float
class CheckResponse(BaseModel):
status: str
risk_score: float
risk_level: str
attack_type: Optional[str] = None
attack_category: Optional[str] = None
flags: list
sanitized_prompt: str
injection_score: float
adversarial_score: float
latency_ms: float
# ---------------------------------------------------------------------------
# Middleware — request timing & metrics
# ---------------------------------------------------------------------------
@app.middleware("http")
async def metrics_middleware(request: Request, call_next):
_metrics["total_requests"] += 1
start = time.perf_counter()
response = await call_next(request)
elapsed = (time.perf_counter() - start) * 1000
response.headers["X-Process-Time-Ms"] = f"{elapsed:.2f}"
return response
# ---------------------------------------------------------------------------
# Helper
# ---------------------------------------------------------------------------
def _demo_model(prompt: str) -> str:
"""Echo model used in DEMO_ECHO_MODE — returns the prompt as output."""
return f"[DEMO ECHO] {prompt}"
def _decision_to_inference_response(decision: FirewallDecision) -> InferenceResponse:
rr = decision.risk_report
_update_metrics(rr.status.value, decision)
return InferenceResponse(
status=rr.status.value,
risk_score=rr.risk_score,
risk_level=rr.risk_level.value,
sanitized_prompt=decision.sanitized_prompt,
model_output=decision.model_output,
safe_output=decision.safe_output,
attack_type=rr.attack_type,
flags=rr.flags,
total_latency_ms=decision.total_latency_ms,
)
def _update_metrics(status: str, decision: FirewallDecision) -> None:
if status == "blocked":
_metrics["blocked"] += 1
elif status == "flagged":
_metrics["flagged"] += 1
else:
_metrics["safe"] += 1
if decision.model_output is not None and decision.safe_output != decision.model_output:
_metrics["output_blocked"] += 1
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health", tags=["System"])
async def health():
"""Liveness / readiness probe."""
return {"status": "ok", "service": "ai-firewall", "version": "1.0.0"}
@app.get("/metrics", tags=["System"])
async def metrics():
"""Basic request counters for monitoring."""
return _metrics
@app.post(
"/check-prompt",
response_model=CheckResponse,
tags=["Security"],
summary="Check a prompt without calling an AI model",
)
async def check_prompt(body: CheckRequest):
"""
Run the full input security pipeline (sanitization + injection detection
+ adversarial detection + risk scoring) without forwarding the prompt to
any model.
Returns a detailed risk report so you can decide whether to proceed.
"""
if _guardrails is None:
raise HTTPException(status_code=503, detail="Firewall not initialised.")
decision = _guardrails.check_input(body.prompt)
rr = decision.risk_report
_update_metrics(rr.status.value, decision)
return CheckResponse(
status=rr.status.value,
risk_score=rr.risk_score,
risk_level=rr.risk_level.value,
attack_type=rr.attack_type,
attack_category=rr.attack_category,
flags=rr.flags,
sanitized_prompt=decision.sanitized_prompt,
injection_score=rr.injection_score,
adversarial_score=rr.adversarial_score,
latency_ms=decision.total_latency_ms,
)
@app.post(
"/secure-inference",
response_model=InferenceResponse,
tags=["Security"],
summary="Secure end-to-end inference with input + output guardrails",
)
async def secure_inference(body: InferenceRequest):
"""
Full security pipeline:
1. Sanitize input
2. Detect prompt injection
3. Detect adversarial inputs
4. Compute risk score → block if too risky
5. Forward to AI model (demo echo in DEMO_ECHO_MODE)
6. Validate model output
7. Return safe, redacted response
**status** values:
- `safe` → passed all checks
- `flagged` → suspicious but allowed through
- `blocked` → rejected; no model output returned
"""
if _guardrails is None:
raise HTTPException(status_code=503, detail="Firewall not initialised.")
model_fn = _demo_model # replace with real model integration
decision = _guardrails.secure_call(body.prompt, model_fn)
return _decision_to_inference_response(decision)
# ---------------------------------------------------------------------------
# Global exception handler
# ---------------------------------------------------------------------------
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error("Unhandled exception: %s", exc, exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "Internal server error. Check server logs."},
)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
uvicorn.run(
"ai_firewall.api_server:app",
host="0.0.0.0",
port=8000,
reload=False,
log_level="info",
)