methylation / app.py
ZhiyuanChen's picture
implement methylation app
2c6f65c unverified
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule
from __future__ import annotations
import csv
import json
import re
import tempfile
import time
from functools import lru_cache
from typing import Any, Mapping
from urllib.parse import parse_qs, urlparse
import gradio as gr
import matplotlib
import numpy as np
import torch
from transformers import pipeline
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers
MODEL_OPTIONS = {
"DeepCpG-DNA Smallwood 2014 serum mESC": "multimolecule/deepcpgdna-smallwood2014-serum",
"DeepCpG-DNA Smallwood 2014 2i mESC": "multimolecule/deepcpgdna-smallwood2014-2i",
"DeepCpG-DNA Hou 2016 HCC": "multimolecule/deepcpgdna-hou2016-hcc",
"DeepCpG-DNA Hou 2016 HepG2": "multimolecule/deepcpgdna-hou2016-hepg2",
"DeepCpG-DNA Hou 2016 mESC": "multimolecule/deepcpgdna-hou2016-mesc",
}
MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()}
DEFAULT_MODEL_LABEL = "DeepCpG-DNA Smallwood 2014 serum mESC"
DEFAULT_SEQUENCE = ("ACGT" * 125)[:499] + "CG" + ("TGCA" * 125)[:500]
DNA_ALPHABET = set("ACGTN")
def _device() -> int:
return 0 if torch.cuda.is_available() else -1
def _device_label() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
@lru_cache(maxsize=2)
def load_predictor(model_id: str):
return pipeline("methylation", model=model_id, device=_device())
def clean_sequence(sequence: str) -> str:
lines = []
for line in str(sequence or "").splitlines():
line = line.strip()
if line and not line.startswith(">"):
lines.append(line)
sequence = re.sub(r"\s+", "", "".join(lines)).upper().replace("U", "T")
if not sequence:
raise gr.Error("Sequence is empty.")
invalid = sorted(set(sequence) - DNA_ALPHABET)
if invalid:
raise gr.Error(f"DNA sequence contains unsupported characters: {', '.join(invalid)}.")
return sequence
def unpack_prediction_result(result: Any) -> dict[str, Any]:
if isinstance(result, list):
if len(result) != 1:
raise gr.Error(f"Expected one prediction result, got {len(result)}.")
result = result[0]
if not isinstance(result, dict):
raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.")
return result
def score_rows_from_result(result: Mapping[str, Any]) -> list[list[Any]]:
channels = [str(channel) for channel in result.get("channels", [])]
if "score" in result:
return rows_from_values(result["score"], channels or ["methylation"])
if "scores" in result:
scores = result["scores"]
if isinstance(scores, Mapping):
return [[str(channel), number_value(score)] for channel, score in scores.items()]
if isinstance(scores, list):
return rows_from_values(scores, channels)
raise gr.Error("The selected model did not return methylation scores.")
def rows_from_values(values: Any, channels: list[str]) -> list[list[Any]]:
if isinstance(values, (list, tuple)):
if len(channels) != len(values):
channels = [f"methylation_{index}" for index in range(len(values))]
return [[channel, number_value(value)] for channel, value in zip(channels, values)]
return [[channels[0] if channels else "methylation", number_value(values)]]
def number_value(value: Any) -> float:
try:
number = float(value)
except (TypeError, ValueError) as error:
raise gr.Error(f"Score value {value!r} is not numeric.") from error
if not np.isfinite(number):
raise gr.Error(f"Score value {value!r} is not finite.")
return number
def plot_scores(rows: list[list[Any]], top_n: int | float):
top_n = max(1, int(top_n or 25))
values = [(str(channel), float(score)) for channel, score in rows]
values = sorted(values, key=lambda item: item[1], reverse=True)[:top_n]
height = max(3.0, min(12.0, 1.2 + 0.34 * len(values)))
fig, ax = plt.subplots(figsize=(8.0, height))
if not values:
ax.set_axis_off()
return fig
labels = [label if len(label) <= 58 else f"{label[:55]}..." for label, _ in values]
scores = [score for _, score in values]
y_positions = np.arange(len(values))
ax.barh(y_positions, scores, color="#2f6f9f")
ax.set_yticks(y_positions, labels)
ax.invert_yaxis()
if all(0.0 <= score <= 1.0 for score in scores):
ax.set_xlim(0.0, 1.0)
ax.set_xlabel("Methylation score")
ax.grid(axis="x", alpha=0.2)
fig.tight_layout()
return fig
def write_result_files(
metadata: Mapping[str, Any], result: Mapping[str, Any], rows: list[list[Any]]
) -> tuple[str, str]:
csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="")
writer = csv.writer(csv_file)
writer.writerow(["channel", "score"])
writer.writerows(rows)
csv_file.close()
json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False)
json.dump(
{
"metadata": dict(metadata),
"scores": [{"channel": channel, "score": score} for channel, score in rows],
"raw_result": result,
},
json_file,
indent=2,
)
json_file.close()
return csv_file.name, json_file.name
def predict(model_label: str, sequence: str, top_n: int | float):
model_id = MODEL_OPTIONS[model_label]
sequence = clean_sequence(sequence)
started = time.perf_counter()
try:
result = load_predictor(model_id)(sequence)
except gr.Error:
raise
except Exception as error:
raise gr.Error(f"Prediction failed for {model_id}: {error}") from error
result = unpack_prediction_result(result)
rows = score_rows_from_result(result)
metadata = {
"task": "methylation",
"model": model_id,
"model_label": model_label,
"device": _device_label(),
"sequence_length": len(sequence),
"score_count": len(rows),
"channels": result.get("channels", []),
"elapsed_seconds": round(time.perf_counter() - started, 3),
}
csv_path, json_path = write_result_files(metadata, result, rows)
return rows, metadata, plot_scores(rows, top_n), csv_path, json_path
def initial_model(request: gr.Request):
if request is None:
return DEFAULT_MODEL_LABEL
query_params = getattr(request, "query_params", None)
model_id = query_params.get("model") if query_params is not None else None
if not model_id and getattr(request, "url", None):
parsed = parse_qs(urlparse(str(request.url)).query)
model_values = parsed.get("model")
model_id = model_values[0] if model_values else None
return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL)
with gr.Blocks(title="Methylation") as demo:
gr.Markdown(
"# Methylation\n" "Run MultiMolecule DNA methylation checkpoints and inspect per-cell methylation scores."
)
with gr.Row():
model = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value=DEFAULT_MODEL_LABEL, label="Checkpoint")
top_n = gr.Slider(1, 50, value=25, step=1, label="Bar count")
sequence = gr.Textbox(label="DNA sequence", value=DEFAULT_SEQUENCE, lines=7)
run = gr.Button("Run prediction", variant="primary")
with gr.Row():
scores = gr.Dataframe(headers=["channel", "score"], datatype=["str", "number"], label="Score table")
metadata = gr.JSON(label="Run metadata")
score_plot = gr.Plot(label="Score bar plot")
with gr.Row():
csv_download = gr.File(label="Download CSV")
json_download = gr.File(label="Download JSON")
run.click(
predict, inputs=[model, sequence, top_n], outputs=[scores, metadata, score_plot, csv_download, json_download]
)
demo.load(initial_model, outputs=model)
if __name__ == "__main__":
demo.launch()