polyadenylation / app.py
ZhiyuanChen's picture
implement polyadenylation app
7b1dcd7 unverified
Raw
History Blame Contribute Delete
8.64 kB
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule
from __future__ import annotations
import csv
import json
import re
import tempfile
import time
from functools import lru_cache
from typing import Any, Mapping
from urllib.parse import parse_qs, urlparse
import gradio as gr
import matplotlib
import numpy as np
import torch
from transformers import pipeline
matplotlib.use("Agg")
import matplotlib.pyplot as plt # noqa: E402
import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers
MODEL_OPTIONS = {
"APARENT2": "multimolecule/aparent2",
"APARENT": "multimolecule/aparent",
}
MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()}
DEFAULT_MODEL_LABEL = "APARENT2"
DEFAULT_SEQUENCE = "A" * 70 + "AATAAA" + "A" * 129
DNA_ALPHABET = set("ACGTN")
TABLE_HEADERS = ["event", "position", "probability"]
def _device() -> int:
return 0 if torch.cuda.is_available() else -1
def _device_label() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
@lru_cache(maxsize=2)
def load_predictor(model_id: str):
return pipeline("polyadenylation", model=model_id, device=_device())
def clean_sequence(sequence: str) -> str:
lines = []
for line in str(sequence or "").splitlines():
line = line.strip()
if line and not line.startswith(">"):
lines.append(line)
sequence = re.sub(r"\s+", "", "".join(lines)).upper().replace("U", "T")
if not sequence:
raise gr.Error("Sequence is empty.")
invalid = sorted(set(sequence) - DNA_ALPHABET)
if invalid:
raise gr.Error(f"DNA sequence contains unsupported characters: {', '.join(invalid)}.")
return sequence
def unpack_prediction_result(result: Any) -> dict[str, Any]:
if isinstance(result, list):
if len(result) != 1:
raise gr.Error(f"Expected one prediction result, got {len(result)}.")
result = result[0]
if not isinstance(result, dict):
raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.")
return result
def rows_from_result(result: Mapping[str, Any]) -> list[dict[str, Any]]:
if isinstance(result.get("cleavage_distribution"), list):
return [_cleavage_row(row) for row in result["cleavage_distribution"]]
if "score" in result:
return [
{
"event": str(result.get("channel", "polyadenylation")),
"position": "",
"probability": number_value(result["score"]),
}
]
if isinstance(result.get("scores"), Mapping):
return [
{"event": str(channel), "position": "", "probability": number_value(score)}
for channel, score in result["scores"].items()
]
raise gr.Error("The selected model did not return polyadenylation scores.")
def _cleavage_row(row: Any) -> dict[str, Any]:
if not isinstance(row, Mapping):
raise gr.Error("Cleavage distribution rows must be dictionaries.")
if "event" in row:
return {"event": str(row["event"]), "position": "", "probability": number_value(row.get("probability"))}
return {
"event": "cleavage",
"position": row.get("position", ""),
"probability": number_value(row.get("probability")),
}
def number_value(value: Any) -> float:
try:
number = float(value)
except (TypeError, ValueError) as error:
raise gr.Error(f"Score value {value!r} is not numeric.") from error
if not np.isfinite(number):
raise gr.Error(f"Score value {value!r} is not finite.")
return number
def table_values(rows: list[Mapping[str, Any]]) -> list[list[Any]]:
return [[row.get(header, "") for header in TABLE_HEADERS] for row in rows]
def plot_polyadenylation(rows: list[Mapping[str, Any]]):
position_rows = [
(int(row["position"]), float(row["probability"]))
for row in rows
if row.get("position") not in ("", None) and _is_number(row.get("probability"))
]
no_cleavage = next((float(row["probability"]) for row in rows if row.get("event") == "no_cleavage"), None)
fig, ax = plt.subplots(figsize=(8.0, 3.2))
if position_rows:
position_rows.sort()
positions = [position for position, _ in position_rows]
probabilities = [probability for _, probability in position_rows]
ax.plot(positions, probabilities, color="#2f6f9f", linewidth=1.8)
ax.fill_between(positions, probabilities, color="#9dcbec", alpha=0.35)
ax.set_xlabel("Position")
ax.set_ylabel("Cleavage probability")
if no_cleavage is not None:
ax.text(
0.99,
0.95,
f"no cleavage: {no_cleavage:.3f}",
ha="right",
va="top",
transform=ax.transAxes,
)
else:
labels = [str(row.get("event", "score")) for row in rows]
values = [float(row.get("probability", 0.0)) for row in rows]
ax.barh(np.arange(len(values)), values, color="#2f6f9f")
ax.set_yticks(np.arange(len(values)), labels)
ax.invert_yaxis()
ax.set_xlabel("Score")
ax.grid(axis="y", alpha=0.2)
fig.tight_layout()
return fig
def _is_number(value: Any) -> bool:
return isinstance(value, int | float | np.number)
def write_result_files(
metadata: Mapping[str, Any],
result: Mapping[str, Any],
rows: list[Mapping[str, Any]],
) -> tuple[str, str]:
csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="")
writer = csv.DictWriter(csv_file, fieldnames=TABLE_HEADERS)
writer.writeheader()
writer.writerows({header: row.get(header, "") for header in TABLE_HEADERS} for row in rows)
csv_file.close()
json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False)
json.dump(
{
"metadata": dict(metadata),
"rows": [{header: row.get(header, "") for header in TABLE_HEADERS} for row in rows],
"raw_result": result,
},
json_file,
indent=2,
)
json_file.close()
return csv_file.name, json_file.name
def predict(model_label: str, sequence: str):
model_id = MODEL_OPTIONS[model_label]
sequence = clean_sequence(sequence)
started = time.perf_counter()
try:
result = load_predictor(model_id)(sequence)
except gr.Error:
raise
except Exception as error:
raise gr.Error(f"Prediction failed for {model_id}: {error}") from error
result = unpack_prediction_result(result)
rows = rows_from_result(result)
metadata = {
"task": "polyadenylation",
"model": model_id,
"model_label": model_label,
"device": _device_label(),
"sequence_length": len(sequence),
"row_count": len(rows),
"elapsed_seconds": round(time.perf_counter() - started, 3),
}
csv_path, json_path = write_result_files(metadata, result, rows)
return table_values(rows), metadata, plot_polyadenylation(rows), csv_path, json_path
def initial_model(request: gr.Request):
if request is None:
return DEFAULT_MODEL_LABEL
query_params = getattr(request, "query_params", None)
model_id = query_params.get("model") if query_params is not None else None
if not model_id and getattr(request, "url", None):
parsed = parse_qs(urlparse(str(request.url)).query)
model_values = parsed.get("model")
model_id = model_values[0] if model_values else None
return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL)
with gr.Blocks(title="Polyadenylation") as demo:
gr.Markdown(
"# Polyadenylation\n"
"Run MultiMolecule polyadenylation checkpoints and inspect APA isoform or cleavage-position scores."
)
model = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value=DEFAULT_MODEL_LABEL, label="Checkpoint")
sequence = gr.Textbox(label="DNA sequence", value=DEFAULT_SEQUENCE, lines=5)
run = gr.Button("Run prediction", variant="primary")
with gr.Row():
table = gr.Dataframe(headers=TABLE_HEADERS, label="Polyadenylation scores", interactive=False)
metadata = gr.JSON(label="Run metadata")
plot = gr.Plot(label="Polyadenylation plot")
with gr.Row():
csv_download = gr.File(label="Download CSV")
json_download = gr.File(label="Download JSON")
run.click(predict, inputs=[model, sequence], outputs=[table, metadata, plot, csv_download, json_download])
demo.load(initial_model, outputs=model)
if __name__ == "__main__":
demo.launch()