GenoTriage / inference.py
fierce74's picture
Upload folder using huggingface_hub
35de6f4 verified
"""
Inference Script — Genetic Variant Classification (GenoTriage)
============================================================
MANDATORY ENVIRONMENT VARIABLES:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
LOCAL_IMAGE_NAME Docker image name for the environment.
STDOUT FORMAT (strictly followed):
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
This script runs 3 tasks (easy, medium, hard). Each task iterates over
all variants in that tier (8 variants = 8 single-step episodes).
Each episode: reset() → LLM classifies → step() → log → done.
Final score per task = average reward across all episodes in that tier.
Overall score = average across all 3 tasks.
"""
import asyncio
import json
import os
import textwrap
from pathlib import Path
from typing import List, Optional
from dotenv import load_dotenv
from openai import OpenAI
from genotriage import VepAction, VepEnv
# Load .env before reading env vars
load_dotenv(Path(__file__).resolve().parent / ".env")
# --- Config ---
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") or os.getenv("IMAGE_NAME")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
BENCHMARK = "genotriage"
TEMPERATURE = 0.2 # low temp for more deterministic clinical reasoning
MAX_TOKENS = 600 # enough for classification + reasoning + criteria
SUCCESS_SCORE_THRESHOLD = 0.5
# Number of episodes per task (matches variants per tier in variants.json)
EPISODES_PER_TASK = 8
TASKS = ["easy", "medium", "hard"]
# ---------------------------------------------------------------------------
# Logging helpers — exact format required by hackathon judges
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
# Sanitise action string — remove newlines to keep on one line
action_clean = action.replace("\n", " ").replace("\r", "")[:120]
print(
f"[STEP] step={step} action={action_clean} reward={reward:.2f} "
f"done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.3f} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# System prompt
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = textwrap.dedent("""
You are an expert clinical geneticist specializing in ACMG/AMP variant classification.
You will be presented with a genetic SNP variant and must classify it into exactly
one of these five categories:
- Pathogenic
- Likely_pathogenic
- Uncertain_significance
- Likely_benign
- Benign
You MUST respond with a valid JSON object and nothing else. No markdown, no prose outside
the JSON. Use exactly this structure:
{
"classification": "<one of the five categories above>",
"reasoning": "<detailed explanation citing specific evidence values from the observation>",
"criteria_used": ["<criterion 1>", "<criterion 2>", ...]
}
CRITICAL RULES:
1. Read ALL the evidence before deciding. Do not default to any category without justification.
2. Your classification MUST be driven by the specific evidence in this variant's observation,
not by assumptions about genes or diseases in general.
3. Cite actual values: mention the population frequency number, the consequence type,
the gene name, and specific evidence snippets that support your conclusion.
4. If evidence is insufficient or conflicting, classify as Uncertain_significance.
Classification guidelines (apply to THIS variant's evidence, not general assumptions):
- Pathogenic / Likely_pathogenic: nonsense/frameshift/splice consequence, absent or
extremely rare in gnomAD (<0.01%), strong published functional/clinical evidence.
- Benign / Likely_benign: high population frequency (>0.1%), synonymous change,
non-coding with no regulatory evidence, observed in many unaffected individuals.
- Uncertain_significance: missense or regulatory variant in a disease gene where
functional data is absent, conflicting computational predictions, intermediate frequency.
""").strip()
# ---------------------------------------------------------------------------
# Prompt builder
# ---------------------------------------------------------------------------
def build_user_prompt(obs) -> str:
"""Build the user-facing prompt from a VepObservation."""
freq_str = (
f"{obs.population_frequency:.6f} ({obs.population_frequency * 100:.4f}%)"
if obs.population_frequency is not None
else "Not observed in gnomAD (frequency unavailable)"
)
snippets_block = "\n".join(
f" [{i+1}] {snippet}"
for i, snippet in enumerate(obs.evidence_snippets)
)
return textwrap.dedent(f"""
{obs.task_description}
=== VARIANT CASE ===
Gene: {obs.gene}
Chromosome: {obs.chromosome}
Position (GRCh38): {obs.position}
Reference allele: {obs.ref}
Alternate allele: {obs.alt}
HGVS notation: {obs.hgvs}
Molecular consequence: {obs.consequence or "Not annotated"}
Associated disease: {obs.disease}
Population frequency: {freq_str}
=== EVIDENCE ===
{snippets_block}
Based on all the above, provide your ACMG/AMP classification as a JSON object.
Remember: respond with ONLY the JSON object.
""").strip()
# ---------------------------------------------------------------------------
# LLM call + JSON parsing
# ---------------------------------------------------------------------------
def get_model_action(client: OpenAI, obs) -> VepAction:
"""
Call the LLM with the variant observation and parse its response into a VepAction.
Falls back to Uncertain_significance with a note if parsing fails.
"""
user_prompt = build_user_prompt(obs)
raw_response = ""
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
raw_response = (completion.choices[0].message.content or "").strip()
# Strip markdown code fences if present
if raw_response.startswith("```"):
lines = raw_response.split("\n")
raw_response = "\n".join(
l for l in lines if not l.strip().startswith("```")
).strip()
parsed = json.loads(raw_response)
# Validate classification is one of the 5 allowed values
valid_classifications = {
"Pathogenic", "Likely_pathogenic",
"Uncertain_significance", "Likely_benign", "Benign",
}
classification = parsed.get("classification", "Uncertain_significance")
if classification not in valid_classifications:
# Try to fuzzy-match common variants
classification = _fuzzy_match_classification(classification)
return VepAction(
classification=classification,
reasoning=str(parsed.get("reasoning", "No reasoning provided.")),
criteria_used=list(parsed.get("criteria_used", [])),
)
except json.JSONDecodeError as e:
print(f"[DEBUG] JSON parse error: {e}. Raw: {raw_response[:200]}", flush=True)
return VepAction(
classification="Uncertain_significance",
reasoning=f"Failed to parse model response as JSON. Raw: {raw_response[:100]}",
criteria_used=[],
)
except Exception as e:
print(f"[DEBUG] LLM call failed: {e}", flush=True)
return VepAction(
classification="Uncertain_significance",
reasoning="LLM call failed — defaulting to Uncertain_significance.",
criteria_used=[],
)
def _fuzzy_match_classification(raw: str) -> str:
"""Map common LLM output variations to valid classification strings."""
raw_lower = raw.lower().replace(" ", "_").replace("-", "_")
mapping = {
"pathogenic": "Pathogenic",
"likely_pathogenic": "Likely_pathogenic",
"uncertain_significance": "Uncertain_significance",
"vus": "Uncertain_significance",
"uncertain": "Uncertain_significance",
"likely_benign": "Likely_benign",
"benign": "Benign",
}
for key, value in mapping.items():
if key in raw_lower:
return value
return "Uncertain_significance"
# ---------------------------------------------------------------------------
# Single task runner
# ---------------------------------------------------------------------------
async def run_task(task: str, env: VepEnv, client: OpenAI) -> float:
"""
Run all episodes for one task tier.
Each episode is single-step:
reset() → observe variant → LLM classifies → step() → reward → done
Returns:
Average reward across all episodes (the task score, in [0.0, 1.0]).
"""
log_start(task=task, env=BENCHMARK, model=MODEL_NAME)
all_rewards: List[float] = []
total_steps = 0
success = False
score = 0.0
try:
for episode in range(1, EPISODES_PER_TASK + 1):
# Reset — get new variant
result = await env.reset()
obs = result.observation
if result.done:
# Shouldn't happen on reset, but guard anyway
break
# Get classification from LLM
action = get_model_action(client, obs)
# Submit to environment
result = await env.step(action)
reward = result.reward if result.reward is not None else 0.0
done = result.done
error = None
all_rewards.append(reward)
total_steps += 1
# Log step — action summarised as "classification|gene"
action_summary = f"{action.classification}|{obs.gene}"
log_step(
step=episode,
action=action_summary,
reward=reward,
done=done,
error=error,
)
# Task score = average reward across all episodes
score = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
score = round(min(max(score, 0.0), 1.0), 4)
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as e:
print(f"[DEBUG] Task {task} error: {e}", flush=True)
error_msg = str(e)[:80]
if total_steps < EPISODES_PER_TASK:
# Pad missing steps with zero reward
for s in range(total_steps + 1, EPISODES_PER_TASK + 1):
log_step(step=s, action="error", reward=0.0, done=True, error=error_msg)
all_rewards.append(0.0)
score = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
score = round(min(max(score, 0.0), 1.0), 4)
success = score >= SUCCESS_SCORE_THRESHOLD
finally:
log_end(
success=success,
steps=total_steps,
score=score,
rewards=all_rewards,
)
return score
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
async def main() -> None:
"""
Run all 3 task tiers sequentially.
Each task gets its own VepEnv instance (controlling VEP_TASK env var).
Prints overall summary at the end.
"""
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
task_scores: dict[str, float] = {}
for task in TASKS:
# Set the task env var so the environment server samples from the right tier
os.environ["VEP_TASK"] = task
env = await VepEnv.from_docker_image(IMAGE_NAME)
try:
task_score = await run_task(task=task, env=env, client=client)
task_scores[task] = task_score
finally:
try:
await env.close()
except Exception as e:
print(f"[DEBUG] env.close() error for task {task}: {e}", flush=True)
# Overall summary
if task_scores:
overall = sum(task_scores.values()) / len(task_scores)
print(
f"[DEBUG] Overall scores — "
f"easy={task_scores.get('easy', 0):.3f} "
f"medium={task_scores.get('medium', 0):.3f} "
f"hard={task_scores.get('hard', 0):.3f} "
f"overall={overall:.3f}",
flush=True,
)
if __name__ == "__main__":
asyncio.run(main())