File size: 1,223 Bytes
c71bf62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""Request models for Dispatch Arena's REST API."""

from __future__ import annotations

from typing import Any, Dict, Optional

from pydantic import BaseModel, Field

from dispatch_arena.models import Action, Config, Mode


class CreateSessionRequest(BaseModel):
    mode: Mode = Mode.MINI
    seed: Optional[int] = None
    config: Dict[str, Any] = Field(default_factory=dict)

    def resolved_config(self, default_max_ticks: int) -> Config:
        payload = dict(self.config)
        payload.setdefault("mode", self.mode)
        payload.setdefault("max_ticks", default_max_ticks)
        return Config.model_validate(payload)


class ResetRequest(BaseModel):
    seed: Optional[int] = None
    episode_id: Optional[str] = None
    session_id: Optional[str] = None
    mode: Optional[Mode] = None
    config: Dict[str, Any] = Field(default_factory=dict)

    def resolved_config(self, fallback: Config) -> Config:
        payload = fallback.to_dict()
        payload.update(self.config)
        if self.mode is not None:
            payload["mode"] = self.mode
        return Config.model_validate(payload)


class StepRequest(BaseModel):
    session_id: Optional[str] = None
    action: Action | str | Dict[str, Any]