File size: 5,024 Bytes
ff8ce5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Quick state verification script.

Runs a single hard episode (vanishing_gradients β€” requires all 3 sources)
and calls env.state() after each action. Prints a diff of what changed.

Usage:
    uv run python test_state.py
    SERVER_URL=http://localhost:8001 uv run python test_state.py
"""

import asyncio
import os

from client import WhyDidItFailEnv
from models import WhyDidItFailAction


SERVER_URL = os.getenv("SERVER_URL", "http://localhost:8000")
SCENARIO    = "vanishing_gradients"   # hard β€” needs logs + config + gradients


def _fmt(state) -> dict:
    return {
        "episode_id":       state.episode_id,
        "step_count":       state.step_count,
        "scenario_key":     state.scenario_key,
        "difficulty":       state.difficulty,
        "inspection_order": state.inspection_order,
        "required_sources": state.required_sources,
        "max_steps":        state.max_steps,
    }


def _check(label: str, state, expected: dict):
    s = _fmt(state)
    failures = []
    for k, v in expected.items():
        if s[k] != v:
            failures.append(f"  FAIL  {k}: expected {v!r}, got {s[k]!r}")
    if failures:
        print(f"\n[{label}] FAILED")
        for f in failures:
            print(f)
    else:
        print(f"[{label}] OK  β€” {s}")


async def main():
    env = WhyDidItFailEnv(base_url=SERVER_URL)

    print(f"Connecting to {SERVER_URL} ...")

    # ── 1. State before reset ─────────────────────────────────────────────────
    state = await env.state()
    _check("before reset", state, {
        "scenario_key":     None,
        "difficulty":       None,
        "inspection_order": [],
        "required_sources": [],
        "max_steps":        0,
    })

    # ── 2. State after reset ──────────────────────────────────────────────────
    await env.reset(scenario_key=SCENARIO)
    state = await env.state()
    _check("after reset", state, {
        "scenario_key":     SCENARIO,
        "difficulty":       "hard",
        "inspection_order": [],
        "required_sources": ["logs", "config", "gradients"],
        "max_steps":        11,   # 3 required Γ— 3 + 2
    })

    # ── 3. After inspect_logs ─────────────────────────────────────────────────
    await env.step(WhyDidItFailAction(action_type="inspect_logs",diagnosis=None,suggested_fix=None,reasoning=None))
    state = await env.state()
    _check("after inspect_logs", state, {
        "step_count":       1,
        "inspection_order": ["logs"],
        "required_sources": ["logs", "config", "gradients"],
    })

    # ── 4. After inspect_config ───────────────────────────────────────────────
    await env.step(WhyDidItFailAction(action_type="inspect_config", diagnosis=None, suggested_fix=None, reasoning=None))
    state = await env.state()
    _check("after inspect_config", state, {
        "step_count":       2,
        "inspection_order": ["logs", "config"],
    })

    # ── 5. After inspect_gradients ────────────────────────────────────────────
    await env.step(WhyDidItFailAction(action_type="inspect_gradients", diagnosis=None, suggested_fix=None, reasoning=None))
    state = await env.state()
    _check("after inspect_gradients", state, {
        "step_count":       3,
        "inspection_order": ["logs", "config", "gradients"],
    })

    # ── 6. After re-inspection (should not duplicate) ─────────────────────────
    await env.step(WhyDidItFailAction(action_type="inspect_logs", diagnosis=None, suggested_fix=None, reasoning=None))
    state = await env.state()
    _check("after re-inspect logs", state, {
        "step_count":       4,
        "inspection_order": ["logs", "config", "gradients"],   # no duplicate
    })

    # ── 7. After submit β€” episode done ────────────────────────────────────────
    await env.step(WhyDidItFailAction(
        action_type="submit_diagnosis",
        diagnosis="vanishing gradients",
        suggested_fix="switch activation to relu and add batch normalization",
        reasoning="gradient norms decay from 0.21 at output to 1e-8 at layer_1; config shows activation=sigmoid",
    ))
    state = await env.state()
    _check("after submit", state, {
        "step_count":       5,
        "scenario_key":     SCENARIO,
        "inspection_order": ["logs", "config", "gradients"],
    })

    await env.close()
    print("\nDone.")


if __name__ == "__main__":
    asyncio.run(main())