File size: 6,476 Bytes
35de6f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02bbaae
 
 
35de6f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02bbaae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35de6f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Data models for the GenoTriage Environment.

This environment presents genetic variant cases to an AI agent and asks it
to classify them using the standard 5-tier ACMG/AMP classification system.
Each episode is single-step: the agent receives a variant observation and
submits exactly one classification action.
"""

from openenv.core.env_server.types import Action, Observation
from pydantic import Field, field_validator
import json
from typing import Any, List, Literal, Optional


class VepAction(Action):
    """
    Action submitted by the agent to classify a genetic variant.

    The agent must provide a classification from the 5-tier ACMG system,
    along with a reasoning string and the specific criteria it used to
    reach its conclusion. Graders evaluate all three fields.
    """

    classification: Literal[
        "Pathogenic",
        "Likely_pathogenic",
        "Uncertain_significance",
        "Likely_benign",
        "Benign",
    ] = Field(
        ...,
        description=(
            "ACMG/AMP classification for the variant. Must be exactly one of: "
            "Pathogenic, Likely_pathogenic, Uncertain_significance, "
            "Likely_benign, Benign."
        ),
    )

    reasoning: str = Field(
        ...,
        min_length=20,
        description=(
            "A clear explanation of why this classification was chosen. "
            "Should reference the evidence presented in the observation, "
            "including population frequency, molecular consequence, and "
            "gene-disease association. Longer, well-supported reasoning "
            "receives higher partial credit."
        ),
    )

    criteria_used: List[str] = Field(
        default_factory=list,
        description=(
            "List of specific criteria that informed the classification. "
            "Examples: 'high population frequency', 'nonsense variant', "
            "'no functional studies available', 'missense in disease gene', "
            "'absent from gnomAD'. Each criterion should be a short phrase."
        ),
    )
    @field_validator("criteria_used", mode="before")
    @classmethod
    def validate_criteria_list(cls, v: Any) -> Any:
        """
        Coerces strings (JSON-formatted, bracketed, or comma-separated) into a list of strings.
        This fixes validation errors when users type lists manually into the web UI.
        """
        if isinstance(v, str):
            v = v.strip()
            # Try to parse as JSON if it looks like a list
            if v.startswith("[") and v.endswith("]"):
                try:
                    return json.loads(v)
                except (json.JSONDecodeError, ValueError):
                    # If [af] is sent, json.loads fails. We strip the brackets and treat as comma-separated.
                    v = v[1:-1].strip()
            
            # Fallback to comma-separated list
            if v:
                # Split by comma and strip extra quotes if the user typed "item1", "item2"
                return [item.strip().strip('"').strip("'") for item in v.split(",") if item.strip()]
            return []
        return v


class VepObservation(Observation):
    """
    Observation presented to the agent describing a genetic variant case.

    On reset(), the agent receives a complete variant case including all
    available evidence. After step(), the feedback and reward fields are
    populated with grader results. done=True after the first step since
    this is a single-step environment.
    """

    # --- Variant identity ---
    gene: str = Field(
        default="",
        description="HGNC gene symbol (e.g. BRCA1, TP53, CFTR).",
    )
    chromosome: str = Field(
        default="",
        description="Chromosome on which the variant resides (e.g. '17', 'X').",
    )
    position: int = Field(
        default=0,
        description="Genomic position of the variant on GRCh38.",
    )
    ref: str = Field(
        default="",
        description="Reference allele (single nucleotide for SNPs).",
    )
    alt: str = Field(
        default="",
        description="Alternate allele observed in the patient (single nucleotide for SNPs).",
    )
    hgvs: str = Field(
        default="",
        description="HGVS genomic notation for the variant (e.g. NC_000017.11:g.43094692A>G).",
    )

    # --- Functional annotation ---
    consequence: Optional[str] = Field(
        default=None,
        description=(
            "Predicted molecular consequence of the variant "
            "(e.g. missense_variant, nonsense, splice_donor_variant, synonymous_variant). "
            "None if not annotated."
        ),
    )

    # --- Clinical context ---
    disease: str = Field(
        default="",
        description=(
            "Primary disease associated with this gene in ClinVar "
            "(e.g. 'Hereditary Breast and Ovarian Cancer syndrome')."
        ),
    )
    population_frequency: Optional[float] = Field(
        default=None,
        description=(
            "Allele frequency in the gnomAD v4 population database (0.0–1.0). "
            "None if the variant was not observed in gnomAD."
        ),
    )

    # --- Evidence ---
    evidence_snippets: List[str] = Field(
        default_factory=list,
        description=(
            "A list of 3–4 evidence snippets providing clinical and functional "
            "context for the variant. May include gene-disease association, "
            "consequence interpretation, population frequency context, and "
            "available functional/literature evidence."
        ),
    )

    # --- Task instructions ---
    task_description: str = Field(
        default="",
        description=(
            "Instructions for the agent describing what it must do in this episode. "
            "Includes classification categories and grading criteria summary."
        ),
    )

    # --- Post-step feedback (empty on reset) ---
    feedback: str = Field(
        default="",
        description=(
            "Grader feedback returned after step(). Empty string on reset(). "
            "After step(), contains the correct classification, score breakdown, "
            "and notes on which criteria were correctly identified."
        ),
    )