Spaces:
Sleeping
Sleeping
| """SchemaShift β typed contracts for actions, observations, rewards.""" | |
| from __future__ import annotations | |
| from typing import Literal, Optional, Any | |
| from pydantic import BaseModel, Field | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ACTION SPACE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ToolName = Literal["mail", "calendar", "crm", "chat", "docs"] | |
| ActionType = Literal[ | |
| "call_tool", "inspect_schema", "retry_with_variant", | |
| "report_drift", "complete_task", | |
| ] | |
| class ToolCallParams(BaseModel): | |
| tool: ToolName | |
| endpoint: str | |
| params: dict[str, Any] = Field(default_factory=dict) | |
| class InspectParams(BaseModel): | |
| tool: ToolName | |
| class RetryParams(BaseModel): | |
| tool: ToolName | |
| endpoint: str | |
| params: dict[str, Any] | |
| class DriftReportParams(BaseModel): | |
| tool: ToolName | |
| drift_kind: Literal[ | |
| "field_rename", "endpoint_deprecation", "response_restructure", | |
| "new_required_param", "error_code_remap", "tool_removal", | |
| "rate_limit_tightening" | |
| ] | |
| description: str | |
| class CompleteParams(BaseModel): | |
| summary: str | |
| class Action(BaseModel): | |
| type: ActionType | |
| tool_call: Optional[ToolCallParams] = None | |
| inspect: Optional[InspectParams] = None | |
| retry: Optional[RetryParams] = None | |
| report: Optional[DriftReportParams] = None | |
| complete: Optional[CompleteParams] = None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # OBSERVATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ToolResponse(BaseModel): | |
| ok: bool | |
| status: int | |
| body: dict[str, Any] | None = None | |
| error: str | None = None | |
| class HistoryStep(BaseModel): | |
| step: int | |
| action: Action | |
| response: ToolResponse | None = None | |
| reward_breakdown: dict[str, float] | None = None | |
| class Observation(BaseModel): | |
| episode_id: str | |
| task_id: str | |
| difficulty: Literal["easy", "medium", "hard"] | |
| step: int | |
| max_steps: int | |
| token_budget_remaining: int | |
| task_description: str | |
| success_criteria: list[str] | |
| tool_schemas: dict[str, dict] # current (possibly drifted) | |
| known_state: dict[str, Any] | |
| history: list[HistoryStep] | |
| last_response: ToolResponse | None | |
| drift_events_visible: list[dict] = [] | |
| done: bool = False | |
| feedback: str = "" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DRIFT | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class DriftEvent(BaseModel): | |
| tool: ToolName | |
| endpoint: str | None = None | |
| kind: Literal[ | |
| "field_rename", "endpoint_deprecation", "response_restructure", | |
| "new_required_param", "error_code_remap", "tool_removal", | |
| "rate_limit_tightening" | |
| ] | |
| fires_at_step: int | |
| details: dict[str, Any] | |
| detected_by_agent: bool = False | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # REWARD β with dense shaping | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RewardBreakdown(BaseModel): | |
| # Terminal-ish rubric dimensions | |
| task_completion: float = 0.0 | |
| drift_detection: float = 0.0 | |
| adaptation_quality: float = 0.0 | |
| efficiency: float = 0.0 | |
| # Gates | |
| catastrophic_gate: float = 1.0 | |
| correct_final_gate: float = 1.0 | |
| # Step-level dense shaping (NEW in v2) | |
| step_shaping: float = 0.0 | |
| # Totals | |
| shaped_total: float = 0.0 # rubric Γ gates + step_shaping | |
| binary: float = 0.0 # {0,1} for GRPO | |
| class EpisodeState(BaseModel): | |
| episode_id: str | |
| task_id: str | |
| difficulty: Literal["easy", "medium", "hard"] | |
| step: int = 0 | |
| max_steps: int | |
| token_budget: int | |
| token_budget_remaining: int | |
| drift_plan: list[DriftEvent] | |
| ground_truth_final_state: dict[str, Any] | |
| agent_state: dict[str, Any] = Field(default_factory=dict) | |
| history: list[HistoryStep] = Field(default_factory=list) | |
| done: bool = False | |
| cumulative_reward: float = 0.0 | |