File size: 5,060 Bytes
8cdb02b
 
 
 
d4ab0f1
8cdb02b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""SchemaShift β€” typed contracts for actions, observations, rewards."""
from __future__ import annotations
from typing import Literal, Optional, Any
from pydantic import BaseModel, Field


# ─────────────────────────────────────────────────────────────────
# ACTION SPACE
# ─────────────────────────────────────────────────────────────────

ToolName = Literal["mail", "calendar", "crm", "chat", "docs"]

ActionType = Literal[
    "call_tool", "inspect_schema", "retry_with_variant",
    "report_drift", "complete_task",
]

class ToolCallParams(BaseModel):
    tool: ToolName
    endpoint: str
    params: dict[str, Any] = Field(default_factory=dict)

class InspectParams(BaseModel):
    tool: ToolName

class RetryParams(BaseModel):
    tool: ToolName
    endpoint: str
    params: dict[str, Any]

class DriftReportParams(BaseModel):
    tool: ToolName
    drift_kind: Literal[
        "field_rename", "endpoint_deprecation", "response_restructure",
        "new_required_param", "error_code_remap", "tool_removal",
        "rate_limit_tightening"
    ]
    description: str

class CompleteParams(BaseModel):
    summary: str

class Action(BaseModel):
    type: ActionType
    tool_call: Optional[ToolCallParams] = None
    inspect: Optional[InspectParams] = None
    retry: Optional[RetryParams] = None
    report: Optional[DriftReportParams] = None
    complete: Optional[CompleteParams] = None


# ─────────────────────────────────────────────────────────────────
# OBSERVATION
# ─────────────────────────────────────────────────────────────────

class ToolResponse(BaseModel):
    ok: bool
    status: int
    body: dict[str, Any] | None = None
    error: str | None = None

class HistoryStep(BaseModel):
    step: int
    action: Action
    response: ToolResponse | None = None
    reward_breakdown: dict[str, float] | None = None

class Observation(BaseModel):
    episode_id: str
    task_id: str
    difficulty: Literal["easy", "medium", "hard"]
    step: int
    max_steps: int
    token_budget_remaining: int
    task_description: str
    success_criteria: list[str]
    tool_schemas: dict[str, dict]                  # current (possibly drifted)
    known_state: dict[str, Any]
    history: list[HistoryStep]
    last_response: ToolResponse | None
    drift_events_visible: list[dict] = []
    done: bool = False
    feedback: str = ""


# ─────────────────────────────────────────────────────────────────
# DRIFT
# ─────────────────────────────────────────────────────────────────

class DriftEvent(BaseModel):
    tool: ToolName
    endpoint: str | None = None
    kind: Literal[
        "field_rename", "endpoint_deprecation", "response_restructure",
        "new_required_param", "error_code_remap", "tool_removal",
        "rate_limit_tightening"
    ]
    fires_at_step: int
    details: dict[str, Any]
    detected_by_agent: bool = False


# ─────────────────────────────────────────────────────────────────
# REWARD β€” with dense shaping
# ─────────────────────────────────────────────────────────────────

class RewardBreakdown(BaseModel):
    # Terminal-ish rubric dimensions
    task_completion: float = 0.0
    drift_detection: float = 0.0
    adaptation_quality: float = 0.0
    efficiency: float = 0.0
    # Gates
    catastrophic_gate: float = 1.0
    correct_final_gate: float = 1.0
    # Step-level dense shaping (NEW in v2)
    step_shaping: float = 0.0
    # Totals
    shaped_total: float = 0.0                      # rubric Γ— gates + step_shaping
    binary: float = 0.0                            # {0,1} for GRPO

class EpisodeState(BaseModel):
    episode_id: str
    task_id: str
    difficulty: Literal["easy", "medium", "hard"]
    step: int = 0
    max_steps: int
    token_budget: int
    token_budget_remaining: int
    drift_plan: list[DriftEvent]
    ground_truth_final_state: dict[str, Any]
    agent_state: dict[str, Any] = Field(default_factory=dict)
    history: list[HistoryStep] = Field(default_factory=list)
    done: bool = False
    cumulative_reward: float = 0.0