Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Typed OpenEnv client for TemporalBenchEnv.""" | |
| from typing import Any, Dict | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_client import EnvClient | |
| try: | |
| from env.models import ( | |
| TemporalBenchAction, | |
| TemporalBenchObservation, | |
| TemporalBenchState, | |
| ) | |
| except ImportError: | |
| from TemporalBenchEnv.env.models import ( | |
| TemporalBenchAction, | |
| TemporalBenchObservation, | |
| TemporalBenchState, | |
| ) | |
| class TemporalBenchEnvClient( | |
| EnvClient[ | |
| TemporalBenchAction, | |
| TemporalBenchObservation, | |
| TemporalBenchState, | |
| ] | |
| ): | |
| """WebSocket client for TemporalBench MCQ episodes.""" | |
| def _step_payload(self, action: TemporalBenchAction) -> Dict[str, Any]: | |
| payload: Dict[str, Any] = {"answer": action.answer} | |
| if action.confidence is not None: | |
| payload["confidence"] = action.confidence | |
| if action.reasoning is not None: | |
| payload["reasoning"] = action.reasoning | |
| return payload | |
| def _parse_result(self, payload: Dict[str, Any]) -> StepResult[TemporalBenchObservation]: | |
| obs_data = payload.get("observation") | |
| if not isinstance(obs_data, dict): | |
| obs_data = payload if isinstance(payload, dict) else {} | |
| done = payload.get("done", obs_data.get("done", False)) | |
| reward = payload.get("reward", obs_data.get("reward")) | |
| observation = TemporalBenchObservation( | |
| step_idx=int(obs_data.get("step_idx", 0)), | |
| steps_remaining=int(obs_data.get("steps_remaining", 0)), | |
| max_steps=int(obs_data.get("max_steps", 9)), | |
| question=str(obs_data.get("question", "")), | |
| options=list(obs_data.get("options", [])), | |
| task_type=str(obs_data.get("task_type", "")), | |
| dataset=str(obs_data.get("dataset", "")), | |
| history=list(obs_data.get("history", [])), | |
| accuracy_so_far=float(obs_data.get("accuracy_so_far", 0.0)), | |
| done=done, | |
| reward=reward, | |
| metadata=obs_data.get("metadata", {}), | |
| ) | |
| return StepResult(observation=observation, reward=reward, done=done) | |
| def _parse_state(self, payload: Dict[str, Any]) -> TemporalBenchState: | |
| state_data = payload.get("state") | |
| if not isinstance(state_data, dict): | |
| state_data = payload if isinstance(payload, dict) else {} | |
| return TemporalBenchState( | |
| episode_id=state_data.get("episode_id"), | |
| step_count=int(state_data.get("step_count", 0)), | |
| total_correct=int(state_data.get("total_correct", 0)), | |
| total_questions=int(state_data.get("total_questions", 9)), | |
| current_accuracy=float(state_data.get("current_accuracy", 0.0)), | |
| primary_domain=str(state_data.get("primary_domain", "PSML")), | |
| per_task_type_accuracy=dict(state_data.get("per_task_type_accuracy", {})), | |
| total_reward=float(state_data.get("total_reward", 0.0)), | |
| ) | |
| TemporalbenchenvEnv = TemporalBenchEnvClient | |