File size: 897 Bytes
e9315b2
 
6e3f176
e9315b2
 
 
 
 
 
 
 
 
 
 
6e3f176
e9315b2
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from openenv_core import HTTPEnvClient, StepResult

from .models import JobObservation, JSSPAction, JSSPObservation, MachineObservation


class JSSPEnvClient(HTTPEnvClient[JSSPAction, JSSPObservation]):
    def _step_payload(self, action: JSSPAction) -> dict:
        return {"job_ids": action.job_ids}

    def _parse_result(self, payload: dict) -> StepResult[JSSPObservation]:
        obs_data = payload["observation"]
        return StepResult[JSSPObservation](
            observation=JSSPObservation(
                machines=[MachineObservation(**machine) for machine in obs_data.pop("machines")],
                jobs=[JobObservation(**job) for job in obs_data.pop("jobs")],
                **obs_data,
            ),
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: dict) -> dict:
        return payload