File size: 4,131 Bytes
d57737f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
AWM-specific Pydantic models for action and observation types.
"""

from typing import Annotated, Any

from openenv.core.env_server.mcp_types import (
    CallToolAction,
    ListToolsAction,
    ListToolsObservation,
)
from openenv.core.env_server.types import Action, Observation
from pydantic import ConfigDict, Field, field_validator, TypeAdapter


_AWMActionUnion = Annotated[
    ListToolsAction | CallToolAction,
    Field(discriminator="type"),
]
_awm_action_adapter = TypeAdapter(_AWMActionUnion)


class AWMAction(Action):
    """Discriminated union action type for AWM.

    model_validate() returns the concrete ListToolsAction or CallToolAction
    (not an AWMAction instance), which is what AWMEnvironment.step() expects.
    """

    @classmethod
    def model_validate(cls, obj: Any, **kwargs: Any) -> Action:  # type: ignore[override]
        return _awm_action_adapter.validate_python(obj)

    @classmethod
    def model_json_schema(cls, **kwargs: Any) -> dict[str, Any]:  # type: ignore[override]
        return _awm_action_adapter.json_schema(**kwargs)


class AWMObservation(Observation):
    """
    Observation with AWM-specific fields promoted to top level.
    model_dump() excludes None-valued fields by default so that keys like
    ``tool_name=None`` do not appear in the wire payload.
    This is because the generic MCPToolClient._parse_result() routes observations based on key presence (e.g. ``"tool_name" in obs_data``). We may need to modify the MCPToolClient in the future. Currently, I try to avoid modifying any openenv code.
    """

    model_config = ConfigDict(extra="forbid")

    reward_type: str | None = Field(
        default=None,
        description="Reward classification label for this step/episode outcome",
    )
    scenario: str | None = Field(default=None, description="Current scenario name")
    task: str | None = Field(default=None, description="Current task description")
    task_idx: int | None = Field(default=None, description="Current task index")
    has_verifier: dict | bool | None = Field(
        default=None,
        description="Verifier support info: {sql: bool, code: bool} or legacy bool",
    )

    @field_validator("has_verifier", mode="before")
    @classmethod
    def _convert_bool_to_dict(cls, v: Any) -> dict | None:
        """Convert legacy bool format to new dict format."""
        if v is None:
            return None
        if isinstance(v, bool):
            # Legacy format: True means both modes available (conservative assumption)
            return {"sql": v, "code": v} if v else None
        return v

    num_tools: int | None = Field(
        default=None, description="Number of tools discovered"
    )
    tool_name: str | None = Field(default=None, description="Name of the tool called")
    tool_result: Any = Field(default=None, description="Result from the tool call")
    error: str | None = Field(default=None, description="Error message if any")
    warning: str | None = Field(default=None, description="Warning message if any")
    verify_result: dict | None = Field(
        default=None, description="Verifier output on episode end"
    )
    steps_taken: int | None = Field(
        default=None, description="Steps taken in this episode"
    )
    scenarios: list | None = Field(
        default=None, description="List of all scenarios (from __list_scenarios__)"
    )
    total: int | None = Field(default=None, description="Total number of scenarios")
    trajectory_path: str | None = Field(
        default=None, description="Path to saved trajectory JSON file"
    )
    session_dir: str | None = Field(
        default=None, description="Session directory path (when keep_session=True)"
    )

    def model_dump(self, **kwargs: Any) -> dict[str, Any]:
        kwargs.setdefault("exclude_none", True)
        return super().model_dump(**kwargs)


class AWMListToolsObservation(ListToolsObservation):
    """ListToolsObservation with AWM error field promoted to top level."""

    model_config = ConfigDict(extra="forbid")

    error: str | None = Field(default=None, description="Error message if any")