File size: 5,186 Bytes
b4b210e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""OpenEnv compatibility endpoints exposed at root-level paths."""

from __future__ import annotations

from typing import Any

from fastapi import APIRouter, Body, HTTPException, status
from pydantic import BaseModel, Field

from app.api.deps import SettingsDep
from app.api.routes.episode import (
    EpisodeState,
    ResetRequest,
    ResetResponse,
    StepRequest,
    get_episode_state,
    reset_episode,
    step_episode,
)
from app.core.action import Action, ActionType

router = APIRouter(tags=["OpenEnv"])


class OpenEnvResetRequest(BaseModel):
    """Lenient reset request supporting common OpenEnv field aliases."""

    task_id: str | None = Field(default=None)
    task: str | None = Field(default=None)
    task_name: str | None = Field(default=None)
    seed: int | None = Field(default=None)
    config: dict[str, Any] | None = Field(default=None)


class OpenEnvStepRequest(BaseModel):
    """Lenient step request supporting common OpenEnv field aliases."""

    episode_id: str | None = Field(default=None)
    episode: str | None = Field(default=None)
    session_id: str | None = Field(default=None)
    action: Any = Field(default_factory=dict)


def _coerce_action(action_payload: Any) -> Action:
    """Coerce OpenEnv-style actions into internal Action model."""
    if isinstance(action_payload, Action):
        return action_payload

    if isinstance(action_payload, str):
        action_type = action_payload.strip().lower()
        try:
            return Action(action_type=ActionType(action_type), parameters={})
        except ValueError:
            return Action.wait()

    if isinstance(action_payload, dict):
        payload = dict(action_payload)

        if "action_type" not in payload:
            for alias in ("action", "type", "name"):
                alias_value = payload.get(alias)
                if isinstance(alias_value, str) and alias_value.strip():
                    payload["action_type"] = alias_value.strip().lower()
                    break

        if "parameters" not in payload:
            params = payload.get("params")
            payload["parameters"] = params if isinstance(params, dict) else {}

        if "reasoning" not in payload and isinstance(payload.get("thought"), str):
            payload["reasoning"] = payload["thought"]

        action_type = payload.get("action_type")
        if not isinstance(action_type, str):
            payload["action_type"] = ActionType.WAIT.value
            payload["parameters"] = {}
        else:
            normalized = action_type.strip().lower()
            try:
                ActionType(normalized)
                payload["action_type"] = normalized
            except ValueError:
                payload["action_type"] = ActionType.WAIT.value
                payload["parameters"] = {}

        try:
            return Action.model_validate(payload)
        except Exception:
            return Action.wait()

    return Action.wait()


@router.post(
    "/reset",
    response_model=ResetResponse,
    status_code=status.HTTP_200_OK,
    summary="OpenEnv-compatible reset endpoint",
)
@router.post(
    "/api/reset",
    response_model=ResetResponse,
    status_code=status.HTTP_200_OK,
    include_in_schema=False,
)
async def openenv_reset(
    settings: SettingsDep,
    request: OpenEnvResetRequest | None = Body(default=None),
) -> ResetResponse:
    """
    Root-level reset alias used by OpenEnv evaluators.

    Defaults to `task_001` when no explicit task identifier is provided.
    """
    payload = request or OpenEnvResetRequest()
    task_id = payload.task_id or payload.task or payload.task_name or "task_001"
    normalized_request = ResetRequest(task_id=task_id, seed=payload.seed, config=payload.config)
    return await reset_episode(normalized_request, settings)


@router.post(
    "/step",
    status_code=status.HTTP_200_OK,
    summary="OpenEnv-compatible step endpoint",
)
@router.post(
    "/api/step",
    status_code=status.HTTP_200_OK,
    include_in_schema=False,
)
async def openenv_step(
    request: OpenEnvStepRequest = Body(default_factory=OpenEnvStepRequest),
) -> dict[str, Any]:
    """Root-level step alias used by OpenEnv evaluators."""
    episode_id = request.episode_id or request.episode or request.session_id
    if not episode_id:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Missing episode_id",
        )

    result = await step_episode(
        StepRequest(
            episode_id=episode_id,
            action=_coerce_action(request.action),
        )
    )
    payload = result.model_dump()
    payload["done"] = bool(result.terminated or result.truncated)
    return payload


@router.get(
    "/state/{episode_id}",
    response_model=EpisodeState,
    status_code=status.HTTP_200_OK,
    summary="OpenEnv-compatible state endpoint",
)
@router.get(
    "/api/state/{episode_id}",
    response_model=EpisodeState,
    status_code=status.HTTP_200_OK,
    include_in_schema=False,
)
async def openenv_state(episode_id: str) -> EpisodeState:
    """Root-level state alias used by OpenEnv evaluators."""
    return await get_episode_state(episode_id)