File size: 4,075 Bytes
d8328bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""FastAPI-based tool server exposing NexaSci capabilities to the agent."""

from __future__ import annotations

from functools import lru_cache
from pathlib import Path
from typing import Any, Dict

import yaml
from fastapi import FastAPI, HTTPException

from tools.paper_sources import ArxivClient, CorpusPaths, CorpusSearcher
from tools.python_sandbox import DEFAULT_ALLOWED_MODULES, SandboxConfig, execute_python
from tools.schemas import (
    CorpusSearchRequest,
    CorpusSearchResponse,
    PaperFetchRequest,
    PaperFetchResponse,
    PaperSearchRequest,
    PaperSearchResponse,
    PythonRunRequest,
    PythonRunResponse,
    ToolResult,
)

APP_TITLE = "NexaSci Tool Server"
CONFIG_PATH = Path("agent/config.yaml")

app = FastAPI(title=APP_TITLE, version="0.1.0")


@lru_cache(maxsize=1)
def _load_yaml_config() -> Dict[str, Any]:
    """Load the shared agent/tool configuration file."""

    if not CONFIG_PATH.exists():
        raise FileNotFoundError(f"Configuration file not found at {CONFIG_PATH}")
    with CONFIG_PATH.open("r", encoding="utf-8") as handle:
        return yaml.safe_load(handle)


@lru_cache(maxsize=1)
def get_sandbox_config() -> SandboxConfig:
    """Initialise SandboxConfig from the shared YAML configuration."""

    config = _load_yaml_config().get("sandbox", {})
    working_directory = Path(config.get("working_directory", "./tmp/python")).resolve()
    allowed_modules = tuple(config.get("allowed_modules", DEFAULT_ALLOWED_MODULES))
    return SandboxConfig(
        timeout_s=int(config.get("timeout_s", 10)),
        memory_limit_mb=int(config.get("memory_limit_mb", 2048)),
        working_directory=working_directory,
        allowed_modules=allowed_modules,
    )


@lru_cache(maxsize=1)
def get_arxiv_client() -> ArxivClient:
    """Return a cached ArxivClient instance."""

    return ArxivClient()


@lru_cache(maxsize=1)
def get_corpus_searcher() -> CorpusSearcher:
    """Return a cached CorpusSearcher initialised from configuration."""

    config = _load_yaml_config().get("corpus", {})
    corpus_path = Path(config.get("corpus_path", "./index/corpus.json")).resolve()
    embeddings_path = Path(config.get("embeddings_path", "./index/embeddings.npy")).resolve()
    paths = CorpusPaths(corpus_path=corpus_path, embeddings_path=embeddings_path)
    device = config.get("embedding_device")
    return CorpusSearcher(paths, device=device)


@app.get("/healthz", response_model=ToolResult)
async def health_check() -> ToolResult:
    """Simple readiness probe."""

    return ToolResult.ok(
        tool="health",
        output={"status": "ok"},
    )


@app.post("/tools/python.run", response_model=PythonRunResponse)
async def run_python(request: PythonRunRequest) -> PythonRunResponse:
    """Execute code snippets inside the sandboxed Python environment."""

    response = execute_python(request, get_sandbox_config())
    return response


@app.post("/tools/papers.search", response_model=PaperSearchResponse)
async def papers_search(request: PaperSearchRequest) -> PaperSearchResponse:
    """Search the arXiv API for relevant papers."""

    client = get_arxiv_client()
    results = client.search(request)
    return PaperSearchResponse(results=results)


@app.post("/tools/papers.fetch", response_model=PaperFetchResponse)
async def papers_fetch(request: PaperFetchRequest) -> PaperFetchResponse:
    """Fetch detailed metadata for a specific paper from arXiv."""

    client = get_arxiv_client()
    metadata = client.fetch(arxiv_id=request.arxiv_id, doi=request.doi)
    if metadata is None:
        raise HTTPException(status_code=404, detail="Paper not found.")
    return PaperFetchResponse(paper=metadata)


@app.post("/tools/papers.search_corpus", response_model=CorpusSearchResponse)
async def papers_search_corpus(request: CorpusSearchRequest) -> CorpusSearchResponse:
    """Search the locally stored scientific corpus using SPECTER2 embeddings."""

    searcher = get_corpus_searcher()
    results = searcher.search(request)
    return CorpusSearchResponse(results=results)


__all__ = ["app"]