File size: 4,752 Bytes
35de6f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""GenoTriage Environment Client."""

from typing import Dict, List, Optional

from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State

from .models import VepAction, VepObservation


class VepEnv(EnvClient[VepAction, VepObservation, State]):
    """
    Client for the Genetic Variant Classification Environment.

    Maintains a persistent WebSocket connection to the environment server.
    Each episode is single-step: the agent receives a variant case on reset()
    and submits exactly one VepAction via step() to receive its reward.

    Example (async):
        >>> async with VepEnv(base_url="http://localhost:8000") as client:
        ...     result = await client.reset()
        ...     obs = result.observation
        ...     print(obs.gene, obs.disease)
        ...
        ...     action = VepAction(
        ...         classification="Pathogenic",
        ...         reasoning="Nonsense variant in a known disease gene, absent from gnomAD.",
        ...         criteria_used=["nonsense variant", "absent from gnomAD", "disease gene"],
        ...     )
        ...     result = await client.step(action)
        ...     print(result.reward, result.observation.feedback)

    Example (sync wrapper):
        >>> with VepEnv(base_url="http://localhost:8000").sync() as client:
        ...     result = client.reset()
        ...     result = client.step(VepAction(
        ...         classification="Benign",
        ...         reasoning="High population frequency strongly suggests common polymorphism.",
        ...         criteria_used=["high population frequency"],
        ...     ))

    Example with Docker:
        >>> env = await VepEnv.from_docker_image("genotriage:latest")
        >>> try:
        ...     result = await env.reset()
        ...     result = await env.step(action)
        ... finally:
        ...     await env.close()
    """

    def _step_payload(self, action: VepAction) -> Dict:
        """
        Serialize VepAction to JSON payload for the step WebSocket message.

        Args:
            action: VepAction with classification, reasoning, and criteria_used.

        Returns:
            Dictionary ready for JSON encoding and transmission to the server.
        """
        return {
            "classification": action.classification,
            "reasoning": action.reasoning,
            "criteria_used": action.criteria_used,
        }

    def _parse_result(self, payload: Dict) -> StepResult[VepObservation]:
        """
        Parse the server's step/reset response into a StepResult[VepObservation].

        Args:
            payload: Raw JSON response dict from the environment server.

        Returns:
            StepResult containing the VepObservation, reward, and done flag.
        """
        obs_data = payload.get("observation", {})

        observation = VepObservation(
            # Variant identity
            gene=obs_data.get("gene", ""),
            chromosome=obs_data.get("chromosome", ""),
            position=obs_data.get("position", 0),
            ref=obs_data.get("ref", ""),
            alt=obs_data.get("alt", ""),
            hgvs=obs_data.get("hgvs", ""),
            # Functional annotation
            consequence=obs_data.get("consequence", None),
            # Clinical context
            disease=obs_data.get("disease", ""),
            population_frequency=obs_data.get("population_frequency", None),
            # Evidence
            evidence_snippets=obs_data.get("evidence_snippets", []),
            # Task instructions
            task_description=obs_data.get("task_description", ""),
            # Post-step feedback
            feedback=obs_data.get("feedback", ""),
            # Episode state
            done=payload.get("done", False),
            reward=payload.get("reward", 0.0),
            metadata=obs_data.get("metadata", {}),
        )

        return StepResult(
            observation=observation,
            reward=payload.get("reward", 0.0),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict) -> State:
        """
        Parse the server's state response into a State object.

        Args:
            payload: Raw JSON response dict from the /state endpoint.

        Returns:
            State object with episode_id and step_count.
        """
        return State(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
        )