File size: 6,324 Bytes
eb1ebe6
 
 
 
 
 
 
8fa7af1
eb1ebe6
 
 
 
 
 
 
 
 
 
8fa7af1
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
 
43f41de
 
 
 
 
 
eb1ebe6
 
8fa7af1
eb1ebe6
 
8fa7af1
eb1ebe6
 
 
 
 
43f41de
eb1ebe6
 
 
 
 
 
 
 
8fa7af1
43f41de
 
 
 
 
eb1ebe6
 
 
 
 
 
 
 
43f41de
 
 
 
 
eb1ebe6
43f41de
eb1ebe6
 
 
 
 
 
43f41de
eb1ebe6
 
 
 
 
 
 
 
 
 
 
43f41de
 
eb1ebe6
 
 
 
 
43f41de
eb1ebe6
 
 
 
 
 
 
 
43f41de
 
 
 
 
eb1ebe6
 
 
 
 
8fa7af1
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43f41de
 
 
 
 
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
 
43f41de
 
eb1ebe6
 
 
8fa7af1
43f41de
 
 
 
 
 
 
 
 
 
 
 
 
8fa7af1
 
 
 
 
 
 
 
 
 
 
 
43f41de
 
8fa7af1
43f41de
 
eb1ebe6
 
 
 
 
 
 
 
 
 
 
 
 
 
8fa7af1
eb1ebe6
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""Tests for ExplainerEnvironment — multi-step explore→generate lifecycle."""

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS
from models import ExplainerAction, ExplainerObservation
from server.explainer_env_environment import ExplainerEnvironment


def test_reset_returns_observation():
    env = ExplainerEnvironment()
    obs = env.reset(seed=1)
    assert isinstance(obs, ExplainerObservation)
    assert obs.topic != ""
    assert obs.phase == "explore"
    assert obs.explore_steps_left == MAX_EXPLORE_STEPS
    assert obs.done is False


def test_reset_deterministic_with_seed():
    env = ExplainerEnvironment()
    obs1 = env.reset(seed=42)
    obs2 = env.reset(seed=42)
    assert obs1.topic == obs2.topic


def test_explore_step():
    env = ExplainerEnvironment()
    env.reset(seed=1)
    action = ExplainerAction(
        action_type="explore",
        tool="search_wikipedia",
        query="gradient descent optimization",
        intent="beginner explanation",
    )
    obs = env.step(action)
    assert obs.done is False
    assert obs.explore_steps_left == MAX_EXPLORE_STEPS - 1
    assert isinstance(obs.reward, (int, float))
    assert obs.reward >= 0.0
    assert isinstance(obs.top_chunks, list)


def test_explore_empty_query():
    env = ExplainerEnvironment()
    env.reset(seed=1)
    action = ExplainerAction(action_type="explore", tool="search_wikipedia", query="")
    obs = env.step(action)
    assert obs.reward == 0.0
    assert "Empty query" in obs.feedback


def test_explore_max_steps():
    env = ExplainerEnvironment()
    env.reset(seed=1)
    for i in range(MAX_EXPLORE_STEPS):
        obs = env.step(ExplainerAction(
            action_type="explore",
            tool="search_wikipedia",
            query=f"search {i}",
        ))
    assert obs.phase == "generate"
    assert obs.explore_steps_left == 0


def test_explore_then_generate():
    env = ExplainerEnvironment()
    env.reset(seed=1)
    # Explore
    obs = env.step(ExplainerAction(
        action_type="explore",
        tool="search_wikipedia",
        query="gradient descent",
    ))
    assert obs.done is False
    assert obs.search_results != ""
    # Generate
    obs = env.step(ExplainerAction(
        action_type="generate",
        format="marimo",
        code="import marimo as mo\napp = mo.App()\n@app.cell\ndef _():\n    return\n",
    ))
    assert obs.phase in ("repair", "done")
    assert isinstance(obs.reward, (int, float))


def test_generate_without_explore_penalty():
    env = ExplainerEnvironment()
    env.reset(seed=1)
    obs = env.step(ExplainerAction(
        action_type="generate",
        format="marimo",
        code="x = 1",
    ))
    assert obs.done is False
    assert obs.phase == "repair"
    assert "penalty" in obs.feedback.lower() or "without" in obs.feedback.lower()


def test_step_without_reset():
    env = ExplainerEnvironment()
    action = ExplainerAction(action_type="explore", tool="search_wikipedia", query="test")
    obs = env.step(action)
    assert obs.done is True
    assert obs.reward == -1.0


def test_generate_reward_in_metadata():
    env = ExplainerEnvironment()
    env.reset(seed=1)
    env.step(ExplainerAction(
        action_type="explore",
        tool="search_wikipedia",
        query="gradient descent",
    ))
    obs = env.step(ExplainerAction(
        action_type="generate",
        format="marimo",
        code="x = 1",
    ))
    for key in ("validity", "task_alignment", "structure", "research_usage"):
        assert key in obs.metadata, f"missing {key} in metadata"
    assert "explore_steps_used" in obs.metadata


def test_state_episode_id_changes():
    env = ExplainerEnvironment()
    env.reset()
    eid1 = env.state.episode_id
    env.reset()
    eid2 = env.state.episode_id
    assert eid1 != eid2


def test_step_increments_count():
    env = ExplainerEnvironment()
    env.reset(seed=1)
    assert env.state.step_count == 0
    env.step(ExplainerAction(
        action_type="explore",
        tool="search_wikipedia",
        query="test",
    ))
    assert env.state.step_count == 1
    env.step(ExplainerAction(action_type="generate", format="marimo", code="x=1"))
    assert env.state.step_count == 2


def test_bad_code_does_not_crash():
    env = ExplainerEnvironment()
    env.reset(seed=1)
    obs = env.step(ExplainerAction(
        action_type="generate",
        format="marimo",
        code=")))syntax error(((",
    ))
    assert obs.done is False
    assert obs.phase == "repair"
    assert "SYNTAX ERROR" in obs.feedback


def test_failed_repair_can_continue_until_limit():
    env = ExplainerEnvironment()
    env.reset(seed=1)
    env.step(ExplainerAction(
        action_type="generate",
        format="marimo",
        code="x = 1",
    ))
    obs = env.step(ExplainerAction(
        action_type="repair",
        format="marimo",
        code="x = 2",
        repair_notes="attempted fix",
    ))
    assert obs.done is False
    assert obs.phase == "repair"
    assert obs.repair_attempts_left == MAX_REPAIR_STEPS - 1
    assert obs.metadata["phase"] == "repair"

    for attempt in range(MAX_REPAIR_STEPS - 1):
        obs = env.step(ExplainerAction(
            action_type="repair",
            format="marimo",
            code=f"x = {attempt + 3}",
            repair_notes="still invalid",
        ))
    assert obs.done is True
    assert obs.phase == "done"
    assert obs.repair_attempts_left == 0


if __name__ == "__main__":
    tests = [
        test_reset_returns_observation,
        test_reset_deterministic_with_seed,
        test_explore_step,
        test_explore_empty_query,
        test_explore_max_steps,
        test_explore_then_generate,
        test_generate_without_explore_penalty,
        test_step_without_reset,
        test_generate_reward_in_metadata,
        test_state_episode_id_changes,
        test_step_increments_count,
        test_bad_code_does_not_crash,
        test_failed_repair_can_continue_until_limit,
    ]
    passed = 0
    for t in tests:
        try:
            t()
            passed += 1
        except Exception as e:
            print(f"FAIL: {t.__name__}: {e}")
    print(f"PASS: test_environment ({passed}/{len(tests)})")