File size: 11,270 Bytes
7c918e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
"""
api_server.py
=============
AI Firewall — FastAPI Security Gateway

Exposes a REST API that acts as a security proxy between end-users
and any AI/LLM backend.  All input/output is validated by the firewall
pipeline before being forwarded or returned.

Endpoints
---------
  POST  /secure-inference      Full pipeline: check → model → output guardrail
  POST  /check-prompt          Input-only check (no model call)
  GET   /health                Liveness probe
  GET   /metrics               Basic request counters
  GET   /docs                  Swagger UI (auto-generated)

Run
---
  uvicorn ai_firewall.api_server:app --reload --port 8000

Environment variables (all optional)
--------------------------------------
  FIREWALL_BLOCK_THRESHOLD   float  default 0.70
  FIREWALL_FLAG_THRESHOLD    float  default 0.40
  FIREWALL_USE_EMBEDDINGS    bool   default false
  FIREWALL_LOG_DIR           str    default "."
  FIREWALL_MAX_LENGTH        int    default 4096
  DEMO_ECHO_MODE             bool   default true  (echo prompt as model output in /secure-inference)
"""

from __future__ import annotations

import logging
import os
import time
from contextlib import asynccontextmanager
from typing import Any, Dict, Optional

import uvicorn
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, field_validator, ConfigDict

from ai_firewall.guardrails import Guardrails, FirewallDecision
from ai_firewall.risk_scoring import RequestStatus

# ---------------------------------------------------------------------------
# Logging setup
# ---------------------------------------------------------------------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
)
logger = logging.getLogger("ai_firewall.api_server")

# ---------------------------------------------------------------------------
# Configuration from environment
# ---------------------------------------------------------------------------
BLOCK_THRESHOLD    = float(os.getenv("FIREWALL_BLOCK_THRESHOLD", "0.70"))
FLAG_THRESHOLD     = float(os.getenv("FIREWALL_FLAG_THRESHOLD", "0.40"))
USE_EMBEDDINGS     = os.getenv("FIREWALL_USE_EMBEDDINGS", "false").lower() in ("1", "true", "yes")
LOG_DIR            = os.getenv("FIREWALL_LOG_DIR", ".")
MAX_LENGTH         = int(os.getenv("FIREWALL_MAX_LENGTH", "4096"))
DEMO_ECHO_MODE     = os.getenv("DEMO_ECHO_MODE", "true").lower() in ("1", "true", "yes")

# ---------------------------------------------------------------------------
# Shared state
# ---------------------------------------------------------------------------
_guardrails: Optional[Guardrails] = None
_metrics: Dict[str, int] = {
    "total_requests": 0,
    "blocked": 0,
    "flagged": 0,
    "safe": 0,
    "output_blocked": 0,
}


# ---------------------------------------------------------------------------
# Lifespan (startup / shutdown)
# ---------------------------------------------------------------------------

@asynccontextmanager
async def lifespan(app: FastAPI):
    global _guardrails
    logger.info("Initialising AI Firewall pipeline…")
    _guardrails = Guardrails(
        block_threshold=BLOCK_THRESHOLD,
        flag_threshold=FLAG_THRESHOLD,
        use_embeddings=USE_EMBEDDINGS,
        log_dir=LOG_DIR,
        sanitizer_max_length=MAX_LENGTH,
    )
    logger.info(
        "AI Firewall ready | block=%.2f flag=%.2f embeddings=%s",
        BLOCK_THRESHOLD, FLAG_THRESHOLD, USE_EMBEDDINGS,
    )
    yield
    logger.info("AI Firewall shutting down.")


# ---------------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------------

