File size: 2,118 Bytes
0387a1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import annotations

from typing import Any, Dict, List, Optional

from fastapi import APIRouter
from pydantic import BaseModel

from .env import EmailEnv
from .models import Action, Observation, StepInfo, StepResult
from .tasks import EmailTask, TaskId, get_all_tasks


router = APIRouter(prefix="/openenv", tags=["openenv"])

_ENV: Optional[EmailEnv] = None


def _get_env() -> EmailEnv:
    global _ENV
    if _ENV is None:
        _ENV = EmailEnv(deterministic_tools=True, dry_run_send=True)
    return _ENV


class TaskSummary(BaseModel):
    task_id: TaskId
    difficulty: str
    description: str
    max_steps: int


class ResetRequest(BaseModel):
    task_id: Optional[TaskId] = None
    deterministic_tools: bool = True
    dry_run_send: bool = True


class ResetResponse(BaseModel):
    observation: Observation
    state: Dict[str, Any]


@router.get("/tasks", response_model=List[TaskSummary])
def list_tasks() -> List[TaskSummary]:
    tasks: List[EmailTask] = get_all_tasks()
    return [
        TaskSummary(
            task_id=t.task_id,
            difficulty=str(t.difficulty),
            description=t.description,
            max_steps=int(t.max_steps),
        )
        for t in tasks
    ]


@router.post("/reset", response_model=ResetResponse)
def reset_env(payload: ResetRequest) -> ResetResponse:
    global _ENV
    env = _get_env()
    if env.deterministic_tools != payload.deterministic_tools or env.dry_run_send != payload.dry_run_send:
        _ENV = EmailEnv(deterministic_tools=payload.deterministic_tools, dry_run_send=payload.dry_run_send)
        env = _ENV
    obs = env.reset(task_id=payload.task_id)
    return ResetResponse(observation=obs, state=env.state.model_dump())


@router.post("/step", response_model=StepResult)
async def step_env(action: Action) -> StepResult:
    env = _get_env()
    obs, reward, done, info_dict = await env.astep(action)
    return StepResult(observation=obs, reward=reward, done=done, info=StepInfo(**info_dict))


@router.get("/state")
def get_state() -> Dict[str, Any]:
    env = _get_env()
    return env.state.model_dump()