Spaces:
Sleeping
Sleeping
File size: 6,651 Bytes
4afcb3a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | """
sdk.py
======
AI Firewall Python SDK
The SDK provides the simplest possible integration for developers who
want to add a security layer to an existing LLM call without touching
their model code.
Quick-start
-----------
from ai_firewall import secure_llm_call
def my_llm(prompt: str) -> str:
# your existing model call
...
response = secure_llm_call(my_llm, "What is the capital of France?")
Full SDK usage
--------------
from ai_firewall.sdk import FirewallSDK
sdk = FirewallSDK(block_threshold=0.70)
# Check only (no model call)
result = sdk.check("ignore all previous instructions")
print(result.risk_report.status) # "blocked"
# Secure call
result = sdk.secure_call(my_llm, "Hello!")
if result.allowed:
print(result.safe_output)
"""
from __future__ import annotations
import functools
import logging
from typing import Any, Callable, Dict, Optional
from ai_firewall.guardrails import Guardrails, FirewallDecision
logger = logging.getLogger("ai_firewall.sdk")
class FirewallSDK:
"""
High-level SDK wrapping the Guardrails pipeline.
Designed for simplicity: instantiate once, use everywhere.
Parameters
----------
block_threshold : float
Requests with risk_score >= this are blocked (default 0.70).
flag_threshold : float
Requests with risk_score >= this are flagged (default 0.40).
use_embeddings : bool
Enable embedding-based detection (default False).
log_dir : str
Directory for security logs (default ".").
sanitizer_max_length : int
Max allowed prompt length after sanitization (default 4096).
raise_on_block : bool
If True, raise FirewallBlockedError when a request is blocked.
If False (default), return the FirewallDecision with allowed=False.
"""
def __init__(
self,
block_threshold: float = 0.70,
flag_threshold: float = 0.40,
use_embeddings: bool = False,
log_dir: str = ".",
sanitizer_max_length: int = 4096,
raise_on_block: bool = False,
) -> None:
self._guardrails = Guardrails(
block_threshold=block_threshold,
flag_threshold=flag_threshold,
use_embeddings=use_embeddings,
log_dir=log_dir,
sanitizer_max_length=sanitizer_max_length,
)
self.raise_on_block = raise_on_block
logger.info("FirewallSDK ready | block=%.2f flag=%.2f embeddings=%s", block_threshold, flag_threshold, use_embeddings)
def check(self, prompt: str) -> FirewallDecision:
"""
Run the input firewall pipeline without calling any model.
Parameters
----------
prompt : str
Raw user prompt to evaluate.
Returns
-------
FirewallDecision
"""
decision = self._guardrails.check_input(prompt)
if self.raise_on_block and not decision.allowed:
raise FirewallBlockedError(decision)
return decision
def secure_call(
self,
model_fn: Callable[[str], str],
prompt: str,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> FirewallDecision:
"""
Run the full secure pipeline: check → model → output guardrail.
Parameters
----------
model_fn : Callable[[str], str]
Your AI model function.
prompt : str
Raw user prompt.
model_kwargs : dict, optional
Extra kwargs passed to model_fn.
Returns
-------
FirewallDecision
"""
decision = self._guardrails.secure_call(prompt, model_fn, model_kwargs)
if self.raise_on_block and not decision.allowed:
raise FirewallBlockedError(decision)
return decision
def wrap(self, model_fn: Callable[[str], str]) -> Callable[[str], str]:
"""
Decorator / wrapper factory.
Returns a new callable that automatically runs the firewall pipeline
around every call to `model_fn`.
Example
-------
sdk = FirewallSDK()
safe_model = sdk.wrap(my_llm)
response = safe_model("Hello!") # returns safe_output or raises
"""
@functools.wraps(model_fn)
def _secured(prompt: str, **kwargs: Any) -> str:
decision = self.secure_call(model_fn, prompt, model_kwargs=kwargs)
if not decision.allowed:
raise FirewallBlockedError(decision)
return decision.safe_output or ""
return _secured
def get_risk_score(self, prompt: str) -> float:
"""Return only the aggregated risk score (0-1)."""
return self.check(prompt).risk_report.risk_score
def is_safe(self, prompt: str) -> bool:
"""Return True if the prompt passes all security checks."""
return self.check(prompt).allowed
class FirewallBlockedError(Exception):
"""Raised when `raise_on_block=True` and a request is blocked."""
def __init__(self, decision: FirewallDecision) -> None:
self.decision = decision
super().__init__(
f"Request blocked by AI Firewall | "
f"risk_score={decision.risk_report.risk_score:.3f} | "
f"attack_type={decision.risk_report.attack_type}"
)
# ---------------------------------------------------------------------------
# Module-level convenience function
# ---------------------------------------------------------------------------
_default_sdk: Optional[FirewallSDK] = None
def _get_default_sdk() -> FirewallSDK:
global _default_sdk
if _default_sdk is None:
_default_sdk = FirewallSDK()
return _default_sdk
def secure_llm_call(
model_fn: Callable[[str], str],
prompt: str,
firewall: Optional[FirewallSDK] = None,
**model_kwargs: Any,
) -> FirewallDecision:
"""
Top-level convenience function for one-liner integration.
Parameters
----------
model_fn : Callable[[str], str]
Your LLM/AI callable.
prompt : str
The user's prompt.
firewall : FirewallSDK, optional
Custom SDK instance. Uses a shared default instance if not provided.
**model_kwargs
Extra kwargs forwarded to model_fn.
Returns
-------
FirewallDecision
Example
-------
from ai_firewall import secure_llm_call
result = secure_llm_call(my_llm, "What is 2+2?")
print(result.safe_output)
"""
sdk = firewall or _get_default_sdk()
return sdk.secure_call(model_fn, prompt, model_kwargs=model_kwargs or None)
|