Spaces:
Sleeping
Sleeping
| """ | |
| Inference Script — Genetic Variant Classification (GenoTriage) | |
| ============================================================ | |
| MANDATORY ENVIRONMENT VARIABLES: | |
| API_BASE_URL The API endpoint for the LLM. | |
| MODEL_NAME The model identifier to use for inference. | |
| HF_TOKEN Your Hugging Face / API key. | |
| LOCAL_IMAGE_NAME Docker image name for the environment. | |
| STDOUT FORMAT (strictly followed): | |
| [START] task=<task_name> env=<benchmark> model=<model_name> | |
| [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null> | |
| [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn> | |
| This script runs 3 tasks (easy, medium, hard). Each task iterates over | |
| all variants in that tier (8 variants = 8 single-step episodes). | |
| Each episode: reset() → LLM classifies → step() → log → done. | |
| Final score per task = average reward across all episodes in that tier. | |
| Overall score = average across all 3 tasks. | |
| """ | |
| import asyncio | |
| import json | |
| import os | |
| import textwrap | |
| from pathlib import Path | |
| from typing import List, Optional | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| from genotriage import VepAction, VepEnv | |
| # Load .env before reading env vars | |
| load_dotenv(Path(__file__).resolve().parent / ".env") | |
| # --- Config --- | |
| IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") or os.getenv("IMAGE_NAME") | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY") | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| BENCHMARK = "genotriage" | |
| TEMPERATURE = 0.2 # low temp for more deterministic clinical reasoning | |
| MAX_TOKENS = 600 # enough for classification + reasoning + criteria | |
| SUCCESS_SCORE_THRESHOLD = 0.5 | |
| # Number of episodes per task (matches variants per tier in variants.json) | |
| EPISODES_PER_TASK = 8 | |
| TASKS = ["easy", "medium", "hard"] | |
| # --------------------------------------------------------------------------- | |
| # Logging helpers — exact format required by hackathon judges | |
| # --------------------------------------------------------------------------- | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| # Sanitise action string — remove newlines to keep on one line | |
| action_clean = action.replace("\n", " ").replace("\r", "")[:120] | |
| print( | |
| f"[STEP] step={step} action={action_clean} reward={reward:.2f} " | |
| f"done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} " | |
| f"score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # System prompt | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = textwrap.dedent(""" | |
| You are an expert clinical geneticist specializing in ACMG/AMP variant classification. | |
| You will be presented with a genetic SNP variant and must classify it into exactly | |
| one of these five categories: | |
| - Pathogenic | |
| - Likely_pathogenic | |
| - Uncertain_significance | |
| - Likely_benign | |
| - Benign | |
| You MUST respond with a valid JSON object and nothing else. No markdown, no prose outside | |
| the JSON. Use exactly this structure: | |
| { | |
| "classification": "<one of the five categories above>", | |
| "reasoning": "<detailed explanation citing specific evidence values from the observation>", | |
| "criteria_used": ["<criterion 1>", "<criterion 2>", ...] | |
| } | |
| CRITICAL RULES: | |
| 1. Read ALL the evidence before deciding. Do not default to any category without justification. | |
| 2. Your classification MUST be driven by the specific evidence in this variant's observation, | |
| not by assumptions about genes or diseases in general. | |
| 3. Cite actual values: mention the population frequency number, the consequence type, | |
| the gene name, and specific evidence snippets that support your conclusion. | |
| 4. If evidence is insufficient or conflicting, classify as Uncertain_significance. | |
| Classification guidelines (apply to THIS variant's evidence, not general assumptions): | |
| - Pathogenic / Likely_pathogenic: nonsense/frameshift/splice consequence, absent or | |
| extremely rare in gnomAD (<0.01%), strong published functional/clinical evidence. | |
| - Benign / Likely_benign: high population frequency (>0.1%), synonymous change, | |
| non-coding with no regulatory evidence, observed in many unaffected individuals. | |
| - Uncertain_significance: missense or regulatory variant in a disease gene where | |
| functional data is absent, conflicting computational predictions, intermediate frequency. | |
| """).strip() | |
| # --------------------------------------------------------------------------- | |
| # Prompt builder | |
| # --------------------------------------------------------------------------- | |
| def build_user_prompt(obs) -> str: | |
| """Build the user-facing prompt from a VepObservation.""" | |
| freq_str = ( | |
| f"{obs.population_frequency:.6f} ({obs.population_frequency * 100:.4f}%)" | |
| if obs.population_frequency is not None | |
| else "Not observed in gnomAD (frequency unavailable)" | |
| ) | |
| snippets_block = "\n".join( | |
| f" [{i+1}] {snippet}" | |
| for i, snippet in enumerate(obs.evidence_snippets) | |
| ) | |
| return textwrap.dedent(f""" | |
| {obs.task_description} | |
| === VARIANT CASE === | |
| Gene: {obs.gene} | |
| Chromosome: {obs.chromosome} | |
| Position (GRCh38): {obs.position} | |
| Reference allele: {obs.ref} | |
| Alternate allele: {obs.alt} | |
| HGVS notation: {obs.hgvs} | |
| Molecular consequence: {obs.consequence or "Not annotated"} | |
| Associated disease: {obs.disease} | |
| Population frequency: {freq_str} | |
| === EVIDENCE === | |
| {snippets_block} | |
| Based on all the above, provide your ACMG/AMP classification as a JSON object. | |
| Remember: respond with ONLY the JSON object. | |
| """).strip() | |
| # --------------------------------------------------------------------------- | |
| # LLM call + JSON parsing | |
| # --------------------------------------------------------------------------- | |
| def get_model_action(client: OpenAI, obs) -> VepAction: | |
| """ | |
| Call the LLM with the variant observation and parse its response into a VepAction. | |
| Falls back to Uncertain_significance with a note if parsing fails. | |
| """ | |
| user_prompt = build_user_prompt(obs) | |
| raw_response = "" | |
| try: | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| stream=False, | |
| ) | |
| raw_response = (completion.choices[0].message.content or "").strip() | |
| # Strip markdown code fences if present | |
| if raw_response.startswith("```"): | |
| lines = raw_response.split("\n") | |
| raw_response = "\n".join( | |
| l for l in lines if not l.strip().startswith("```") | |
| ).strip() | |
| parsed = json.loads(raw_response) | |
| # Validate classification is one of the 5 allowed values | |
| valid_classifications = { | |
| "Pathogenic", "Likely_pathogenic", | |
| "Uncertain_significance", "Likely_benign", "Benign", | |
| } | |
| classification = parsed.get("classification", "Uncertain_significance") | |
| if classification not in valid_classifications: | |
| # Try to fuzzy-match common variants | |
| classification = _fuzzy_match_classification(classification) | |
| return VepAction( | |
| classification=classification, | |
| reasoning=str(parsed.get("reasoning", "No reasoning provided.")), | |
| criteria_used=list(parsed.get("criteria_used", [])), | |
| ) | |
| except json.JSONDecodeError as e: | |
| print(f"[DEBUG] JSON parse error: {e}. Raw: {raw_response[:200]}", flush=True) | |
| return VepAction( | |
| classification="Uncertain_significance", | |
| reasoning=f"Failed to parse model response as JSON. Raw: {raw_response[:100]}", | |
| criteria_used=[], | |
| ) | |
| except Exception as e: | |
| print(f"[DEBUG] LLM call failed: {e}", flush=True) | |
| return VepAction( | |
| classification="Uncertain_significance", | |
| reasoning="LLM call failed — defaulting to Uncertain_significance.", | |
| criteria_used=[], | |
| ) | |
| def _fuzzy_match_classification(raw: str) -> str: | |
| """Map common LLM output variations to valid classification strings.""" | |
| raw_lower = raw.lower().replace(" ", "_").replace("-", "_") | |
| mapping = { | |
| "pathogenic": "Pathogenic", | |
| "likely_pathogenic": "Likely_pathogenic", | |
| "uncertain_significance": "Uncertain_significance", | |
| "vus": "Uncertain_significance", | |
| "uncertain": "Uncertain_significance", | |
| "likely_benign": "Likely_benign", | |
| "benign": "Benign", | |
| } | |
| for key, value in mapping.items(): | |
| if key in raw_lower: | |
| return value | |
| return "Uncertain_significance" | |
| # --------------------------------------------------------------------------- | |
| # Single task runner | |
| # --------------------------------------------------------------------------- | |
| async def run_task(task: str, env: VepEnv, client: OpenAI) -> float: | |
| """ | |
| Run all episodes for one task tier. | |
| Each episode is single-step: | |
| reset() → observe variant → LLM classifies → step() → reward → done | |
| Returns: | |
| Average reward across all episodes (the task score, in [0.0, 1.0]). | |
| """ | |
| log_start(task=task, env=BENCHMARK, model=MODEL_NAME) | |
| all_rewards: List[float] = [] | |
| total_steps = 0 | |
| success = False | |
| score = 0.0 | |
| try: | |
| for episode in range(1, EPISODES_PER_TASK + 1): | |
| # Reset — get new variant | |
| result = await env.reset() | |
| obs = result.observation | |
| if result.done: | |
| # Shouldn't happen on reset, but guard anyway | |
| break | |
| # Get classification from LLM | |
| action = get_model_action(client, obs) | |
| # Submit to environment | |
| result = await env.step(action) | |
| reward = result.reward if result.reward is not None else 0.0 | |
| done = result.done | |
| error = None | |
| all_rewards.append(reward) | |
| total_steps += 1 | |
| # Log step — action summarised as "classification|gene" | |
| action_summary = f"{action.classification}|{obs.gene}" | |
| log_step( | |
| step=episode, | |
| action=action_summary, | |
| reward=reward, | |
| done=done, | |
| error=error, | |
| ) | |
| # Task score = average reward across all episodes | |
| score = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0 | |
| score = round(min(max(score, 0.0), 1.0), 4) | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| except Exception as e: | |
| print(f"[DEBUG] Task {task} error: {e}", flush=True) | |
| error_msg = str(e)[:80] | |
| if total_steps < EPISODES_PER_TASK: | |
| # Pad missing steps with zero reward | |
| for s in range(total_steps + 1, EPISODES_PER_TASK + 1): | |
| log_step(step=s, action="error", reward=0.0, done=True, error=error_msg) | |
| all_rewards.append(0.0) | |
| score = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0 | |
| score = round(min(max(score, 0.0), 1.0), 4) | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| finally: | |
| log_end( | |
| success=success, | |
| steps=total_steps, | |
| score=score, | |
| rewards=all_rewards, | |
| ) | |
| return score | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| async def main() -> None: | |
| """ | |
| Run all 3 task tiers sequentially. | |
| Each task gets its own VepEnv instance (controlling VEP_TASK env var). | |
| Prints overall summary at the end. | |
| """ | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| task_scores: dict[str, float] = {} | |
| for task in TASKS: | |
| # Set the task env var so the environment server samples from the right tier | |
| os.environ["VEP_TASK"] = task | |
| env = await VepEnv.from_docker_image(IMAGE_NAME) | |
| try: | |
| task_score = await run_task(task=task, env=env, client=client) | |
| task_scores[task] = task_score | |
| finally: | |
| try: | |
| await env.close() | |
| except Exception as e: | |
| print(f"[DEBUG] env.close() error for task {task}: {e}", flush=True) | |
| # Overall summary | |
| if task_scores: | |
| overall = sum(task_scores.values()) / len(task_scores) | |
| print( | |
| f"[DEBUG] Overall scores — " | |
| f"easy={task_scores.get('easy', 0):.3f} " | |
| f"medium={task_scores.get('medium', 0):.3f} " | |
| f"hard={task_scores.get('hard', 0):.3f} " | |
| f"overall={overall:.3f}", | |
| flush=True, | |
| ) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |