File size: 6,281 Bytes
13d2477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Batch variant scoring using AlphaGenome for genomic variant analysis.

This MCP Server provides 1 tool:
1. score_batch_variants: Score variants in batch across modalities using AlphaGenome

All tools extracted from `AlphaPOP/score_batch.ipynb`.
"""

# Standard imports
from typing import Annotated, Literal
import pandas as pd
from pathlib import Path
import os
from fastmcp import FastMCP
from datetime import datetime
from tqdm import tqdm
from alphagenome.data import genome
from alphagenome.models import dna_client, variant_scorers

# Project structure
PROJECT_ROOT = Path(__file__).parent.parent.parent.resolve()
DEFAULT_INPUT_DIR = PROJECT_ROOT / "tmp" / "inputs"
DEFAULT_OUTPUT_DIR = PROJECT_ROOT / "tmp" / "outputs"

INPUT_DIR = Path(os.environ.get("SCORE_BATCH_INPUT_DIR", DEFAULT_INPUT_DIR))
OUTPUT_DIR = Path(os.environ.get("SCORE_BATCH_OUTPUT_DIR", DEFAULT_OUTPUT_DIR))

# Ensure directories exist
INPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Timestamp for unique outputs
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# MCP server instance
score_batch_mcp = FastMCP(name="score_batch")

@score_batch_mcp.tool
def score_batch_variants(
    api_key: Annotated[str, "API key for the AlphaGenome model"],
    vcf_file: Annotated[str | None, "Path to VCF/TSV/CSV file with extension .vcf, .tsv, or .csv. The header should include columns: variant_id, CHROM, POS, REF, ALT"] = None,
    organism: Annotated[Literal["human", "mouse"], "Organism to score against"] = "human",
    sequence_length: Annotated[Literal["2KB", "16KB", "100KB", "500KB", "1MB"], "Context window"] = "1MB",
    score_rna_seq: Annotated[bool, "Include RNA-seq signal prediction"] = True,
    score_cage: Annotated[bool, "Include CAGE"] = True,
    score_procap: Annotated[bool, "Include PRO-cap (human only)"] = True,
    score_atac: Annotated[bool, "Include ATAC"] = True,
    score_dnase: Annotated[bool, "Include DNase"] = True,
    score_chip_histone: Annotated[bool, "Include ChIP-histone"] = True,
    score_chip_tf: Annotated[bool, "Include ChIP-transcription-factor"] = True,
    score_polyadenylation: Annotated[bool, "Include polyadenylation"] = True,
    score_splice_sites: Annotated[bool, "Include splice sites"] = True,
    score_splice_site_usage: Annotated[bool, "Include splice site usage"] = True,
    score_splice_junctions: Annotated[bool, "Include splice junctions"] = True,
    out_prefix: Annotated[str | None, "Output file prefix"] = None,
) -> dict:
    """
    Score genetic variants in batch across multiple regulatory modalities using AlphaGenome.
    Input is VCF/TSV/CSV file with variant information and output is variant scores table.
    """
    # Input file validation only
    if vcf_file is None:
        raise ValueError("Path to VCF/TSV/CSV file must be provided")

    # File existence validation
    vcf_path = Path(vcf_file)
    if not vcf_path.exists():
        raise FileNotFoundError(f"Input file not found: {vcf_file}")

    # Load data
    sep = "\t" if vcf_path.suffix.lower() in {".vcf", ".tsv"} else ","
    vcf = pd.read_csv(str(vcf_path), sep=sep)

    # Create model
    dna_model = dna_client.create(api_key)

    # Parse organism specification
    organism_map = {
        "human": dna_client.Organism.HOMO_SAPIENS,
        "mouse": dna_client.Organism.MUS_MUSCULUS,
    }
    organism_enum = organism_map[organism]

    # Parse sequence length specification
    sequence_length_enum = dna_client.SUPPORTED_SEQUENCE_LENGTHS[
        f"SEQUENCE_LENGTH_{sequence_length}"
    ]

    # Parse scorer specification
    scorer_selections = {
        "rna_seq": score_rna_seq,
        "cage": score_cage,
        "procap": score_procap,
        "atac": score_atac,
        "dnase": score_dnase,
        "chip_histone": score_chip_histone,
        "chip_tf": score_chip_tf,
        "polyadenylation": score_polyadenylation,
        "splice_sites": score_splice_sites,
        "splice_site_usage": score_splice_site_usage,
        "splice_junctions": score_splice_junctions,
    }

    all_scorers = variant_scorers.RECOMMENDED_VARIANT_SCORERS
    selected_scorers = [
        all_scorers[key]
        for key in all_scorers
        if scorer_selections.get(key.lower(), False)
    ]

    # Remove any scorers that are not supported for the chosen organism
    unsupported_scorers = [
        scorer
        for scorer in selected_scorers
        if (
            organism_enum.value
            not in variant_scorers.SUPPORTED_ORGANISMS[scorer.base_variant_scorer]
        )
        or (
            (scorer.requested_output == dna_client.OutputType.PROCAP)
            and (organism_enum == dna_client.Organism.MUS_MUSCULUS)
        )
    ]
    if len(unsupported_scorers) > 0:
        for unsupported_scorer in unsupported_scorers:
            selected_scorers.remove(unsupported_scorer)

    # Score variants in the VCF file
    results = []
    for _, vcf_row in tqdm(vcf.iterrows(), total=len(vcf), desc="Scoring variants"):
        variant = genome.Variant(
            chromosome=str(vcf_row.CHROM),
            position=int(vcf_row.POS),
            reference_bases=vcf_row.REF,
            alternate_bases=vcf_row.ALT,
            name=vcf_row.variant_id,
        )
        interval = variant.reference_interval.resize(sequence_length_enum)

        variant_scores = dna_model.score_variant(
            interval=interval,
            variant=variant,
            variant_scorers=selected_scorers,
            organism=organism_enum,
        )
        results.append(variant_scores)

    # Process results
    df_scores = variant_scorers.tidy_scores(results)

    # Set output prefix
    if out_prefix is None:
        out_prefix = f"score_batch_variants_{timestamp}"

    # Save results
    download_path = OUTPUT_DIR / f"{out_prefix}.csv"
    download_path.write_text(df_scores.to_csv(index=False))

    # Return standardized format
    return {
        "message": f"Scored {len(vcf)} variants and saved results table",
        "reference": "https://github.com/AlphaPOP/blob/main/score_batch.ipynb",
        "artifacts": [
            {
                "description": "Variant scores results table",
                "path": str(download_path.resolve())
            }
        ]
    }