ZhiyuanChen's picture
implement splice-variant-effect app
ae141c8 unverified
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule
# This file is part of MultiMolecule.
# MultiMolecule is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.
# MultiMolecule is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# For additional terms and clarifications, please refer to our License FAQ at:
# <https://multimolecule.danling.org/about/license-faq>.
from __future__ import annotations
import json
import math
import tempfile
import time
from collections.abc import Mapping
from datetime import datetime, timezone
from functools import lru_cache
from pathlib import Path
from typing import Any
from urllib.parse import parse_qs, urlparse
import gradio as gr
import matplotlib
import numpy as np
import pandas as pd
import torch
from Bio import SeqIO
from transformers import pipeline
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers
MODEL_OPTIONS = {
"MMSplice": "multimolecule/mmsplice",
"MTSplice": "multimolecule/mtsplice",
"HAL": "multimolecule/hal",
"MaxEntScan score5": "multimolecule/maxentscan-score5",
"MaxEntScan score3": "multimolecule/maxentscan-score3",
"Pangolin": "multimolecule/pangolin",
"SpTransformer": "multimolecule/sptransformer",
}
MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()}
FASTA_SUFFIXES = {".fa", ".fasta", ".fna"}
VALID_DNA = set("ACGTNRYSWKMBDHVX")
META_COLUMNS = {"scope", "position", "nucleotide", "sequence", "label", "type"}
DEFAULT_REFERENCE = "ACGT" * 25 + "CCCCCCCCCCCCCCCCCCCC" + "TGCA" * 25
DEFAULT_ALTERNATIVE = "ACGT" * 25 + "CCCCCCCCCCCTCCCCCCCC" + "TGCA" * 25
def _device() -> int:
return 0 if torch.cuda.is_available() else -1
@lru_cache(maxsize=len(MODEL_OPTIONS))
def load_predictor(model_id: str):
return pipeline("splice-variant-effect", model=model_id, device=_device())
def clean_sequence(sequence: str, label: str) -> str:
sequence = "".join(str(sequence or "").split()).upper().replace("U", "T")
if not sequence:
raise gr.Error(f"{label} sequence is empty.")
invalid = sorted(set(sequence) - VALID_DNA)
if invalid:
raise gr.Error(f"{label} sequence contains unsupported DNA symbols: {', '.join(invalid)}.")
return sequence
def validate_pair(reference: str, alternative: str) -> tuple[str, str]:
reference = clean_sequence(reference, "Reference")
alternative = clean_sequence(alternative, "Alternative")
if len(reference) != len(alternative):
raise gr.Error(
"Reference and alternative sequences must have the same length. "
"This app does not perform genome-coordinate lookup or sequence reconstruction."
)
return reference, alternative
def load_fasta_pair(input_file: Any):
if input_file is None:
return gr.update(), gr.update()
path = Path(getattr(input_file, "name", input_file))
if path.suffix.lower() not in FASTA_SUFFIXES:
raise gr.Error("Upload a FASTA file with two records: reference first, alternative second.")
records = list(SeqIO.parse(path, "fasta"))
if len(records) != 2:
raise gr.Error(f"Expected exactly two FASTA records, found {len(records)}.")
reference, alternative = validate_pair(str(records[0].seq), str(records[1].seq))
return reference, alternative
def _json_safe(value: Any) -> Any:
if isinstance(value, torch.Tensor):
return _json_safe(value.detach().cpu().tolist())
if isinstance(value, np.ndarray):
return _json_safe(value.tolist())
if isinstance(value, np.generic):
return value.item()
if isinstance(value, Mapping):
return {str(key): _json_safe(item) for key, item in value.items()}
if isinstance(value, (list, tuple)):
return [_json_safe(item) for item in value]
return value
def _is_scalar(value: Any) -> bool:
if isinstance(value, (str, bytes)) or value is None:
return False
try:
float(value)
except (TypeError, ValueError):
return False
return True
def _number(value: Any) -> float | Any:
if not _is_scalar(value):
return value
number = float(value)
if math.isfinite(number):
return number
return value
def _position_key(key: Any) -> bool:
try:
int(str(key))
except ValueError:
return False
return True
def _vector_row(values: list[Any], channels: list[str], scalar_column: str, scope: str = "sequence") -> dict[str, Any]:
row: dict[str, Any] = {"scope": scope}
if channels and len(values) == len(channels):
row.update({channel: _number(value) for channel, value in zip(channels, values)})
elif len(values) == 1:
row[scalar_column] = _number(values[0])
else:
row.update({f"{scalar_column}_{index}": _number(value) for index, value in enumerate(values)})
return row
def _flatten_mapping(
mapping: Mapping[str, Any],
channels: list[str],
scalar_column: str,
prefix: str | None = None,
) -> dict[str, Any]:
row: dict[str, Any] = {}
for key, value in mapping.items():
key = str(key)
column = f"{prefix}_{key}" if prefix else key
value = _json_safe(value)
if _is_scalar(value) or value is None or isinstance(value, str):
row[column] = _number(value)
elif isinstance(value, Mapping):
row.update(_flatten_mapping(value, channels, scalar_column, prefix=column))
elif isinstance(value, list) and all(_is_scalar(item) for item in value):
if key in META_COLUMNS:
row[column] = value
elif channels and len(value) == len(channels):
row.update({channel: _number(item) for channel, item in zip(channels, value)})
else:
row.update({f"{column}_{index}": _number(item) for index, item in enumerate(value)})
else:
row[column] = value
return row
def normalize_score_rows(score_value: Any, channels: list[str], scalar_column: str) -> list[dict[str, Any]]:
score_value = _json_safe(score_value)
if score_value is None:
return []
if _is_scalar(score_value):
return [{"scope": "sequence", scalar_column: _number(score_value)}]
if isinstance(score_value, Mapping):
if score_value and not all(_position_key(key) for key in score_value):
series_lengths = {
len(value)
for value in score_value.values()
if isinstance(value, list) and all(_is_scalar(item) for item in value)
}
if len(series_lengths) == 1:
length = series_lengths.pop()
if length > 1 and all(isinstance(value, list) for value in score_value.values()):
return [
{
"position": position,
**{str(key): _number(value[position]) for key, value in score_value.items()},
}
for position in range(length)
]
if score_value and all(_position_key(key) for key in score_value):
rows = []
for key, value in score_value.items():
row = {"position": int(str(key))}
if isinstance(value, Mapping):
row.update(_flatten_mapping(value, channels, scalar_column))
elif isinstance(value, list):
row.update(_vector_row(value, channels, scalar_column, scope="position"))
row.pop("scope", None)
else:
row[scalar_column] = _number(value)
rows.append(row)
return rows
return [_flatten_mapping(score_value, channels, scalar_column)]
if isinstance(score_value, list):
if not score_value:
return []
if all(_is_scalar(item) for item in score_value):
return [_vector_row(score_value, channels, scalar_column)]
rows = []
for index, item in enumerate(score_value):
item = _json_safe(item)
if isinstance(item, Mapping):
rows.append(_flatten_mapping(item, channels, scalar_column))
elif isinstance(item, list):
row = {"position": index}
row.update(_vector_row(item, channels, scalar_column, scope="position"))
row.pop("scope", None)
rows.append(row)
elif _is_scalar(item):
rows.append({"position": index, scalar_column: _number(item)})
return rows
return [{"scope": "sequence", scalar_column: score_value}]
def result_table(result: Mapping[str, Any], score_key: str, scores_key: str, scalar_column: str) -> pd.DataFrame:
channels = [str(channel) for channel in result.get("channels", [])]
score_value = result.get(scores_key, result.get(score_key))
rows = normalize_score_rows(score_value, channels, scalar_column)
if not rows:
return pd.DataFrame()
table = pd.DataFrame(rows)
ordered = [column for column in ("scope", "position", "nucleotide", "sequence", "label", "type") if column in table]
remaining = [column for column in table.columns if column not in ordered]
return table[ordered + remaining]
def dataframe_records(table: pd.DataFrame) -> list[dict[str, Any]]:
if table.empty:
return []
return json.loads(table.to_json(orient="records"))
def difference_summary(reference: str, alternative: str) -> dict[str, Any]:
differences = [
{
"position": index,
"reference": ref_base,
"alternative": alt_base,
}
for index, (ref_base, alt_base) in enumerate(zip(reference, alternative))
if ref_base != alt_base
]
return {
"count": len(differences),
"positions": differences[:25],
"positions_truncated": len(differences) > 25,
}
def make_delta_plot(delta_table: pd.DataFrame, model_label: str):
fig, ax = plt.subplots(figsize=(7, 2.8))
values: list[tuple[str, float]] = []
if not delta_table.empty:
numeric_columns = [
column
for column in delta_table.columns
if column not in META_COLUMNS and pd.api.types.is_numeric_dtype(delta_table[column])
]
for _, row in delta_table.iterrows():
position = row.get("position")
for column in numeric_columns:
value = row.get(column)
if pd.notna(value):
suffix = f"@{int(position)}" if position is not None and pd.notna(position) else ""
values.append((f"{column}{suffix}", float(value)))
values = sorted(values, key=lambda item: abs(item[1]), reverse=True)[:20]
values.reverse()
if not values:
ax.text(0.5, 0.5, "No numeric delta scores", ha="center", va="center")
ax.set_axis_off()
fig.tight_layout()
return fig
labels, scores = zip(*values)
colors = ["#2563eb" if score >= 0 else "#dc2626" for score in scores]
ax.barh(labels, scores, color=colors)
ax.axvline(0, color="#111827", linewidth=0.8)
ax.set_title(f"{model_label} top delta scores")
ax.set_xlabel("alternative - reference")
ax.tick_params(axis="y", labelsize=8)
fig.tight_layout()
return fig
def write_result_files(
metadata: dict[str, Any],
result: Mapping[str, Any],
delta_table: pd.DataFrame,
reference_table: pd.DataFrame,
alternative_table: pd.DataFrame,
) -> tuple[str, str]:
csv_tables = []
for score_set, table in (
("delta", delta_table),
("reference", reference_table),
("alternative", alternative_table),
):
if not table.empty:
csv_table = table.copy()
csv_table.insert(0, "score_set", score_set)
csv_tables.append(csv_table)
csv_payload = pd.concat(csv_tables, ignore_index=True, sort=False) if csv_tables else pd.DataFrame()
csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="")
csv_path = csv_file.name
csv_file.close()
csv_payload.to_csv(csv_path, index=False)
json_payload = {
"metadata": metadata,
"result": _json_safe(result),
"tables": {
"delta": dataframe_records(delta_table),
"reference": dataframe_records(reference_table),
"alternative": dataframe_records(alternative_table),
},
}
json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False)
json_path = json_file.name
json_file.close()
with open(json_path, "w") as handle:
json.dump(json_payload, handle, indent=2)
return csv_path, json_path
def unpack_prediction_result(result: Any) -> Mapping[str, Any]:
result = _json_safe(result)
if isinstance(result, list):
if len(result) != 1:
raise gr.Error(f"Expected one prediction result, got {len(result)}.")
result = result[0]
if not isinstance(result, Mapping):
raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.")
return result
def predict(model_label: str, reference: str, alternative: str):
started = time.perf_counter()
model_id = MODEL_OPTIONS[model_label]
reference, alternative = validate_pair(reference, alternative)
try:
predictor = load_predictor(model_id)
result = unpack_prediction_result(predictor(reference, alternative=alternative))
except gr.Error:
raise
except Exception as exc:
raise gr.Error(f"Prediction failed for {model_label}: {exc}") from exc
delta_table = result_table(result, "delta_score", "delta_scores", "delta_score")
reference_table = result_table(result, "reference_score", "reference_scores", "reference_score")
alternative_table = result_table(result, "alternative_score", "alternative_scores", "alternative_score")
metadata = {
"task": "splice-variant-effect",
"model": model_id,
"model_label": model_label,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"reference_length": len(reference),
"alternative_length": len(alternative),
"differences": difference_summary(reference, alternative),
"channels": result.get("channels", []),
"output_fields": sorted(result.keys()),
"runtime_seconds": round(time.perf_counter() - started, 3),
"timestamp_utc": datetime.now(timezone.utc).isoformat(),
}
csv_path, json_path = write_result_files(metadata, result, delta_table, reference_table, alternative_table)
delta_plot = make_delta_plot(delta_table, model_label)
return delta_table, reference_table, alternative_table, metadata, delta_plot, csv_path, json_path
def initial_model(request: gr.Request):
if request is None:
return "MMSplice"
query_params = getattr(request, "query_params", None)
model_id = None
if query_params is not None:
model_id = query_params.get("model")
if not model_id and getattr(request, "url", None):
parsed = parse_qs(urlparse(str(request.url)).query)
model_values = parsed.get("model")
model_id = model_values[0] if model_values else None
return MODEL_LABELS.get(model_id, "MMSplice")
with gr.Blocks(title="Splice Variant Effect") as demo:
gr.Markdown(
"# Splice Variant Effect\n"
"Score paired reference and alternative DNA windows with MultiMolecule splice variant-effect models."
)
model = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value="MMSplice",
label="Checkpoint",
)
with gr.Row():
reference = gr.Textbox(
label="Reference DNA sequence",
value=DEFAULT_REFERENCE,
lines=5,
)
alternative = gr.Textbox(
label="Alternative DNA sequence",
value=DEFAULT_ALTERNATIVE,
lines=5,
)
input_file = gr.File(
label="Upload paired FASTA (reference record first, alternative record second)",
file_types=[".fa", ".fasta", ".fna"],
)
run = gr.Button("Run variant effect", variant="primary")
with gr.Row():
delta_scores = gr.Dataframe(label="Delta scores")
run_metadata = gr.JSON(label="Run metadata")
with gr.Row():
reference_scores = gr.Dataframe(label="Reference scores")
alternative_scores = gr.Dataframe(label="Alternative scores")
delta_plot = gr.Plot(label="Top delta scores")
with gr.Row():
csv_download = gr.File(label="Download CSV")
json_download = gr.File(label="Download JSON")
run.click(
predict,
inputs=[model, reference, alternative],
outputs=[
delta_scores,
reference_scores,
alternative_scores,
run_metadata,
delta_plot,
csv_download,
json_download,
],
)
input_file.change(load_fasta_pair, inputs=input_file, outputs=[reference, alternative])
demo.load(initial_model, outputs=model)
if __name__ == "__main__":
demo.launch()