Spaces:
Sleeping
Sleeping
| """ | |
| Quick state verification script. | |
| Runs a single hard episode (vanishing_gradients β requires all 3 sources) | |
| and calls env.state() after each action. Prints a diff of what changed. | |
| Usage: | |
| uv run python test_state.py | |
| SERVER_URL=http://localhost:8001 uv run python test_state.py | |
| """ | |
| import asyncio | |
| import os | |
| from client import WhyDidItFailEnv | |
| from models import WhyDidItFailAction | |
| SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000") | |
| SCENARIO = "vanishing_gradients" # hard β needs logs + config + gradients | |
| def _fmt(state) -> dict: | |
| return { | |
| "episode_id": state.episode_id, | |
| "step_count": state.step_count, | |
| "scenario_key": state.scenario_key, | |
| "difficulty": state.difficulty, | |
| "inspection_order": state.inspection_order, | |
| "required_sources": state.required_sources, | |
| "max_steps": state.max_steps, | |
| } | |
| def _check(label: str, state, expected: dict): | |
| s = _fmt(state) | |
| failures = [] | |
| for k, v in expected.items(): | |
| if s[k] != v: | |
| failures.append(f" FAIL {k}: expected {v!r}, got {s[k]!r}") | |
| if failures: | |
| print(f"\n[{label}] FAILED") | |
| for f in failures: | |
| print(f) | |
| else: | |
| print(f"[{label}] OK β {s}") | |
| async def main(): | |
| env = WhyDidItFailEnv(base_url=SERVER_URL) | |
| print(f"Connecting to {SERVER_URL} ...") | |
| # ββ 1. State before reset βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| state = await env.state() | |
| _check("before reset", state, { | |
| "scenario_key": None, | |
| "difficulty": None, | |
| "inspection_order": [], | |
| "required_sources": [], | |
| "max_steps": 0, | |
| }) | |
| # ββ 2. State after reset ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| await env.reset(scenario_key=SCENARIO) | |
| state = await env.state() | |
| _check("after reset", state, { | |
| "scenario_key": SCENARIO, | |
| "difficulty": "hard", | |
| "inspection_order": [], | |
| "required_sources": ["logs", "config", "gradients"], | |
| "max_steps": 11, # 3 required Γ 3 + 2 | |
| }) | |
| # ββ 3. After inspect_logs βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| await env.step(WhyDidItFailAction(action_type="inspect_logs",diagnosis=None,suggested_fix=None,reasoning=None)) | |
| state = await env.state() | |
| _check("after inspect_logs", state, { | |
| "step_count": 1, | |
| "inspection_order": ["logs"], | |
| "required_sources": ["logs", "config", "gradients"], | |
| }) | |
| # ββ 4. After inspect_config βββββββββββββββββββββββββββββββββββββββββββββββ | |
| await env.step(WhyDidItFailAction(action_type="inspect_config", diagnosis=None, suggested_fix=None, reasoning=None)) | |
| state = await env.state() | |
| _check("after inspect_config", state, { | |
| "step_count": 2, | |
| "inspection_order": ["logs", "config"], | |
| }) | |
| # ββ 5. After inspect_gradients ββββββββββββββββββββββββββββββββββββββββββββ | |
| await env.step(WhyDidItFailAction(action_type="inspect_gradients", diagnosis=None, suggested_fix=None, reasoning=None)) | |
| state = await env.state() | |
| _check("after inspect_gradients", state, { | |
| "step_count": 3, | |
| "inspection_order": ["logs", "config", "gradients"], | |
| }) | |
| # ββ 6. After re-inspection (should not duplicate) βββββββββββββββββββββββββ | |
| await env.step(WhyDidItFailAction(action_type="inspect_logs", diagnosis=None, suggested_fix=None, reasoning=None)) | |
| state = await env.state() | |
| _check("after re-inspect logs", state, { | |
| "step_count": 4, | |
| "inspection_order": ["logs", "config", "gradients"], # no duplicate | |
| }) | |
| # ββ 7. After submit β episode done ββββββββββββββββββββββββββββββββββββββββ | |
| await env.step(WhyDidItFailAction( | |
| action_type="submit_diagnosis", | |
| diagnosis="vanishing gradients", | |
| suggested_fix="switch activation to relu and add batch normalization", | |
| reasoning="gradient norms decay from 0.21 at output to 1e-8 at layer_1; config shows activation=sigmoid", | |
| )) | |
| state = await env.state() | |
| _check("after submit", state, { | |
| "step_count": 5, | |
| "scenario_key": SCENARIO, | |
| "inspection_order": ["logs", "config", "gradients"], | |
| }) | |
| await env.close() | |
| print("\nDone.") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |