File size: 3,223 Bytes
03a907a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""

RlCodeFix Environment  OpenEnv-compliant server wrapper.



Wraps the core CodeEnv engine and exposes it through the OpenEnv

HTTP/WebSocket interface via create_app().



Episode lifecycle:

    reset()  loads a randomly selected task (easy | medium | hard)

    step()   dispatches apply_patch / run_tests / get_logs

    state    returns current episode_id + step_count

"""

from uuid import uuid4
import os

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

from rl_code_fix_env.models import CodeFixerAction, CodeFixerObservation
from rl_code_fix_env.src.environment.environment import CodeEnv

_DIFFICULTY_CYCLE = ["easy", "medium", "hard"]


class RlCodeFixEnvironment(Environment):
    """

    OpenEnv-compliant wrapper around CodeEnv.



    Exposes reset / step / state to the HTTP server produced by create_app().

    Task difficulty is chosen randomly on each reset so the agent sees a

    variety of problems across episodes.

    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._engine = CodeEnv()         # core engine
        self._difficulty_idx = 0

    def _select_difficulty(self) -> str:
        forced = (os.getenv("TRACERL_TASK") or "").strip().lower()
        if forced in _DIFFICULTY_CYCLE:
            return forced

        difficulty = _DIFFICULTY_CYCLE[self._difficulty_idx % len(_DIFFICULTY_CYCLE)]
        self._difficulty_idx += 1
        return difficulty

    def reset(self) -> CodeFixerObservation:
        """Load a randomly selected task and return initial observation."""
        difficulty = self._select_difficulty()

        self._state = State(episode_id=str(uuid4()), step_count=0)
        obs_dict = self._engine.reset(difficulty=difficulty)

        return CodeFixerObservation(
            code=obs_dict["code"],
            logs=obs_dict["logs"],
            test_score=float(obs_dict["test_score"]),
            total_tests=obs_dict["total_tests"],
            steps=obs_dict["steps"],
            done=False,
            reward=0.0,
        )

    def step(self, action: CodeFixerAction) -> CodeFixerObservation:  # type: ignore[override]
        """Dispatch action to the core engine and return observation."""
        self._state.step_count += 1

        obs_dict, reward, done, _ = self._engine.step(
            {"type": action.type, "payload": action.payload}
        )

        return CodeFixerObservation(
            code=obs_dict["code"],
            logs=obs_dict["logs"],
            test_score=float(obs_dict["test_score"]),
            total_tests=obs_dict["total_tests"],
            steps=obs_dict["steps"],
            done=done,
            reward=float(reward),
        )

    @property
    def state(self) -> State:
        """Current episode_id and step_count."""
        return self._state