File size: 9,960 Bytes
4b7e54c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
#!/usr/bin/env python3
"""Test script for the ADHD coaching environment.

Tests the environment directly (no server needed) and via HTTP if a server is running.

Usage:
    # Direct test (no server):
    cd adhd_env && .venv/bin/python test_environment.py

    # With server running:
    cd adhd_env && .venv/bin/uvicorn server.app:app --host 0.0.0.0 --port 8000 &
    cd adhd_env && .venv/bin/python test_environment.py --http
"""

import sys


def test_direct():
    """Test environment directly without HTTP server."""
    from server.adhd_env_environment import ADHDEnvironment
    from models import ADHDAction

    env = ADHDEnvironment()
    print("=" * 60)
    print("DIRECT ENVIRONMENT TEST")
    print("=" * 60)

    # Test reset returns valid state
    obs = env.reset()
    print(f"\n--- Reset ---")
    print(f"Scenario: {obs.scenario}")
    print(f"State: {obs.state}")
    print(f"Done: {obs.done}")
    print(f"Reward: {obs.reward}")

    assert obs.scenario, "Scenario should not be empty"
    assert obs.done is False
    assert obs.reward == 0.0

    # Validate state has all 3 keys
    assert "time_of_day" in obs.state, "Missing time_of_day"
    assert "position_in_chair" in obs.state, "Missing position_in_chair"
    assert "minutes_since_last_stood" in obs.state, "Missing minutes_since_last_stood"
    assert obs.state["position_in_chair"] in ("normal", "slouching", "standing")
    assert 0 <= obs.state["minutes_since_last_stood"] <= 240
    print("State validation: PASS")

    # Variety check: reset 10x and verify we get at least 2 distinct states
    states = []
    for _ in range(10):
        o = env.reset()
        states.append(
            (o.state["time_of_day"], o.state["position_in_chair"], o.state["minutes_since_last_stood"])
        )
    unique_states = len(set(states))
    assert unique_states >= 2, f"Expected at least 2 distinct states, got {unique_states}"
    print(f"State variety check ({unique_states} unique in 10 resets): PASS")

    print(f"\n{'=' * 60}")
    print("ALL DIRECT TESTS PASSED")
    print(f"{'=' * 60}")


def test_rubric():
    """Test rubric scoring with positive and negative cases."""
    from server.adhd_env_environment import ADHDEnvironment
    from models import ADHDAction
    from reward import score_rubric

    print(f"\n{'=' * 60}")
    print("RUBRIC TEST")
    print(f"{'=' * 60}")

    # State where user has been sitting a long time and is slouching
    tired_state = {
        "time_of_day": "14:00",
        "position_in_chair": "slouching",
        "minutes_since_last_stood": 90,
    }

    evening_state = {
        "time_of_day": "21:00",
        "position_in_chair": "normal",
        "minutes_since_last_stood": 30,
    }

    # POSITIVE: ADHD scenario + primary tool + state-aware message
    action_good = ADHDAction(
        tool_calls=["adhd_coach_tool"],
        message="Stand up and stretch for 30 seconds, then type just the recipient name.",
    )
    result = score_rubric(action_good, "I can't start the email", tired_state, True, None)
    print(f"\nPOSITIVE (ADHD + primary tool + state-aware): {result['total_score']}")
    assert result["total_score"] >= 0.7, f"Expected >= 0.7, got {result['total_score']}"
    print("PASS")

    # NEGATIVE: ADHD scenario + wrong-domain tool
    action_wrong_tool = ADHDAction(
        tool_calls=["web_search_tool"],
        message="Let me search for tips on email writing.",
    )
    result = score_rubric(action_wrong_tool, "I can't start the email", tired_state, True, None)
    print(f"\nNEGATIVE (ADHD + web_search_tool): {result['total_score']}")
    assert result["total_score"] < 0.3, f"Expected < 0.3, got {result['total_score']}"
    print("PASS")

    # NEGATIVE: Non-ADHD scenario + ADHD tool
    action_adhd_on_non = ADHDAction(
        tool_calls=["adhd_coach_tool"],
        message="Let me help you initiate that task.",
    )
    result = score_rubric(action_adhd_on_non, "What's the weather?", tired_state, False, "web_search_tool")
    print(f"\nNEGATIVE (non-ADHD + ADHD tool): {result['total_score']}")
    assert result["total_score"] < 0.3, f"Expected < 0.3, got {result['total_score']}"
    print("PASS")

    # SLIGHTLY POSITIVE: Non-ADHD factual + correct tool
    action_correct_non_adhd = ADHDAction(
        tool_calls=["web_search_tool"],
        message="Let me look that up for you.",
    )
    result = score_rubric(action_correct_non_adhd, "What is the capital of France?", tired_state, False, "web_search_tool")
    print(f"\nSLIGHTLY POSITIVE (non-ADHD + correct tool): {result['total_score']}")
    assert result["total_score"] >= 0.5, f"Expected >= 0.5, got {result['total_score']}"
    print("PASS")

    # NEUTRAL: Non-ADHD creative + no tool
    action_no_tool_creative = ADHDAction(
        tool_calls=[],
        message="Here is a poem about cats.",
    )
    result = score_rubric(action_no_tool_creative, "Write me a poem about cats", tired_state, False, None)
    print(f"\nNEUTRAL (non-ADHD creative + no tool): {result['total_score']}")
    assert 0.3 <= result["total_score"] <= 0.7, f"Expected 0.3-0.7, got {result['total_score']}"
    print("PASS")

    # MEDIUM: ADHD + primary tool + generic message (no state awareness)
    action_generic = ADHDAction(
        tool_calls=["adhd_coach_tool"],
        message="Try breaking this task into smaller pieces.",
    )
    result = score_rubric(action_generic, "I'm stuck on this report", tired_state, True, None)
    print(f"\nMEDIUM (ADHD + primary tool + generic): {result['total_score']}")
    assert 0.4 <= result["total_score"] <= 0.85, f"Expected 0.4-0.85, got {result['total_score']}"
    print("PASS")

    # EVENING: ADHD + primary tool + evening-aware message
    action_evening = ADHDAction(
        tool_calls=["adhd_coach_tool"],
        message="It's late. Pick a small easy task to finish tonight, save the rest for tomorrow.",
    )
    result = score_rubric(action_evening, "I can't focus on this", evening_state, True, None)
    print(f"\nEVENING AWARE (ADHD + primary tool + evening tips): {result['total_score']}")
    assert result["total_score"] >= 0.7, f"Expected >= 0.7, got {result['total_score']}"
    print("PASS")

    # REFLECTIVE QUESTION: ADHD + primary tool + clarifying question
    action_reflective = ADHDAction(
        tool_calls=["adhd_coach_tool"],
        message="What are you specifically stuck on? Explain the first step you think you need to take.",
    )
    result_reflective = score_rubric(action_reflective, "I've been stuck for 30 minutes", tired_state, True, None)
    # Compare against same scenario with generic non-reflective message
    action_plain = ADHDAction(
        tool_calls=["adhd_coach_tool"],
        message="Just try to get started on it.",
    )
    result_plain = score_rubric(action_plain, "I've been stuck for 30 minutes", tired_state, True, None)
    print(f"\nREFLECTIVE Q (ADHD + primary tool + clarifying question): {result_reflective['total_score']}")
    print(f"  vs PLAIN (ADHD + primary tool + generic): {result_plain['total_score']}")
    assert result_reflective["total_score"] > result_plain["total_score"], \
        f"Reflective question should score higher than plain: {result_reflective['total_score']} vs {result_plain['total_score']}"
    print("PASS")

    print(f"\n{'=' * 60}")
    print("ALL RUBRIC TESTS PASSED")
    print(f"{'=' * 60}")