app = FastAPI(
    title="AI Firewall",
    description=(
        "Production-ready AI Security Firewall. "
        "Protects LLM systems from prompt injection, adversarial inputs, "
        "and data leakage."
    ),
    version="1.0.0",
    lifespan=lifespan,
    docs_url="/docs",
    redoc_url="/redoc",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


# ---------------------------------------------------------------------------
# Request / Response schemas
# ---------------------------------------------------------------------------

class InferenceRequest(BaseModel):
    model_config = ConfigDict(protected_namespaces=())
    prompt: str = Field(..., min_length=1, max_length=32_000, description="The user prompt to secure.")
    model_endpoint: Optional[str] = Field(None, description="External model endpoint URL (future use).")
    metadata: Optional[Dict[str, Any]] = Field(None, description="Arbitrary caller metadata.")

    @field_validator("prompt")
    @classmethod
    def prompt_not_empty(cls, v: str) -> str:
        if not v.strip():
            raise ValueError("Prompt must not be blank.")
        return v


class CheckRequest(BaseModel):
    prompt: str = Field(..., min_length=1, max_length=32_000)


class RiskReportSchema(BaseModel):
    status: str
    risk_score: float
    risk_level: str
    injection_score: float
    adversarial_score: float
    attack_type: Optional[str] = None
    attack_category: Optional[str] = None
    flags: list
    latency_ms: float


class InferenceResponse(BaseModel):
    model_config = ConfigDict(protected_namespaces=())
    status: str
    risk_score: float
    risk_level: str
    sanitized_prompt: str
    model_output: Optional[str] = None
    safe_output: Optional[str] = None
    attack_type: Optional[str] = None
    flags: list = []
    total_latency_ms: float


class CheckResponse(BaseModel):
    status: str
    risk_score: float
    risk_level: str
    attack_type: Optional[str] = None
    attack_category: Optional[str] = None
    flags: list
    sanitized_prompt: str
    injection_score: float
    adversarial_score: float
    latency_ms: float


# ---------------------------------------------------------------------------
# Middleware — request timing & metrics
# ---------------------------------------------------------------------------

@app.middleware("http")
async def metrics_middleware(request: Request, call_next):
    _metrics["total_requests"] += 1
    start = time.perf_counter()
    response = await call_next(request)
    elapsed = (time.perf_counter() - start) * 1000
    response.headers["X-Process-Time-Ms"] = f"{elapsed:.2f}"
    return response


# ---------------------------------------------------------------------------
# Helper
# ---------------------------------------------------------------------------

def _demo_model(prompt: str) -> str:
    """Echo model used in DEMO_ECHO_MODE — returns the prompt as output."""
    return f"[DEMO ECHO] {prompt}"


def _decision_to_inference_response(decision: FirewallDecision) -> InferenceResponse:
    rr = decision.risk_report
    _update_metrics(rr.status.value, decision)
    return InferenceResponse(
        status=rr.status.value,
        risk_score=rr.risk_score,
        risk_level=rr.risk_level.value,
        sanitized_prompt=decision.sanitized_prompt,
        model_output=decision.model_output,
        safe_output=decision.safe_output,
        attack_type=rr.attack_type,
        flags=rr.flags,
        total_latency_ms=decision.total_latency_ms,
    )


def _update_metrics(status: str, decision: FirewallDecision) -> None:
    if status == "blocked":
        _metrics["blocked"] += 1
    elif status == "flagged":
        _metrics["flagged"] += 1
    else:
        _metrics["safe"] += 1
    if decision.model_output is not None and decision.safe_output != decision.model_output:
        _metrics["output_blocked"] += 1


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------

@app.get("/health", tags=["System"])
async def health():
    """Liveness / readiness probe."""
    return {"status": "ok", "service": "ai-firewall", "version": "1.0.0"}


@app.get("/metrics", tags=["System"])
async def metrics():
    """Basic request counters for monitoring."""
    return _metrics


@app.post(
    "/check-prompt",
    response_model=CheckResponse,
    tags=["Security"],
    summary="Check a prompt without calling an AI model",
)
async def check_prompt(body: CheckRequest):
    """
    Run the full input security pipeline (sanitization + injection detection
    + adversarial detection + risk scoring) without forwarding the prompt to
    any model.

    Returns a detailed risk report so you can decide whether to proceed.
    """
    if _guardrails is None:
        raise HTTPException(status_code=503, detail="Firewall not initialised.")

    decision = _guardrails.check_input(body.prompt)
    rr = decision.risk_report

    _update_metrics(rr.status.value, decision)

    return CheckResponse(
        status=rr.status.value,
        risk_score=rr.risk_score,
        risk_level=rr.risk_level.value,
        attack_type=rr.attack_type,
        attack_category=rr.attack_category,
        flags=rr.flags,
        sanitized_prompt=decision.sanitized_prompt,
        injection_score=rr.injection_score,
        adversarial_score=rr.adversarial_score,
        latency_ms=decision.total_latency_ms,
    )


@app.post(
    "/secure-inference",
    response_model=InferenceResponse,
    tags=["Security"],
    summary="Secure end-to-end inference with input + output guardrails",
)
async def secure_inference(body: InferenceRequest):
    """
    Full security pipeline:

    1. Sanitize input
    2. Detect prompt injection
    3. Detect adversarial inputs
    4. Compute risk score → block if too risky
    5. Forward to AI model (demo echo in DEMO_ECHO_MODE)
    6. Validate model output
    7. Return safe, redacted response

    **status** values:
    - `safe`    → passed all checks
    - `flagged` → suspicious but allowed through
    - `blocked` → rejected; no model output returned
    """
    if _guardrails is None:
        raise HTTPException(status_code=503, detail="Firewall not initialised.")

    model_fn = _demo_model  # replace with real model integration

    decision = _guardrails.secure_call(body.prompt, model_fn)
    return _decision_to_inference_response(decision)


# ---------------------------------------------------------------------------
# Global exception handler
# ---------------------------------------------------------------------------

@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
    logger.error("Unhandled exception: %s", exc, exc_info=True)
    return JSONResponse(
        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
        content={"detail": "Internal server error. Check server logs."},
    )


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    uvicorn.run(
        "ai_firewall.api_server:app",
        host="0.0.0.0",
        port=8000,
        reload=False,
        log_level="info",
    )