File size: 8,890 Bytes
2de6dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468d518
2de6dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468d518
 
 
 
 
 
 
 
 
 
 
2de6dc1
 
 
 
 
 
468d518
 
 
 
2de6dc1
 
 
 
 
 
468d518
 
 
 
 
 
 
 
 
 
 
 
 
 
2de6dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""Additional functionality tests for the Navis web environment."""

from __future__ import annotations

import os
import sys

from fastapi.testclient import TestClient

ROOT = os.path.join(os.path.dirname(__file__), "..")
if ROOT not in sys.path:
    sys.path.insert(0, ROOT)

from navis_web_env.server.app import app
from navis_web_env.server.navis_web_environment import NavisWebEnvironment
from navis_web_env.site_loader import list_curriculum_task_ids, list_task_ids, shortest_path_length, load_task


def _post_step(client: TestClient, click_link_id: str):
    response = client.post("/step", json={"click_link_id": click_link_id})
    if response.status_code == 422:
        response = client.post("/step", json={"action": {"click_link_id": click_link_id}})
    return response


def _post_step_with_optional_session(client: TestClient, click_link_id: str, session_id: str | None):
    payload = {"click_link_id": click_link_id}
    if session_id:
        payload["session_id"] = session_id
    response = client.post("/step", json=payload)
    if response.status_code == 422:
        action_payload = {"action": {"click_link_id": click_link_id}}
        if session_id:
            action_payload["session_id"] = session_id
        response = client.post("/step", json=action_payload)
    return response


def _unwrap_observation_payload(payload: dict):
    if "observation" in payload:
        return payload["observation"]
    if "result" in payload and isinstance(payload["result"], dict) and "observation" in payload["result"]:
        return payload["result"]["observation"]
    return payload


def _unwrap_info_payload(payload: dict):
    if "info" in payload and isinstance(payload["info"], dict):
        return payload["info"]
    if "result" in payload and isinstance(payload["result"], dict):
        result = payload["result"]
        if "info" in result and isinstance(result["info"], dict):
            return result["info"]
    return {}


def _looks_like_state(payload: dict) -> bool:
    return any(key in payload for key in ("task_id", "current_page_id", "page_id", "visited_pages", "step_count"))


def test_state_tracks_task_metadata_after_reset():
    env = NavisWebEnvironment(default_task_id="hard")
    env.reset(task_id="hard")

    state = env.state
    assert state.task_id == "hard"
    assert state.current_page_id == "dashboard"
    assert state.target_page_id == "emergency_access_reset_playbook"
    assert state.visited_pages == ["dashboard"]
    assert state.visited_counts == {"dashboard": 1}
    assert state.shortest_distance_to_target == 5


def test_state_updates_after_valid_transition():
    env = NavisWebEnvironment(default_task_id="easy")
    env.reset(task_id="easy")
    env.step(action=type("Action", (), {"click_link_id": "home_support"})())

    state = env.state
    assert state.step_count == 1
    assert state.current_page_id == "support_center"
    assert state.visited_pages == ["home", "support_center"]
    assert state.visited_counts["support_center"] == 1
    assert state.last_action_valid is True
    assert state.shortest_distance_to_target == 1


def test_loop_cap_termination_sets_reason_and_penalty():
    env = NavisWebEnvironment(default_task_id="hard")
    env.reset(task_id="hard")

    observation = None
    for link_id in [
        "dash_remote_work",
        "remote_dashboard",
        "dash_remote_work",
        "remote_dashboard",
        "dash_remote_work",
        "remote_dashboard",
    ]:
        observation = env.step(action=type("Action", (), {"click_link_id": link_id})())
        if observation.done:
            break

    assert observation is not None
    assert observation.done is True
    assert env.state.termination_reason == "loop_cap_exceeded"
    assert env.get_last_info()["termination_reason"] == "loop_cap_exceeded"


def test_task_catalog_and_shortest_paths_are_deterministic():
    assert list_task_ids() == [
        "easy",
        "medium",
        "hard",
        "expert",
        "adversarial",
        "bank_dispute",
        "it_access_recovery",
        "permit_renewal",
    ]
    assert list_curriculum_task_ids() == ["curriculum_easy", "curriculum_medium", "curriculum_hard"]

    easy = load_task("easy")
    medium = load_task("medium")
    hard = load_task("hard")
    expert = load_task("expert")
    adversarial = load_task("adversarial")
    bank_dispute = load_task("bank_dispute")
    it_access_recovery = load_task("it_access_recovery")
    permit_renewal = load_task("permit_renewal")
    curriculum_medium = load_task("curriculum_medium")

    assert shortest_path_length(easy, easy.start_page_id) == 2
    assert shortest_path_length(medium, medium.start_page_id) == 4
    assert shortest_path_length(hard, hard.start_page_id) == 5
    assert shortest_path_length(expert, expert.start_page_id) == 6
    assert shortest_path_length(adversarial, adversarial.start_page_id) == 6
    assert shortest_path_length(bank_dispute, bank_dispute.start_page_id) == 4
    assert shortest_path_length(it_access_recovery, it_access_recovery.start_page_id) == 4
    assert shortest_path_length(permit_renewal, permit_renewal.start_page_id) == 4
    assert shortest_path_length(curriculum_medium, curriculum_medium.start_page_id) == 5


def test_curriculum_tasks_are_deterministic_and_include_required_checkpoints():
    first = load_task("curriculum_hard")
    second = load_task("curriculum_hard")

    assert first.goal_instruction == second.goal_instruction
    assert first.target_page_id == second.target_page_id
    assert first.required_page_ids == ["intake_hub", "workflow_center"]
    assert first.difficulty == "hard"


def test_http_endpoints_expose_health_schema_and_state():
    client = TestClient(app)

    health_response = client.get("/health")
    assert health_response.status_code == 200
    assert health_response.json()["status"] in {"ok", "healthy"}

    schema_response = client.get("/schema")
    assert schema_response.status_code == 200
    schema_payload = schema_response.json()
    assert ("action_schema" in schema_payload and "observation_schema" in schema_payload) or (
        "action" in schema_payload and "observation" in schema_payload
    )

    reset_response = client.post("/reset", json={"task_id": "easy"})
    assert reset_response.status_code == 200
    reset_payload = reset_response.json()
    observation_payload = _unwrap_observation_payload(reset_payload)
    assert observation_payload["page_id"] == "home"

    state_response = client.get("/state")
    assert state_response.status_code == 200
    state_payload = state_response.json()
    assert isinstance(state_payload, dict)
    assert _looks_like_state(state_payload)


def test_http_step_returns_info_summary_on_success():
    client = TestClient(app)
    reset_response = client.post("/reset", json={"task_id": "easy"})
    session_id = reset_response.json().get("session_id")
    step_response = _post_step(client, "home_support")

    assert step_response.status_code == 200
    payload = step_response.json()
    observation_payload = _unwrap_observation_payload(payload)
    info_payload = _unwrap_info_payload(payload)
    assert payload.get("session_id") == session_id
    assert observation_payload["page_id"] in {"support_center", "contact_support"}
    if info_payload:
        assert "grade" in info_payload or "reached_target" in info_payload or "task_id" in info_payload


def test_http_session_id_persists_state_across_steps():
    client = TestClient(app)

    reset_response = client.post("/reset", json={"task_id": "easy"})
    assert reset_response.status_code == 200
    reset_payload = reset_response.json()
    session_id = reset_payload.get("session_id") or reset_payload.get("episode_id")

    if not session_id:
        first_step = _post_step_with_optional_session(client, "home_support", None)
        assert first_step.status_code == 200
        assert _unwrap_observation_payload(first_step.json())["page_id"] == "support_center"
        return

    first_step = _post_step_with_optional_session(client, "home_support", session_id)
    assert first_step.status_code == 200
    first_payload = first_step.json()
    assert first_payload.get("session_id") == session_id
    assert _unwrap_observation_payload(first_payload)["page_id"] == "support_center"

    second_step = _post_step_with_optional_session(client, "support_contact", session_id)
    assert second_step.status_code == 200
    second_payload = second_step.json()
    second_observation = _unwrap_observation_payload(second_payload)
    assert second_observation["page_id"] == "contact_support"
    assert second_observation["done"] is True

    state_params = {"session_id": session_id} if session_id else None
    state_response = client.get("/state", params=state_params)
    assert state_response.status_code == 200
    assert state_response.json()["step_count"] == 2