File size: 9,051 Bytes
ed71b0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
"""

Base classes for security scanner plugins.



Defines the interface that all plugins must implement and the registry for managing them.

"""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Optional, Callable
import re


@dataclass
class PluginMetadata:
    """Metadata about a plugin."""
    name: str
    version: str = "1.0.0"
    description: str = ""
    author: str = ""
    enabled: bool = True

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


@dataclass
class ScanResult:
    """Result of a single scanner plugin execution."""
    plugin_name: str
    detected: bool
    risk_score: float = 0.0
    reasons: List[str] = field(default_factory=list)
    flags: Dict[str, bool] = field(default_factory=dict)
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)


class ScannerPlugin(ABC):
    """

    Abstract base class for all security scanner plugins.



    Plugins scan tool calls, arguments, and context for security threats.

    Each plugin is responsible for detecting a specific class of vulnerabilities.

    """

    def __init__(self, metadata: Optional[PluginMetadata] = None):
        """

        Initialize plugin with optional metadata.



        Args:

            metadata: Plugin metadata (name, version, etc). If not provided,

                     subclass should implement get_metadata().

        """
        self._metadata = metadata
        self._enabled = True

    def get_metadata(self) -> PluginMetadata:
        """

        Get plugin metadata.



        Returns:

            PluginMetadata with plugin name, version, description, etc.

        """
        if self._metadata:
            return self._metadata
        return PluginMetadata(
            name=self.__class__.__name__,
            description=self.__class__.__doc__ or "",
        )

    def set_enabled(self, enabled: bool) -> None:
        """Enable or disable this plugin."""
        self._enabled = enabled

    def is_enabled(self) -> bool:
        """Check if plugin is enabled."""
        return self._enabled

    @abstractmethod
    def scan(

        self,

        user_id: Optional[str],

        server_key: str,

        tool: str,

        arguments: Dict[str, Any],

        llm_context: Optional[str] = None,

    ) -> ScanResult:
        """

        Scan a tool call for security threats.



        Args:

            user_id: Logical user identifier (e.g., 'admin', 'judge-1')

            server_key: Downstream server key (e.g., 'filesystem', 'fetch')

            tool: Tool name on the downstream server

            arguments: Arguments passed to the tool

            llm_context: Optional prompt or reasoning context



        Returns:

            ScanResult with detection status, risk score, and reasons

        """
        pass


class PluginRegistry:
    """

    Central registry for managing security scanner plugins.



    Handles plugin registration, discovery, and execution.

    Provides a single point of access for all plugins.

    """

    def __init__(self):
        """Initialize empty registry."""
        self._plugins: Dict[str, ScannerPlugin] = {}
        self._metadata: Dict[str, PluginMetadata] = {}

    def register(self, plugin: ScannerPlugin) -> None:
        """

        Register a plugin.



        Args:

            plugin: ScannerPlugin instance to register



        Raises:

            ValueError: If plugin with same name already registered

        """
        metadata = plugin.get_metadata()
        name = metadata.name

        if name in self._plugins:
            raise ValueError(f"Plugin '{name}' is already registered")

        self._plugins[name] = plugin
        self._metadata[name] = metadata

    def unregister(self, plugin_name: str) -> bool:
        """

        Unregister a plugin by name.



        Args:

            plugin_name: Name of plugin to remove



        Returns:

            True if plugin was removed, False if not found

        """
        if plugin_name in self._plugins:
            del self._plugins[plugin_name]
            del self._metadata[plugin_name]
            return True
        return False

    def get_plugin(self, plugin_name: str) -> Optional[ScannerPlugin]:
        """Get a plugin by name."""
        return self._plugins.get(plugin_name)

    def get_all_plugins(self) -> Dict[str, ScannerPlugin]:
        """Get all registered plugins."""
        return self._plugins.copy()

    def get_enabled_plugins(self) -> Dict[str, ScannerPlugin]:
        """Get only enabled plugins."""
        return {
            name: plugin
            for name, plugin in self._plugins.items()
            if plugin.is_enabled()
        }

    def get_metadata(self, plugin_name: str) -> Optional[PluginMetadata]:
        """Get metadata for a plugin."""
        return self._metadata.get(plugin_name)

    def get_all_metadata(self) -> Dict[str, PluginMetadata]:
        """Get metadata for all registered plugins."""
        return self._metadata.copy()

    def enable_plugin(self, plugin_name: str) -> bool:
        """

        Enable a plugin.



        Returns:

            True if enabled, False if plugin not found

        """
        plugin = self._plugins.get(plugin_name)
        if plugin:
            plugin.set_enabled(True)
            return True
        return False

    def disable_plugin(self, plugin_name: str) -> bool:
        """

        Disable a plugin.



        Returns:

            True if disabled, False if plugin not found

        """
        plugin = self._plugins.get(plugin_name)
        if plugin:
            plugin.set_enabled(False)
            return True
        return False

    def scan_all(

        self,

        user_id: Optional[str],

        server_key: str,

        tool: str,

        arguments: Dict[str, Any],

        llm_context: Optional[str] = None,

    ) -> Dict[str, ScanResult]:
        """

        Run all enabled plugins against a tool call.



        Args:

            user_id: Logical user identifier

            server_key: Downstream server key

            tool: Tool name

            arguments: Tool arguments

            llm_context: Optional context



        Returns:

            Dict mapping plugin name -> ScanResult

        """
        results = {}
        for name, plugin in self.get_enabled_plugins().items():
            try:
                result = plugin.scan(
                    user_id=user_id,
                    server_key=server_key,
                    tool=tool,
                    arguments=arguments,
                    llm_context=llm_context,
                )
                results[name] = result
            except Exception as e:
                # Log failure but don't crash; return failed scan
                results[name] = ScanResult(
                    plugin_name=name,
                    detected=False,
                    risk_score=0.0,
                    reasons=[f"Plugin execution error: {str(e)}"],
                    flags={"plugin_error": True},
                )
        return results

    def aggregate_results(self, results: Dict[str, ScanResult]) -> Dict[str, Any]:
        """

        Aggregate scan results across all plugins.



        Combines scores, reasons, and flags for a unified threat assessment.



        Args:

            results: Dict from scan_all()



        Returns:

            Aggregated results with combined score, all reasons, etc.

        """
        total_score = 0.0
        all_reasons = []
        all_flags = {}
        detected_threats = []

        for plugin_name, result in results.items():
            if result.detected:
                detected_threats.append(plugin_name)

            total_score += result.risk_score
            all_reasons.extend(result.reasons)
            all_flags.update(result.flags)

        # Cap score at 1.0
        total_score = min(1.0, total_score)

        return {
            "total_score": total_score,
            "reasons": all_reasons,
            "flags": all_flags,
            "detected_threats": detected_threats,
            "plugin_count": len(results),
            "threat_count": len(detected_threats),
        }


# Global registry instance
_global_registry: Optional[PluginRegistry] = None


def get_registry() -> PluginRegistry:
    """Get the global plugin registry."""
    global _global_registry
    if _global_registry is None:
        _global_registry = PluginRegistry()
    return _global_registry


def set_registry(registry: PluginRegistry) -> None:
    """Set the global plugin registry (for testing)."""
    global _global_registry
    _global_registry = registry