Spaces:
Sleeping
Sleeping
File size: 3,268 Bytes
d954568 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Typed OpenEnv client for TemporalBenchEnv."""
from typing import Any, Dict
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
try:
from env.models import (
TemporalBenchAction,
TemporalBenchObservation,
TemporalBenchState,
)
except ImportError:
from TemporalBenchEnv.env.models import (
TemporalBenchAction,
TemporalBenchObservation,
TemporalBenchState,
)
class TemporalBenchEnvClient(
EnvClient[
TemporalBenchAction,
TemporalBenchObservation,
TemporalBenchState,
]
):
"""WebSocket client for TemporalBench MCQ episodes."""
def _step_payload(self, action: TemporalBenchAction) -> Dict[str, Any]:
payload: Dict[str, Any] = {"answer": action.answer}
if action.confidence is not None:
payload["confidence"] = action.confidence
if action.reasoning is not None:
payload["reasoning"] = action.reasoning
return payload
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[TemporalBenchObservation]:
obs_data = payload.get("observation")
if not isinstance(obs_data, dict):
obs_data = payload if isinstance(payload, dict) else {}
done = payload.get("done", obs_data.get("done", False))
reward = payload.get("reward", obs_data.get("reward"))
observation = TemporalBenchObservation(
step_idx=int(obs_data.get("step_idx", 0)),
steps_remaining=int(obs_data.get("steps_remaining", 0)),
max_steps=int(obs_data.get("max_steps", 9)),
question=str(obs_data.get("question", "")),
options=list(obs_data.get("options", [])),
task_type=str(obs_data.get("task_type", "")),
dataset=str(obs_data.get("dataset", "")),
history=list(obs_data.get("history", [])),
accuracy_so_far=float(obs_data.get("accuracy_so_far", 0.0)),
done=done,
reward=reward,
metadata=obs_data.get("metadata", {}),
)
return StepResult(observation=observation, reward=reward, done=done)
def _parse_state(self, payload: Dict[str, Any]) -> TemporalBenchState:
state_data = payload.get("state")
if not isinstance(state_data, dict):
state_data = payload if isinstance(payload, dict) else {}
return TemporalBenchState(
episode_id=state_data.get("episode_id"),
step_count=int(state_data.get("step_count", 0)),
total_correct=int(state_data.get("total_correct", 0)),
total_questions=int(state_data.get("total_questions", 9)),
current_accuracy=float(state_data.get("current_accuracy", 0.0)),
primary_domain=str(state_data.get("primary_domain", "PSML")),
per_task_type_accuracy=dict(state_data.get("per_task_type_accuracy", {})),
total_reward=float(state_data.get("total_reward", 0.0)),
)
TemporalbenchenvEnv = TemporalBenchEnvClient
|