Spaces:
Build error
Build error
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional | |
| from fastapi import APIRouter | |
| from pydantic import BaseModel | |
| from .env import EmailEnv | |
| from .models import Action, Observation, StepInfo, StepResult | |
| from .tasks import EmailTask, TaskId, get_all_tasks | |
| router = APIRouter(prefix="/openenv", tags=["openenv"]) | |
| _ENV: Optional[EmailEnv] = None | |
| def _get_env() -> EmailEnv: | |
| global _ENV | |
| if _ENV is None: | |
| _ENV = EmailEnv(deterministic_tools=True, dry_run_send=True) | |
| return _ENV | |
| class TaskSummary(BaseModel): | |
| task_id: TaskId | |
| difficulty: str | |
| description: str | |
| max_steps: int | |
| class ResetRequest(BaseModel): | |
| task_id: Optional[TaskId] = None | |
| deterministic_tools: bool = True | |
| dry_run_send: bool = True | |
| class ResetResponse(BaseModel): | |
| observation: Observation | |
| state: Dict[str, Any] | |
| def list_tasks() -> List[TaskSummary]: | |
| tasks: List[EmailTask] = get_all_tasks() | |
| return [ | |
| TaskSummary( | |
| task_id=t.task_id, | |
| difficulty=str(t.difficulty), | |
| description=t.description, | |
| max_steps=int(t.max_steps), | |
| ) | |
| for t in tasks | |
| ] | |
| def reset_env(payload: ResetRequest) -> ResetResponse: | |
| global _ENV | |
| env = _get_env() | |
| if env.deterministic_tools != payload.deterministic_tools or env.dry_run_send != payload.dry_run_send: | |
| _ENV = EmailEnv(deterministic_tools=payload.deterministic_tools, dry_run_send=payload.dry_run_send) | |
| env = _ENV | |
| obs = env.reset(task_id=payload.task_id) | |
| return ResetResponse(observation=obs, state=env.state.model_dump()) | |
| async def step_env(action: Action) -> StepResult: | |
| env = _get_env() | |
| obs, reward, done, info_dict = await env.astep(action) | |
| return StepResult(observation=obs, reward=reward, done=done, info=StepInfo(**info_dict)) | |
| def get_state() -> Dict[str, Any]: | |
| env = _get_env() | |
| return env.state.model_dump() | |