yuki-sui's picture
Upload 169 files
ed71b0e verified
from __future__ import annotations
from dataclasses import dataclass, asdict
from typing import Any, Dict, List, Optional
from .plugins.base import get_registry, PluginRegistry
from .plugins.loader import PluginLoader
@dataclass
class RiskResult:
"""Legacy result structure for backwards compatibility."""
score: float
reasons: List[str]
flags: Dict[str, bool]
def risk_result_to_dict(r: RiskResult) -> Dict[str, Any]:
"""Convert RiskResult to dict for backwards compatibility."""
return asdict(r)
# ============================================================================
# PLUGIN SYSTEM - PRIMARY RISK ASSESSMENT ENGINE
# ============================================================================
# All risk assessment is now handled by the plugin system.
# Plugins are automatically loaded and registered at startup.
_plugin_loader: Optional[PluginLoader] = None
_plugins_initialized: bool = False
_plugin_initialization_time: Optional[float] = None
def initialize_plugins() -> PluginRegistry:
"""
Initialize the plugin system with caching.
Loads all built-in plugins:
- JailbreakDetector: Jailbreak attempts and prompt injection
- SSRFDetector: Server-side request forgery attacks
- SQLInjectionDetector: SQL injection patterns
- PathTraversalDetector: Path traversal and sensitive file access
- ExfiltrationDetector: Data leakage and exfiltration intent
- DataTheftDetector: Competitive intelligence and proprietary data attempts
- PayloadSizeDetector: Large payloads and excessive fetch sizes
Uses caching to avoid reloading plugins that are already loaded.
Call this once at startup.
Returns:
PluginRegistry instance with all loaded plugins
Example:
>>> registry = initialize_plugins()
>>> print(f"Loaded {len(registry.get_all_plugins())} plugins")
"""
global _plugin_loader, _plugins_initialized, _plugin_initialization_time
import time
if _plugins_initialized:
if _plugin_initialization_time:
print(f"[RiskModel] Using cached plugin system (initialized in {_plugin_initialization_time:.3f}s)")
return get_registry()
start_time = time.perf_counter()
try:
_plugin_loader = PluginLoader(get_registry())
count = _plugin_loader.load_and_register_builtin()
_plugin_initialization_time = time.perf_counter() - start_time
cache_status = _plugin_loader.get_cache_status()
print(f"[RiskModel] Initialized plugin system with {count} built-in plugins in {_plugin_initialization_time:.3f}s")
print(f"[RiskModel] Plugin cache: {cache_status['loaded_modules']} modules, {cache_status['loaded_plugins']} plugins")
_plugins_initialized = True
return get_registry()
except Exception as e:
_plugin_initialization_time = time.perf_counter() - start_time
print(f"[RiskModel] Failed to initialize plugins: {e}")
_plugins_initialized = True # Mark as attempted to prevent retry loops
return get_registry()
def compute_risk(
user_id: Optional[str],
server_key: str,
tool: str,
arguments: Dict[str, Any],
llm_context: Optional[str] = None,
) -> RiskResult:
"""
Compute risk using the plugin-based scanner system.
This is the PRIMARY risk assessment method. All security checks are performed
by independent, modular plugins that can be enabled/disabled dynamically.
Plugins currently loaded:
- JailbreakDetector (0.5 for jailbreak, 0.3 for secrets)
- SSRFDetector (0.7 for internal IPs, 0.5 for malicious URLs)
- SQLInjectionDetector (0.35 for SQL patterns)
- PathTraversalDetector (0.35 for traversal, 0.2 for sensitive paths)
- ExfiltrationDetector (0.3 for exfiltration intent, 0.15 for sensitive servers, 0.15 for dangerous tools, +0.4 for fetch)
- PayloadSizeDetector (0.2 for large payloads, 0.2-0.4 for excessive fetch sizes)
Special handling:
- native/web_search: Low risk, standard plugins apply (legacy native function)
- native/code_interpreter: High risk (0.8 base) + plugin results
- web-search (downstream): Low risk, risk capped at 0.35 (read-only internet access)
Args:
user_id: Logical user identifier (e.g., 'admin', 'judge-1')
server_key: Downstream server key (e.g., 'fetch', 'filesystem', 'native')
tool: Tool name
arguments: Tool arguments
llm_context: Optional context/prompt
Returns:
RiskResult with aggregated risk assessment
Example:
>>> result = compute_risk(
... user_id="user1",
... server_key="fetch",
... tool="fetch_url",
... arguments={"url": "http://localhost:8080"}
... )
>>> print(f"Risk score: {result.score}")
"""
registry = get_registry()
# If plugins haven't been loaded yet, initialize them
if not registry.get_all_plugins():
initialize_plugins()
# Special handling for native functions and low-risk downstream servers
if server_key == "native":
if tool == "code_interpreter":
# Code execution is inherently high-risk
# Starts with 0.8, then plugins add on top
return RiskResult(
score=0.8, # High risk by default
reasons=[
"Code execution requires explicit user approval and sandboxing",
"Potential for arbitrary code execution"
],
flags={"code_execution": True, "high_risk": True},
)
elif tool == "web_search":
# Web search is low-risk (read-only internet access)
# Still run plugins for jailbreak/injection detection
pass # Continue to normal plugin processing
# Special handling for web-search downstream server (read-only internet access)
# Since web-search was migrated from native to downstream MCP, we need to
# apply the same low-risk handling here
if server_key == "web-search":
# Web search is low-risk (read-only internet access, no file access)
# Run plugins for jailbreak/injection detection, but don't let SSRF/exfiltration
# plugins over-penalize a tool designed to make external requests
plugin_results = registry.scan_all(
user_id=user_id,
server_key=server_key,
tool=tool,
arguments=arguments,
llm_context=llm_context,
)
aggregated = registry.aggregate_results(plugin_results)
# Cap the risk score for web-search: even if plugins detect issues,
# web-search is inherently a low-risk read-only tool
# Only block on critical jailbreak/injection attempts, not on SSRF/exfiltration flags
capped_score = min(aggregated["total_score"], 0.35) # Max medium-risk threshold
# Only keep REAL critical flags that should block web-search
# Include SSRF and malicious URL detection - these should always block web-search
# Only block on actual jailbreak/injection/exfiltration/data_theft attempts
critical_flags = {
k: v for k, v in aggregated["flags"].items()
if k in {"jailbreak_detected", "jailbreak_like", "prompt_injection", "exfiltration_like", "data_theft_like", "ssrf_attempt", "malicious_url"}
and v # Only include flags that are True
}
# Filter reasons to show only critical security issues for web-search
# Include jailbreak/injection/exfiltration/data_theft/SSRF/malicious URL reasons
critical_reasons = [
r for r in aggregated["reasons"]
if any(keyword in r.lower() for keyword in ["jailbreak", "injection", "prompt", "override", "malicious", "exfiltration", "data theft", "competitive", "ssrf"])
]
# If we have critical flags but no critical reasons, use all reasons
if critical_flags and not critical_reasons:
critical_reasons = aggregated["reasons"]
# If no critical reasons, use all (for transparency)
if not critical_reasons:
critical_reasons = aggregated["reasons"]
return RiskResult(
score=capped_score,
reasons=critical_reasons,
flags=critical_flags,
)
# Run all enabled plugins
plugin_results = registry.scan_all(
user_id=user_id,
server_key=server_key,
tool=tool,
arguments=arguments,
llm_context=llm_context,
)
# Aggregate results
aggregated = registry.aggregate_results(plugin_results)
# Convert to legacy RiskResult format for compatibility
return RiskResult(
score=aggregated["total_score"],
reasons=aggregated["reasons"],
flags=aggregated["flags"],
)
def compute_risk_detailed(
user_id: Optional[str],
server_key: str,
tool: str,
arguments: Dict[str, Any],
llm_context: Optional[str] = None,
) -> Dict[str, Any]:
"""
Compute risk using the plugin-based scanner system with detailed results.
Returns individual plugin results in addition to aggregated assessment.
Use this for debugging and understanding which plugins detected threats.
Args:
user_id: Logical user identifier
server_key: Downstream server key
tool: Tool name
arguments: Tool arguments
llm_context: Optional context
Returns:
Dict with aggregated and detailed risk assessment:
- total_score: Combined risk score (0.0-1.0)
- reasons: List of all detected reasons
- flags: Combined flags from all plugins
- plugin_results: Dict mapping plugin name -> ScanResult
- plugin_count: Number of plugins executed
- threat_count: Number of plugins that detected threats
- detected_threats: List of plugin names that detected threats
Example:
>>> result = compute_risk_detailed(
... user_id="user1",
... server_key="fetch",
... tool="fetch_url",
... arguments={"url": "http://localhost:8080"}
... )
>>> print(f"Detected threats: {result['detected_threats']}")
>>> for name, plugin_result in result['plugin_results'].items():
... print(f"{name}: {plugin_result['detected']}")
"""
registry = get_registry()
# If plugins haven't been loaded yet, initialize them
if not registry.get_all_plugins():
initialize_plugins()
# Run all enabled plugins
plugin_results = registry.scan_all(
user_id=user_id,
server_key=server_key,
tool=tool,
arguments=arguments,
llm_context=llm_context,
)
# Aggregate results
aggregated = registry.aggregate_results(plugin_results)
# Return enriched result with individual plugin data
return {
"total_score": aggregated["total_score"],
"reasons": aggregated["reasons"],
"flags": aggregated["flags"],
"plugin_results": {
name: result.to_dict()
for name, result in plugin_results.items()
},
"plugin_count": aggregated["plugin_count"],
"threat_count": aggregated["threat_count"],
"detected_threats": aggregated["detected_threats"],
}
def get_plugin_loader() -> Optional[PluginLoader]:
"""
Get the global plugin loader instance.
Initialize plugins first with initialize_plugins() if needed.
Returns:
PluginLoader instance if plugins initialized, None otherwise
"""
return _plugin_loader
def get_plugin_cache_status() -> Dict[str, Any]:
"""
Get current plugin cache status and statistics.
Returns:
Dict with cache metrics including:
- loaded_modules: Number of cached modules
- loaded_plugins: Number of cached plugin instances
- builtin_loaded: Whether builtin plugins are loaded
- registered_plugins: Total registered plugins
- enabled_plugins: Number of enabled plugins
- initialization_time: Time taken to initialize plugins
"""
if not _plugin_loader:
return {"error": "Plugins not initialized. Call initialize_plugins() first."}
status = _plugin_loader.get_cache_status()
status["initialization_time_ms"] = _plugin_initialization_time * 1000 if _plugin_initialization_time else None
return status