File size: 2,696 Bytes
1e3b07a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
WebSocket client for Code Debugging Challenge environment.
"""

from openenv.core.env_client import EnvClient
from openenv.core.client_types import StepResult, ResetResult
from .models import DebugAction, DebugObservation, DebugState


class DebugEnv(EnvClient[DebugAction, DebugObservation, DebugState]):
    """Client for interacting with Code Debugging Challenge environment."""
    
    def _step_payload(self, action: DebugAction) -> dict:
        """Convert action to JSON payload for server."""
        return {
            "action_type": action.action_type,
            "content": action.content
        }
    
    def _parse_result(self, data: dict) -> StepResult[DebugObservation]:
        """Parse step response from server into typed result."""
        obs_data = data["observation"]
        
        observation = DebugObservation(
            buggy_code=obs_data["buggy_code"],
            expected_output=obs_data["expected_output"],
            test_inputs=obs_data.get("test_inputs", []),
            current_output=obs_data.get("current_output"),
            error_message=obs_data.get("error_message"),
            attempts_remaining=obs_data["attempts_remaining"],
            hint=obs_data.get("hint"),
            success=obs_data.get("success", False)
        )
        
        return StepResult(
            observation=observation,
            reward=data["reward"],
            terminated=data["terminated"],
            truncated=data["truncated"],
            info=data.get("info", {})
        )
    
    def _parse_reset_result(self, data: dict) -> ResetResult[DebugObservation]:
        """Parse reset response from server into typed result."""
        obs_data = data["observation"]
        
        observation = DebugObservation(
            buggy_code=obs_data["buggy_code"],
            expected_output=obs_data["expected_output"],
            test_inputs=obs_data.get("test_inputs", []),
            attempts_remaining=obs_data.get("attempts_remaining", 5),
            success=False
        )
        
        return ResetResult(
            observation=observation,
            info=data.get("info", {})
        )
    
    def _parse_state(self, data: dict) -> DebugState:
        """Parse state response from server into typed state."""
        return DebugState(
            current_problem_index=data.get("current_problem_index", 0),
            attempts_made=data.get("attempts_made", 0),
            max_attempts=data.get("max_attempts", 5),
            score=data.get("score", 0.0),
            solved=data.get("solved", False),
            total_problems=data.get("total_problems", 7),
            episode_id=data.get("episode_id", "")
        )