ZhiyuanChen's picture
implement regulatory-activity app
4065188 unverified
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule
# This file is part of MultiMolecule.
# MultiMolecule is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.
# MultiMolecule is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# For additional terms and clarifications, please refer to our License FAQ at:
# <https://multimolecule.danling.org/about/license-faq>.
from __future__ import annotations
import csv
import json
import re
import tempfile
from datetime import datetime, timezone
from functools import lru_cache
from typing import Any, Mapping
from urllib.parse import parse_qs, urlparse
import gradio as gr
import matplotlib
import numpy as np
import torch
from transformers import pipeline
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers
MODEL_OPTIONS = {
"A2Z Chromatin": "multimolecule/a2zchromatin",
"Basset": "multimolecule/basset",
"DeepMEL": "multimolecule/deepmel",
"DeepSEA": "multimolecule/deepsea",
"DeepSTARR": "multimolecule/deepstarr",
"Malinois": "multimolecule/malinois",
"MPRA-DragoNN": "multimolecule/mpradragonn",
"scBasset": "multimolecule/scbasset",
"Xpresso": "multimolecule/xpresso",
}
MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()}
DEFAULT_MODEL_LABEL = "DeepSEA"
DEFAULT_SEQUENCE = "ACGT" * 150
DNA_ALPHABET = set("ACGTN")
def _device() -> int:
return 0 if torch.cuda.is_available() else -1
def _device_label() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
@lru_cache(maxsize=2)
def load_predictor(model_id: str):
return pipeline("regulatory-activity", model=model_id, device=_device())
def clean_sequence(sequence: str) -> str:
lines = []
for line in (sequence or "").splitlines():
line = line.strip()
if line and not line.startswith(">"):
lines.append(line)
sequence = re.sub(r"\s+", "", "".join(lines)).upper().replace("U", "T")
if not sequence:
raise gr.Error("Sequence is empty.")
invalid = sorted(set(sequence) - DNA_ALPHABET)
if invalid:
raise gr.Error(f"DNA sequence contains unsupported characters: {', '.join(invalid)}.")
return sequence
def parse_features(features_text: str | None) -> tuple[Any | None, list[int] | None]:
text = (features_text or "").strip()
if not text:
return None, None
try:
features = json.loads(text)
except json.JSONDecodeError:
tokens = [token for token in re.split(r"[\s,;]+", text) if token]
try:
features = [float(token) for token in tokens]
except ValueError as error:
raise gr.Error(
"Auxiliary features must be a JSON numeric value/list or comma-separated numbers."
) from error
else:
if isinstance(features, Mapping):
if "features" not in features:
raise gr.Error('JSON object features must use a "features" key, for example {"features": [0, 0]}.')
features = features["features"]
try:
array = np.asarray(features, dtype=np.float32)
except (TypeError, ValueError) as error:
raise gr.Error("Auxiliary features must contain only numeric values.") from error
if array.size == 0:
raise gr.Error("Auxiliary features are empty.")
if array.ndim > 2:
raise gr.Error("Auxiliary features must be a number, a 1-D list, or a 2-D batch-sized list.")
if not np.isfinite(array).all():
raise gr.Error("Auxiliary features must be finite numbers.")
return array.tolist(), list(array.shape)
def unpack_prediction_result(result: Any) -> dict[str, Any]:
if isinstance(result, list):
if len(result) != 1:
raise gr.Error(f"Expected one prediction result, got {len(result)}.")
result = result[0]
if not isinstance(result, dict):
raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.")
return result
def score_rows_from_result(result: Mapping[str, Any]) -> list[list[Any]]:
channels = [str(channel) for channel in result.get("channels", [])]
if "score" in result:
return rows_from_values(result["score"], channels or ["score"])
if "scores" in result:
scores = result["scores"]
if isinstance(scores, Mapping):
return [[str(channel), number_value(score)] for channel, score in scores.items()]
if isinstance(scores, list):
return rows_from_score_list(scores, channels)
raise gr.Error("The selected model did not return sequence-level score output.")
def rows_from_values(values: Any, channels: list[str]) -> list[list[Any]]:
if isinstance(values, (list, tuple)):
if len(channels) != len(values):
channels = [f"score_{index}" for index in range(len(values))]
return [[channel, number_value(value)] for channel, value in zip(channels, values)]
return [[channels[0] if channels else "score", number_value(values)]]
def rows_from_score_list(scores: list[Any], channels: list[str]) -> list[list[Any]]:
if scores and all(isinstance(score, (int, float)) for score in scores):
return rows_from_values(scores, channels)
rows: list[list[Any]] = []
for index, item in enumerate(scores):
if not isinstance(item, Mapping):
rows.append([f"score_{index}", number_value(item)])
continue
prefix_parts = []
for key in ("position", "bin", "nucleotide"):
if key in item:
prefix_parts.append(f"{key}={item[key]}")
prefix = " ".join(prefix_parts)
for key, value in item.items():
if key in {"position", "bin", "nucleotide"}:
continue
label = str(key) if not prefix else f"{prefix} {key}"
rows.append([label, number_value(value)])
return rows
def number_value(value: Any) -> float:
try:
number = float(value)
except (TypeError, ValueError) as error:
raise gr.Error(f"Score value {value!r} is not numeric.") from error
if not np.isfinite(number):
raise gr.Error(f"Score value {value!r} is not finite.")
return number
def plot_scores(rows: list[list[Any]], top_n: int | float) -> Any:
top_n = max(1, int(top_n or 20))
values = [(str(channel), float(score)) for channel, score in rows]
values = sorted(values, key=lambda item: abs(item[1]), reverse=True)[:top_n]
height = max(3.0, min(12.0, 1.2 + 0.36 * len(values)))
fig, ax = plt.subplots(figsize=(9, height))
if not values:
ax.set_axis_off()
return fig
labels = [label if len(label) <= 54 else f"{label[:51]}..." for label, _ in values]
scores = [score for _, score in values]
y_positions = np.arange(len(values))
colors = ["#2f6f9f" if score >= 0 else "#c75146" for score in scores]
ax.barh(y_positions, scores, color=colors)
ax.set_yticks(y_positions, labels)
ax.invert_yaxis()
ax.axvline(0, color="#555555", linewidth=0.8)
ax.set_xlabel("Score")
ax.set_title(f"Top {len(values)} channels by absolute score")
ax.grid(axis="x", alpha=0.2)
fig.tight_layout()
return fig
def write_result_files(
metadata: Mapping[str, Any],
result: Mapping[str, Any],
rows: list[list[Any]],
) -> tuple[str, str]:
csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="")
writer = csv.writer(csv_file)
writer.writerow(["channel", "score"])
writer.writerows(rows)
csv_file.close()
json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False)
json.dump(
{
"metadata": metadata,
"scores": [{"channel": channel, "score": score} for channel, score in rows],
"raw_result": result,
},
json_file,
indent=2,
)
json_file.close()
return csv_file.name, json_file.name
def predict(
model_label: str,
sequence: str,
features_text: str,
top_n: int | float,
):
model_id = MODEL_OPTIONS[model_label]
sequence = clean_sequence(sequence)
features, features_shape = parse_features(features_text)
try:
predictor = load_predictor(model_id)
if features is None:
result = predictor(sequence)
else:
result = predictor(sequence, features=features)
except gr.Error:
raise
except Exception as error:
raise gr.Error(str(error)) from error
result = unpack_prediction_result(result)
rows = score_rows_from_result(result)
metadata = {
"task": "regulatory-activity",
"model": model_id,
"model_label": model_label,
"device": _device_label(),
"sequence_length": len(sequence),
"features_provided": features is not None,
"features_shape": features_shape,
"score_count": len(rows),
"channels": result.get("channels", []),
"created_at": datetime.now(timezone.utc).isoformat(),
}
figure = plot_scores(rows, top_n)
csv_path, json_path = write_result_files(metadata, result, rows)
return rows, metadata, figure, csv_path, json_path
def initial_model(request: gr.Request):
if request is None:
return DEFAULT_MODEL_LABEL
query_params = getattr(request, "query_params", None)
model_id = None
if query_params is not None:
model_id = query_params.get("model")
if not model_id and getattr(request, "url", None):
parsed = parse_qs(urlparse(str(request.url)).query)
model_values = parsed.get("model")
model_id = model_values[0] if model_values else None
if model_id in MODEL_OPTIONS:
return model_id
return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL)
with gr.Blocks(title="Regulatory Activity") as demo:
gr.Markdown(
"# Regulatory Activity\n"
"Run MultiMolecule sequence-level DNA regulatory checkpoints and inspect the returned activity scores."
)
with gr.Row():
model = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value=DEFAULT_MODEL_LABEL,
label="Checkpoint",
)
top_n = gr.Slider(1, 50, value=20, step=1, label="Bar count")
sequence = gr.Textbox(
label="DNA sequence",
value=DEFAULT_SEQUENCE,
lines=5,
)
features = gr.Textbox(
label="Auxiliary numeric features (optional)",
placeholder='JSON list, {"features": [...]}, or comma-separated numbers',
lines=2,
)
run = gr.Button("Run prediction", variant="primary")
with gr.Row():
scores = gr.Dataframe(
headers=["channel", "score"],
datatype=["str", "number"],
label="Score table",
)
metadata = gr.JSON(label="Run metadata")
score_plot = gr.Plot(label="Score bar plot")
with gr.Row():
csv_download = gr.File(label="Download CSV")
json_download = gr.File(label="Download JSON")
run.click(
predict,
inputs=[model, sequence, features, top_n],
outputs=[scores, metadata, score_plot, csv_download, json_download],
)
demo.load(initial_model, outputs=model)
if __name__ == "__main__":
demo.launch()