File size: 4,507 Bytes
d8328bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""HTTP client for interacting with the NexaSci tool server."""

from __future__ import annotations

import os
from contextlib import AbstractContextManager
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Type, TypeVar

import httpx
import yaml
from pydantic import BaseModel

from tools.schemas import (
    CorpusSearchRequest,
    CorpusSearchResponse,
    PaperFetchRequest,
    PaperFetchResponse,
    PaperSearchRequest,
    PaperSearchResponse,
    PythonRunRequest,
    PythonRunResponse,
    ToolCall,
    ToolResult,
)

T = TypeVar("T", bound=BaseModel)


@dataclass(frozen=True)
class ToolClientConfig:
    """Configuration required to initialise the ToolClient."""

    base_url: str
    timeout_s: int = 30

    @classmethod
    def from_yaml(cls, path: Path | str = "agent/config.yaml") -> "ToolClientConfig":
        """Load configuration parameters from the shared YAML file."""

        env_base_url = os.environ.get("TOOL_SERVER_BASE_URL")
        env_timeout = os.environ.get("TOOL_SERVER_TIMEOUT")
        if env_base_url:
            timeout_value = int(env_timeout) if env_timeout else 30
            return cls(base_url=env_base_url, timeout_s=timeout_value)

        config_path = Path(path)
        if not config_path.exists():
            raise FileNotFoundError(f"Configuration file not found: {config_path}")
        with config_path.open("r", encoding="utf-8") as handle:
            data = yaml.safe_load(handle)
        tool_cfg = data.get("tool_server", {})
        return cls(
            base_url=str(tool_cfg.get("base_url", "http://127.0.0.1:8000")),
            timeout_s=int(tool_cfg.get("request_timeout_s", 30)),
        )


class ToolClient(AbstractContextManager["ToolClient"]):
    """Synchronous HTTP client for the NexaSci tool server."""

    _tool_to_endpoint: Dict[str, str] = {
        "python.run": "/tools/python.run",
        "papers.search": "/tools/papers.search",
        "papers.fetch": "/tools/papers.fetch",
        "papers.search_corpus": "/tools/papers.search_corpus",
    }

    def __init__(self, config: ToolClientConfig) -> None:
        self._config = config
        self._client = httpx.Client(base_url=self._config.base_url, timeout=self._config.timeout_s)

    @classmethod
    def from_config(cls, path: Path | str = "agent/config.yaml") -> "ToolClient":
        """Construct a ToolClient by loading configuration from disk."""

        return cls(ToolClientConfig.from_yaml(path))

    def __exit__(self, exc_type, exc, exc_tb) -> None:  # type: ignore[override]
        self.close()

    def close(self) -> None:
        """Close the underlying HTTP client."""

        self._client.close()

    def call_tool(self, call: ToolCall) -> ToolResult:
        """Invoke a tool using the generic tool call schema."""

        endpoint = self._resolve_endpoint(call.tool)
        response = self._client.post(endpoint, json=call.arguments)
        if not response.is_success:
            return ToolResult.failed(call.tool, f"Tool invocation failed: {response.text}")
        payload = response.json()
        return ToolResult.ok(call.tool, payload)

    def python_run(self, request: PythonRunRequest) -> PythonRunResponse:
        """Execute code snippets inside the Python sandbox."""

        return self._post("python.run", request, PythonRunResponse)

    def papers_search(self, request: PaperSearchRequest) -> PaperSearchResponse:
        """Search the arXiv API via the tool server."""

        return self._post("papers.search", request, PaperSearchResponse)

    def papers_fetch(self, request: PaperFetchRequest) -> PaperFetchResponse:
        """Fetch a single paper's metadata from arXiv."""

        return self._post("papers.fetch", request, PaperFetchResponse)

    def papers_search_corpus(self, request: CorpusSearchRequest) -> CorpusSearchResponse:
        """Search the local SPECTER2 corpus."""

        return self._post("papers.search_corpus", request, CorpusSearchResponse)

    def _post(self, tool: str, model: BaseModel, response_model: Type[T]) -> T:
        endpoint = self._resolve_endpoint(tool)
        response = self._client.post(endpoint, json=model.dict())
        response.raise_for_status()
        return response_model.parse_obj(response.json())

    def _resolve_endpoint(self, tool: str) -> str:
        try:
            return self._tool_to_endpoint[tool]
        except KeyError as exc:
            raise ValueError(f"Unknown tool: {tool}") from exc