"""Shader environment client.""" from typing import Dict from openenv.core import EnvClient from openenv.core.client_types import StepResult from openenv.core.env_server.types import State from .models import ShaderAction, ShaderObservation class ShaderEnv(EnvClient[ShaderAction, ShaderObservation, State]): """ Client for the shader environment. Example: >>> with ShaderEnv(base_url="http://localhost:8000").sync() as client: ... result = client.reset() ... result = client.step(ShaderAction(code="void mainImage(...)")) ... print(result.observation.ssim) """ def _step_payload(self, action: ShaderAction) -> Dict: return {"code": action.code} def _parse_result(self, payload: Dict) -> StepResult[ShaderObservation]: obs_data = payload.get("observation", {}) observation = ShaderObservation.model_validate({ **obs_data, "done": payload.get("done", False), "reward": payload.get("reward"), }) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict) -> State: return State.model_validate(payload)