gakrchat1 / backend /tool_client.py
extraplus's picture
Upload 8 files
2c31cbe verified
"""Tool client — provides the local web_search tool only."""
import asyncio
import json
import sys
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
REPO_ROOT = Path(__file__).resolve().parent.parent
# Ensure top-level package imports (e.g., tools.web_search.search) work when
# backend is launched with cwd=backend.
repo_root_str = str(REPO_ROOT)
if repo_root_str not in sys.path:
sys.path.insert(0, repo_root_str)
TOOL_PURPOSE_HINTS = {
"web_search": "Keyword web search with base/medium/pro modes and structured snippets/content output.",
}
TOOL_EXAMPLE_ARGUMENTS = {
"web_search": {"query": "latest ai model updates", "mode": "base"},
}
WEB_SEARCH_TOOL = {
"name": "web_search",
"description": (
"Search by keyword and return structured web results. "
"Modes: base(15), medium(35), pro(5 with content_sections)."
),
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query keyword or phrase.",
},
"mode": {
"type": "string",
"enum": ["base", "medium", "pro"],
"default": "base",
"description": "Result detail mode. base/medium return snippets, pro includes content sections.",
},
"workers": {
"type": "integer",
"default": 8,
"minimum": 2,
"description": "Parallel workers for the search pipeline.",
},
},
"required": ["query"],
},
}
WEB_SEARCH_ALIASES = {
"web_search",
"websearch",
"search_web",
"fetch_content",
"featch_content",
}
class ToolClient:
"""Local tool client — provides only the web_search tool."""
def __init__(self):
self._tools: List[Dict[str, Any]] = []
self._initialized = False
self._init_error: Optional[str] = None
self._web_search_runner: Optional[Callable[..., Any]] = None
@staticmethod
def _normalize_tool_name(name: str) -> str:
raw = str(name or "").strip().lower().replace("-", "_").replace(" ", "_")
return "web_search" if raw in WEB_SEARCH_ALIASES else raw
@staticmethod
def _coerce_positive_int(value: Any, default: int) -> int:
try:
parsed = int(value)
return parsed if parsed > 0 else default
except (TypeError, ValueError):
return default
def _load_web_search_runner(self) -> Callable[..., Any]:
if self._web_search_runner is not None:
return self._web_search_runner
try:
from tools.web_search.search import run_search_tool # pylint: disable=import-outside-toplevel
except Exception as exc:
raise ImportError(f"Unable to load web search runner: {exc}") from exc
self._web_search_runner = run_search_tool
return run_search_tool
async def _call_web_search(self, arguments: Dict[str, Any]) -> str:
args = arguments if isinstance(arguments, dict) else {}
query = str(
args.get("query")
or args.get("keyword")
or args.get("keywords")
or ""
).strip()
if not query:
return json.dumps(
{
"status": "error",
"tool": "web_search",
"error": "Missing required parameter: query",
}
)
mode = str(args.get("mode") or "base").strip().lower()
if mode not in {"base", "medium", "pro"}:
mode = "base"
workers = self._coerce_positive_int(args.get("workers"), default=8)
try:
runner = self._load_web_search_runner()
result = await asyncio.to_thread(runner, query=query, mode=mode, workers=workers)
return json.dumps(result, ensure_ascii=False, default=str)
except Exception as exc:
err = str(exc).strip() or repr(exc)
print(f"[TOOL] web_search failed: {err}")
return json.dumps({"status": "error", "tool": "web_search", "error": err})
@property
def is_available(self) -> bool:
return self._initialized
@property
def tools(self) -> List[Dict[str, Any]]:
return list(self._tools)
@property
def init_error(self) -> Optional[str]:
return self._init_error
async def initialize(self) -> bool:
"""Register the web_search tool. Always succeeds."""
self._tools = [dict(WEB_SEARCH_TOOL)]
self._initialized = True
self._init_error = None
print("[TOOLS] web_search tool registered")
return True
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str:
"""Execute a tool by name. Only web_search is supported."""
normalized_name = self._normalize_tool_name(name)
if normalized_name == "web_search":
return await self._call_web_search(arguments)
return json.dumps({"error": f"Unknown tool: {name}. Only web_search is supported."})
@staticmethod
def _infer_example_value(parameter_name: str, parameter_schema: Dict[str, Any]) -> Any:
name = (parameter_name or "").lower()
schema_type = str(parameter_schema.get("type") or "").lower()
if name in {"query", "q", "search", "prompt"}:
return "latest official update"
if name in {"limit", "max", "count", "max_results", "top_k"}:
return 5
if schema_type == "boolean":
return False
if schema_type in {"integer", "number"}:
return 1
if schema_type == "array":
return []
if schema_type == "object":
return {}
return "value"
@classmethod
def _build_tool_example_arguments(cls, tool_name: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
explicit = TOOL_EXAMPLE_ARGUMENTS.get(tool_name)
if explicit:
return dict(explicit)
props = parameters.get("properties", {}) if isinstance(parameters, dict) else {}
required = parameters.get("required", []) if isinstance(parameters, dict) else []
example: Dict[str, Any] = {}
if isinstance(required, list):
for param_name in required:
if isinstance(param_name, str):
param_schema = props.get(param_name, {}) if isinstance(props, dict) else {}
example[param_name] = cls._infer_example_value(param_name, param_schema)
if not example and isinstance(props, dict):
for param_name, param_schema in props.items():
example[param_name] = cls._infer_example_value(param_name, param_schema or {})
if len(example) >= 2:
break
return example
@classmethod
def _build_tool_example_call(cls, tool_name: str, parameters: Dict[str, Any]) -> str:
payload = {
"tool": tool_name,
"arguments": cls._build_tool_example_arguments(tool_name, parameters),
}
return json.dumps(payload, ensure_ascii=False)
def get_tools_description(self) -> str:
"""Format tool descriptions for the system prompt."""
if not self._tools:
return "No tools currently available."
parts: List[str] = []
for t in self._tools:
name = t["name"]
description = t["description"]
purpose = TOOL_PURPOSE_HINTS.get(name, "General external data retrieval tool.")
parts.append(f"- **{name}**")
parts.append(f" - Description: {description}")
parts.append(f" - Purpose: {purpose}")
params = t.get("parameters", {})
props = params.get("properties", {})
required = set(params.get("required", []))
if not props:
parts.append(" - Parameters: none")
else:
parts.append(" - Parameters:")
for pname, pinfo in props.items():
req_label = " (required)" if pname in required else " (optional)"
ptype = pinfo.get("type", "any")
pdesc = pinfo.get("description", "")
parts.append(f" - {pname} ({ptype}{req_label}): {pdesc}")
parts.append(" - Example Tool Call (JSON):")
parts.append(f" {self._build_tool_example_call(name, params)}")
parts.append(" - Expected Result Style: Return concise factual extraction with source-aware context.")
return "\n".join(parts)
def get_tool_call_format(self) -> str:
"""Return an example JSON tool-call format for the system prompt."""
if self._tools:
first = self._tools[0]
single_obj = json.loads(self._build_tool_example_call(first["name"], first.get("parameters", {})))
return json.dumps(single_obj, ensure_ascii=False)
return '{"tool": "tool_name", "arguments": {...}}'
def get_tool_names(self) -> List[str]:
return [t["name"] for t in self._tools]
async def shutdown(self):
"""Clean up tool registry."""
self._initialized = False
self._tools = []
self._web_search_runner = None
print("[TOOLS] Tool registry shut down")
# Global singleton
tool_client = ToolClient()