File size: 14,293 Bytes
35de6f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
"""

Inference Script β€” Genetic Variant Classification (GenoTriage)

============================================================

MANDATORY ENVIRONMENT VARIABLES:

    API_BASE_URL        The API endpoint for the LLM.

    MODEL_NAME          The model identifier to use for inference.

    HF_TOKEN            Your Hugging Face / API key.

    LOCAL_IMAGE_NAME    Docker image name for the environment.



STDOUT FORMAT (strictly followed):

    [START] task=<task_name> env=<benchmark> model=<model_name>

    [STEP]  step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>

    [END]   success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>



This script runs 3 tasks (easy, medium, hard). Each task iterates over

all variants in that tier (8 variants = 8 single-step episodes).

Each episode: reset() β†’ LLM classifies β†’ step() β†’ log β†’ done.



Final score per task = average reward across all episodes in that tier.

Overall score = average across all 3 tasks.

"""

import asyncio
import json
import os
import textwrap
from pathlib import Path
from typing import List, Optional

from dotenv import load_dotenv
from openai import OpenAI

from genotriage import VepAction, VepEnv

# Load .env before reading env vars
load_dotenv(Path(__file__).resolve().parent / ".env")

# --- Config ---
IMAGE_NAME   = os.getenv("LOCAL_IMAGE_NAME") or os.getenv("IMAGE_NAME")
API_KEY      = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME   = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
BENCHMARK    = "genotriage"

TEMPERATURE          = 0.2   # low temp for more deterministic clinical reasoning
MAX_TOKENS           = 600   # enough for classification + reasoning + criteria
SUCCESS_SCORE_THRESHOLD = 0.5

# Number of episodes per task (matches variants per tier in variants.json)
EPISODES_PER_TASK = 8

TASKS = ["easy", "medium", "hard"]

# ---------------------------------------------------------------------------
# Logging helpers β€” exact format required by hackathon judges
# ---------------------------------------------------------------------------

def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    error_val = error if error else "null"
    done_val  = str(done).lower()
    # Sanitise action string β€” remove newlines to keep on one line
    action_clean = action.replace("\n", " ").replace("\r", "")[:120]
    print(
        f"[STEP]  step={step} action={action_clean} reward={reward:.2f} "
        f"done={done_val} error={error_val}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(
        f"[END]   success={str(success).lower()} steps={steps} "
        f"score={score:.3f} rewards={rewards_str}",
        flush=True,
    )


# ---------------------------------------------------------------------------
# System prompt
# ---------------------------------------------------------------------------

SYSTEM_PROMPT = textwrap.dedent("""

    You are an expert clinical geneticist specializing in ACMG/AMP variant classification.

    You will be presented with a genetic SNP variant and must classify it into exactly

    one of these five categories:

        - Pathogenic

        - Likely_pathogenic

        - Uncertain_significance

        - Likely_benign

        - Benign



    You MUST respond with a valid JSON object and nothing else. No markdown, no prose outside

    the JSON. Use exactly this structure:

    {

        "classification": "<one of the five categories above>",

        "reasoning": "<detailed explanation citing specific evidence values from the observation>",

        "criteria_used": ["<criterion 1>", "<criterion 2>", ...]

    }



    CRITICAL RULES:

    1. Read ALL the evidence before deciding. Do not default to any category without justification.

    2. Your classification MUST be driven by the specific evidence in this variant's observation,

       not by assumptions about genes or diseases in general.

    3. Cite actual values: mention the population frequency number, the consequence type,

       the gene name, and specific evidence snippets that support your conclusion.

    4. If evidence is insufficient or conflicting, classify as Uncertain_significance.



    Classification guidelines (apply to THIS variant's evidence, not general assumptions):

    - Pathogenic / Likely_pathogenic: nonsense/frameshift/splice consequence, absent or

      extremely rare in gnomAD (<0.01%), strong published functional/clinical evidence.

    - Benign / Likely_benign: high population frequency (>0.1%), synonymous change,

      non-coding with no regulatory evidence, observed in many unaffected individuals.

    - Uncertain_significance: missense or regulatory variant in a disease gene where

      functional data is absent, conflicting computational predictions, intermediate frequency.

""").strip()


# ---------------------------------------------------------------------------
# Prompt builder
# ---------------------------------------------------------------------------

def build_user_prompt(obs) -> str:
    """Build the user-facing prompt from a VepObservation."""
    freq_str = (
        f"{obs.population_frequency:.6f} ({obs.population_frequency * 100:.4f}%)"
        if obs.population_frequency is not None
        else "Not observed in gnomAD (frequency unavailable)"
    )

    snippets_block = "\n".join(
        f"  [{i+1}] {snippet}"
        for i, snippet in enumerate(obs.evidence_snippets)
    )

    return textwrap.dedent(f"""

        {obs.task_description}



        === VARIANT CASE ===

        Gene:                   {obs.gene}

        Chromosome:             {obs.chromosome}

        Position (GRCh38):      {obs.position}

        Reference allele:       {obs.ref}

        Alternate allele:       {obs.alt}

        HGVS notation:          {obs.hgvs}

        Molecular consequence:  {obs.consequence or "Not annotated"}

        Associated disease:     {obs.disease}

        Population frequency:   {freq_str}



        === EVIDENCE ===

        {snippets_block}



        Based on all the above, provide your ACMG/AMP classification as a JSON object.

        Remember: respond with ONLY the JSON object.

    """).strip()


# ---------------------------------------------------------------------------
# LLM call + JSON parsing
# ---------------------------------------------------------------------------

