File size: 6,477 Bytes
a5fc2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697e014
a5fc2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697e014
 
 
a5fc2da
697e014
 
a5fc2da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Skill Forge Environment Implementation.

An RL training environment where LLM Agents evolve from "reinventing the wheel" to "building a skill library."
"""

import json
import traceback
from uuid import uuid4

import pandas as pd

from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

try:
    from ..models import SkillForgeAction, SkillForgeObservation
    from .data_generator import TASKS
except ImportError:
    from models import SkillForgeAction, SkillForgeObservation
    from data_generator import TASKS


class SkillForgeEnvironment(Environment):
    """
    SkillForge RL environment.

    The agent solves chained pandas tasks and can build a reusable skill library.
    Skills persist across episodes so the agent can discover and reuse patterns.
    """

    SUPPORTS_CONCURRENT_SESSIONS: bool = True

    def __init__(self):
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self.skill_library: dict = {}
        self.task_idx: int = 0
        self.tasks_solved: int = 0
        self.total_tokens: int = 0

    def reset(self) -> SkillForgeObservation:
        """
        Reset episode state. skill_library is NOT reset — persists across episodes.
        """
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self.task_idx = 0
        self.tasks_solved = 0
        self.total_tokens = 0

        task = TASKS[self.task_idx]
        return self._make_observation(task, result_correct=False, result_output="", reward=0.0, done=False)

    def step(self, action: SkillForgeAction) -> SkillForgeObservation:
        self._state.step_count += 1
        task = TASKS[self.task_idx]
        reward = 0.0

        # --- create_skill: store template, stay on current task ---
        if action.action_type == "create_skill":
            token_cost = len(action.content)
            self.total_tokens += token_cost

            self.skill_library[action.skill_name] = {
                "template": action.content,
                "description": action.reasoning,
                "used_count": 0,
            }
            reward = 0.5
            return self._make_observation(
                task, result_correct=False,
                result_output=f"Skill '{action.skill_name}' saved.",
                reward=reward, done=False,
            )

        # --- use_skill or raw_code: execute and evaluate ---
        if action.action_type == "use_skill":
            skill = self.skill_library.get(action.content)
            if skill is None:
                # skill not found — treat as error
                self.total_tokens += len(action.content)
                return self._make_observation(
                    task, result_correct=False,
                    result_output=f"Skill '{action.content}' not found in library.",
                    reward=-0.3, done=False,
                )
            exec_code = skill["template"].format(**(action.params or {}))
            skill["used_count"] += 1
            # token cost for use_skill: skill name + serialized params (much shorter than full code)
            skill_call_repr = action.content + json.dumps(action.params or {})
            token_cost = len(skill_call_repr)
        else:
            # raw_code
            exec_code = action.content
            token_cost = len(action.content)

        self.total_tokens += token_cost

        result_correct, result_output = self._evaluate(exec_code, task["dataframe"], task["expected_output"])

        if action.action_type == "create_skill":
            result_correct = True

        if result_correct:
            reward += 2.0
            reward -= 0.001 * token_cost 
            if action.action_type == "use_skill":
                reward += 0.5
            self.tasks_solved += 1
            self.task_idx += 1
        else:
            reward = -0.3

        done = self.task_idx >= len(TASKS)
        next_task = TASKS[self.task_idx] if not done else task

        return self._make_observation(
            next_task,
            result_correct=result_correct,
            result_output=result_output,
            reward=reward,
            done=done,
        )

    def _evaluate(self, exec_code: str | None, dataframe: pd.DataFrame, expected_output) -> tuple[bool, str]:
        if exec_code is None:
            return False, "No code to execute."
        try:
            namespace = {"df": dataframe.copy(), "pd": pd, "__builtins__": {"len": len, "str": str, "int": int, "float": float, "list": list, "dict": dict, "bool": bool, "range": range, "abs": abs, "min": min, "max": max, "sum": sum, "sorted": sorted, "round": round, "True": True, "False": False, "None": None}}
            result = eval(exec_code, namespace)

            # normalize for comparison
            if isinstance(result, pd.DataFrame):
                result = result.values.tolist()
            if isinstance(result, pd.Series):
                result = result.tolist()
            if isinstance(result, pd.Index):
                result = result.tolist()

            expected = expected_output
            if isinstance(expected, pd.Series):
                expected = expected.tolist()

            try:
                is_correct = result == expected
            except (ValueError, TypeError):
                is_correct = False

            return bool(is_correct), str(result)
        except Exception:
            return False, traceback.format_exc()

    def _make_observation(self, task: dict, result_correct: bool, result_output: str,
                          reward: float, done: bool) -> SkillForgeObservation:
        return SkillForgeObservation(
            task_id=task["id"],
            task_description=task["description"],
            snapshot_data=task["dataframe"].head(5).to_string(),
            skill_library=self.skill_library,
            context="",
            step_count=self._state.step_count,
            total_tokens=self.total_tokens,
            result_correct=result_correct,
            result_output=result_output,
            expected_output=str(task["expected_output"]),
            reward=reward,
            done=done,
        )

    @property
    def state(self) -> State:
        return self._state