# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Typed OpenEnv client for TemporalBenchEnv.""" from typing import Any, Dict from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient try: from env.models import ( TemporalBenchAction, TemporalBenchObservation, TemporalBenchState, ) except ImportError: from TemporalBenchEnv.env.models import ( TemporalBenchAction, TemporalBenchObservation, TemporalBenchState, ) class TemporalBenchEnvClient( EnvClient[ TemporalBenchAction, TemporalBenchObservation, TemporalBenchState, ] ): """WebSocket client for TemporalBench MCQ episodes.""" def _step_payload(self, action: TemporalBenchAction) -> Dict[str, Any]: payload: Dict[str, Any] = {"answer": action.answer} if action.confidence is not None: payload["confidence"] = action.confidence if action.reasoning is not None: payload["reasoning"] = action.reasoning return payload def _parse_result(self, payload: Dict[str, Any]) -> StepResult[TemporalBenchObservation]: obs_data = payload.get("observation") if not isinstance(obs_data, dict): obs_data = payload if isinstance(payload, dict) else {} done = payload.get("done", obs_data.get("done", False)) reward = payload.get("reward", obs_data.get("reward")) observation = TemporalBenchObservation( step_idx=int(obs_data.get("step_idx", 0)), steps_remaining=int(obs_data.get("steps_remaining", 0)), max_steps=int(obs_data.get("max_steps", 9)), question=str(obs_data.get("question", "")), options=list(obs_data.get("options", [])), task_type=str(obs_data.get("task_type", "")), dataset=str(obs_data.get("dataset", "")), history=list(obs_data.get("history", [])), accuracy_so_far=float(obs_data.get("accuracy_so_far", 0.0)), done=done, reward=reward, metadata=obs_data.get("metadata", {}), ) return StepResult(observation=observation, reward=reward, done=done) def _parse_state(self, payload: Dict[str, Any]) -> TemporalBenchState: state_data = payload.get("state") if not isinstance(state_data, dict): state_data = payload if isinstance(payload, dict) else {} return TemporalBenchState( episode_id=state_data.get("episode_id"), step_count=int(state_data.get("step_count", 0)), total_correct=int(state_data.get("total_correct", 0)), total_questions=int(state_data.get("total_questions", 9)), current_accuracy=float(state_data.get("current_accuracy", 0.0)), primary_domain=str(state_data.get("primary_domain", "PSML")), per_task_type_accuracy=dict(state_data.get("per_task_type_accuracy", {})), total_reward=float(state_data.get("total_reward", 0.0)), ) TemporalbenchenvEnv = TemporalBenchEnvClient