Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Commit ·
dc2e64b
1
Parent(s): 513a2e2
fix: align client with openenv payload
Browse files- fusion_lab/client.py +6 -3
fusion_lab/client.py
CHANGED
|
@@ -13,11 +13,14 @@ class FusionLabClient(EnvClient[StellaratorAction, StellaratorObservation, Stell
|
|
| 13 |
return action.model_dump(exclude_none=True)
|
| 14 |
|
| 15 |
def _parse_result(self, payload: dict[str, object]) -> StepResult[StellaratorObservation]:
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
return StepResult(
|
| 18 |
observation=observation,
|
| 19 |
-
reward=
|
| 20 |
-
done=
|
| 21 |
)
|
| 22 |
|
| 23 |
def _parse_state(self, payload: dict[str, object]) -> StellaratorState:
|
|
|
|
| 13 |
return action.model_dump(exclude_none=True)
|
| 14 |
|
| 15 |
def _parse_result(self, payload: dict[str, object]) -> StepResult[StellaratorObservation]:
|
| 16 |
+
observation_payload = dict(payload.get("observation", {}))
|
| 17 |
+
observation_payload["reward"] = payload.get("reward")
|
| 18 |
+
observation_payload["done"] = payload.get("done", False)
|
| 19 |
+
observation = StellaratorObservation.model_validate(observation_payload)
|
| 20 |
return StepResult(
|
| 21 |
observation=observation,
|
| 22 |
+
reward=payload.get("reward"),
|
| 23 |
+
done=payload.get("done", False),
|
| 24 |
)
|
| 25 |
|
| 26 |
def _parse_state(self, payload: dict[str, object]) -> StellaratorState:
|