Spaces:
Sleeping
Sleeping
File size: 4,752 Bytes
35de6f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""GenoTriage Environment Client."""
from typing import Dict, List, Optional
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from .models import VepAction, VepObservation
class VepEnv(EnvClient[VepAction, VepObservation, State]):
"""
Client for the Genetic Variant Classification Environment.
Maintains a persistent WebSocket connection to the environment server.
Each episode is single-step: the agent receives a variant case on reset()
and submits exactly one VepAction via step() to receive its reward.
Example (async):
>>> async with VepEnv(base_url="http://localhost:8000") as client:
... result = await client.reset()
... obs = result.observation
... print(obs.gene, obs.disease)
...
... action = VepAction(
... classification="Pathogenic",
... reasoning="Nonsense variant in a known disease gene, absent from gnomAD.",
... criteria_used=["nonsense variant", "absent from gnomAD", "disease gene"],
... )
... result = await client.step(action)
... print(result.reward, result.observation.feedback)
Example (sync wrapper):
>>> with VepEnv(base_url="http://localhost:8000").sync() as client:
... result = client.reset()
... result = client.step(VepAction(
... classification="Benign",
... reasoning="High population frequency strongly suggests common polymorphism.",
... criteria_used=["high population frequency"],
... ))
Example with Docker:
>>> env = await VepEnv.from_docker_image("genotriage:latest")
>>> try:
... result = await env.reset()
... result = await env.step(action)
... finally:
... await env.close()
"""
def _step_payload(self, action: VepAction) -> Dict:
"""
Serialize VepAction to JSON payload for the step WebSocket message.
Args:
action: VepAction with classification, reasoning, and criteria_used.
Returns:
Dictionary ready for JSON encoding and transmission to the server.
"""
return {
"classification": action.classification,
"reasoning": action.reasoning,
"criteria_used": action.criteria_used,
}
def _parse_result(self, payload: Dict) -> StepResult[VepObservation]:
"""
Parse the server's step/reset response into a StepResult[VepObservation].
Args:
payload: Raw JSON response dict from the environment server.
Returns:
StepResult containing the VepObservation, reward, and done flag.
"""
obs_data = payload.get("observation", {})
observation = VepObservation(
# Variant identity
gene=obs_data.get("gene", ""),
chromosome=obs_data.get("chromosome", ""),
position=obs_data.get("position", 0),
ref=obs_data.get("ref", ""),
alt=obs_data.get("alt", ""),
hgvs=obs_data.get("hgvs", ""),
# Functional annotation
consequence=obs_data.get("consequence", None),
# Clinical context
disease=obs_data.get("disease", ""),
population_frequency=obs_data.get("population_frequency", None),
# Evidence
evidence_snippets=obs_data.get("evidence_snippets", []),
# Task instructions
task_description=obs_data.get("task_description", ""),
# Post-step feedback
feedback=obs_data.get("feedback", ""),
# Episode state
done=payload.get("done", False),
reward=payload.get("reward", 0.0),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward", 0.0),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> State:
"""
Parse the server's state response into a State object.
Args:
payload: Raw JSON response dict from the /state endpoint.
Returns:
State object with episode_id and step_count.
"""
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
) |