Spaces:
Sleeping
Sleeping
| """ | |
| api_server.py | |
| ============= | |
| AI Firewall — FastAPI Security Gateway | |
| Exposes a REST API that acts as a security proxy between end-users | |
| and any AI/LLM backend. All input/output is validated by the firewall | |
| pipeline before being forwarded or returned. | |
| Endpoints | |
| --------- | |
| POST /secure-inference Full pipeline: check → model → output guardrail | |
| POST /check-prompt Input-only check (no model call) | |
| GET /health Liveness probe | |
| GET /metrics Basic request counters | |
| GET /docs Swagger UI (auto-generated) | |
| Run | |
| --- | |
| uvicorn ai_firewall.api_server:app --reload --port 8000 | |
| Environment variables (all optional) | |
| -------------------------------------- | |
| FIREWALL_BLOCK_THRESHOLD float default 0.70 | |
| FIREWALL_FLAG_THRESHOLD float default 0.40 | |
| FIREWALL_USE_EMBEDDINGS bool default false | |
| FIREWALL_LOG_DIR str default "." | |
| FIREWALL_MAX_LENGTH int default 4096 | |
| DEMO_ECHO_MODE bool default true (echo prompt as model output in /secure-inference) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| import time | |
| from contextlib import asynccontextmanager | |
| from typing import Any, Dict, Optional | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException, Request, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field, field_validator, ConfigDict | |
| from ai_firewall.guardrails import Guardrails, FirewallDecision | |
| from ai_firewall.risk_scoring import RequestStatus | |
| # --------------------------------------------------------------------------- | |
| # Logging setup | |
| # --------------------------------------------------------------------------- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", | |
| ) | |
| logger = logging.getLogger("ai_firewall.api_server") | |
| # --------------------------------------------------------------------------- | |
| # Configuration from environment | |
| # --------------------------------------------------------------------------- | |
| BLOCK_THRESHOLD = float(os.getenv("FIREWALL_BLOCK_THRESHOLD", "0.70")) | |
| FLAG_THRESHOLD = float(os.getenv("FIREWALL_FLAG_THRESHOLD", "0.40")) | |
| USE_EMBEDDINGS = os.getenv("FIREWALL_USE_EMBEDDINGS", "false").lower() in ("1", "true", "yes") | |
| LOG_DIR = os.getenv("FIREWALL_LOG_DIR", ".") | |
| MAX_LENGTH = int(os.getenv("FIREWALL_MAX_LENGTH", "4096")) | |
| DEMO_ECHO_MODE = os.getenv("DEMO_ECHO_MODE", "true").lower() in ("1", "true", "yes") | |
| # --------------------------------------------------------------------------- | |
| # Shared state | |
| # --------------------------------------------------------------------------- | |
| _guardrails: Optional[Guardrails] = None | |
| _metrics: Dict[str, int] = { | |
| "total_requests": 0, | |
| "blocked": 0, | |
| "flagged": 0, | |
| "safe": 0, | |
| "output_blocked": 0, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Lifespan (startup / shutdown) | |
| # --------------------------------------------------------------------------- | |
| async def lifespan(app: FastAPI): | |
| global _guardrails | |
| logger.info("Initialising AI Firewall pipeline…") | |
| _guardrails = Guardrails( | |
| block_threshold=BLOCK_THRESHOLD, | |
| flag_threshold=FLAG_THRESHOLD, | |
| use_embeddings=USE_EMBEDDINGS, | |
| log_dir=LOG_DIR, | |
| sanitizer_max_length=MAX_LENGTH, | |
| ) | |
| logger.info( | |
| "AI Firewall ready | block=%.2f flag=%.2f embeddings=%s", | |
| BLOCK_THRESHOLD, FLAG_THRESHOLD, USE_EMBEDDINGS, | |
| ) | |
| yield | |
| logger.info("AI Firewall shutting down.") | |
| # --------------------------------------------------------------------------- | |
| # FastAPI app | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="AI Firewall", | |
| description=( | |
| "Production-ready AI Security Firewall. " | |
| "Protects LLM systems from prompt injection, adversarial inputs, " | |
| "and data leakage." | |
| ), | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Request / Response schemas | |
| # --------------------------------------------------------------------------- | |
| class InferenceRequest(BaseModel): | |
| model_config = ConfigDict(protected_namespaces=()) | |
| prompt: str = Field(..., min_length=1, max_length=32_000, description="The user prompt to secure.") | |
| model_endpoint: Optional[str] = Field(None, description="External model endpoint URL (future use).") | |
| metadata: Optional[Dict[str, Any]] = Field(None, description="Arbitrary caller metadata.") | |
| def prompt_not_empty(cls, v: str) -> str: | |
| if not v.strip(): | |
| raise ValueError("Prompt must not be blank.") | |
| return v | |
| class CheckRequest(BaseModel): | |
| prompt: str = Field(..., min_length=1, max_length=32_000) | |
| class RiskReportSchema(BaseModel): | |
| status: str | |
| risk_score: float | |
| risk_level: str | |
| injection_score: float | |
| adversarial_score: float | |
| attack_type: Optional[str] = None | |
| attack_category: Optional[str] = None | |
| flags: list | |
| latency_ms: float | |
| class InferenceResponse(BaseModel): | |
| model_config = ConfigDict(protected_namespaces=()) | |
| status: str | |
| risk_score: float | |
| risk_level: str | |
| sanitized_prompt: str | |
| model_output: Optional[str] = None | |
| safe_output: Optional[str] = None | |
| attack_type: Optional[str] = None | |
| flags: list = [] | |
| total_latency_ms: float | |
| class CheckResponse(BaseModel): | |
| status: str | |
| risk_score: float | |
| risk_level: str | |
| attack_type: Optional[str] = None | |
| attack_category: Optional[str] = None | |
| flags: list | |
| sanitized_prompt: str | |
| injection_score: float | |
| adversarial_score: float | |
| latency_ms: float | |
| # --------------------------------------------------------------------------- | |
| # Middleware — request timing & metrics | |
| # --------------------------------------------------------------------------- | |
| async def metrics_middleware(request: Request, call_next): | |
| _metrics["total_requests"] += 1 | |
| start = time.perf_counter() | |
| response = await call_next(request) | |
| elapsed = (time.perf_counter() - start) * 1000 | |
| response.headers["X-Process-Time-Ms"] = f"{elapsed:.2f}" | |
| return response | |
| # --------------------------------------------------------------------------- | |
| # Helper | |
| # --------------------------------------------------------------------------- | |
| def _demo_model(prompt: str) -> str: | |
| """Echo model used in DEMO_ECHO_MODE — returns the prompt as output.""" | |
| return f"[DEMO ECHO] {prompt}" | |
| def _decision_to_inference_response(decision: FirewallDecision) -> InferenceResponse: | |
| rr = decision.risk_report | |
| _update_metrics(rr.status.value, decision) | |
| return InferenceResponse( | |
| status=rr.status.value, | |
| risk_score=rr.risk_score, | |
| risk_level=rr.risk_level.value, | |
| sanitized_prompt=decision.sanitized_prompt, | |
| model_output=decision.model_output, | |
| safe_output=decision.safe_output, | |
| attack_type=rr.attack_type, | |
| flags=rr.flags, | |
| total_latency_ms=decision.total_latency_ms, | |
| ) | |
| def _update_metrics(status: str, decision: FirewallDecision) -> None: | |
| if status == "blocked": | |
| _metrics["blocked"] += 1 | |
| elif status == "flagged": | |
| _metrics["flagged"] += 1 | |
| else: | |
| _metrics["safe"] += 1 | |
| if decision.model_output is not None and decision.safe_output != decision.model_output: | |
| _metrics["output_blocked"] += 1 | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| async def health(): | |
| """Liveness / readiness probe.""" | |
| return {"status": "ok", "service": "ai-firewall", "version": "1.0.0"} | |
| async def metrics(): | |
| """Basic request counters for monitoring.""" | |
| return _metrics | |
| async def check_prompt(body: CheckRequest): | |
| """ | |
| Run the full input security pipeline (sanitization + injection detection | |
| + adversarial detection + risk scoring) without forwarding the prompt to | |
| any model. | |
| Returns a detailed risk report so you can decide whether to proceed. | |
| """ | |
| if _guardrails is None: | |
| raise HTTPException(status_code=503, detail="Firewall not initialised.") | |
| decision = _guardrails.check_input(body.prompt) | |
| rr = decision.risk_report | |
| _update_metrics(rr.status.value, decision) | |
| return CheckResponse( | |
| status=rr.status.value, | |
| risk_score=rr.risk_score, | |
| risk_level=rr.risk_level.value, | |
| attack_type=rr.attack_type, | |
| attack_category=rr.attack_category, | |
| flags=rr.flags, | |
| sanitized_prompt=decision.sanitized_prompt, | |
| injection_score=rr.injection_score, | |
| adversarial_score=rr.adversarial_score, | |
| latency_ms=decision.total_latency_ms, | |
| ) | |
| async def secure_inference(body: InferenceRequest): | |
| """ | |
| Full security pipeline: | |
| 1. Sanitize input | |
| 2. Detect prompt injection | |
| 3. Detect adversarial inputs | |
| 4. Compute risk score → block if too risky | |
| 5. Forward to AI model (demo echo in DEMO_ECHO_MODE) | |
| 6. Validate model output | |
| 7. Return safe, redacted response | |
| **status** values: | |
| - `safe` → passed all checks | |
| - `flagged` → suspicious but allowed through | |
| - `blocked` → rejected; no model output returned | |
| """ | |
| if _guardrails is None: | |
| raise HTTPException(status_code=503, detail="Firewall not initialised.") | |
| model_fn = _demo_model # replace with real model integration | |
| decision = _guardrails.secure_call(body.prompt, model_fn) | |
| return _decision_to_inference_response(decision) | |
| # --------------------------------------------------------------------------- | |
| # Global exception handler | |
| # --------------------------------------------------------------------------- | |
| async def global_exception_handler(request: Request, exc: Exception): | |
| logger.error("Unhandled exception: %s", exc, exc_info=True) | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={"detail": "Internal server error. Check server logs."}, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "ai_firewall.api_server:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=False, | |
| log_level="info", | |
| ) | |