File size: 2,819 Bytes
ea03c8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37bfd28
 
 
ea03c8c
37bfd28
 
 
ea03c8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""WebSocket client for ChargebackOps."""

from __future__ import annotations

from typing import Any

from openenv.core import EnvClient
from openenv.core.client_types import StepResult

from .models import (
    ActionTraceItem,
    CaseQueueItem,
    CaseResolutionState,
    ChargebackOpsAction,
    ChargebackOpsObservation,
    ChargebackOpsState,
    EvidenceCard,
    GraderReport,
    PolicyView,
    VisibleCase,
)


def _parse_evidence(payload: dict[str, Any]) -> EvidenceCard:
    return EvidenceCard(**payload)


def _parse_policy(payload: dict[str, Any] | None) -> PolicyView | None:
    if payload is None:
        return None
    return PolicyView(**payload)


def _parse_visible_case(payload: dict[str, Any] | None) -> VisibleCase | None:
    if payload is None:
        return None
    data = dict(payload)
    data["retrieved_evidence"] = [
        _parse_evidence(item) for item in data.get("retrieved_evidence", [])
    ]
    data["attached_evidence"] = [
        _parse_evidence(item) for item in data.get("attached_evidence", [])
    ]
    data["policy"] = _parse_policy(data.get("policy"))
    return VisibleCase(**data)


def _parse_grader(payload: dict[str, Any] | None) -> GraderReport | None:
    if payload is None:
        return None
    return GraderReport(**payload)


class ChargebackOpsEnv(
    EnvClient[ChargebackOpsAction, ChargebackOpsObservation, ChargebackOpsState]
):
    """Typed client for the ChargebackOps environment."""

    def _step_payload(self, action: ChargebackOpsAction) -> dict[str, Any]:
        return action.model_dump()

    def _parse_result(
        self, payload: dict[str, Any]
    ) -> StepResult[ChargebackOpsObservation]:
        obs_data = dict(payload.get("observation", {}))
        obs_data["queue"] = [
            CaseQueueItem(**item) for item in obs_data.get("queue", [])
        ]
        obs_data["visible_case"] = _parse_visible_case(obs_data.get("visible_case"))
        obs_data["grader_report"] = _parse_grader(obs_data.get("grader_report"))
        observation = ChargebackOpsObservation(
            **obs_data,
            done=payload.get("done", False),
            reward=payload.get("reward"),
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: dict[str, Any]) -> ChargebackOpsState:
        data = dict(payload)
        data["queue_state"] = [
            CaseResolutionState(**item) for item in data.get("queue_state", [])
        ]
        data["action_history"] = [
            ActionTraceItem(**item) for item in data.get("action_history", [])
        ]
        data["grader_report"] = _parse_grader(data.get("grader_report"))
        return ChargebackOpsState(**data)