Gaurav3134's picture
Upload 43 files
0387a1c verified
from __future__ import annotations
from typing import Any, Dict, List, Optional
from fastapi import APIRouter
from pydantic import BaseModel
from .env import EmailEnv
from .models import Action, Observation, StepInfo, StepResult
from .tasks import EmailTask, TaskId, get_all_tasks
router = APIRouter(prefix="/openenv", tags=["openenv"])
_ENV: Optional[EmailEnv] = None
def _get_env() -> EmailEnv:
global _ENV
if _ENV is None:
_ENV = EmailEnv(deterministic_tools=True, dry_run_send=True)
return _ENV
class TaskSummary(BaseModel):
task_id: TaskId
difficulty: str
description: str
max_steps: int
class ResetRequest(BaseModel):
task_id: Optional[TaskId] = None
deterministic_tools: bool = True
dry_run_send: bool = True
class ResetResponse(BaseModel):
observation: Observation
state: Dict[str, Any]
@router.get("/tasks", response_model=List[TaskSummary])
def list_tasks() -> List[TaskSummary]:
tasks: List[EmailTask] = get_all_tasks()
return [
TaskSummary(
task_id=t.task_id,
difficulty=str(t.difficulty),
description=t.description,
max_steps=int(t.max_steps),
)
for t in tasks
]
@router.post("/reset", response_model=ResetResponse)
def reset_env(payload: ResetRequest) -> ResetResponse:
global _ENV
env = _get_env()
if env.deterministic_tools != payload.deterministic_tools or env.dry_run_send != payload.dry_run_send:
_ENV = EmailEnv(deterministic_tools=payload.deterministic_tools, dry_run_send=payload.dry_run_send)
env = _ENV
obs = env.reset(task_id=payload.task_id)
return ResetResponse(observation=obs, state=env.state.model_dump())
@router.post("/step", response_model=StepResult)
async def step_env(action: Action) -> StepResult:
env = _get_env()
obs, reward, done, info_dict = await env.astep(action)
return StepResult(observation=obs, reward=reward, done=done, info=StepInfo(**info_dict))
@router.get("/state")
def get_state() -> Dict[str, Any]:
env = _get_env()
return env.state.model_dump()