File size: 5,731 Bytes
9b47159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
HTTP clients for the Bug Triage OpenEnv environment.

Provides synchronous (BugTriageEnvClient) and asynchronous
(AsyncBugTriageEnvClient) clients for interacting with the server.

Sync usage:
    from bug_triage_env.client import BugTriageEnvClient
    from bug_triage_env.models import BugTriageAction

    with BugTriageEnvClient("http://localhost:8000") as env:
        obs = env.reset(task_id="task_1")
        action = BugTriageAction(task_id="task_1", bug_type="crash")
        result = env.step(obs["episode_id"], action)
        print(result["grader_score"])

Async usage:
    from bug_triage_env.client import AsyncBugTriageEnvClient

    async with AsyncBugTriageEnvClient("http://localhost:8000") as env:
        obs = await env.reset("task_3")
"""

from __future__ import annotations

import os
from typing import Any, Dict, Optional

import requests

from .models import BugTriageAction

DEFAULT_URL = os.getenv("BUG_TRIAGE_ENV_URL", "http://localhost:8000")


class BugTriageEnvClient:
    """
    Synchronous HTTP client for the Bug Triage environment.

    Follows the OpenEnv HTTPEnvClient interface: reset(), step(), state().
    """

    def __init__(self, base_url: str = DEFAULT_URL, timeout: int = 30) -> None:
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self._session = requests.Session()
        self._episode_id: Optional[str] = None

    def reset(self, task_id: str = "task_1") -> Dict[str, Any]:
        """Start a new episode. Returns observation dict with episode_id."""
        resp = self._session.post(
            f"{self.base_url}/reset",
            json={"task_id": task_id},
            timeout=self.timeout,
        )
        resp.raise_for_status()
        data = resp.json()
        self._episode_id = data.get("episode_id")
        return data

    def step(
        self,
        episode_id: Optional[str] = None,
        action: Optional[BugTriageAction] = None,
    ) -> Dict[str, Any]:
        """Execute one action. Returns observation with grader_score."""
        ep_id = episode_id or self._episode_id
        if ep_id is None:
            raise ValueError("episode_id required. Call reset() first.")
        if action is None:
            raise ValueError("action required.")

        resp = self._session.post(
            f"{self.base_url}/step",
            json={"episode_id": ep_id, "action": action.model_dump()},
            timeout=self.timeout,
        )
        resp.raise_for_status()
        return resp.json()

    def state(self) -> Dict[str, Any]:
        """Return current episode state."""
        resp = self._session.get(
            f"{self.base_url}/state", timeout=self.timeout,
        )
        resp.raise_for_status()
        return resp.json()

    def health(self) -> bool:
        """Check if the server is healthy."""
        try:
            resp = self._session.get(
                f"{self.base_url}/health", timeout=self.timeout,
            )
            return resp.json().get("status") == "healthy"
        except Exception:
            return False

    def tasks(self) -> Dict[str, Any]:
        """Get task list and action schemas."""
        resp = self._session.get(
            f"{self.base_url}/tasks", timeout=self.timeout,
        )
        resp.raise_for_status()
        return resp.json()

    def grade(self, episode_id: str, task_id: str) -> Dict[str, Any]:
        """Get grader score for a completed episode."""
        resp = self._session.post(
            f"{self.base_url}/grader",
            json={"episode_id": episode_id, "task_id": task_id},
            timeout=self.timeout,
        )
        resp.raise_for_status()
        return resp.json()

    def __enter__(self) -> BugTriageEnvClient:
        return self

    def __exit__(self, *args: Any) -> None:
        self._session.close()


class AsyncBugTriageEnvClient:
    """
    Async HTTP client using aiohttp.

    Designed for use with GRPO training and vLLM colocate mode.
    """

    def __init__(self, base_url: str = DEFAULT_URL, timeout: int = 30) -> None:
        self.base_url = base_url.rstrip("/")
        self.timeout = timeout
        self._episode_id: Optional[str] = None
        self._session: Any = None

    async def __aenter__(self) -> AsyncBugTriageEnvClient:
        import aiohttp

        timeout = aiohttp.ClientTimeout(total=self.timeout)
        self._session = aiohttp.ClientSession(timeout=timeout)
        return self

    async def __aexit__(self, *args: Any) -> None:
        if self._session:
            await self._session.close()

    async def reset(self, task_id: str = "task_1") -> Dict[str, Any]:
        """Start a new episode."""
        async with self._session.post(
            f"{self.base_url}/reset",
            json={"task_id": task_id},
        ) as resp:
            data = await resp.json()
            self._episode_id = data.get("episode_id")
            return data

    async def step(
        self,
        episode_id: Optional[str] = None,
        action: Optional[BugTriageAction] = None,
    ) -> Dict[str, Any]:
        """Execute one triage action."""
        ep_id = episode_id or self._episode_id
        if ep_id is None:
            raise ValueError("Call reset() first.")
        if action is None:
            raise ValueError("action required.")
        async with self._session.post(
            f"{self.base_url}/step",
            json={"episode_id": ep_id, "action": action.model_dump()},
        ) as resp:
            return await resp.json()

    async def state(self) -> Dict[str, Any]:
        """Return current episode state."""
        async with self._session.get(f"{self.base_url}/state") as resp:
            return await resp.json()