Spaces:
Running
Running
File size: 12,222 Bytes
399b80c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | """
BaseTool.
All pre-defined grounding atomic operations will inherit this tool class.
RemoteTool needs to pass in connector.
"""
import asyncio, time, inspect
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, ClassVar, Dict, Optional, TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field, create_model
from ..types import BackendType, ToolResult, ToolSchema, ToolStatus
from ..exceptions import GroundingError, ErrorCode
from openspace.utils.logging import Logger
import jsonschema
if TYPE_CHECKING:
from ..grounding_client import GroundingClient
logger = Logger.get_logger(__name__)
class ToolRuntimeInfo:
"""Runtime information for a tool instance"""
def __init__(
self,
backend: BackendType,
session_name: str,
server_name: Optional[str] = None,
grounding_client: Optional['GroundingClient'] = None,
):
self.backend = backend
self.session_name = session_name
self.server_name = server_name
self.grounding_client = grounding_client
def __repr__(self):
return f"<ToolRuntimeInfo backend={self.backend.value} session={self.session_name}>"
class BaseTool(ABC):
_name: ClassVar[str] = ""
_description: ClassVar[str] = ""
backend_type: ClassVar[BackendType] = BackendType.NOT_SET
def __init__(self,
schema: Optional[ToolSchema] = None,
*,
verbose: bool = False,
handle_errors: bool = True) -> None:
self.verbose = verbose
self.handle_errors = handle_errors
self.schema: ToolSchema = schema or ToolSchema(
name=self._name or self.__class__.__name__.lower(),
description=self._description,
parameters=self.get_parameters_schema(),
backend_type=self.backend_type,
)
self._runtime_info: Optional[ToolRuntimeInfo] = None
self._disable_outer_recording = True
@property
def name(self) -> str:
"""Get tool name from schema (supports both class-defined and runtime-injected names)"""
return self.schema.name if hasattr(self, 'schema') and self.schema else self._name
@property
def description(self) -> str:
"""Get tool description from schema (supports both class-defined and runtime-injected descriptions)"""
return self.schema.description if hasattr(self, 'schema') and self.schema else self._description
@classmethod
@lru_cache
def get_parameters_schema(cls) -> Dict[str, Any]:
"""Auto-generate JSON-schema from _run() or _arun() signature.
Returns empty dict for tools with no parameters.
Priority: prefer _arun if overridden, otherwise use _run.
"""
# Priority: prefer _arun if it's overridden by subclass, else use _run
# This allows async-first tools to define their signature via _arun
sig_src = None
# Check if _arun is overridden (not from BaseTool)
if cls._arun is not BaseTool._arun:
sig_src = cls._arun
# Otherwise check if _run is overridden
elif cls._run is not BaseTool._run:
sig_src = cls._run
# If neither is overridden, raise error
else:
raise ValueError(
f"{cls.__name__} must implement _run() or _arun() to define its parameters schema"
)
sig = inspect.signature(sig_src)
fields: dict[str, Any] = {}
for name, p in sig.parameters.items():
# Skip 'self' and **kwargs / *args
if name == "self" or p.kind in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL):
continue
typ = p.annotation if p.annotation is not inspect._empty else str
default = p.default if p.default is not inspect._empty else ...
fields[name] = (typ, Field(default))
if not fields:
return {}
PModel: type[BaseModel] = create_model(
f"{cls.__name__}Params",
__config__=ConfigDict(arbitrary_types_allowed=True),
**fields
)
return PModel.model_json_schema()
def validate_parameters(self, params: Dict[str, Any]) -> None:
try:
self.schema.validate_parameters(params, raise_exc=True)
except jsonschema.ValidationError as ve:
raise GroundingError(
f"Invalid parameters: {ve.message}",
code=ErrorCode.TOOL_EXECUTION_FAIL,
tool_name=self.schema.name,
) from ve
def run(self, **kwargs):
try:
return asyncio.run(self.invoke(**kwargs))
except RuntimeError: # already in running loop
loop = asyncio.get_running_loop()
return loop.create_task(self.invoke(**kwargs))
def __call__(self, **kwargs):
return self.run(**kwargs)
async def __acall__(self, **kwargs):
return await self.arun(**kwargs)
async def arun(self, **kwargs) -> ToolResult:
start = time.time()
try:
self.validate_parameters(kwargs)
raw = await self._arun(**kwargs)
result = self._wrap_result(raw, time.time() - start)
except Exception as e:
if self.handle_errors:
result = ToolResult(
status=ToolStatus.ERROR,
error=str(e),
metadata={"tool": self.schema.name},
)
else:
raise
await self._auto_record_execution(kwargs, result, time.time() - start)
return result
# to be implemented by subclasses
@abstractmethod
async def _arun(self, **kwargs): ...
def bind_runtime_info(
self,
backend: BackendType,
session_name: str,
server_name: Optional[str] = None,
grounding_client: Optional['GroundingClient'] = None,
) -> 'BaseTool':
"""
Bind runtime information to the tool instance.
Allow the tool to be invoked directly without specifying backend/session/server.
Args:
backend: Backend type
session_name: Session name
server_name: Server name (for MCP)
grounding_client: Optional reference to GroundingClient for direct invocation
"""
self._runtime_info = ToolRuntimeInfo(
backend=backend,
session_name=session_name,
server_name=server_name,
grounding_client=grounding_client,
)
return self
@property
def runtime_info(self) -> Optional['ToolRuntimeInfo']:
"""Get runtime information if bound"""
return self._runtime_info
@property
def is_bound(self) -> bool:
"""Check if tool has runtime information bound"""
return self._runtime_info is not None
async def invoke(
self,
parameters: Dict[str, Any] | None = None,
keep_session: bool = True,
**kwargs
) -> ToolResult:
"""
Invoke this tool using bound runtime information.
Requires runtime info to be bound via bind_runtime_info().
If no runtime info is bound, the tool will be executed locally.
"""
params = parameters or kwargs
if self.is_bound and self._runtime_info.grounding_client:
return await self._runtime_info.grounding_client.invoke_tool(
tool=self,
parameters=params,
keep_session=keep_session,
)
return await self.arun(**params)
def _wrap_result(self, obj: Any, elapsed: float) -> ToolResult:
if isinstance(obj, ToolResult):
obj.execution_time = elapsed
return obj
if self.verbose:
logger.debug("[%s] done in %.2f s", self.schema.name, elapsed)
if isinstance(obj, (bytes, bytearray)):
obj = obj.decode("utf-8", errors="replace")
return ToolResult(
status=ToolStatus.SUCCESS,
content=str(obj),
execution_time=elapsed,
metadata={"tool": self.schema.name},
)
async def _auto_record_execution(
self,
parameters: Dict[str, Any],
result: ToolResult,
execution_time: float,
):
"""Auto-record tool execution to recording manager and quality manager."""
# Record to quality manager (for quality tracking)
await self._record_to_quality_manager(result, execution_time * 1000)
# Record to recording manager (for trajectory recording)
try:
from openspace.recording import RecordingManager
if not RecordingManager.is_recording():
return
# Check if tool has disabled outer recording (e.g., GUI agent with intermediate steps)
if hasattr(self, '_disable_outer_recording') and self._disable_outer_recording:
logger.debug(f"Skipping outer recording for {self.schema.name} (intermediate steps recorded)")
return
# Get backend and server_name from runtime_info (if bound)
backend = self.backend_type.value
server_name = None
if self.is_bound and self._runtime_info:
# Prefer runtime_info information (more accurate)
backend = self._runtime_info.backend.value
server_name = self._runtime_info.server_name
# Get screenshot (if GUI backend)
screenshot = None
if self.backend_type == BackendType.GUI and hasattr(self, 'connector'):
try:
screenshot = await self.connector.get_screenshot()
except Exception as e:
logger.debug(f"Failed to capture screenshot: {e}")
# Record tool execution with complete runtime information
await RecordingManager.record_tool_execution(
tool_name=self.schema.name,
backend=backend,
parameters=parameters,
result=result.content,
server_name=server_name,
is_success=result.is_success, # Pass actual success status from ToolResult
)
except Exception as e:
logger.warning(f"Failed to auto-record tool execution for {self.schema.name}: {e}")
async def _record_to_quality_manager(
self,
result: ToolResult,
execution_time_ms: float,
):
"""Record execution result to quality manager for quality tracking."""
try:
from openspace.grounding.core.quality import get_quality_manager
manager = get_quality_manager()
if manager:
await manager.record_execution(self, result, execution_time_ms)
except Exception as e:
# Quality recording failure should not affect tool execution
logger.debug(f"Failed to record to quality manager: {e}")
# keep _run for backward-compatibility / thread-pool fallback
def _run(self, **kwargs):
raise NotImplementedError
def __repr__(self):
base = f"<Tool {self.schema.name} ({self.backend_type.value})"
if self.is_bound:
base += f" @ {self._runtime_info.session_name}"
return base + ">"
def __init_subclass__(cls, **kwargs):
"""
- at least implement _run or _arun
- backend_type is NOT_SET, only give a warning, allow RemoteTool to inject at runtime
"""
super().__init_subclass__(**kwargs)
if cls._arun is BaseTool._arun and cls._run is BaseTool._run:
raise ValueError(f"{cls.__name__} must implement _run() or _arun()")
if cls.backend_type is BackendType.NOT_SET:
logger.debug(
"%s.backend_type is NOT_SET; remember to override or set at runtime.",
cls.__name__,
) |