File size: 3,268 Bytes
d954568
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Typed OpenEnv client for TemporalBenchEnv."""

from typing import Any, Dict

from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient

try:
    from env.models import (
        TemporalBenchAction,
        TemporalBenchObservation,
        TemporalBenchState,
    )
except ImportError:
    from TemporalBenchEnv.env.models import (
        TemporalBenchAction,
        TemporalBenchObservation,
        TemporalBenchState,
    )


class TemporalBenchEnvClient(
    EnvClient[
        TemporalBenchAction,
        TemporalBenchObservation,
        TemporalBenchState,
    ]
):
    """WebSocket client for TemporalBench MCQ episodes."""

    def _step_payload(self, action: TemporalBenchAction) -> Dict[str, Any]:
        payload: Dict[str, Any] = {"answer": action.answer}
        if action.confidence is not None:
            payload["confidence"] = action.confidence
        if action.reasoning is not None:
            payload["reasoning"] = action.reasoning
        return payload

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[TemporalBenchObservation]:
        obs_data = payload.get("observation")
        if not isinstance(obs_data, dict):
            obs_data = payload if isinstance(payload, dict) else {}

        done = payload.get("done", obs_data.get("done", False))
        reward = payload.get("reward", obs_data.get("reward"))

        observation = TemporalBenchObservation(
            step_idx=int(obs_data.get("step_idx", 0)),
            steps_remaining=int(obs_data.get("steps_remaining", 0)),
            max_steps=int(obs_data.get("max_steps", 9)),
            question=str(obs_data.get("question", "")),
            options=list(obs_data.get("options", [])),
            task_type=str(obs_data.get("task_type", "")),
            dataset=str(obs_data.get("dataset", "")),
            history=list(obs_data.get("history", [])),
            accuracy_so_far=float(obs_data.get("accuracy_so_far", 0.0)),
            done=done,
            reward=reward,
            metadata=obs_data.get("metadata", {}),
        )
        return StepResult(observation=observation, reward=reward, done=done)

    def _parse_state(self, payload: Dict[str, Any]) -> TemporalBenchState:
        state_data = payload.get("state")
        if not isinstance(state_data, dict):
            state_data = payload if isinstance(payload, dict) else {}

        return TemporalBenchState(
            episode_id=state_data.get("episode_id"),
            step_count=int(state_data.get("step_count", 0)),
            total_correct=int(state_data.get("total_correct", 0)),
            total_questions=int(state_data.get("total_questions", 9)),
            current_accuracy=float(state_data.get("current_accuracy", 0.0)),
            primary_domain=str(state_data.get("primary_domain", "PSML")),
            per_task_type_accuracy=dict(state_data.get("per_task_type_accuracy", {})),
            total_reward=float(state_data.get("total_reward", 0.0)),
        )


TemporalbenchenvEnv = TemporalBenchEnvClient