# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """GenoTriage Environment Client.""" from typing import Dict, List, Optional from openenv.core import EnvClient from openenv.core.client_types import StepResult from openenv.core.env_server.types import State from .models import VepAction, VepObservation class VepEnv(EnvClient[VepAction, VepObservation, State]): """ Client for the Genetic Variant Classification Environment. Maintains a persistent WebSocket connection to the environment server. Each episode is single-step: the agent receives a variant case on reset() and submits exactly one VepAction via step() to receive its reward. Example (async): >>> async with VepEnv(base_url="http://localhost:8000") as client: ... result = await client.reset() ... obs = result.observation ... print(obs.gene, obs.disease) ... ... action = VepAction( ... classification="Pathogenic", ... reasoning="Nonsense variant in a known disease gene, absent from gnomAD.", ... criteria_used=["nonsense variant", "absent from gnomAD", "disease gene"], ... ) ... result = await client.step(action) ... print(result.reward, result.observation.feedback) Example (sync wrapper): >>> with VepEnv(base_url="http://localhost:8000").sync() as client: ... result = client.reset() ... result = client.step(VepAction( ... classification="Benign", ... reasoning="High population frequency strongly suggests common polymorphism.", ... criteria_used=["high population frequency"], ... )) Example with Docker: >>> env = await VepEnv.from_docker_image("genotriage:latest") >>> try: ... result = await env.reset() ... result = await env.step(action) ... finally: ... await env.close() """ def _step_payload(self, action: VepAction) -> Dict: """ Serialize VepAction to JSON payload for the step WebSocket message. Args: action: VepAction with classification, reasoning, and criteria_used. Returns: Dictionary ready for JSON encoding and transmission to the server. """ return { "classification": action.classification, "reasoning": action.reasoning, "criteria_used": action.criteria_used, } def _parse_result(self, payload: Dict) -> StepResult[VepObservation]: """ Parse the server's step/reset response into a StepResult[VepObservation]. Args: payload: Raw JSON response dict from the environment server. Returns: StepResult containing the VepObservation, reward, and done flag. """ obs_data = payload.get("observation", {}) observation = VepObservation( # Variant identity gene=obs_data.get("gene", ""), chromosome=obs_data.get("chromosome", ""), position=obs_data.get("position", 0), ref=obs_data.get("ref", ""), alt=obs_data.get("alt", ""), hgvs=obs_data.get("hgvs", ""), # Functional annotation consequence=obs_data.get("consequence", None), # Clinical context disease=obs_data.get("disease", ""), population_frequency=obs_data.get("population_frequency", None), # Evidence evidence_snippets=obs_data.get("evidence_snippets", []), # Task instructions task_description=obs_data.get("task_description", ""), # Post-step feedback feedback=obs_data.get("feedback", ""), # Episode state done=payload.get("done", False), reward=payload.get("reward", 0.0), metadata=obs_data.get("metadata", {}), ) return StepResult( observation=observation, reward=payload.get("reward", 0.0), done=payload.get("done", False), ) def _parse_state(self, payload: Dict) -> State: """ Parse the server's state response into a State object. Args: payload: Raw JSON response dict from the /state endpoint. Returns: State object with episode_id and step_count. """ return State( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), )