ZhiyuanChen's picture
implement regulatory-variant-effect app
915fbd6 unverified
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule
# This file is part of MultiMolecule.
# MultiMolecule is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.
# MultiMolecule is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# For additional terms and clarifications, please refer to our License FAQ at:
# <https://multimolecule.danling.org/about/license-faq>.
from __future__ import annotations
import csv
import json
import re
import tempfile
import time
from functools import lru_cache
from typing import Any, Mapping
from urllib.parse import parse_qs, urlparse
import gradio as gr
import matplotlib
import numpy as np
import torch
from transformers import pipeline
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers
DEFAULT_REFERENCE_SEQUENCE = "ACGT" * 250
DEFAULT_ALTERNATIVE_SEQUENCE = "ACGT" * 125 + "TCGA" + "ACGT" * 124
DEFAULT_MODEL_LABEL = "DeepSEA"
MODEL_OPTIONS = {
"A2Z Chromatin": "multimolecule/a2zchromatin",
"Basset": "multimolecule/basset",
"DeepMEL": "multimolecule/deepmel",
"DeepSEA": "multimolecule/deepsea",
"DeepSTARR": "multimolecule/deepstarr",
"Malinois": "multimolecule/malinois",
"MPRA-DragoNN": "multimolecule/mpradragonn",
"scBasset": "multimolecule/scbasset",
"Xpresso": "multimolecule/xpresso",
}
MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()}
TABLE_HEADERS = ["position", "nucleotide", "channel", "delta_score", "reference_score", "alternative_score"]
DNA_ALPHABET = set("ACGTN")
FLOAT_PATTERN = re.compile(r"[-+]?(?:(?:\d*\.\d+)|(?:\d+\.?))(?:[eE][-+]?\d+)?")
def _device() -> int:
return 0 if torch.cuda.is_available() else -1
@lru_cache(maxsize=2)
def load_predictor(model_id: str):
return pipeline("regulatory-variant-effect", model=model_id, device=_device())
def clean_sequence(sequence: str, label: str) -> str:
sequence = "".join(str(sequence or "").split()).upper().replace("U", "T")
if not sequence:
raise gr.Error(f"{label} sequence is empty.")
invalid = sorted(set(sequence) - DNA_ALPHABET)
if invalid:
invalid_text = ", ".join(invalid)
raise gr.Error(f"{label} sequence contains unsupported symbols: {invalid_text}. Use A, C, G, T, or N.")
return sequence
def parse_features(features_text: str) -> Any | None:
text = str(features_text or "").strip()
if not text:
return None
try:
parsed = json.loads(text)
except json.JSONDecodeError:
values = FLOAT_PATTERN.findall(text)
if not values:
raise gr.Error("Features must be JSON or comma/space-separated numbers.")
return [float(value) for value in values]
if isinstance(parsed, Mapping):
for key in ("features", "values", "reference_features", "alternative_features"):
if key in parsed:
return parsed[key]
if all(isinstance(value, int | float) for value in parsed.values()):
return list(parsed.values())
raise gr.Error("Feature JSON objects must contain a features/values list or only numeric values.")
if isinstance(parsed, str):
return parse_features(parsed)
return parsed
def feature_summary(features: Any | None) -> dict[str, Any]:
if features is None:
return {"provided": False}
try:
array = np.asarray(features, dtype=float)
except (TypeError, ValueError):
return {"provided": True, "shape": None}
return {"provided": True, "shape": list(array.shape)}
def unpack_prediction_result(result: Any) -> dict[str, Any]:
if isinstance(result, list):
if len(result) != 1:
raise gr.Error(f"Expected one prediction result, got {len(result)}.")
result = result[0]
if not isinstance(result, dict):
raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.")
return result
def build_delta_rows(result: Mapping[str, Any]) -> list[dict[str, Any]]:
if "delta_score" in result:
return [
{
"position": "",
"nucleotide": "",
"channel": "score",
"delta_score": result.get("delta_score"),
"reference_score": result.get("reference_score", ""),
"alternative_score": result.get("alternative_score", ""),
}
]
delta_scores = result.get("delta_scores")
if isinstance(delta_scores, Mapping):
reference_scores = result.get("reference_scores") if isinstance(result.get("reference_scores"), Mapping) else {}
alternative_scores = (
result.get("alternative_scores") if isinstance(result.get("alternative_scores"), Mapping) else {}
)
return [
{
"position": "",
"nucleotide": "",
"channel": str(channel),
"delta_score": value,
"reference_score": reference_scores.get(channel, ""),
"alternative_score": alternative_scores.get(channel, ""),
}
for channel, value in delta_scores.items()
]
if isinstance(delta_scores, list):
return build_axis_delta_rows(result, delta_scores)
raise gr.Error("The selected model did not return delta scores.")
def build_axis_delta_rows(result: Mapping[str, Any], delta_scores: list[Any]) -> list[dict[str, Any]]:
channels = [str(channel) for channel in result.get("channels", [])]
reference_scores = _index_axis_rows(result.get("reference_scores"))
alternative_scores = _index_axis_rows(result.get("alternative_scores"))
output_rows: list[dict[str, Any]] = []
for row_index, row in enumerate(delta_scores):
if not isinstance(row, Mapping):
continue
position = row.get("position", row.get("bin", row_index))
channel_names = channels or [
str(key) for key in row if key not in {"position", "bin", "nucleotide"} and _is_number(row[key])
]
ref_row = reference_scores.get(position, {})
alt_row = alternative_scores.get(position, {})
for channel in channel_names:
if channel not in row:
continue
output_rows.append(
{
"position": position,
"nucleotide": row.get("nucleotide", ""),
"channel": channel,
"delta_score": row[channel],
"reference_score": ref_row.get(channel, ""),
"alternative_score": alt_row.get(channel, ""),
}
)
return output_rows
def _index_axis_rows(rows: Any) -> dict[Any, Mapping[str, Any]]:
if not isinstance(rows, list):
return {}
indexed = {}
for row_index, row in enumerate(rows):
if isinstance(row, Mapping):
indexed[row.get("position", row.get("bin", row_index))] = row
return indexed
def _is_number(value: Any) -> bool:
return isinstance(value, int | float | np.number)
def table_values(rows: list[Mapping[str, Any]]) -> list[list[Any]]:
return [[row.get(header, "") for header in TABLE_HEADERS] for row in rows]
def plot_delta_rows(rows: list[Mapping[str, Any]], max_bars: int = 24):
numeric_rows = [row for row in rows if _is_number(row.get("delta_score"))]
fig, ax = plt.subplots(figsize=(7.0, 2.4))
if not numeric_rows:
ax.text(0.5, 0.5, "No numeric delta scores", ha="center", va="center", transform=ax.transAxes)
ax.set_axis_off()
fig.tight_layout()
return fig
top_rows = sorted(numeric_rows, key=lambda row: abs(float(row["delta_score"])), reverse=True)[:max_bars]
labels = [_row_label(row) for row in top_rows]
values = [float(row["delta_score"]) for row in top_rows]
colors = ["#1b9e77" if value >= 0 else "#d95f02" for value in values]
height = min(7.0, max(2.4, 0.28 * len(top_rows) + 1.2))
fig.set_size_inches(7.0, height, forward=True)
ax.barh(range(len(top_rows)), values, color=colors)
ax.axvline(0, color="#333333", linewidth=0.8)
ax.set_yticks(range(len(top_rows)), labels)
ax.invert_yaxis()
ax.set_xlabel("Alternative - reference")
ax.set_title("Largest absolute delta scores")
ax.tick_params(axis="y", labelsize=8)
fig.tight_layout()
return fig
def _row_label(row: Mapping[str, Any]) -> str:
channel = str(row.get("channel", "score"))
position = row.get("position")
if position not in ("", None):
nucleotide = row.get("nucleotide")
suffix = f" {nucleotide}" if nucleotide not in ("", None) else ""
return f"{position}{suffix} {channel}"
return channel
def write_result_files(
model_id: str,
result: Mapping[str, Any],
rows: list[Mapping[str, Any]],
metadata: Mapping[str, Any],
) -> tuple[str, str]:
csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", newline="", delete=False)
writer = csv.DictWriter(csv_file, fieldnames=TABLE_HEADERS)
writer.writeheader()
writer.writerows({header: row.get(header, "") for header in TABLE_HEADERS} for row in rows)
csv_file.close()
json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False)
json.dump(
{
"metadata": dict(metadata),
"model": model_id,
"result": result,
"delta_table": [{header: row.get(header, "") for header in TABLE_HEADERS} for row in rows],
},
json_file,
indent=2,
default=_json_default,
)
json_file.close()
return csv_file.name, json_file.name
def _json_default(value: Any):
if isinstance(value, np.generic):
return value.item()
if isinstance(value, np.ndarray):
return value.tolist()
raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
def predict(
model_label: str,
reference_sequence: str,
alternative_sequence: str,
reference_features_text: str,
alternative_features_text: str,
):
model_id = MODEL_OPTIONS[model_label]
reference_sequence = clean_sequence(reference_sequence, "Reference")
alternative_sequence = clean_sequence(alternative_sequence, "Alternative")
if len(reference_sequence) != len(alternative_sequence):
raise gr.Error(
f"Reference and alternative sequences must have the same length. "
f"Got {len(reference_sequence)} and {len(alternative_sequence)}."
)
reference_features = parse_features(reference_features_text)
alternative_features = parse_features(alternative_features_text)
started = time.perf_counter()
predictor = load_predictor(model_id)
try:
result = predictor(
reference_sequence,
alternative=alternative_sequence,
features=reference_features,
alternative_features=alternative_features,
)
except Exception as error:
raise gr.Error(f"Prediction failed for {model_id}: {error}") from error
result = unpack_prediction_result(result)
rows = build_delta_rows(result)
if not rows:
raise gr.Error("The selected model returned no tabular delta scores.")
metadata = {
"task": "regulatory-variant-effect",
"model": model_id,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"reference_length": len(reference_sequence),
"alternative_length": len(alternative_sequence),
"reference_features": feature_summary(reference_features),
"alternative_features": feature_summary(alternative_features),
"alternative_features_inherit_reference": alternative_features is None and reference_features is not None,
"score_definition": "alternative_minus_reference",
"num_delta_rows": len(rows),
"has_reference_scores": any(row.get("reference_score") not in ("", None) for row in rows),
"has_alternative_scores": any(row.get("alternative_score") not in ("", None) for row in rows),
"elapsed_seconds": round(time.perf_counter() - started, 3),
}
csv_path, json_path = write_result_files(model_id, result, rows, metadata)
return (
table_values(rows),
metadata,
plot_delta_rows(rows),
csv_path,
json_path,
)
def initial_model(request: gr.Request):
if request is None:
return DEFAULT_MODEL_LABEL
query_params = getattr(request, "query_params", None)
model_id = None
if query_params is not None:
model_id = query_params.get("model")
if not model_id and getattr(request, "url", None):
parsed = parse_qs(urlparse(str(request.url)).query)
model_values = parsed.get("model")
model_id = model_values[0] if model_values else None
return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL)
with gr.Blocks(title="Regulatory Variant Effect") as demo:
gr.Markdown(
"# Regulatory Variant Effect\n"
"Score matched reference and alternative DNA windows with MultiMolecule regulatory variant-effect models."
)
model = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value=DEFAULT_MODEL_LABEL,
label="Checkpoint",
)
with gr.Row():
reference_sequence = gr.Textbox(label="Reference DNA sequence", value=DEFAULT_REFERENCE_SEQUENCE, lines=5)
alternative_sequence = gr.Textbox(label="Alternative DNA sequence", value=DEFAULT_ALTERNATIVE_SEQUENCE, lines=5)
with gr.Accordion("Optional numeric features", open=False), gr.Row():
reference_features = gr.Textbox(
label="Reference features JSON/text",
placeholder='[0.1, 0.2, 0.3] or {"features": [0.1, 0.2, 0.3]}',
lines=3,
)
alternative_features = gr.Textbox(
label="Alternative features JSON/text",
placeholder="Leave blank to reuse reference features when provided.",
lines=3,
)
run = gr.Button("Run prediction", variant="primary")
delta_table = gr.Dataframe(headers=TABLE_HEADERS, label="Delta scores", interactive=False, wrap=True)
with gr.Row():
metadata = gr.JSON(label="Run metadata")
delta_plot = gr.Plot(label="Delta plot")
with gr.Row():
csv_download = gr.File(label="Download CSV")
json_download = gr.File(label="Download JSON")
run.click(
predict,
inputs=[model, reference_sequence, alternative_sequence, reference_features, alternative_features],
outputs=[delta_table, metadata, delta_plot, csv_download, json_download],
)
demo.load(initial_model, outputs=model)
if __name__ == "__main__":
demo.launch()