File size: 5,050 Bytes
8cdb02b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""Validation tests for models.py — Phase 1 acceptance."""
from __future__ import annotations

import pytest
from pydantic import ValidationError

from models import (
    Action,
    CompleteParams,
    DriftEvent,
    DriftReportParams,
    EpisodeState,
    HistoryStep,
    InspectParams,
    Observation,
    RetryParams,
    RewardBreakdown,
    ToolCallParams,
    ToolResponse,
)


def test_imports() -> None:
    classes = [
        ToolCallParams, InspectParams, RetryParams, DriftReportParams, CompleteParams,
        Action, ToolResponse, HistoryStep, Observation, DriftEvent, RewardBreakdown,
        EpisodeState,
    ]
    assert len(classes) == 12
    for cls in classes:
        assert hasattr(cls, "model_validate")
        assert hasattr(cls, "model_dump_json")


def test_action_serialization() -> None:
    actions = [
        Action(
            type="call_tool",
            tool_call=ToolCallParams(
                tool="mail", endpoint="send_message",
                params={"to": "a@b.com", "subject": "hi", "body": "hello"},
            ),
        ),
        Action(type="inspect_schema", inspect=InspectParams(tool="calendar")),
        Action(
            type="retry_with_variant",
            retry=RetryParams(
                tool="crm", endpoint="search_contacts",
                params={"email_address": "x@y.com"},
            ),
        ),
        Action(
            type="report_drift",
            report=DriftReportParams(
                tool="calendar", drift_kind="field_rename",
                description="attendees renamed to participants",
            ),
        ),
        Action(type="complete_task", complete=CompleteParams(summary="all done")),
    ]
    assert len(actions) == 5
    for original in actions:
        payload = original.model_dump_json()
        assert isinstance(payload, str) and len(payload) > 0
        roundtrip = Action.model_validate_json(payload)
        assert roundtrip == original


def test_observation_validation() -> None:
    step_action = Action(type="inspect_schema", inspect=InspectParams(tool="mail"))
    resp = ToolResponse(ok=True, status=200, body={"messages": []}, error=None)
    hist = [HistoryStep(step=0, action=step_action, response=resp,
                        reward_breakdown={"shaped_total": 0.10})]
    obs = Observation(
        episode_id="ep-1",
        task_id="E1_onboard_new_hire",
        difficulty="easy",
        step=1,
        max_steps=8,
        token_budget_remaining=3900,
        task_description="Send welcome email and create orientation event.",
        success_criteria=["welcome email sent", "calendar event created"],
        tool_schemas={"mail": {"send_message": {"required": ["to", "subject", "body"]}}},
        known_state={"mail.sent_count": 0},
        history=hist,
        last_response=resp,
        drift_events_visible=[{"tool": "calendar", "kind": "field_rename"}],
        done=False,
        feedback="ok",
    )
    assert obs.step == 1
    assert obs.history[0].response is not None
    assert obs.last_response is not None and obs.last_response.ok is True


def test_episode_state_with_drifts() -> None:
    drifts = [
        DriftEvent(
            tool="mail", endpoint="send_message", kind="endpoint_deprecation",
            fires_at_step=1, details={"replacement": "messages.send"},
        ),
        DriftEvent(
            tool="calendar", endpoint="create_event", kind="field_rename",
            fires_at_step=3, details={"from": "attendees", "to": "participants"},
        ),
    ]
    state = EpisodeState(
        episode_id="ep-xyz",
        task_id="E1_onboard_new_hire",
        difficulty="easy",
        max_steps=8,
        token_budget=4000,
        token_budget_remaining=4000,
        drift_plan=drifts,
        ground_truth_final_state={"mail.sent_count": 1, "calendar.events_count": 1},
    )
    assert state.step == 0
    assert len(state.drift_plan) == 2
    assert state.drift_plan[0].tool == "mail"
    assert state.drift_plan[1].fires_at_step == 3
    assert state.drift_plan[0].detected_by_agent is False
    assert state.agent_state == {}
    assert state.history == []
    assert state.cumulative_reward == 0.0


def test_reward_breakdown_defaults() -> None:
    r = RewardBreakdown()
    assert r.task_completion == 0.0
    assert r.drift_detection == 0.0
    assert r.adaptation_quality == 0.0
    assert r.efficiency == 0.0
    assert r.catastrophic_gate == 1.0
    assert r.correct_final_gate == 1.0
    assert r.step_shaping == 0.0
    assert r.shaped_total == 0.0
    assert r.binary == 0.0


def test_literal_enforcement() -> None:
    with pytest.raises(ValidationError):
        Action(type="teleport")  # type: ignore[arg-type]

    with pytest.raises(ValidationError):
        InspectParams(tool="salesforce")  # type: ignore[arg-type]

    with pytest.raises(ValidationError):
        DriftReportParams(
            tool="mail",
            drift_kind="meteor_strike",  # type: ignore[arg-type]
            description="not a real drift kind",
        )