Spaces:
Sleeping
Sleeping
| """ | |
| guardrails.py | |
| ============= | |
| High-level Guardrails orchestrator. | |
| This module wires together all detection and sanitization layers into a | |
| single cohesive pipeline. It is the primary entry point used by both | |
| the SDK (`sdk.py`) and the REST API (`api_server.py`). | |
| Pipeline order: | |
| Input β InputSanitizer β InjectionDetector β AdversarialDetector β RiskScorer | |
| β | |
| [block or pass to AI model] | |
| β | |
| AI Model β OutputGuardrail β RiskScorer (output pass) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Any, Callable, Dict, Optional | |
| from ai_firewall.injection_detector import InjectionDetector, AttackCategory | |
| from ai_firewall.adversarial_detector import AdversarialDetector | |
| from ai_firewall.sanitizer import InputSanitizer | |
| from ai_firewall.output_guardrail import OutputGuardrail | |
| from ai_firewall.risk_scoring import RiskScorer, RiskReport, RequestStatus | |
| from ai_firewall.security_logger import SecurityLogger | |
| logger = logging.getLogger("ai_firewall.guardrails") | |
| class FirewallDecision: | |
| """ | |
| Complete result of a full firewall check cycle. | |
| Attributes | |
| ---------- | |
| allowed : bool | |
| Whether the request was allowed through. | |
| sanitized_prompt : str | |
| The sanitized input prompt (may differ from original). | |
| risk_report : RiskReport | |
| Detailed risk scoring breakdown. | |
| model_output : Optional[str] | |
| The raw model output (None if request was blocked). | |
| safe_output : Optional[str] | |
| The guardrail-validated output (None if blocked or output unsafe). | |
| total_latency_ms : float | |
| End-to-end pipeline latency. | |
| """ | |
| allowed: bool | |
| sanitized_prompt: str | |
| risk_report: RiskReport | |
| model_output: Optional[str] = None | |
| safe_output: Optional[str] = None | |
| total_latency_ms: float = 0.0 | |
| def to_dict(self) -> dict: | |
| d = { | |
| "allowed": self.allowed, | |
| "sanitized_prompt": self.sanitized_prompt, | |
| "risk_report": self.risk_report.to_dict(), | |
| "total_latency_ms": round(self.total_latency_ms, 2), | |
| } | |
| if self.model_output is not None: | |
| d["model_output"] = self.model_output | |
| if self.safe_output is not None: | |
| d["safe_output"] = self.safe_output | |
| return d | |
| class Guardrails: | |
| """ | |
| Full-pipeline AI security orchestrator. | |
| Instantiate once and reuse across requests for optimal performance | |
| (models and embedders are loaded once at init time). | |
| Parameters | |
| ---------- | |
| injection_threshold : float | |
| Injection confidence above which input is blocked (default 0.55). | |
| adversarial_threshold : float | |
| Adversarial risk score above which input is blocked (default 0.60). | |
| block_threshold : float | |
| Combined risk score threshold for blocking (default 0.70). | |
| flag_threshold : float | |
| Combined risk score threshold for flagging (default 0.40). | |
| use_embeddings : bool | |
| Enable embedding-based detection layers (default False, adds latency). | |
| log_dir : str, optional | |
| Directory to write security logs to (default: current dir). | |
| sanitizer_max_length : int | |
| Max prompt length after sanitization (default 4096). | |
| """ | |
| def __init__( | |
| self, | |
| injection_threshold: float = 0.55, | |
| adversarial_threshold: float = 0.60, | |
| block_threshold: float = 0.70, | |
| flag_threshold: float = 0.40, | |
| use_embeddings: bool = False, | |
| log_dir: str = ".", | |
| sanitizer_max_length: int = 4096, | |
| ) -> None: | |
| self.injection_detector = InjectionDetector( | |
| threshold=injection_threshold, | |
| use_embeddings=use_embeddings, | |
| ) | |
| self.adversarial_detector = AdversarialDetector( | |
| threshold=adversarial_threshold, | |
| ) | |
| self.sanitizer = InputSanitizer(max_length=sanitizer_max_length) | |
| self.output_guardrail = OutputGuardrail() | |
| self.risk_scorer = RiskScorer( | |
| block_threshold=block_threshold, | |
| flag_threshold=flag_threshold, | |
| ) | |
| self.security_logger = SecurityLogger(log_dir=log_dir) | |
| logger.info("Guardrails pipeline initialised.") | |
| # ------------------------------------------------------------------ | |
| # Core pipeline | |
| # ------------------------------------------------------------------ | |
| def check_input(self, prompt: str) -> FirewallDecision: | |
| """ | |
| Run input-only pipeline (no model call). | |
| Use this when you want to decide whether to forward the prompt | |
| to your model yourself. | |
| Parameters | |
| ---------- | |
| prompt : str | |
| Raw user prompt. | |
| Returns | |
| ------- | |
| FirewallDecision (model_output and safe_output will be None) | |
| """ | |
| t0 = time.perf_counter() | |
| # 1. Sanitize | |
| san_result = self.sanitizer.sanitize(prompt) | |
| clean_prompt = san_result.sanitized | |
| # 2. Injection detection | |
| inj_result = self.injection_detector.detect(clean_prompt) | |
| # 3. Adversarial detection | |
| adv_result = self.adversarial_detector.detect(clean_prompt) | |
| # 4. Risk scoring | |
| all_flags = list(set(inj_result.matched_patterns[:5] + adv_result.flags)) | |
| attack_type = None | |
| if inj_result.is_injection: | |
| attack_type = "prompt_injection" | |
| elif adv_result.is_adversarial: | |
| attack_type = "adversarial_input" | |
| risk_report = self.risk_scorer.score( | |
| injection_score=inj_result.confidence, | |
| adversarial_score=adv_result.risk_score, | |
| injection_is_flagged=inj_result.is_injection, | |
| adversarial_is_flagged=adv_result.is_adversarial, | |
| attack_type=attack_type, | |
| attack_category=inj_result.attack_category.value if inj_result.is_injection else None, | |
| flags=all_flags, | |
| latency_ms=(time.perf_counter() - t0) * 1000, | |
| ) | |
| allowed = risk_report.status != RequestStatus.BLOCKED | |
| total_latency = (time.perf_counter() - t0) * 1000 | |
| decision = FirewallDecision( | |
| allowed=allowed, | |
| sanitized_prompt=clean_prompt, | |
| risk_report=risk_report, | |
| total_latency_ms=total_latency, | |
| ) | |
| # Log | |
| self.security_logger.log_request( | |
| prompt=prompt, | |
| sanitized=clean_prompt, | |
| decision=decision, | |
| ) | |
| return decision | |
| def secure_call( | |
| self, | |
| prompt: str, | |
| model_fn: Callable[[str], str], | |
| model_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> FirewallDecision: | |
| """ | |
| Full pipeline: check input β call model β validate output. | |
| Parameters | |
| ---------- | |
| prompt : str | |
| Raw user prompt. | |
| model_fn : Callable[[str], str] | |
| Your AI model function. Must accept a string prompt and | |
| return a string response. | |
| model_kwargs : dict, optional | |
| Extra kwargs forwarded to model_fn (as keyword args). | |
| Returns | |
| ------- | |
| FirewallDecision | |
| """ | |
| t0 = time.perf_counter() | |
| # Input pipeline | |
| decision = self.check_input(prompt) | |
| if not decision.allowed: | |
| decision.total_latency_ms = (time.perf_counter() - t0) * 1000 | |
| return decision | |
| # Call the model | |
| try: | |
| model_kwargs = model_kwargs or {} | |
| raw_output = model_fn(decision.sanitized_prompt, **model_kwargs) | |
| except Exception as exc: | |
| logger.error("Model function raised an exception: %s", exc) | |
| decision.allowed = False | |
| decision.model_output = None | |
| decision.total_latency_ms = (time.perf_counter() - t0) * 1000 | |
| return decision | |
| decision.model_output = raw_output | |
| # Output guardrail | |
| out_result = self.output_guardrail.validate(raw_output) | |
| if out_result.is_safe: | |
| decision.safe_output = raw_output | |
| else: | |
| decision.safe_output = out_result.redacted_output | |
| # Update risk report with output score | |
| updated_report = self.risk_scorer.score( | |
| injection_score=decision.risk_report.injection_score, | |
| adversarial_score=decision.risk_report.adversarial_score, | |
| injection_is_flagged=decision.risk_report.injection_score >= 0.55, | |
| adversarial_is_flagged=decision.risk_report.adversarial_score >= 0.60, | |
| attack_type=decision.risk_report.attack_type or "output_guardrail", | |
| attack_category=decision.risk_report.attack_category, | |
| flags=decision.risk_report.flags + out_result.flags, | |
| output_score=out_result.risk_score, | |
| ) | |
| decision.risk_report = updated_report | |
| decision.total_latency_ms = (time.perf_counter() - t0) * 1000 | |
| self.security_logger.log_response( | |
| output=raw_output, | |
| safe_output=decision.safe_output, | |
| guardrail_result=out_result, | |
| ) | |
| return decision | |