| from __future__ import annotations
|
|
|
| from dataclasses import dataclass, asdict
|
| from typing import Any, Dict, List, Optional
|
|
|
| from .plugins.base import get_registry, PluginRegistry
|
| from .plugins.loader import PluginLoader
|
|
|
|
|
| @dataclass
|
| class RiskResult:
|
| """Legacy result structure for backwards compatibility."""
|
| score: float
|
| reasons: List[str]
|
| flags: Dict[str, bool]
|
|
|
|
|
| def risk_result_to_dict(r: RiskResult) -> Dict[str, Any]:
|
| """Convert RiskResult to dict for backwards compatibility."""
|
| return asdict(r)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| _plugin_loader: Optional[PluginLoader] = None
|
| _plugins_initialized: bool = False
|
| _plugin_initialization_time: Optional[float] = None
|
|
|
|
|
| def initialize_plugins() -> PluginRegistry:
|
| """
|
| Initialize the plugin system with caching.
|
|
|
| Loads all built-in plugins:
|
| - JailbreakDetector: Jailbreak attempts and prompt injection
|
| - SSRFDetector: Server-side request forgery attacks
|
| - SQLInjectionDetector: SQL injection patterns
|
| - PathTraversalDetector: Path traversal and sensitive file access
|
| - ExfiltrationDetector: Data leakage and exfiltration intent
|
| - DataTheftDetector: Competitive intelligence and proprietary data attempts
|
| - PayloadSizeDetector: Large payloads and excessive fetch sizes
|
|
|
| Uses caching to avoid reloading plugins that are already loaded.
|
| Call this once at startup.
|
|
|
| Returns:
|
| PluginRegistry instance with all loaded plugins
|
|
|
| Example:
|
| >>> registry = initialize_plugins()
|
| >>> print(f"Loaded {len(registry.get_all_plugins())} plugins")
|
| """
|
| global _plugin_loader, _plugins_initialized, _plugin_initialization_time
|
| import time
|
|
|
| if _plugins_initialized:
|
| if _plugin_initialization_time:
|
| print(f"[RiskModel] Using cached plugin system (initialized in {_plugin_initialization_time:.3f}s)")
|
| return get_registry()
|
|
|
| start_time = time.perf_counter()
|
| try:
|
| _plugin_loader = PluginLoader(get_registry())
|
| count = _plugin_loader.load_and_register_builtin()
|
| _plugin_initialization_time = time.perf_counter() - start_time
|
| cache_status = _plugin_loader.get_cache_status()
|
| print(f"[RiskModel] Initialized plugin system with {count} built-in plugins in {_plugin_initialization_time:.3f}s")
|
| print(f"[RiskModel] Plugin cache: {cache_status['loaded_modules']} modules, {cache_status['loaded_plugins']} plugins")
|
| _plugins_initialized = True
|
| return get_registry()
|
| except Exception as e:
|
| _plugin_initialization_time = time.perf_counter() - start_time
|
| print(f"[RiskModel] Failed to initialize plugins: {e}")
|
| _plugins_initialized = True
|
| return get_registry()
|
|
|
|
|
| def compute_risk(
|
| user_id: Optional[str],
|
| server_key: str,
|
| tool: str,
|
| arguments: Dict[str, Any],
|
| llm_context: Optional[str] = None,
|
| ) -> RiskResult:
|
| """
|
| Compute risk using the plugin-based scanner system.
|
|
|
| This is the PRIMARY risk assessment method. All security checks are performed
|
| by independent, modular plugins that can be enabled/disabled dynamically.
|
|
|
| Plugins currently loaded:
|
| - JailbreakDetector (0.5 for jailbreak, 0.3 for secrets)
|
| - SSRFDetector (0.7 for internal IPs, 0.5 for malicious URLs)
|
| - SQLInjectionDetector (0.35 for SQL patterns)
|
| - PathTraversalDetector (0.35 for traversal, 0.2 for sensitive paths)
|
| - ExfiltrationDetector (0.3 for exfiltration intent, 0.15 for sensitive servers, 0.15 for dangerous tools, +0.4 for fetch)
|
| - PayloadSizeDetector (0.2 for large payloads, 0.2-0.4 for excessive fetch sizes)
|
|
|
| Special handling:
|
| - native/web_search: Low risk, standard plugins apply (legacy native function)
|
| - native/code_interpreter: High risk (0.8 base) + plugin results
|
| - web-search (downstream): Low risk, risk capped at 0.35 (read-only internet access)
|
|
|
| Args:
|
| user_id: Logical user identifier (e.g., 'admin', 'judge-1')
|
| server_key: Downstream server key (e.g., 'fetch', 'filesystem', 'native')
|
| tool: Tool name
|
| arguments: Tool arguments
|
| llm_context: Optional context/prompt
|
|
|
| Returns:
|
| RiskResult with aggregated risk assessment
|
|
|
| Example:
|
| >>> result = compute_risk(
|
| ... user_id="user1",
|
| ... server_key="fetch",
|
| ... tool="fetch_url",
|
| ... arguments={"url": "http://localhost:8080"}
|
| ... )
|
| >>> print(f"Risk score: {result.score}")
|
| """
|
| registry = get_registry()
|
|
|
|
|
| if not registry.get_all_plugins():
|
| initialize_plugins()
|
|
|
|
|
| if server_key == "native":
|
| if tool == "code_interpreter":
|
|
|
|
|
| return RiskResult(
|
| score=0.8,
|
| reasons=[
|
| "Code execution requires explicit user approval and sandboxing",
|
| "Potential for arbitrary code execution"
|
| ],
|
| flags={"code_execution": True, "high_risk": True},
|
| )
|
| elif tool == "web_search":
|
|
|
|
|
| pass
|
|
|
|
|
|
|
|
|
| if server_key == "web-search":
|
|
|
|
|
|
|
| plugin_results = registry.scan_all(
|
| user_id=user_id,
|
| server_key=server_key,
|
| tool=tool,
|
| arguments=arguments,
|
| llm_context=llm_context,
|
| )
|
| aggregated = registry.aggregate_results(plugin_results)
|
|
|
|
|
|
|
|
|
| capped_score = min(aggregated["total_score"], 0.35)
|
|
|
|
|
|
|
|
|
| critical_flags = {
|
| k: v for k, v in aggregated["flags"].items()
|
| if k in {"jailbreak_detected", "jailbreak_like", "prompt_injection", "exfiltration_like", "data_theft_like", "ssrf_attempt", "malicious_url"}
|
| and v
|
| }
|
|
|
|
|
|
|
| critical_reasons = [
|
| r for r in aggregated["reasons"]
|
| if any(keyword in r.lower() for keyword in ["jailbreak", "injection", "prompt", "override", "malicious", "exfiltration", "data theft", "competitive", "ssrf"])
|
| ]
|
|
|
| if critical_flags and not critical_reasons:
|
| critical_reasons = aggregated["reasons"]
|
|
|
| if not critical_reasons:
|
| critical_reasons = aggregated["reasons"]
|
|
|
| return RiskResult(
|
| score=capped_score,
|
| reasons=critical_reasons,
|
| flags=critical_flags,
|
| )
|
|
|
|
|
| plugin_results = registry.scan_all(
|
| user_id=user_id,
|
| server_key=server_key,
|
| tool=tool,
|
| arguments=arguments,
|
| llm_context=llm_context,
|
| )
|
|
|
|
|
| aggregated = registry.aggregate_results(plugin_results)
|
|
|
|
|
| return RiskResult(
|
| score=aggregated["total_score"],
|
| reasons=aggregated["reasons"],
|
| flags=aggregated["flags"],
|
| )
|
|
|
|
|
| def compute_risk_detailed(
|
| user_id: Optional[str],
|
| server_key: str,
|
| tool: str,
|
| arguments: Dict[str, Any],
|
| llm_context: Optional[str] = None,
|
| ) -> Dict[str, Any]:
|
| """
|
| Compute risk using the plugin-based scanner system with detailed results.
|
|
|
| Returns individual plugin results in addition to aggregated assessment.
|
| Use this for debugging and understanding which plugins detected threats.
|
|
|
| Args:
|
| user_id: Logical user identifier
|
| server_key: Downstream server key
|
| tool: Tool name
|
| arguments: Tool arguments
|
| llm_context: Optional context
|
|
|
| Returns:
|
| Dict with aggregated and detailed risk assessment:
|
| - total_score: Combined risk score (0.0-1.0)
|
| - reasons: List of all detected reasons
|
| - flags: Combined flags from all plugins
|
| - plugin_results: Dict mapping plugin name -> ScanResult
|
| - plugin_count: Number of plugins executed
|
| - threat_count: Number of plugins that detected threats
|
| - detected_threats: List of plugin names that detected threats
|
|
|
| Example:
|
| >>> result = compute_risk_detailed(
|
| ... user_id="user1",
|
| ... server_key="fetch",
|
| ... tool="fetch_url",
|
| ... arguments={"url": "http://localhost:8080"}
|
| ... )
|
| >>> print(f"Detected threats: {result['detected_threats']}")
|
| >>> for name, plugin_result in result['plugin_results'].items():
|
| ... print(f"{name}: {plugin_result['detected']}")
|
| """
|
| registry = get_registry()
|
|
|
|
|
| if not registry.get_all_plugins():
|
| initialize_plugins()
|
|
|
|
|
| plugin_results = registry.scan_all(
|
| user_id=user_id,
|
| server_key=server_key,
|
| tool=tool,
|
| arguments=arguments,
|
| llm_context=llm_context,
|
| )
|
|
|
|
|
| aggregated = registry.aggregate_results(plugin_results)
|
|
|
|
|
| return {
|
| "total_score": aggregated["total_score"],
|
| "reasons": aggregated["reasons"],
|
| "flags": aggregated["flags"],
|
| "plugin_results": {
|
| name: result.to_dict()
|
| for name, result in plugin_results.items()
|
| },
|
| "plugin_count": aggregated["plugin_count"],
|
| "threat_count": aggregated["threat_count"],
|
| "detected_threats": aggregated["detected_threats"],
|
| }
|
|
|
|
|
| def get_plugin_loader() -> Optional[PluginLoader]:
|
| """
|
| Get the global plugin loader instance.
|
|
|
| Initialize plugins first with initialize_plugins() if needed.
|
|
|
| Returns:
|
| PluginLoader instance if plugins initialized, None otherwise
|
| """
|
| return _plugin_loader
|
|
|
|
|
| def get_plugin_cache_status() -> Dict[str, Any]:
|
| """
|
| Get current plugin cache status and statistics.
|
|
|
| Returns:
|
| Dict with cache metrics including:
|
| - loaded_modules: Number of cached modules
|
| - loaded_plugins: Number of cached plugin instances
|
| - builtin_loaded: Whether builtin plugins are loaded
|
| - registered_plugins: Total registered plugins
|
| - enabled_plugins: Number of enabled plugins
|
| - initialization_time: Time taken to initialize plugins
|
| """
|
| if not _plugin_loader:
|
| return {"error": "Plugins not initialized. Call initialize_plugins() first."}
|
|
|
| status = _plugin_loader.get_cache_status()
|
| status["initialization_time_ms"] = _plugin_initialization_time * 1000 if _plugin_initialization_time else None
|
| return status
|
|
|