Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """GenoTriage Environment Client.""" | |
| from typing import Dict, List, Optional | |
| from openenv.core import EnvClient | |
| from openenv.core.client_types import StepResult | |
| from openenv.core.env_server.types import State | |
| from .models import VepAction, VepObservation | |
| class VepEnv(EnvClient[VepAction, VepObservation, State]): | |
| """ | |
| Client for the Genetic Variant Classification Environment. | |
| Maintains a persistent WebSocket connection to the environment server. | |
| Each episode is single-step: the agent receives a variant case on reset() | |
| and submits exactly one VepAction via step() to receive its reward. | |
| Example (async): | |
| >>> async with VepEnv(base_url="http://localhost:8000") as client: | |
| ... result = await client.reset() | |
| ... obs = result.observation | |
| ... print(obs.gene, obs.disease) | |
| ... | |
| ... action = VepAction( | |
| ... classification="Pathogenic", | |
| ... reasoning="Nonsense variant in a known disease gene, absent from gnomAD.", | |
| ... criteria_used=["nonsense variant", "absent from gnomAD", "disease gene"], | |
| ... ) | |
| ... result = await client.step(action) | |
| ... print(result.reward, result.observation.feedback) | |
| Example (sync wrapper): | |
| >>> with VepEnv(base_url="http://localhost:8000").sync() as client: | |
| ... result = client.reset() | |
| ... result = client.step(VepAction( | |
| ... classification="Benign", | |
| ... reasoning="High population frequency strongly suggests common polymorphism.", | |
| ... criteria_used=["high population frequency"], | |
| ... )) | |
| Example with Docker: | |
| >>> env = await VepEnv.from_docker_image("genotriage:latest") | |
| >>> try: | |
| ... result = await env.reset() | |
| ... result = await env.step(action) | |
| ... finally: | |
| ... await env.close() | |
| """ | |
| def _step_payload(self, action: VepAction) -> Dict: | |
| """ | |
| Serialize VepAction to JSON payload for the step WebSocket message. | |
| Args: | |
| action: VepAction with classification, reasoning, and criteria_used. | |
| Returns: | |
| Dictionary ready for JSON encoding and transmission to the server. | |
| """ | |
| return { | |
| "classification": action.classification, | |
| "reasoning": action.reasoning, | |
| "criteria_used": action.criteria_used, | |
| } | |
| def _parse_result(self, payload: Dict) -> StepResult[VepObservation]: | |
| """ | |
| Parse the server's step/reset response into a StepResult[VepObservation]. | |
| Args: | |
| payload: Raw JSON response dict from the environment server. | |
| Returns: | |
| StepResult containing the VepObservation, reward, and done flag. | |
| """ | |
| obs_data = payload.get("observation", {}) | |
| observation = VepObservation( | |
| # Variant identity | |
| gene=obs_data.get("gene", ""), | |
| chromosome=obs_data.get("chromosome", ""), | |
| position=obs_data.get("position", 0), | |
| ref=obs_data.get("ref", ""), | |
| alt=obs_data.get("alt", ""), | |
| hgvs=obs_data.get("hgvs", ""), | |
| # Functional annotation | |
| consequence=obs_data.get("consequence", None), | |
| # Clinical context | |
| disease=obs_data.get("disease", ""), | |
| population_frequency=obs_data.get("population_frequency", None), | |
| # Evidence | |
| evidence_snippets=obs_data.get("evidence_snippets", []), | |
| # Task instructions | |
| task_description=obs_data.get("task_description", ""), | |
| # Post-step feedback | |
| feedback=obs_data.get("feedback", ""), | |
| # Episode state | |
| done=payload.get("done", False), | |
| reward=payload.get("reward", 0.0), | |
| metadata=obs_data.get("metadata", {}), | |
| ) | |
| return StepResult( | |
| observation=observation, | |
| reward=payload.get("reward", 0.0), | |
| done=payload.get("done", False), | |
| ) | |
| def _parse_state(self, payload: Dict) -> State: | |
| """ | |
| Parse the server's state response into a State object. | |
| Args: | |
| payload: Raw JSON response dict from the /state endpoint. | |
| Returns: | |
| State object with episode_id and step_count. | |
| """ | |
| return State( | |
| episode_id=payload.get("episode_id"), | |
| step_count=payload.get("step_count", 0), | |
| ) |