File size: 4,474 Bytes
ff9fcbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
HTTP client for the Code Review Environment.

Usage:
    from client import CodeReviewEnv, ReviewAction

    with CodeReviewEnv(base_url="http://localhost:7860").sync() as env:
        result = env.reset(task_id="bug-detection")
        obs = result.observation
        print(obs.task_description)
        print(obs.code_files)

        # Flag an issue
        result = env.step(ReviewAction(
            action_type="flag_issue",
            line_number=6,
            filename="utils.py",
            issue_type="bug",
            severity="high",
            description="Off-by-one in range()"
        ))
        print(result.observation.feedback)

        # Submit
        result = env.step(ReviewAction(action_type="submit_review"))
        print(f"Score: {result.reward:.3f}")
"""
from __future__ import annotations

import os
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from typing import Optional, Generic, TypeVar
from models import ReviewAction, ReviewObservation, ReviewState, Issue

ObsT = TypeVar("ObsT")


class StepResult(Generic[ObsT]):
    def __init__(
        self,
        observation: ObsT,
        reward: Optional[float] = None,
        done: bool = False,
    ):
        self.observation = observation
        self.reward = reward
        self.done = done

    def __repr__(self) -> str:
        return (
            f"StepResult(done={self.done}, reward={self.reward}, "
            f"score={getattr(self.observation, 'current_score', None)})"
        )


try:
    import httpx
    _HAS_HTTPX = True
except ImportError:
    _HAS_HTTPX = False

try:
    from openenv.core.http_env_client import HTTPEnvClient as _OfficialClient
    _HAS_OPENENV_CLIENT = True
except ImportError:
    _HAS_OPENENV_CLIENT = False


class SyncCodeReviewEnv:

    def __init__(self, base_url: str = "http://localhost:7860"):
        self.base_url = base_url.rstrip("/")
        if not _HAS_HTTPX:
            raise ImportError("httpx is required: pip install httpx")
        import httpx
        self._client = httpx.Client(timeout=30.0)

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()

    def close(self):
        self._client.close()

    def reset(
        self,
        task_id: Optional[str] = None,
        seed: Optional[int] = None,
        episode_id: Optional[str] = None,
    ) -> StepResult[ReviewObservation]:
        body = {}
        if task_id:
            body["task_id"] = task_id
        if seed is not None:
            body["seed"] = seed
        if episode_id:
            body["episode_id"] = episode_id

        resp = self._client.post(f"{self.base_url}/reset", json=body)
        resp.raise_for_status()
        obs = ReviewObservation.from_dict(resp.json())
        return StepResult(observation=obs, reward=obs.reward, done=obs.done)

    def step(self, action: ReviewAction) -> StepResult[ReviewObservation]:
        body = action.to_dict()
        resp = self._client.post(f"{self.base_url}/step", json=body)
        resp.raise_for_status()
        obs = ReviewObservation.from_dict(resp.json())
        return StepResult(observation=obs, reward=obs.reward, done=obs.done)

    def state(self) -> ReviewState:
        resp = self._client.get(f"{self.base_url}/state")
        resp.raise_for_status()
        data = resp.json()
        return ReviewState(
            task_id=data.get("task_id", ""),
            difficulty=data.get("difficulty", ""),
            episode_id=data.get("episode_id"),
            step_count=data.get("step_count", 0),
            flagged_issues=[Issue.from_dict(i) for i in data.get("flagged_issues", [])],
            current_score=data.get("current_score", 0.0),
            submitted=data.get("submitted", False),
        )

    def health(self) -> dict:
        resp = self._client.get(f"{self.base_url}/health")
        resp.raise_for_status()
        return resp.json()

    def list_tasks(self) -> dict:
        resp = self._client.get(f"{self.base_url}/tasks")
        resp.raise_for_status()
        return resp.json()


class CodeReviewEnv:

    def __init__(self, base_url: str = "http://localhost:7860"):
        self.base_url = base_url

    def sync(self) -> SyncCodeReviewEnv:
        return SyncCodeReviewEnv(self.base_url)

    def __enter__(self):
        self._sync = self.sync()
        return self._sync

    def __exit__(self, *args):
        if hasattr(self, "_sync"):
            self._sync.close()