def test_http(base_url="http://localhost:8000"):
    """Test environment via HTTP endpoints."""
    import requests

    print(f"\n{'=' * 60}")
    print(f"HTTP TEST ({base_url})")
    print(f"{'=' * 60}")

    # Health check
    r = requests.get(f"{base_url}/health")
    assert r.status_code == 200
    print(f"\nHealth: {r.json()}")

    # Schema
    r = requests.get(f"{base_url}/schema")
    assert r.status_code == 200
    schema = r.json()
    assert "action" in schema
    assert "observation" in schema
    print(f"Schema: action has {list(schema['action']['properties'].keys())}")
    print(f"Schema: observation has {list(schema['observation']['properties'].keys())}")

    # Reset
    r = requests.post(f"{base_url}/reset")
    assert r.status_code == 200
    data = r.json()
    assert data["done"] is False
    assert data["reward"] == 0.0
    assert "scenario" in data["observation"]
    obs = data["observation"]
    assert "state" in obs
    assert "time_of_day" in obs["state"]
    assert "position_in_chair" in obs["state"]
    assert "minutes_since_last_stood" in obs["state"]
    print(f"\nReset: scenario='{obs['scenario']}'")
    print(f"  state={obs['state']}")
    print(f"  State keys present: PASS")

    # Good action (ADHD scenario + primary tool)
    r = requests.post(f"{base_url}/step", json={
        "action": {
            "tool_calls": ["adhd_coach_tool"],
            "message": "Stand up and stretch, then type just the recipient name.",
        }
    })
    assert r.status_code == 200
    data = r.json()
    assert data["done"] is True
    assert data["reward"] > 0
    print(f"Good action: reward={data['reward']} PASS")

    # Bad action (no tools on presumed ADHD scenario)
    r = requests.post(f"{base_url}/step", json={
        "action": {
            "tool_calls": [],
            "message": "What do you want to work on?",
        }
    })
    assert r.status_code == 200
    data = r.json()
    print(f"No-tool action: reward={data['reward']}")

    # Verify scoring details in response
    assert "scoring" in data["observation"]
    assert "total_score" in data["observation"]["scoring"]
    assert "criteria" in data["observation"]["scoring"]
    print(f"Scoring details present: PASS")

    print(f"\n{'=' * 60}")
    print("ALL HTTP TESTS PASSED")
    print(f"{'=' * 60}")


if __name__ == "__main__":
    test_direct()
    test_rubric()

    if "--http" in sys.argv:
        url = "http://localhost:8000"
        for arg in sys.argv:
            if arg.startswith("http"):
                url = arg
        test_http(url)