File size: 3,228 Bytes
d25ab77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Python Env Environment Client."""

from __future__ import annotations

from typing import Any, Dict
from urllib.parse import urlparse

import httpx

from openenv.core import EnvClient
from openenv.core.client_types import StepResult

try:
    from .models import (
        HealthResponse,
        MetricsResponse,
        PythonAction,
        PythonObservation,
        PythonState,
        TaskListResponse,
    )
except ImportError:
    from models import (  # type: ignore
        HealthResponse,
        MetricsResponse,
        PythonAction,
        PythonObservation,
        PythonState,
        TaskListResponse,
    )


def _to_http_base_url(base_url: str) -> str:
    parsed = urlparse(base_url)
    scheme = "https" if parsed.scheme == "wss" else "http"
    if parsed.scheme in {"http", "https"}:
        scheme = parsed.scheme
    return f"{scheme}://{parsed.netloc}{parsed.path}".rstrip("/")


class PythonEnv(EnvClient[PythonAction, PythonObservation, PythonState]):
    """Typed client for the Python code-review environment."""

    def __init__(self, base_url: str, **kwargs: Any):
        super().__init__(base_url=base_url, **kwargs)
        self._http_base_url = _to_http_base_url(base_url)

    def _step_payload(self, action: PythonAction) -> Dict[str, Any]:
        """Convert a validated action model to the JSON payload expected by the server."""

        return action.model_dump(exclude_none=True)

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[PythonObservation]:
        """Parse a server response into a typed step result."""

        obs_data = dict(payload.get("observation", {}))
        obs_data.setdefault("done", payload.get("done", False))
        obs_data.setdefault("reward", payload.get("reward"))
        observation = PythonObservation.model_validate(obs_data)

        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> PythonState:
        """Parse the server state payload into the shared state model."""

        return PythonState.model_validate(payload)

    async def get_tasks(self) -> TaskListResponse:
        async with httpx.AsyncClient() as client:
            response = await client.get(f"{self._http_base_url}/tasks")
            response.raise_for_status()
        return TaskListResponse.model_validate(response.json())

    async def get_metrics(self) -> MetricsResponse:
        async with httpx.AsyncClient() as client:
            response = await client.get(f"{self._http_base_url}/metrics")
            response.raise_for_status()
        return MetricsResponse.model_validate(response.json())

    async def get_health(self) -> HealthResponse:
        async with httpx.AsyncClient() as client:
            response = await client.get(f"{self._http_base_url}/health")
            response.raise_for_status()
        return HealthResponse.model_validate(response.json())