def get_model_action(client: OpenAI, obs) -> VepAction:
    """

    Call the LLM with the variant observation and parse its response into a VepAction.

    Falls back to Uncertain_significance with a note if parsing fails.

    """
    user_prompt = build_user_prompt(obs)
    raw_response = ""

    try:
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user",   "content": user_prompt},
            ],
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
            stream=False,
        )
        raw_response = (completion.choices[0].message.content or "").strip()

        # Strip markdown code fences if present
        if raw_response.startswith("```"):
            lines = raw_response.split("\n")
            raw_response = "\n".join(
                l for l in lines if not l.strip().startswith("```")
            ).strip()

        parsed = json.loads(raw_response)

        # Validate classification is one of the 5 allowed values
        valid_classifications = {
            "Pathogenic", "Likely_pathogenic",
            "Uncertain_significance", "Likely_benign", "Benign",
        }
        classification = parsed.get("classification", "Uncertain_significance")
        if classification not in valid_classifications:
            # Try to fuzzy-match common variants
            classification = _fuzzy_match_classification(classification)

        return VepAction(
            classification=classification,
            reasoning=str(parsed.get("reasoning", "No reasoning provided.")),
            criteria_used=list(parsed.get("criteria_used", [])),
        )

    except json.JSONDecodeError as e:
        print(f"[DEBUG] JSON parse error: {e}. Raw: {raw_response[:200]}", flush=True)
        return VepAction(
            classification="Uncertain_significance",
            reasoning=f"Failed to parse model response as JSON. Raw: {raw_response[:100]}",
            criteria_used=[],
        )
    except Exception as e:
        print(f"[DEBUG] LLM call failed: {e}", flush=True)
        return VepAction(
            classification="Uncertain_significance",
            reasoning="LLM call failed β€” defaulting to Uncertain_significance.",
            criteria_used=[],
        )


def _fuzzy_match_classification(raw: str) -> str:
    """Map common LLM output variations to valid classification strings."""
    raw_lower = raw.lower().replace(" ", "_").replace("-", "_")
    mapping = {
        "pathogenic":             "Pathogenic",
        "likely_pathogenic":      "Likely_pathogenic",
        "uncertain_significance": "Uncertain_significance",
        "vus":                    "Uncertain_significance",
        "uncertain":              "Uncertain_significance",
        "likely_benign":          "Likely_benign",
        "benign":                 "Benign",
    }
    for key, value in mapping.items():
        if key in raw_lower:
            return value
    return "Uncertain_significance"


# ---------------------------------------------------------------------------
# Single task runner
# ---------------------------------------------------------------------------

async def run_task(task: str, env: VepEnv, client: OpenAI) -> float:
    """

    Run all episodes for one task tier.



    Each episode is single-step:

        reset() β†’ observe variant β†’ LLM classifies β†’ step() β†’ reward β†’ done



    Returns:

        Average reward across all episodes (the task score, in [0.0, 1.0]).

    """
    log_start(task=task, env=BENCHMARK, model=MODEL_NAME)

    all_rewards: List[float] = []
    total_steps = 0
    success = False
    score = 0.0

    try:
        for episode in range(1, EPISODES_PER_TASK + 1):
            # Reset β€” get new variant
            result = await env.reset()
            obs = result.observation

            if result.done:
                # Shouldn't happen on reset, but guard anyway
                break

            # Get classification from LLM
            action = get_model_action(client, obs)

            # Submit to environment
            result = await env.step(action)

            reward = result.reward if result.reward is not None else 0.0
            done   = result.done
            error  = None

            all_rewards.append(reward)
            total_steps += 1

            # Log step β€” action summarised as "classification|gene"
            action_summary = f"{action.classification}|{obs.gene}"
            log_step(
                step=episode,
                action=action_summary,
                reward=reward,
                done=done,
                error=error,
            )

        # Task score = average reward across all episodes
        score = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
        score = round(min(max(score, 0.0), 1.0), 4)
        success = score >= SUCCESS_SCORE_THRESHOLD

    except Exception as e:
        print(f"[DEBUG] Task {task} error: {e}", flush=True)
        error_msg = str(e)[:80]
        if total_steps < EPISODES_PER_TASK:
            # Pad missing steps with zero reward
            for s in range(total_steps + 1, EPISODES_PER_TASK + 1):
                log_step(step=s, action="error", reward=0.0, done=True, error=error_msg)
                all_rewards.append(0.0)
        score = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
        score = round(min(max(score, 0.0), 1.0), 4)
        success = score >= SUCCESS_SCORE_THRESHOLD

    finally:
        log_end(
            success=success,
            steps=total_steps,
            score=score,
            rewards=all_rewards,
        )

    return score


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

async def main() -> None:
    """

    Run all 3 task tiers sequentially.

    Each task gets its own VepEnv instance (controlling VEP_TASK env var).

    Prints overall summary at the end.

    """
    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)

    task_scores: dict[str, float] = {}

    for task in TASKS:
        # Set the task env var so the environment server samples from the right tier
        os.environ["VEP_TASK"] = task

        env = await VepEnv.from_docker_image(IMAGE_NAME)

        try:
            task_score = await run_task(task=task, env=env, client=client)
            task_scores[task] = task_score
        finally:
            try:
                await env.close()
            except Exception as e:
                print(f"[DEBUG] env.close() error for task {task}: {e}", flush=True)

    # Overall summary
    if task_scores:
        overall = sum(task_scores.values()) / len(task_scores)
        print(
            f"[DEBUG] Overall scores β€” "
            f"easy={task_scores.get('easy', 0):.3f} "
            f"medium={task_scores.get('medium', 0):.3f} "
            f"hard={task_scores.get('hard', 0):.3f} "
            f"overall={overall:.3f}",
            flush=True,
        )


if __name__ == "__main__":
    asyncio.run(main())