File size: 2,906 Bytes
b003bf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations

from typing import Literal
from uuid import uuid4

from fastapi.responses import RedirectResponse
from openenv.core.env_server import create_app
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import Action, Observation, State
from pydantic import Field


class MinimalAction(Action):
    action_type: Literal["noop", "increment", "finish"] = "noop"
    amount: int = Field(default=1, ge=1, le=3)


class MinimalObservation(Observation):
    status: str
    counter: int
    summary: str
    reward: float = 0.0
    done: bool = False


class MinimalState(State):
    counter: int = 0


class MinimalEnvironment(Environment[MinimalAction, MinimalObservation, MinimalState]):
    SUPPORTS_CONCURRENT_SESSIONS = False

    def __init__(self):
        super().__init__()
        self._done = False
        self._state = MinimalState(episode_id=str(uuid4()), step_count=0, counter=0)

    def reset(self, seed: int | None = None, episode_id: str | None = None, **kwargs) -> MinimalObservation:
        del seed, kwargs
        self._done = False
        self._state = MinimalState(
            episode_id=episode_id or str(uuid4()),
            step_count=0,
            counter=0,
        )
        return self._observation(status="ready", reward=0.0, done=False)

    def step(self, action: MinimalAction, timeout_s: float | None = None, **kwargs) -> MinimalObservation:
        del timeout_s, kwargs
        if self._done:
            return self._observation(status="done", reward=0.0, done=True)

        self._state.step_count += 1
        reward = 0.0
        status = "ok"

        if action.action_type == "increment":
            self._state.counter += action.amount
            reward = float(action.amount)
        elif action.action_type == "finish":
            self._done = True
            status = "finished"

        if self._state.step_count >= 8:
            self._done = True
            status = "finished"

        return self._observation(status=status, reward=reward, done=self._done)

    @property
    def state(self) -> MinimalState:
        return self._state

    def close(self) -> None:
        return None

    def _observation(self, *, status: str, reward: float, done: bool) -> MinimalObservation:
        return MinimalObservation(
            status=status,
            counter=self._state.counter,
            summary=(
                f"Minimal demo environment. Counter={self._state.counter}. "
                f"Step={self._state.step_count}. "
                f"Choose noop, increment, or finish."
            ),
            reward=reward,
            done=done,
        )


app = create_app(MinimalEnvironment, MinimalAction, MinimalObservation, env_name="minimal_space")


@app.get("/", include_in_schema=False)
def root() -> RedirectResponse:
    return RedirectResponse(url="/web")