| |
| """Test script for the ADHD coaching environment. |
| |
| Tests the environment directly (no server needed) and via HTTP if a server is running. |
| |
| Usage: |
| # Direct test (no server): |
| cd adhd_env && .venv/bin/python test_environment.py |
| |
| # With server running: |
| cd adhd_env && .venv/bin/uvicorn server.app:app --host 0.0.0.0 --port 8000 & |
| cd adhd_env && .venv/bin/python test_environment.py --http |
| """ |
|
|
| import sys |
|
|
|
|
| def test_direct(): |
| """Test environment directly without HTTP server.""" |
| from server.adhd_env_environment import ADHDEnvironment |
| from models import ADHDAction |
|
|
| env = ADHDEnvironment() |
| print("=" * 60) |
| print("DIRECT ENVIRONMENT TEST") |
| print("=" * 60) |
|
|
| |
| obs = env.reset() |
| print(f"\n--- Reset ---") |
| print(f"Scenario: {obs.scenario}") |
| print(f"State: {obs.state}") |
| print(f"Done: {obs.done}") |
| print(f"Reward: {obs.reward}") |
|
|
| assert obs.scenario, "Scenario should not be empty" |
| assert obs.done is False |
| assert obs.reward == 0.0 |
|
|
| |
| assert "time_of_day" in obs.state, "Missing time_of_day" |
| assert "position_in_chair" in obs.state, "Missing position_in_chair" |
| assert "minutes_since_last_stood" in obs.state, "Missing minutes_since_last_stood" |
| assert obs.state["position_in_chair"] in ("normal", "slouching", "standing") |
| assert 0 <= obs.state["minutes_since_last_stood"] <= 240 |
| print("State validation: PASS") |
|
|
| |
| states = [] |
| for _ in range(10): |
| o = env.reset() |
| states.append( |
| (o.state["time_of_day"], o.state["position_in_chair"], o.state["minutes_since_last_stood"]) |
| ) |
| unique_states = len(set(states)) |
| assert unique_states >= 2, f"Expected at least 2 distinct states, got {unique_states}" |
| print(f"State variety check ({unique_states} unique in 10 resets): PASS") |
|
|
| print(f"\n{'=' * 60}") |
| print("ALL DIRECT TESTS PASSED") |
| print(f"{'=' * 60}") |
|
|
|
|
| def test_rubric(): |
| """Test rubric scoring with positive and negative cases.""" |
| from server.adhd_env_environment import ADHDEnvironment |
| from models import ADHDAction |
| from reward import score_rubric |
|
|
| print(f"\n{'=' * 60}") |
| print("RUBRIC TEST") |
| print(f"{'=' * 60}") |
|
|
| |
| tired_state = { |
| "time_of_day": "14:00", |
| "position_in_chair": "slouching", |
| "minutes_since_last_stood": 90, |
| } |
|
|
| evening_state = { |
| "time_of_day": "21:00", |
| "position_in_chair": "normal", |
| "minutes_since_last_stood": 30, |
| } |
|
|
| |
| action_good = ADHDAction( |
| tool_calls=["adhd_coach_tool"], |
| message="Stand up and stretch for 30 seconds, then type just the recipient name.", |
| ) |
| result = score_rubric(action_good, "I can't start the email", tired_state, True, None) |
| print(f"\nPOSITIVE (ADHD + primary tool + state-aware): {result['total_score']}") |
| assert result["total_score"] >= 0.7, f"Expected >= 0.7, got {result['total_score']}" |
| print("PASS") |
|
|
| |
| action_wrong_tool = ADHDAction( |
| tool_calls=["web_search_tool"], |
| message="Let me search for tips on email writing.", |
| ) |
| result = score_rubric(action_wrong_tool, "I can't start the email", tired_state, True, None) |
| print(f"\nNEGATIVE (ADHD + web_search_tool): {result['total_score']}") |
| assert result["total_score"] < 0.3, f"Expected < 0.3, got {result['total_score']}" |
| print("PASS") |
|
|
| |
| action_adhd_on_non = ADHDAction( |
| tool_calls=["adhd_coach_tool"], |
| message="Let me help you initiate that task.", |
| ) |
| result = score_rubric(action_adhd_on_non, "What's the weather?", tired_state, False, "web_search_tool") |
| print(f"\nNEGATIVE (non-ADHD + ADHD tool): {result['total_score']}") |
| assert result["total_score"] < 0.3, f"Expected < 0.3, got {result['total_score']}" |
| print("PASS") |
|
|
| |
| action_correct_non_adhd = ADHDAction( |
| tool_calls=["web_search_tool"], |
| message="Let me look that up for you.", |
| ) |
| result = score_rubric(action_correct_non_adhd, "What is the capital of France?", tired_state, False, "web_search_tool") |
| print(f"\nSLIGHTLY POSITIVE (non-ADHD + correct tool): {result['total_score']}") |
| assert result["total_score"] >= 0.5, f"Expected >= 0.5, got {result['total_score']}" |
| print("PASS") |
|
|
| |
| action_no_tool_creative = ADHDAction( |
| tool_calls=[], |
| message="Here is a poem about cats.", |
| ) |
| result = score_rubric(action_no_tool_creative, "Write me a poem about cats", tired_state, False, None) |
| print(f"\nNEUTRAL (non-ADHD creative + no tool): {result['total_score']}") |
| assert 0.3 <= result["total_score"] <= 0.7, f"Expected 0.3-0.7, got {result['total_score']}" |
| print("PASS") |
|
|
| |
| action_generic = ADHDAction( |
| tool_calls=["adhd_coach_tool"], |
| message="Try breaking this task into smaller pieces.", |
| ) |
| result = score_rubric(action_generic, "I'm stuck on this report", tired_state, True, None) |
| print(f"\nMEDIUM (ADHD + primary tool + generic): {result['total_score']}") |
| assert 0.4 <= result["total_score"] <= 0.85, f"Expected 0.4-0.85, got {result['total_score']}" |
| print("PASS") |
|
|
| |
| action_evening = ADHDAction( |
| tool_calls=["adhd_coach_tool"], |
| message="It's late. Pick a small easy task to finish tonight, save the rest for tomorrow.", |
| ) |
| result = score_rubric(action_evening, "I can't focus on this", evening_state, True, None) |
| print(f"\nEVENING AWARE (ADHD + primary tool + evening tips): {result['total_score']}") |
| assert result["total_score"] >= 0.7, f"Expected >= 0.7, got {result['total_score']}" |
| print("PASS") |
|
|
| |
| action_reflective = ADHDAction( |
| tool_calls=["adhd_coach_tool"], |
| message="What are you specifically stuck on? Explain the first step you think you need to take.", |
| ) |
| result_reflective = score_rubric(action_reflective, "I've been stuck for 30 minutes", tired_state, True, None) |
| |
| action_plain = ADHDAction( |
| tool_calls=["adhd_coach_tool"], |
| message="Just try to get started on it.", |
| ) |
| result_plain = score_rubric(action_plain, "I've been stuck for 30 minutes", tired_state, True, None) |
| print(f"\nREFLECTIVE Q (ADHD + primary tool + clarifying question): {result_reflective['total_score']}") |
| print(f" vs PLAIN (ADHD + primary tool + generic): {result_plain['total_score']}") |
| assert result_reflective["total_score"] > result_plain["total_score"], \ |
| f"Reflective question should score higher than plain: {result_reflective['total_score']} vs {result_plain['total_score']}" |
| print("PASS") |
|
|
| print(f"\n{'=' * 60}") |
| print("ALL RUBRIC TESTS PASSED") |
| print(f"{'=' * 60}") |
|
|
|
|
| def test_http(base_url="http://localhost:8000"): |
| """Test environment via HTTP endpoints.""" |
| import requests |
|
|
| print(f"\n{'=' * 60}") |
| print(f"HTTP TEST ({base_url})") |
| print(f"{'=' * 60}") |
|
|
| |
| r = requests.get(f"{base_url}/health") |
| assert r.status_code == 200 |
| print(f"\nHealth: {r.json()}") |
|
|
| |
| r = requests.get(f"{base_url}/schema") |
| assert r.status_code == 200 |
| schema = r.json() |
| assert "action" in schema |
| assert "observation" in schema |
| print(f"Schema: action has {list(schema['action']['properties'].keys())}") |
| print(f"Schema: observation has {list(schema['observation']['properties'].keys())}") |
|
|
| |
| r = requests.post(f"{base_url}/reset") |
| assert r.status_code == 200 |
| data = r.json() |
| assert data["done"] is False |
| assert data["reward"] == 0.0 |
| assert "scenario" in data["observation"] |
| obs = data["observation"] |
| assert "state" in obs |
| assert "time_of_day" in obs["state"] |
| assert "position_in_chair" in obs["state"] |
| assert "minutes_since_last_stood" in obs["state"] |
| print(f"\nReset: scenario='{obs['scenario']}'") |
| print(f" state={obs['state']}") |
| print(f" State keys present: PASS") |
|
|
| |
| r = requests.post(f"{base_url}/step", json={ |
| "action": { |
| "tool_calls": ["adhd_coach_tool"], |
| "message": "Stand up and stretch, then type just the recipient name.", |
| } |
| }) |
| assert r.status_code == 200 |
| data = r.json() |
| assert data["done"] is True |
| assert data["reward"] > 0 |
| print(f"Good action: reward={data['reward']} PASS") |
|
|
| |
| r = requests.post(f"{base_url}/step", json={ |
| "action": { |
| "tool_calls": [], |
| "message": "What do you want to work on?", |
| } |
| }) |
| assert r.status_code == 200 |
| data = r.json() |
| print(f"No-tool action: reward={data['reward']}") |
|
|
| |
| assert "scoring" in data["observation"] |
| assert "total_score" in data["observation"]["scoring"] |
| assert "criteria" in data["observation"]["scoring"] |
| print(f"Scoring details present: PASS") |
|
|
| print(f"\n{'=' * 60}") |
| print("ALL HTTP TESTS PASSED") |
| print(f"{'=' * 60}") |
|
|
|
|
| if __name__ == "__main__": |
| test_direct() |
| test_rubric() |
|
|
| if "--http" in sys.argv: |
| url = "http://localhost:8000" |
| for arg in sys.argv: |
| if arg.startswith("http"): |
| url = arg |
| test_http(url) |
|
|