yuki-sui's picture
Upload 169 files
ed71b0e verified
"""
Base classes for security scanner plugins.
Defines the interface that all plugins must implement and the registry for managing them.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Optional, Callable
import re
@dataclass
class PluginMetadata:
"""Metadata about a plugin."""
name: str
version: str = "1.0.0"
description: str = ""
author: str = ""
enabled: bool = True
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@dataclass
class ScanResult:
"""Result of a single scanner plugin execution."""
plugin_name: str
detected: bool
risk_score: float = 0.0
reasons: List[str] = field(default_factory=list)
flags: Dict[str, bool] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
class ScannerPlugin(ABC):
"""
Abstract base class for all security scanner plugins.
Plugins scan tool calls, arguments, and context for security threats.
Each plugin is responsible for detecting a specific class of vulnerabilities.
"""
def __init__(self, metadata: Optional[PluginMetadata] = None):
"""
Initialize plugin with optional metadata.
Args:
metadata: Plugin metadata (name, version, etc). If not provided,
subclass should implement get_metadata().
"""
self._metadata = metadata
self._enabled = True
def get_metadata(self) -> PluginMetadata:
"""
Get plugin metadata.
Returns:
PluginMetadata with plugin name, version, description, etc.
"""
if self._metadata:
return self._metadata
return PluginMetadata(
name=self.__class__.__name__,
description=self.__class__.__doc__ or "",
)
def set_enabled(self, enabled: bool) -> None:
"""Enable or disable this plugin."""
self._enabled = enabled
def is_enabled(self) -> bool:
"""Check if plugin is enabled."""
return self._enabled
@abstractmethod
def scan(
self,
user_id: Optional[str],
server_key: str,
tool: str,
arguments: Dict[str, Any],
llm_context: Optional[str] = None,
) -> ScanResult:
"""
Scan a tool call for security threats.
Args:
user_id: Logical user identifier (e.g., 'admin', 'judge-1')
server_key: Downstream server key (e.g., 'filesystem', 'fetch')
tool: Tool name on the downstream server
arguments: Arguments passed to the tool
llm_context: Optional prompt or reasoning context
Returns:
ScanResult with detection status, risk score, and reasons
"""
pass
class PluginRegistry:
"""
Central registry for managing security scanner plugins.
Handles plugin registration, discovery, and execution.
Provides a single point of access for all plugins.
"""
def __init__(self):
"""Initialize empty registry."""
self._plugins: Dict[str, ScannerPlugin] = {}
self._metadata: Dict[str, PluginMetadata] = {}
def register(self, plugin: ScannerPlugin) -> None:
"""
Register a plugin.
Args:
plugin: ScannerPlugin instance to register
Raises:
ValueError: If plugin with same name already registered
"""
metadata = plugin.get_metadata()
name = metadata.name
if name in self._plugins:
raise ValueError(f"Plugin '{name}' is already registered")
self._plugins[name] = plugin
self._metadata[name] = metadata
def unregister(self, plugin_name: str) -> bool:
"""
Unregister a plugin by name.
Args:
plugin_name: Name of plugin to remove
Returns:
True if plugin was removed, False if not found
"""
if plugin_name in self._plugins:
del self._plugins[plugin_name]
del self._metadata[plugin_name]
return True
return False
def get_plugin(self, plugin_name: str) -> Optional[ScannerPlugin]:
"""Get a plugin by name."""
return self._plugins.get(plugin_name)
def get_all_plugins(self) -> Dict[str, ScannerPlugin]:
"""Get all registered plugins."""
return self._plugins.copy()
def get_enabled_plugins(self) -> Dict[str, ScannerPlugin]:
"""Get only enabled plugins."""
return {
name: plugin
for name, plugin in self._plugins.items()
if plugin.is_enabled()
}
def get_metadata(self, plugin_name: str) -> Optional[PluginMetadata]:
"""Get metadata for a plugin."""
return self._metadata.get(plugin_name)
def get_all_metadata(self) -> Dict[str, PluginMetadata]:
"""Get metadata for all registered plugins."""
return self._metadata.copy()
def enable_plugin(self, plugin_name: str) -> bool:
"""
Enable a plugin.
Returns:
True if enabled, False if plugin not found
"""
plugin = self._plugins.get(plugin_name)
if plugin:
plugin.set_enabled(True)
return True
return False
def disable_plugin(self, plugin_name: str) -> bool:
"""
Disable a plugin.
Returns:
True if disabled, False if plugin not found
"""
plugin = self._plugins.get(plugin_name)
if plugin:
plugin.set_enabled(False)
return True
return False
def scan_all(
self,
user_id: Optional[str],
server_key: str,
tool: str,
arguments: Dict[str, Any],
llm_context: Optional[str] = None,
) -> Dict[str, ScanResult]:
"""
Run all enabled plugins against a tool call.
Args:
user_id: Logical user identifier
server_key: Downstream server key
tool: Tool name
arguments: Tool arguments
llm_context: Optional context
Returns:
Dict mapping plugin name -> ScanResult
"""
results = {}
for name, plugin in self.get_enabled_plugins().items():
try:
result = plugin.scan(
user_id=user_id,
server_key=server_key,
tool=tool,
arguments=arguments,
llm_context=llm_context,
)
results[name] = result
except Exception as e:
# Log failure but don't crash; return failed scan
results[name] = ScanResult(
plugin_name=name,
detected=False,
risk_score=0.0,
reasons=[f"Plugin execution error: {str(e)}"],
flags={"plugin_error": True},
)
return results
def aggregate_results(self, results: Dict[str, ScanResult]) -> Dict[str, Any]:
"""
Aggregate scan results across all plugins.
Combines scores, reasons, and flags for a unified threat assessment.
Args:
results: Dict from scan_all()
Returns:
Aggregated results with combined score, all reasons, etc.
"""
total_score = 0.0
all_reasons = []
all_flags = {}
detected_threats = []
for plugin_name, result in results.items():
if result.detected:
detected_threats.append(plugin_name)
total_score += result.risk_score
all_reasons.extend(result.reasons)
all_flags.update(result.flags)
# Cap score at 1.0
total_score = min(1.0, total_score)
return {
"total_score": total_score,
"reasons": all_reasons,
"flags": all_flags,
"detected_threats": detected_threats,
"plugin_count": len(results),
"threat_count": len(detected_threats),
}
# Global registry instance
_global_registry: Optional[PluginRegistry] = None
def get_registry() -> PluginRegistry:
"""Get the global plugin registry."""
global _global_registry
if _global_registry is None:
_global_registry = PluginRegistry()
return _global_registry
def set_registry(registry: PluginRegistry) -> None:
"""Set the global plugin registry (for testing)."""
global _global_registry
_global_registry = registry