File size: 9,960 Bytes
4b7e54c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 | #!/usr/bin/env python3
"""Test script for the ADHD coaching environment.
Tests the environment directly (no server needed) and via HTTP if a server is running.
Usage:
# Direct test (no server):
cd adhd_env && .venv/bin/python test_environment.py
# With server running:
cd adhd_env && .venv/bin/uvicorn server.app:app --host 0.0.0.0 --port 8000 &
cd adhd_env && .venv/bin/python test_environment.py --http
"""
import sys
def test_direct():
"""Test environment directly without HTTP server."""
from server.adhd_env_environment import ADHDEnvironment
from models import ADHDAction
env = ADHDEnvironment()
print("=" * 60)
print("DIRECT ENVIRONMENT TEST")
print("=" * 60)
# Test reset returns valid state
obs = env.reset()
print(f"\n--- Reset ---")
print(f"Scenario: {obs.scenario}")
print(f"State: {obs.state}")
print(f"Done: {obs.done}")
print(f"Reward: {obs.reward}")
assert obs.scenario, "Scenario should not be empty"
assert obs.done is False
assert obs.reward == 0.0
# Validate state has all 3 keys
assert "time_of_day" in obs.state, "Missing time_of_day"
assert "position_in_chair" in obs.state, "Missing position_in_chair"
assert "minutes_since_last_stood" in obs.state, "Missing minutes_since_last_stood"
assert obs.state["position_in_chair"] in ("normal", "slouching", "standing")
assert 0 <= obs.state["minutes_since_last_stood"] <= 240
print("State validation: PASS")
# Variety check: reset 10x and verify we get at least 2 distinct states
states = []
for _ in range(10):
o = env.reset()
states.append(
(o.state["time_of_day"], o.state["position_in_chair"], o.state["minutes_since_last_stood"])
)
unique_states = len(set(states))
assert unique_states >= 2, f"Expected at least 2 distinct states, got {unique_states}"
print(f"State variety check ({unique_states} unique in 10 resets): PASS")
print(f"\n{'=' * 60}")
print("ALL DIRECT TESTS PASSED")
print(f"{'=' * 60}")
def test_rubric():
"""Test rubric scoring with positive and negative cases."""
from server.adhd_env_environment import ADHDEnvironment
from models import ADHDAction
from reward import score_rubric
print(f"\n{'=' * 60}")
print("RUBRIC TEST")
print(f"{'=' * 60}")
# State where user has been sitting a long time and is slouching
tired_state = {
"time_of_day": "14:00",
"position_in_chair": "slouching",
"minutes_since_last_stood": 90,
}
evening_state = {
"time_of_day": "21:00",
"position_in_chair": "normal",
"minutes_since_last_stood": 30,
}
# POSITIVE: ADHD scenario + primary tool + state-aware message
action_good = ADHDAction(
tool_calls=["adhd_coach_tool"],
message="Stand up and stretch for 30 seconds, then type just the recipient name.",
)
result = score_rubric(action_good, "I can't start the email", tired_state, True, None)
print(f"\nPOSITIVE (ADHD + primary tool + state-aware): {result['total_score']}")
assert result["total_score"] >= 0.7, f"Expected >= 0.7, got {result['total_score']}"
print("PASS")
# NEGATIVE: ADHD scenario + wrong-domain tool
action_wrong_tool = ADHDAction(
tool_calls=["web_search_tool"],
message="Let me search for tips on email writing.",
)
result = score_rubric(action_wrong_tool, "I can't start the email", tired_state, True, None)
print(f"\nNEGATIVE (ADHD + web_search_tool): {result['total_score']}")
assert result["total_score"] < 0.3, f"Expected < 0.3, got {result['total_score']}"
print("PASS")
# NEGATIVE: Non-ADHD scenario + ADHD tool
action_adhd_on_non = ADHDAction(
tool_calls=["adhd_coach_tool"],
message="Let me help you initiate that task.",
)
result = score_rubric(action_adhd_on_non, "What's the weather?", tired_state, False, "web_search_tool")
print(f"\nNEGATIVE (non-ADHD + ADHD tool): {result['total_score']}")
assert result["total_score"] < 0.3, f"Expected < 0.3, got {result['total_score']}"
print("PASS")
# SLIGHTLY POSITIVE: Non-ADHD factual + correct tool
action_correct_non_adhd = ADHDAction(
tool_calls=["web_search_tool"],
message="Let me look that up for you.",
)
result = score_rubric(action_correct_non_adhd, "What is the capital of France?", tired_state, False, "web_search_tool")
print(f"\nSLIGHTLY POSITIVE (non-ADHD + correct tool): {result['total_score']}")
assert result["total_score"] >= 0.5, f"Expected >= 0.5, got {result['total_score']}"
print("PASS")
# NEUTRAL: Non-ADHD creative + no tool
action_no_tool_creative = ADHDAction(
tool_calls=[],
message="Here is a poem about cats.",
)
result = score_rubric(action_no_tool_creative, "Write me a poem about cats", tired_state, False, None)
print(f"\nNEUTRAL (non-ADHD creative + no tool): {result['total_score']}")
assert 0.3 <= result["total_score"] <= 0.7, f"Expected 0.3-0.7, got {result['total_score']}"
print("PASS")
# MEDIUM: ADHD + primary tool + generic message (no state awareness)
action_generic = ADHDAction(
tool_calls=["adhd_coach_tool"],
message="Try breaking this task into smaller pieces.",
)
result = score_rubric(action_generic, "I'm stuck on this report", tired_state, True, None)
print(f"\nMEDIUM (ADHD + primary tool + generic): {result['total_score']}")
assert 0.4 <= result["total_score"] <= 0.85, f"Expected 0.4-0.85, got {result['total_score']}"
print("PASS")
# EVENING: ADHD + primary tool + evening-aware message
action_evening = ADHDAction(
tool_calls=["adhd_coach_tool"],
message="It's late. Pick a small easy task to finish tonight, save the rest for tomorrow.",
)
result = score_rubric(action_evening, "I can't focus on this", evening_state, True, None)
print(f"\nEVENING AWARE (ADHD + primary tool + evening tips): {result['total_score']}")
assert result["total_score"] >= 0.7, f"Expected >= 0.7, got {result['total_score']}"
print("PASS")
# REFLECTIVE QUESTION: ADHD + primary tool + clarifying question
action_reflective = ADHDAction(
tool_calls=["adhd_coach_tool"],
message="What are you specifically stuck on? Explain the first step you think you need to take.",
)
result_reflective = score_rubric(action_reflective, "I've been stuck for 30 minutes", tired_state, True, None)
# Compare against same scenario with generic non-reflective message
action_plain = ADHDAction(
tool_calls=["adhd_coach_tool"],
message="Just try to get started on it.",
)
result_plain = score_rubric(action_plain, "I've been stuck for 30 minutes", tired_state, True, None)
print(f"\nREFLECTIVE Q (ADHD + primary tool + clarifying question): {result_reflective['total_score']}")
print(f" vs PLAIN (ADHD + primary tool + generic): {result_plain['total_score']}")
assert result_reflective["total_score"] > result_plain["total_score"], \
f"Reflective question should score higher than plain: {result_reflective['total_score']} vs {result_plain['total_score']}"
print("PASS")
print(f"\n{'=' * 60}")
print("ALL RUBRIC TESTS PASSED")
print(f"{'=' * 60}")
def test_http(base_url="http://localhost:8000"):
"""Test environment via HTTP endpoints."""
import requests
print(f"\n{'=' * 60}")
print(f"HTTP TEST ({base_url})")
print(f"{'=' * 60}")
# Health check
r = requests.get(f"{base_url}/health")
assert r.status_code == 200
print(f"\nHealth: {r.json()}")
# Schema
r = requests.get(f"{base_url}/schema")
assert r.status_code == 200
schema = r.json()
assert "action" in schema
assert "observation" in schema
print(f"Schema: action has {list(schema['action']['properties'].keys())}")
print(f"Schema: observation has {list(schema['observation']['properties'].keys())}")
# Reset
r = requests.post(f"{base_url}/reset")
assert r.status_code == 200
data = r.json()
assert data["done"] is False
assert data["reward"] == 0.0
assert "scenario" in data["observation"]
obs = data["observation"]
assert "state" in obs
assert "time_of_day" in obs["state"]
assert "position_in_chair" in obs["state"]
assert "minutes_since_last_stood" in obs["state"]
print(f"\nReset: scenario='{obs['scenario']}'")
print(f" state={obs['state']}")
print(f" State keys present: PASS")
# Good action (ADHD scenario + primary tool)
r = requests.post(f"{base_url}/step", json={
"action": {
"tool_calls": ["adhd_coach_tool"],
"message": "Stand up and stretch, then type just the recipient name.",
}
})
assert r.status_code == 200
data = r.json()
assert data["done"] is True
assert data["reward"] > 0
print(f"Good action: reward={data['reward']} PASS")
# Bad action (no tools on presumed ADHD scenario)
r = requests.post(f"{base_url}/step", json={
"action": {
"tool_calls": [],
"message": "What do you want to work on?",
}
})
assert r.status_code == 200
data = r.json()
print(f"No-tool action: reward={data['reward']}")
# Verify scoring details in response
assert "scoring" in data["observation"]
assert "total_score" in data["observation"]["scoring"]
assert "criteria" in data["observation"]["scoring"]
print(f"Scoring details present: PASS")
print(f"\n{'=' * 60}")
print("ALL HTTP TESTS PASSED")
print(f"{'=' * 60}")
if __name__ == "__main__":
test_direct()
test_rubric()
if "--http" in sys.argv:
url = "http://localhost:8000"
for arg in sys.argv:
if arg.startswith("http"):
url = arg
test_http(url)
|