Spaces:
Sleeping
Sleeping
File size: 5,024 Bytes
ff8ce5f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | """
Quick state verification script.
Runs a single hard episode (vanishing_gradients β requires all 3 sources)
and calls env.state() after each action. Prints a diff of what changed.
Usage:
uv run python test_state.py
SERVER_URL=http://localhost:8001 uv run python test_state.py
"""
import asyncio
import os
from client import WhyDidItFailEnv
from models import WhyDidItFailAction
SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
SCENARIO = "vanishing_gradients" # hard β needs logs + config + gradients
def _fmt(state) -> dict:
return {
"episode_id": state.episode_id,
"step_count": state.step_count,
"scenario_key": state.scenario_key,
"difficulty": state.difficulty,
"inspection_order": state.inspection_order,
"required_sources": state.required_sources,
"max_steps": state.max_steps,
}
def _check(label: str, state, expected: dict):
s = _fmt(state)
failures = []
for k, v in expected.items():
if s[k] != v:
failures.append(f" FAIL {k}: expected {v!r}, got {s[k]!r}")
if failures:
print(f"\n[{label}] FAILED")
for f in failures:
print(f)
else:
print(f"[{label}] OK β {s}")
async def main():
env = WhyDidItFailEnv(base_url=SERVER_URL)
print(f"Connecting to {SERVER_URL} ...")
# ββ 1. State before reset βββββββββββββββββββββββββββββββββββββββββββββββββ
state = await env.state()
_check("before reset", state, {
"scenario_key": None,
"difficulty": None,
"inspection_order": [],
"required_sources": [],
"max_steps": 0,
})
# ββ 2. State after reset ββββββββββββββββββββββββββββββββββββββββββββββββββ
await env.reset(scenario_key=SCENARIO)
state = await env.state()
_check("after reset", state, {
"scenario_key": SCENARIO,
"difficulty": "hard",
"inspection_order": [],
"required_sources": ["logs", "config", "gradients"],
"max_steps": 11, # 3 required Γ 3 + 2
})
# ββ 3. After inspect_logs βββββββββββββββββββββββββββββββββββββββββββββββββ
await env.step(WhyDidItFailAction(action_type="inspect_logs",diagnosis=None,suggested_fix=None,reasoning=None))
state = await env.state()
_check("after inspect_logs", state, {
"step_count": 1,
"inspection_order": ["logs"],
"required_sources": ["logs", "config", "gradients"],
})
# ββ 4. After inspect_config βββββββββββββββββββββββββββββββββββββββββββββββ
await env.step(WhyDidItFailAction(action_type="inspect_config", diagnosis=None, suggested_fix=None, reasoning=None))
state = await env.state()
_check("after inspect_config", state, {
"step_count": 2,
"inspection_order": ["logs", "config"],
})
# ββ 5. After inspect_gradients ββββββββββββββββββββββββββββββββββββββββββββ
await env.step(WhyDidItFailAction(action_type="inspect_gradients", diagnosis=None, suggested_fix=None, reasoning=None))
state = await env.state()
_check("after inspect_gradients", state, {
"step_count": 3,
"inspection_order": ["logs", "config", "gradients"],
})
# ββ 6. After re-inspection (should not duplicate) βββββββββββββββββββββββββ
await env.step(WhyDidItFailAction(action_type="inspect_logs", diagnosis=None, suggested_fix=None, reasoning=None))
state = await env.state()
_check("after re-inspect logs", state, {
"step_count": 4,
"inspection_order": ["logs", "config", "gradients"], # no duplicate
})
# ββ 7. After submit β episode done ββββββββββββββββββββββββββββββββββββββββ
await env.step(WhyDidItFailAction(
action_type="submit_diagnosis",
diagnosis="vanishing gradients",
suggested_fix="switch activation to relu and add batch normalization",
reasoning="gradient norms decay from 0.21 at output to 1e-8 at layer_1; config shows activation=sigmoid",
))
state = await env.state()
_check("after submit", state, {
"step_count": 5,
"scenario_key": SCENARIO,
"inspection_order": ["logs", "config", "gradients"],
})
await env.close()
print("\nDone.")
if __name__ == "__main__":
asyncio.run(main()) |