TemporalBenchEnv / client.py
yashu2000's picture
Upload folder using huggingface_hub
d954568 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Typed OpenEnv client for TemporalBenchEnv."""
from typing import Any, Dict
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
try:
from env.models import (
TemporalBenchAction,
TemporalBenchObservation,
TemporalBenchState,
)
except ImportError:
from TemporalBenchEnv.env.models import (
TemporalBenchAction,
TemporalBenchObservation,
TemporalBenchState,
)
class TemporalBenchEnvClient(
EnvClient[
TemporalBenchAction,
TemporalBenchObservation,
TemporalBenchState,
]
):
"""WebSocket client for TemporalBench MCQ episodes."""
def _step_payload(self, action: TemporalBenchAction) -> Dict[str, Any]:
payload: Dict[str, Any] = {"answer": action.answer}
if action.confidence is not None:
payload["confidence"] = action.confidence
if action.reasoning is not None:
payload["reasoning"] = action.reasoning
return payload
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[TemporalBenchObservation]:
obs_data = payload.get("observation")
if not isinstance(obs_data, dict):
obs_data = payload if isinstance(payload, dict) else {}
done = payload.get("done", obs_data.get("done", False))
reward = payload.get("reward", obs_data.get("reward"))
observation = TemporalBenchObservation(
step_idx=int(obs_data.get("step_idx", 0)),
steps_remaining=int(obs_data.get("steps_remaining", 0)),
max_steps=int(obs_data.get("max_steps", 9)),
question=str(obs_data.get("question", "")),
options=list(obs_data.get("options", [])),
task_type=str(obs_data.get("task_type", "")),
dataset=str(obs_data.get("dataset", "")),
history=list(obs_data.get("history", [])),
accuracy_so_far=float(obs_data.get("accuracy_so_far", 0.0)),
done=done,
reward=reward,
metadata=obs_data.get("metadata", {}),
)
return StepResult(observation=observation, reward=reward, done=done)
def _parse_state(self, payload: Dict[str, Any]) -> TemporalBenchState:
state_data = payload.get("state")
if not isinstance(state_data, dict):
state_data = payload if isinstance(payload, dict) else {}
return TemporalBenchState(
episode_id=state_data.get("episode_id"),
step_count=int(state_data.get("step_count", 0)),
total_correct=int(state_data.get("total_correct", 0)),
total_questions=int(state_data.get("total_questions", 9)),
current_accuracy=float(state_data.get("current_accuracy", 0.0)),
primary_domain=str(state_data.get("primary_domain", "PSML")),
per_task_type_accuracy=dict(state_data.get("per_task_type_accuracy", {})),
total_reward=float(state_data.get("total_reward", 0.0)),
)
TemporalbenchenvEnv = TemporalBenchEnvClient