Spaces:
Sleeping
Sleeping
File size: 5,186 Bytes
b4b210e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | """OpenEnv compatibility endpoints exposed at root-level paths."""
from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Body, HTTPException, status
from pydantic import BaseModel, Field
from app.api.deps import SettingsDep
from app.api.routes.episode import (
EpisodeState,
ResetRequest,
ResetResponse,
StepRequest,
get_episode_state,
reset_episode,
step_episode,
)
from app.core.action import Action, ActionType
router = APIRouter(tags=["OpenEnv"])
class OpenEnvResetRequest(BaseModel):
"""Lenient reset request supporting common OpenEnv field aliases."""
task_id: str | None = Field(default=None)
task: str | None = Field(default=None)
task_name: str | None = Field(default=None)
seed: int | None = Field(default=None)
config: dict[str, Any] | None = Field(default=None)
class OpenEnvStepRequest(BaseModel):
"""Lenient step request supporting common OpenEnv field aliases."""
episode_id: str | None = Field(default=None)
episode: str | None = Field(default=None)
session_id: str | None = Field(default=None)
action: Any = Field(default_factory=dict)
def _coerce_action(action_payload: Any) -> Action:
"""Coerce OpenEnv-style actions into internal Action model."""
if isinstance(action_payload, Action):
return action_payload
if isinstance(action_payload, str):
action_type = action_payload.strip().lower()
try:
return Action(action_type=ActionType(action_type), parameters={})
except ValueError:
return Action.wait()
if isinstance(action_payload, dict):
payload = dict(action_payload)
if "action_type" not in payload:
for alias in ("action", "type", "name"):
alias_value = payload.get(alias)
if isinstance(alias_value, str) and alias_value.strip():
payload["action_type"] = alias_value.strip().lower()
break
if "parameters" not in payload:
params = payload.get("params")
payload["parameters"] = params if isinstance(params, dict) else {}
if "reasoning" not in payload and isinstance(payload.get("thought"), str):
payload["reasoning"] = payload["thought"]
action_type = payload.get("action_type")
if not isinstance(action_type, str):
payload["action_type"] = ActionType.WAIT.value
payload["parameters"] = {}
else:
normalized = action_type.strip().lower()
try:
ActionType(normalized)
payload["action_type"] = normalized
except ValueError:
payload["action_type"] = ActionType.WAIT.value
payload["parameters"] = {}
try:
return Action.model_validate(payload)
except Exception:
return Action.wait()
return Action.wait()
@router.post(
"/reset",
response_model=ResetResponse,
status_code=status.HTTP_200_OK,
summary="OpenEnv-compatible reset endpoint",
)
@router.post(
"/api/reset",
response_model=ResetResponse,
status_code=status.HTTP_200_OK,
include_in_schema=False,
)
async def openenv_reset(
settings: SettingsDep,
request: OpenEnvResetRequest | None = Body(default=None),
) -> ResetResponse:
"""
Root-level reset alias used by OpenEnv evaluators.
Defaults to `task_001` when no explicit task identifier is provided.
"""
payload = request or OpenEnvResetRequest()
task_id = payload.task_id or payload.task or payload.task_name or "task_001"
normalized_request = ResetRequest(task_id=task_id, seed=payload.seed, config=payload.config)
return await reset_episode(normalized_request, settings)
@router.post(
"/step",
status_code=status.HTTP_200_OK,
summary="OpenEnv-compatible step endpoint",
)
@router.post(
"/api/step",
status_code=status.HTTP_200_OK,
include_in_schema=False,
)
async def openenv_step(
request: OpenEnvStepRequest = Body(default_factory=OpenEnvStepRequest),
) -> dict[str, Any]:
"""Root-level step alias used by OpenEnv evaluators."""
episode_id = request.episode_id or request.episode or request.session_id
if not episode_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing episode_id",
)
result = await step_episode(
StepRequest(
episode_id=episode_id,
action=_coerce_action(request.action),
)
)
payload = result.model_dump()
payload["done"] = bool(result.terminated or result.truncated)
return payload
@router.get(
"/state/{episode_id}",
response_model=EpisodeState,
status_code=status.HTTP_200_OK,
summary="OpenEnv-compatible state endpoint",
)
@router.get(
"/api/state/{episode_id}",
response_model=EpisodeState,
status_code=status.HTTP_200_OK,
include_in_schema=False,
)
async def openenv_state(episode_id: str) -> EpisodeState:
"""Root-level state alias used by OpenEnv evaluators."""
return await get_episode_state(episode_id)
|