Spaces:
Sleeping
Sleeping
File size: 14,293 Bytes
35de6f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 | """
Inference Script β Genetic Variant Classification (GenoTriage)
============================================================
MANDATORY ENVIRONMENT VARIABLES:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
LOCAL_IMAGE_NAME Docker image name for the environment.
STDOUT FORMAT (strictly followed):
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
This script runs 3 tasks (easy, medium, hard). Each task iterates over
all variants in that tier (8 variants = 8 single-step episodes).
Each episode: reset() β LLM classifies β step() β log β done.
Final score per task = average reward across all episodes in that tier.
Overall score = average across all 3 tasks.
"""
import asyncio
import json
import os
import textwrap
from pathlib import Path
from typing import List, Optional
from dotenv import load_dotenv
from openai import OpenAI
from genotriage import VepAction, VepEnv
# Load .env before reading env vars
load_dotenv(Path(__file__).resolve().parent / ".env")
# --- Config ---
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") or os.getenv("IMAGE_NAME")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
BENCHMARK = "genotriage"
TEMPERATURE = 0.2 # low temp for more deterministic clinical reasoning
MAX_TOKENS = 600 # enough for classification + reasoning + criteria
SUCCESS_SCORE_THRESHOLD = 0.5
# Number of episodes per task (matches variants per tier in variants.json)
EPISODES_PER_TASK = 8
TASKS = ["easy", "medium", "hard"]
# ---------------------------------------------------------------------------
# Logging helpers β exact format required by hackathon judges
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
# Sanitise action string β remove newlines to keep on one line
action_clean = action.replace("\n", " ").replace("\r", "")[:120]
print(
f"[STEP] step={step} action={action_clean} reward={reward:.2f} "
f"done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.3f} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# System prompt
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = textwrap.dedent("""
You are an expert clinical geneticist specializing in ACMG/AMP variant classification.
You will be presented with a genetic SNP variant and must classify it into exactly
one of these five categories:
- Pathogenic
- Likely_pathogenic
- Uncertain_significance
- Likely_benign
- Benign
You MUST respond with a valid JSON object and nothing else. No markdown, no prose outside
the JSON. Use exactly this structure:
{
"classification": "<one of the five categories above>",
"reasoning": "<detailed explanation citing specific evidence values from the observation>",
"criteria_used": ["<criterion 1>", "<criterion 2>", ...]
}
CRITICAL RULES:
1. Read ALL the evidence before deciding. Do not default to any category without justification.
2. Your classification MUST be driven by the specific evidence in this variant's observation,
not by assumptions about genes or diseases in general.
3. Cite actual values: mention the population frequency number, the consequence type,
the gene name, and specific evidence snippets that support your conclusion.
4. If evidence is insufficient or conflicting, classify as Uncertain_significance.
Classification guidelines (apply to THIS variant's evidence, not general assumptions):
- Pathogenic / Likely_pathogenic: nonsense/frameshift/splice consequence, absent or
extremely rare in gnomAD (<0.01%), strong published functional/clinical evidence.
- Benign / Likely_benign: high population frequency (>0.1%), synonymous change,
non-coding with no regulatory evidence, observed in many unaffected individuals.
- Uncertain_significance: missense or regulatory variant in a disease gene where
functional data is absent, conflicting computational predictions, intermediate frequency.
""").strip()
# ---------------------------------------------------------------------------
# Prompt builder
# ---------------------------------------------------------------------------
def build_user_prompt(obs) -> str:
"""Build the user-facing prompt from a VepObservation."""
freq_str = (
f"{obs.population_frequency:.6f} ({obs.population_frequency * 100:.4f}%)"
if obs.population_frequency is not None
else "Not observed in gnomAD (frequency unavailable)"
)
snippets_block = "\n".join(
f" [{i+1}] {snippet}"
for i, snippet in enumerate(obs.evidence_snippets)
)
return textwrap.dedent(f"""
{obs.task_description}
=== VARIANT CASE ===
Gene: {obs.gene}
Chromosome: {obs.chromosome}
Position (GRCh38): {obs.position}
Reference allele: {obs.ref}
Alternate allele: {obs.alt}
HGVS notation: {obs.hgvs}
Molecular consequence: {obs.consequence or "Not annotated"}
Associated disease: {obs.disease}
Population frequency: {freq_str}
=== EVIDENCE ===
{snippets_block}
Based on all the above, provide your ACMG/AMP classification as a JSON object.
Remember: respond with ONLY the JSON object.
""").strip()
# ---------------------------------------------------------------------------
# LLM call + JSON parsing
# ---------------------------------------------------------------------------
def get_model_action(client: OpenAI, obs) -> VepAction:
"""
Call the LLM with the variant observation and parse its response into a VepAction.
Falls back to Uncertain_significance with a note if parsing fails.
"""
user_prompt = build_user_prompt(obs)
raw_response = ""
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
raw_response = (completion.choices[0].message.content or "").strip()
# Strip markdown code fences if present
if raw_response.startswith("```"):
lines = raw_response.split("\n")
raw_response = "\n".join(
l for l in lines if not l.strip().startswith("```")
).strip()
parsed = json.loads(raw_response)
# Validate classification is one of the 5 allowed values
valid_classifications = {
"Pathogenic", "Likely_pathogenic",
"Uncertain_significance", "Likely_benign", "Benign",
}
classification = parsed.get("classification", "Uncertain_significance")
if classification not in valid_classifications:
# Try to fuzzy-match common variants
classification = _fuzzy_match_classification(classification)
return VepAction(
classification=classification,
reasoning=str(parsed.get("reasoning", "No reasoning provided.")),
criteria_used=list(parsed.get("criteria_used", [])),
)
except json.JSONDecodeError as e:
print(f"[DEBUG] JSON parse error: {e}. Raw: {raw_response[:200]}", flush=True)
return VepAction(
classification="Uncertain_significance",
reasoning=f"Failed to parse model response as JSON. Raw: {raw_response[:100]}",
criteria_used=[],
)
except Exception as e:
print(f"[DEBUG] LLM call failed: {e}", flush=True)
return VepAction(
classification="Uncertain_significance",
reasoning="LLM call failed β defaulting to Uncertain_significance.",
criteria_used=[],
)
def _fuzzy_match_classification(raw: str) -> str:
"""Map common LLM output variations to valid classification strings."""
raw_lower = raw.lower().replace(" ", "_").replace("-", "_")
mapping = {
"pathogenic": "Pathogenic",
"likely_pathogenic": "Likely_pathogenic",
"uncertain_significance": "Uncertain_significance",
"vus": "Uncertain_significance",
"uncertain": "Uncertain_significance",
"likely_benign": "Likely_benign",
"benign": "Benign",
}
for key, value in mapping.items():
if key in raw_lower:
return value
return "Uncertain_significance"
# ---------------------------------------------------------------------------
# Single task runner
# ---------------------------------------------------------------------------
async def run_task(task: str, env: VepEnv, client: OpenAI) -> float:
"""
Run all episodes for one task tier.
Each episode is single-step:
reset() β observe variant β LLM classifies β step() β reward β done
Returns:
Average reward across all episodes (the task score, in [0.0, 1.0]).
"""
log_start(task=task, env=BENCHMARK, model=MODEL_NAME)
all_rewards: List[float] = []
total_steps = 0
success = False
score = 0.0
try:
for episode in range(1, EPISODES_PER_TASK + 1):
# Reset β get new variant
result = await env.reset()
obs = result.observation
if result.done:
# Shouldn't happen on reset, but guard anyway
break
# Get classification from LLM
action = get_model_action(client, obs)
# Submit to environment
result = await env.step(action)
reward = result.reward if result.reward is not None else 0.0
done = result.done
error = None
all_rewards.append(reward)
total_steps += 1
# Log step β action summarised as "classification|gene"
action_summary = f"{action.classification}|{obs.gene}"
log_step(
step=episode,
action=action_summary,
reward=reward,
done=done,
error=error,
)
# Task score = average reward across all episodes
score = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
score = round(min(max(score, 0.0), 1.0), 4)
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as e:
print(f"[DEBUG] Task {task} error: {e}", flush=True)
error_msg = str(e)[:80]
if total_steps < EPISODES_PER_TASK:
# Pad missing steps with zero reward
for s in range(total_steps + 1, EPISODES_PER_TASK + 1):
log_step(step=s, action="error", reward=0.0, done=True, error=error_msg)
all_rewards.append(0.0)
score = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
score = round(min(max(score, 0.0), 1.0), 4)
success = score >= SUCCESS_SCORE_THRESHOLD
finally:
log_end(
success=success,
steps=total_steps,
score=score,
rewards=all_rewards,
)
return score
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
async def main() -> None:
"""
Run all 3 task tiers sequentially.
Each task gets its own VepEnv instance (controlling VEP_TASK env var).
Prints overall summary at the end.
"""
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
task_scores: dict[str, float] = {}
for task in TASKS:
# Set the task env var so the environment server samples from the right tier
os.environ["VEP_TASK"] = task
env = await VepEnv.from_docker_image(IMAGE_NAME)
try:
task_score = await run_task(task=task, env=env, client=client)
task_scores[task] = task_score
finally:
try:
await env.close()
except Exception as e:
print(f"[DEBUG] env.close() error for task {task}: {e}", flush=True)
# Overall summary
if task_scores:
overall = sum(task_scores.values()) / len(task_scores)
print(
f"[DEBUG] Overall scores β "
f"easy={task_scores.get('easy', 0):.3f} "
f"medium={task_scores.get('medium', 0):.3f} "
f"hard={task_scores.get('hard', 0):.3f} "
f"overall={overall:.3f}",
flush=True,
)
if __name__ == "__main__":
asyncio.run(main()) |