cloud450's picture
Upload 48 files
4afcb3a verified
"""
guardrails.py
=============
High-level Guardrails orchestrator.
This module wires together all detection and sanitization layers into a
single cohesive pipeline. It is the primary entry point used by both
the SDK (`sdk.py`) and the REST API (`api_server.py`).
Pipeline order:
Input β†’ InputSanitizer β†’ InjectionDetector β†’ AdversarialDetector β†’ RiskScorer
↓
[block or pass to AI model]
↓
AI Model β†’ OutputGuardrail β†’ RiskScorer (output pass)
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional
from ai_firewall.injection_detector import InjectionDetector, AttackCategory
from ai_firewall.adversarial_detector import AdversarialDetector
from ai_firewall.sanitizer import InputSanitizer
from ai_firewall.output_guardrail import OutputGuardrail
from ai_firewall.risk_scoring import RiskScorer, RiskReport, RequestStatus
from ai_firewall.security_logger import SecurityLogger
logger = logging.getLogger("ai_firewall.guardrails")
@dataclass
class FirewallDecision:
"""
Complete result of a full firewall check cycle.
Attributes
----------
allowed : bool
Whether the request was allowed through.
sanitized_prompt : str
The sanitized input prompt (may differ from original).
risk_report : RiskReport
Detailed risk scoring breakdown.
model_output : Optional[str]
The raw model output (None if request was blocked).
safe_output : Optional[str]
The guardrail-validated output (None if blocked or output unsafe).
total_latency_ms : float
End-to-end pipeline latency.
"""
allowed: bool
sanitized_prompt: str
risk_report: RiskReport
model_output: Optional[str] = None
safe_output: Optional[str] = None
total_latency_ms: float = 0.0
def to_dict(self) -> dict:
d = {
"allowed": self.allowed,
"sanitized_prompt": self.sanitized_prompt,
"risk_report": self.risk_report.to_dict(),
"total_latency_ms": round(self.total_latency_ms, 2),
}
if self.model_output is not None:
d["model_output"] = self.model_output
if self.safe_output is not None:
d["safe_output"] = self.safe_output
return d
class Guardrails:
"""
Full-pipeline AI security orchestrator.
Instantiate once and reuse across requests for optimal performance
(models and embedders are loaded once at init time).
Parameters
----------
injection_threshold : float
Injection confidence above which input is blocked (default 0.55).
adversarial_threshold : float
Adversarial risk score above which input is blocked (default 0.60).
block_threshold : float
Combined risk score threshold for blocking (default 0.70).
flag_threshold : float
Combined risk score threshold for flagging (default 0.40).
use_embeddings : bool
Enable embedding-based detection layers (default False, adds latency).
log_dir : str, optional
Directory to write security logs to (default: current dir).
sanitizer_max_length : int
Max prompt length after sanitization (default 4096).
"""
def __init__(
self,
injection_threshold: float = 0.55,
adversarial_threshold: float = 0.60,
block_threshold: float = 0.70,
flag_threshold: float = 0.40,
use_embeddings: bool = False,
log_dir: str = ".",
sanitizer_max_length: int = 4096,
) -> None:
self.injection_detector = InjectionDetector(
threshold=injection_threshold,
use_embeddings=use_embeddings,
)
self.adversarial_detector = AdversarialDetector(
threshold=adversarial_threshold,
)
self.sanitizer = InputSanitizer(max_length=sanitizer_max_length)
self.output_guardrail = OutputGuardrail()
self.risk_scorer = RiskScorer(
block_threshold=block_threshold,
flag_threshold=flag_threshold,
)
self.security_logger = SecurityLogger(log_dir=log_dir)
logger.info("Guardrails pipeline initialised.")
# ------------------------------------------------------------------
# Core pipeline
# ------------------------------------------------------------------
def check_input(self, prompt: str) -> FirewallDecision:
"""
Run input-only pipeline (no model call).
Use this when you want to decide whether to forward the prompt
to your model yourself.
Parameters
----------
prompt : str
Raw user prompt.
Returns
-------
FirewallDecision (model_output and safe_output will be None)
"""
t0 = time.perf_counter()
# 1. Sanitize
san_result = self.sanitizer.sanitize(prompt)
clean_prompt = san_result.sanitized
# 2. Injection detection
inj_result = self.injection_detector.detect(clean_prompt)
# 3. Adversarial detection
adv_result = self.adversarial_detector.detect(clean_prompt)
# 4. Risk scoring
all_flags = list(set(inj_result.matched_patterns[:5] + adv_result.flags))
attack_type = None
if inj_result.is_injection:
attack_type = "prompt_injection"
elif adv_result.is_adversarial:
attack_type = "adversarial_input"
risk_report = self.risk_scorer.score(
injection_score=inj_result.confidence,
adversarial_score=adv_result.risk_score,
injection_is_flagged=inj_result.is_injection,
adversarial_is_flagged=adv_result.is_adversarial,
attack_type=attack_type,
attack_category=inj_result.attack_category.value if inj_result.is_injection else None,
flags=all_flags,
latency_ms=(time.perf_counter() - t0) * 1000,
)
allowed = risk_report.status != RequestStatus.BLOCKED
total_latency = (time.perf_counter() - t0) * 1000
decision = FirewallDecision(
allowed=allowed,
sanitized_prompt=clean_prompt,
risk_report=risk_report,
total_latency_ms=total_latency,
)
# Log
self.security_logger.log_request(
prompt=prompt,
sanitized=clean_prompt,
decision=decision,
)
return decision
def secure_call(
self,
prompt: str,
model_fn: Callable[[str], str],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> FirewallDecision:
"""
Full pipeline: check input β†’ call model β†’ validate output.
Parameters
----------
prompt : str
Raw user prompt.
model_fn : Callable[[str], str]
Your AI model function. Must accept a string prompt and
return a string response.
model_kwargs : dict, optional
Extra kwargs forwarded to model_fn (as keyword args).
Returns
-------
FirewallDecision
"""
t0 = time.perf_counter()
# Input pipeline
decision = self.check_input(prompt)
if not decision.allowed:
decision.total_latency_ms = (time.perf_counter() - t0) * 1000
return decision
# Call the model
try:
model_kwargs = model_kwargs or {}
raw_output = model_fn(decision.sanitized_prompt, **model_kwargs)
except Exception as exc:
logger.error("Model function raised an exception: %s", exc)
decision.allowed = False
decision.model_output = None
decision.total_latency_ms = (time.perf_counter() - t0) * 1000
return decision
decision.model_output = raw_output
# Output guardrail
out_result = self.output_guardrail.validate(raw_output)
if out_result.is_safe:
decision.safe_output = raw_output
else:
decision.safe_output = out_result.redacted_output
# Update risk report with output score
updated_report = self.risk_scorer.score(
injection_score=decision.risk_report.injection_score,
adversarial_score=decision.risk_report.adversarial_score,
injection_is_flagged=decision.risk_report.injection_score >= 0.55,
adversarial_is_flagged=decision.risk_report.adversarial_score >= 0.60,
attack_type=decision.risk_report.attack_type or "output_guardrail",
attack_category=decision.risk_report.attack_category,
flags=decision.risk_report.flags + out_result.flags,
output_score=out_result.risk_score,
)
decision.risk_report = updated_report
decision.total_latency_ms = (time.perf_counter() - t0) * 1000
self.security_logger.log_response(
output=raw_output,
safe_output=decision.safe_output,
guardrail_result=out_result,
)
return decision