File size: 12,222 Bytes
399b80c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
"""
BaseTool.
All pre-defined grounding atomic operations will inherit this tool class.
RemoteTool needs to pass in connector.
"""
import asyncio, time, inspect
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, ClassVar, Dict, Optional, TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field, create_model

from ..types import BackendType, ToolResult, ToolSchema, ToolStatus
from ..exceptions import GroundingError, ErrorCode
from openspace.utils.logging import Logger
import jsonschema

if TYPE_CHECKING:
    from ..grounding_client import GroundingClient

logger = Logger.get_logger(__name__)


class ToolRuntimeInfo:
    """Runtime information for a tool instance"""
    def __init__(
        self,
        backend: BackendType,
        session_name: str,
        server_name: Optional[str] = None,
        grounding_client: Optional['GroundingClient'] = None,
    ):
        self.backend = backend
        self.session_name = session_name
        self.server_name = server_name
        self.grounding_client = grounding_client
    
    def __repr__(self):
        return f"<ToolRuntimeInfo backend={self.backend.value} session={self.session_name}>"
    

class BaseTool(ABC):
    _name: ClassVar[str] = ""
    _description: ClassVar[str] = ""
    backend_type: ClassVar[BackendType] = BackendType.NOT_SET

    def __init__(self,
                 schema: Optional[ToolSchema] = None,
                 *,
                 verbose: bool = False,
                 handle_errors: bool = True) -> None:
        self.verbose = verbose
        self.handle_errors = handle_errors
        self.schema: ToolSchema = schema or ToolSchema(
            name=self._name or self.__class__.__name__.lower(),
            description=self._description,
            parameters=self.get_parameters_schema(),
            backend_type=self.backend_type,
        )
        
        self._runtime_info: Optional[ToolRuntimeInfo] = None
        self._disable_outer_recording = True
    
    @property
    def name(self) -> str:
        """Get tool name from schema (supports both class-defined and runtime-injected names)"""
        return self.schema.name if hasattr(self, 'schema') and self.schema else self._name
    
    @property
    def description(self) -> str:
        """Get tool description from schema (supports both class-defined and runtime-injected descriptions)"""
        return self.schema.description if hasattr(self, 'schema') and self.schema else self._description

    @classmethod
    @lru_cache
    def get_parameters_schema(cls) -> Dict[str, Any]:
        """Auto-generate JSON-schema from _run() or _arun() signature.
        
        Returns empty dict for tools with no parameters.
        Priority: prefer _arun if overridden, otherwise use _run.
        """
        # Priority: prefer _arun if it's overridden by subclass, else use _run
        # This allows async-first tools to define their signature via _arun
        sig_src = None
        
        # Check if _arun is overridden (not from BaseTool)
        if cls._arun is not BaseTool._arun:
            sig_src = cls._arun
        # Otherwise check if _run is overridden
        elif cls._run is not BaseTool._run:
            sig_src = cls._run
        # If neither is overridden, raise error
        else:
            raise ValueError(
                f"{cls.__name__} must implement _run() or _arun() to define its parameters schema"
            )
        
        sig = inspect.signature(sig_src)
        fields: dict[str, Any] = {}
        for name, p in sig.parameters.items():
            # Skip 'self' and **kwargs / *args
            if name == "self" or p.kind in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL):
                continue
            typ = p.annotation if p.annotation is not inspect._empty else str
            default = p.default if p.default is not inspect._empty else ...
            fields[name] = (typ, Field(default))
        
        if not fields:
            return {}
        
        PModel: type[BaseModel] = create_model(
            f"{cls.__name__}Params",
            __config__=ConfigDict(arbitrary_types_allowed=True),
            **fields
        )
        return PModel.model_json_schema()

    def validate_parameters(self, params: Dict[str, Any]) -> None:
        try:
            self.schema.validate_parameters(params, raise_exc=True)
        except jsonschema.ValidationError as ve:
            raise GroundingError(
                f"Invalid parameters: {ve.message}",
                code=ErrorCode.TOOL_EXECUTION_FAIL,
                tool_name=self.schema.name,
            ) from ve

    def run(self, **kwargs):
        try:
            return asyncio.run(self.invoke(**kwargs))
        except RuntimeError:                     # already in running loop
            loop = asyncio.get_running_loop()
            return loop.create_task(self.invoke(**kwargs))

    def __call__(self, **kwargs):
        return self.run(**kwargs)

    async def __acall__(self, **kwargs):
        return await self.arun(**kwargs)
    
    async def arun(self, **kwargs) -> ToolResult:
        start = time.time()
        try:
            self.validate_parameters(kwargs)
            raw = await self._arun(**kwargs)
            result = self._wrap_result(raw, time.time() - start)
        except Exception as e:
            if self.handle_errors:
                result = ToolResult(
                    status=ToolStatus.ERROR,
                    error=str(e),
                    metadata={"tool": self.schema.name},
                )
            else:
                raise
        
        await self._auto_record_execution(kwargs, result, time.time() - start)
        return result

    # to be implemented by subclasses
    @abstractmethod
    async def _arun(self, **kwargs): ...
    
    def bind_runtime_info(
        self,
        backend: BackendType,
        session_name: str,
        server_name: Optional[str] = None,
        grounding_client: Optional['GroundingClient'] = None,
    ) -> 'BaseTool':
        """
        Bind runtime information to the tool instance.
        Allow the tool to be invoked directly without specifying backend/session/server.
        
        Args:
            backend: Backend type
            session_name: Session name
            server_name: Server name (for MCP)
            grounding_client: Optional reference to GroundingClient for direct invocation
        """
        self._runtime_info = ToolRuntimeInfo(
            backend=backend,
            session_name=session_name,
            server_name=server_name,
            grounding_client=grounding_client,
        )
        return self
    
    @property
    def runtime_info(self) -> Optional['ToolRuntimeInfo']:
        """Get runtime information if bound"""
        return self._runtime_info
    
    @property
    def is_bound(self) -> bool:
        """Check if tool has runtime information bound"""
        return self._runtime_info is not None
    
    async def invoke(
        self, 
        parameters: Dict[str, Any] | None = None, 
        keep_session: bool = True,
        **kwargs
    ) -> ToolResult:
        """
        Invoke this tool using bound runtime information.
        Requires runtime info to be bound via bind_runtime_info().
        If no runtime info is bound, the tool will be executed locally.   
        """
        params = parameters or kwargs

        if self.is_bound and self._runtime_info.grounding_client:
            return await self._runtime_info.grounding_client.invoke_tool(
                tool=self,
                parameters=params,
                keep_session=keep_session,
            )

        return await self.arun(**params)

    def _wrap_result(self, obj: Any, elapsed: float) -> ToolResult:
        if isinstance(obj, ToolResult):
            obj.execution_time = elapsed
            return obj
        if self.verbose:
            logger.debug("[%s] done in %.2f s", self.schema.name, elapsed)
        if isinstance(obj, (bytes, bytearray)):
            obj = obj.decode("utf-8", errors="replace")
        return ToolResult(
            status=ToolStatus.SUCCESS,
            content=str(obj),
            execution_time=elapsed,
            metadata={"tool": self.schema.name},
        )
    
    async def _auto_record_execution(
        self,
        parameters: Dict[str, Any],
        result: ToolResult,
        execution_time: float,
    ):
        """Auto-record tool execution to recording manager and quality manager."""
        # Record to quality manager (for quality tracking)
        await self._record_to_quality_manager(result, execution_time * 1000)
        
        # Record to recording manager (for trajectory recording)
        try:
            from openspace.recording import RecordingManager
            
            if not RecordingManager.is_recording():
                return
            
            # Check if tool has disabled outer recording (e.g., GUI agent with intermediate steps)
            if hasattr(self, '_disable_outer_recording') and self._disable_outer_recording:
                logger.debug(f"Skipping outer recording for {self.schema.name} (intermediate steps recorded)")
                return
            
            # Get backend and server_name from runtime_info (if bound)
            backend = self.backend_type.value
            server_name = None
            
            if self.is_bound and self._runtime_info:
                # Prefer runtime_info information (more accurate)
                backend = self._runtime_info.backend.value
                server_name = self._runtime_info.server_name
            
            # Get screenshot (if GUI backend)
            screenshot = None
            if self.backend_type == BackendType.GUI and hasattr(self, 'connector'):
                try:
                    screenshot = await self.connector.get_screenshot()
                except Exception as e:
                    logger.debug(f"Failed to capture screenshot: {e}")
            
            # Record tool execution with complete runtime information
            await RecordingManager.record_tool_execution(
                tool_name=self.schema.name,
                backend=backend,
                parameters=parameters,
                result=result.content,
                server_name=server_name,
                is_success=result.is_success,  # Pass actual success status from ToolResult
            )
        except Exception as e:
            logger.warning(f"Failed to auto-record tool execution for {self.schema.name}: {e}")
    
    async def _record_to_quality_manager(
        self,
        result: ToolResult,
        execution_time_ms: float,
    ):
        """Record execution result to quality manager for quality tracking."""
        try:
            from openspace.grounding.core.quality import get_quality_manager
            
            manager = get_quality_manager()
            if manager:
                await manager.record_execution(self, result, execution_time_ms)
        except Exception as e:
            # Quality recording failure should not affect tool execution
            logger.debug(f"Failed to record to quality manager: {e}")

    # keep _run for backward-compatibility / thread-pool fallback
    def _run(self, **kwargs):
        raise NotImplementedError

    def __repr__(self):
        base = f"<Tool {self.schema.name} ({self.backend_type.value})"
        if self.is_bound:
            base += f" @ {self._runtime_info.session_name}"
        return base + ">"

    def __init_subclass__(cls, **kwargs):
        """
        - at least implement _run or _arun
        - backend_type is NOT_SET, only give a warning, allow RemoteTool to inject at runtime
        """
        super().__init_subclass__(**kwargs)

        if cls._arun is BaseTool._arun and cls._run is BaseTool._run:
            raise ValueError(f"{cls.__name__} must implement _run() or _arun()")

        if cls.backend_type is BackendType.NOT_SET:
            logger.debug(
                "%s.backend_type is NOT_SET; remember to override or set at runtime.",
                cls.__name__,
            )