GenoTriage / client.py
fierce74's picture
Upload folder using huggingface_hub
35de6f4 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""GenoTriage Environment Client."""
from typing import Dict, List, Optional
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from .models import VepAction, VepObservation
class VepEnv(EnvClient[VepAction, VepObservation, State]):
"""
Client for the Genetic Variant Classification Environment.
Maintains a persistent WebSocket connection to the environment server.
Each episode is single-step: the agent receives a variant case on reset()
and submits exactly one VepAction via step() to receive its reward.
Example (async):
>>> async with VepEnv(base_url="http://localhost:8000") as client:
... result = await client.reset()
... obs = result.observation
... print(obs.gene, obs.disease)
...
... action = VepAction(
... classification="Pathogenic",
... reasoning="Nonsense variant in a known disease gene, absent from gnomAD.",
... criteria_used=["nonsense variant", "absent from gnomAD", "disease gene"],
... )
... result = await client.step(action)
... print(result.reward, result.observation.feedback)
Example (sync wrapper):
>>> with VepEnv(base_url="http://localhost:8000").sync() as client:
... result = client.reset()
... result = client.step(VepAction(
... classification="Benign",
... reasoning="High population frequency strongly suggests common polymorphism.",
... criteria_used=["high population frequency"],
... ))
Example with Docker:
>>> env = await VepEnv.from_docker_image("genotriage:latest")
>>> try:
... result = await env.reset()
... result = await env.step(action)
... finally:
... await env.close()
"""
def _step_payload(self, action: VepAction) -> Dict:
"""
Serialize VepAction to JSON payload for the step WebSocket message.
Args:
action: VepAction with classification, reasoning, and criteria_used.
Returns:
Dictionary ready for JSON encoding and transmission to the server.
"""
return {
"classification": action.classification,
"reasoning": action.reasoning,
"criteria_used": action.criteria_used,
}
def _parse_result(self, payload: Dict) -> StepResult[VepObservation]:
"""
Parse the server's step/reset response into a StepResult[VepObservation].
Args:
payload: Raw JSON response dict from the environment server.
Returns:
StepResult containing the VepObservation, reward, and done flag.
"""
obs_data = payload.get("observation", {})
observation = VepObservation(
# Variant identity
gene=obs_data.get("gene", ""),
chromosome=obs_data.get("chromosome", ""),
position=obs_data.get("position", 0),
ref=obs_data.get("ref", ""),
alt=obs_data.get("alt", ""),
hgvs=obs_data.get("hgvs", ""),
# Functional annotation
consequence=obs_data.get("consequence", None),
# Clinical context
disease=obs_data.get("disease", ""),
population_frequency=obs_data.get("population_frequency", None),
# Evidence
evidence_snippets=obs_data.get("evidence_snippets", []),
# Task instructions
task_description=obs_data.get("task_description", ""),
# Post-step feedback
feedback=obs_data.get("feedback", ""),
# Episode state
done=payload.get("done", False),
reward=payload.get("reward", 0.0),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward", 0.0),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> State:
"""
Parse the server's state response into a State object.
Args:
payload: Raw JSON response dict from the /state endpoint.
Returns:
State object with episode_id and step_count.
"""
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)