Spaces:
Sleeping
Sleeping
File size: 5,060 Bytes
8cdb02b d4ab0f1 8cdb02b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """SchemaShift β typed contracts for actions, observations, rewards."""
from __future__ import annotations
from typing import Literal, Optional, Any
from pydantic import BaseModel, Field
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# ACTION SPACE
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
ToolName = Literal["mail", "calendar", "crm", "chat", "docs"]
ActionType = Literal[
"call_tool", "inspect_schema", "retry_with_variant",
"report_drift", "complete_task",
]
class ToolCallParams(BaseModel):
tool: ToolName
endpoint: str
params: dict[str, Any] = Field(default_factory=dict)
class InspectParams(BaseModel):
tool: ToolName
class RetryParams(BaseModel):
tool: ToolName
endpoint: str
params: dict[str, Any]
class DriftReportParams(BaseModel):
tool: ToolName
drift_kind: Literal[
"field_rename", "endpoint_deprecation", "response_restructure",
"new_required_param", "error_code_remap", "tool_removal",
"rate_limit_tightening"
]
description: str
class CompleteParams(BaseModel):
summary: str
class Action(BaseModel):
type: ActionType
tool_call: Optional[ToolCallParams] = None
inspect: Optional[InspectParams] = None
retry: Optional[RetryParams] = None
report: Optional[DriftReportParams] = None
complete: Optional[CompleteParams] = None
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# OBSERVATION
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class ToolResponse(BaseModel):
ok: bool
status: int
body: dict[str, Any] | None = None
error: str | None = None
class HistoryStep(BaseModel):
step: int
action: Action
response: ToolResponse | None = None
reward_breakdown: dict[str, float] | None = None
class Observation(BaseModel):
episode_id: str
task_id: str
difficulty: Literal["easy", "medium", "hard"]
step: int
max_steps: int
token_budget_remaining: int
task_description: str
success_criteria: list[str]
tool_schemas: dict[str, dict] # current (possibly drifted)
known_state: dict[str, Any]
history: list[HistoryStep]
last_response: ToolResponse | None
drift_events_visible: list[dict] = []
done: bool = False
feedback: str = ""
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# DRIFT
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class DriftEvent(BaseModel):
tool: ToolName
endpoint: str | None = None
kind: Literal[
"field_rename", "endpoint_deprecation", "response_restructure",
"new_required_param", "error_code_remap", "tool_removal",
"rate_limit_tightening"
]
fires_at_step: int
details: dict[str, Any]
detected_by_agent: bool = False
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# REWARD β with dense shaping
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class RewardBreakdown(BaseModel):
# Terminal-ish rubric dimensions
task_completion: float = 0.0
drift_detection: float = 0.0
adaptation_quality: float = 0.0
efficiency: float = 0.0
# Gates
catastrophic_gate: float = 1.0
correct_final_gate: float = 1.0
# Step-level dense shaping (NEW in v2)
step_shaping: float = 0.0
# Totals
shaped_total: float = 0.0 # rubric Γ gates + step_shaping
binary: float = 0.0 # {0,1} for GRPO
class EpisodeState(BaseModel):
episode_id: str
task_id: str
difficulty: Literal["easy", "medium", "hard"]
step: int = 0
max_steps: int
token_budget: int
token_budget_remaining: int
drift_plan: list[DriftEvent]
ground_truth_final_state: dict[str, Any]
agent_state: dict[str, Any] = Field(default_factory=dict)
history: list[HistoryStep] = Field(default_factory=list)
done: bool = False
cumulative_reward: float = 0.0
|