File size: 3,076 Bytes
aae7b06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Typed models for the RecallTrace OpenEnv environment."""

from __future__ import annotations

from enum import Enum
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, ConfigDict, Field


class ActionType(str, Enum):
    INSPECT_NODE = "inspect_node"
    TRACE_LOT = "trace_lot"
    QUARANTINE = "quarantine"
    NOTIFY = "notify"
    FINALIZE = "finalize"


class RecallAction(BaseModel):
    """Action submitted by an agent."""

    model_config = ConfigDict(extra="forbid")

    type: ActionType
    node_id: Optional[str] = None
    lot_id: Optional[str] = None
    quantity: Optional[int] = Field(default=None, ge=1)
    rationale: Optional[str] = None


class RewardSignal(BaseModel):
    """Typed reward payload."""

    model_config = ConfigDict(extra="forbid")

    value: float = Field(ge=-1.0, le=1.0)
    reason: str
    components: Dict[str, float] = Field(default_factory=dict)


class InspectionEvidence(BaseModel):
    """Evidence revealed after inspecting a node."""

    model_config = ConfigDict(extra="allow")

    status: str
    unsafe_quantity: int = Field(ge=0)
    evidence: str
    safe_quantity: Optional[int] = Field(default=None, ge=0)


class TaskDefinition(BaseModel):
    """Static task descriptor."""

    model_config = ConfigDict(extra="forbid")

    task_id: str
    name: str
    difficulty: str
    objective: str
    max_steps: int = Field(ge=1)


class RecallObservation(BaseModel):
    """Observable state exposed to the agent."""

    model_config = ConfigDict(extra="forbid")

    task_id: str
    phase: int
    recall_notice: str
    available_actions: List[str]
    inventory: Dict[str, Dict[str, int]]
    discovered_shipments: Dict[str, List[str]]
    inspected_nodes: List[str]
    inspection_results: Dict[str, Dict[str, InspectionEvidence]]
    trace_results: Dict[str, Dict[str, Any]]
    notified_nodes: List[str]
    quarantined_inventory: Dict[str, Dict[str, int]]
    history: List[str]
    steps_taken: int = Field(ge=0)
    remaining_step_budget: int = Field(ge=0)


class StepInfo(BaseModel):
    """Structured info payload returned after each step."""

    model_config = ConfigDict(extra="allow")

    message: str
    action_type: str
    score: Optional[float] = Field(default=None, ge=0.0, le=1.0)
    reward_breakdown: Dict[str, float] = Field(default_factory=dict)


class EnvironmentState(BaseModel):
    """Full internal state for debugging and grading."""

    model_config = ConfigDict(extra="forbid")

    done: bool
    task: TaskDefinition
    steps_taken: int = Field(ge=0)
    state_data: Dict[str, Any]
    ground_truth: Dict[str, Any]


class TaskGrade(BaseModel):
    """Deterministic grader output."""

    model_config = ConfigDict(extra="forbid")

    task_id: str
    score: float = Field(ge=0.0, le=1.0)
    success: bool
    steps_taken: int = Field(ge=0)
    max_steps: int = Field(ge=1)
    reward_total: float
    final_info: Dict[str, Any]