# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule
# This file is part of MultiMolecule.
# MultiMolecule is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.
# MultiMolecule is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
# For additional terms and clarifications, please refer to our License FAQ at:
# .
from __future__ import annotations
import json
import tempfile
from functools import lru_cache
from pathlib import Path
from typing import Any
from urllib.parse import parse_qs, urlparse
import gradio as gr
import matplotlib
import pandas as pd
import torch
from transformers import pipeline
matplotlib.use("Agg")
import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers
import multimolecule.io as mmio # noqa: E402
from matplotlib import pyplot as plt # noqa: E402
MODEL_OPTIONS = {
"OpenSpliceAI": "multimolecule/openspliceai-mane-400nt",
"Pangolin": "multimolecule/pangolin",
"SpTransformer": "multimolecule/sptransformer",
"MaxEntScan": "multimolecule/maxentscan-score5",
}
MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()}
MODEL_LABELS["multimolecule/maxentscan-score3"] = "MaxEntScan"
MAXENTSCAN_MODELS = {
"donor": {
"model_id": "multimolecule/maxentscan-score5",
"window": 9,
"site_offset": 3,
},
"acceptor": {
"model_id": "multimolecule/maxentscan-score3",
"window": 23,
"site_offset": 18,
},
}
FASTA_SUFFIXES = {f".{suffix}" for suffix in mmio.FASTA}
VALID_BASES = set("ACGTN")
AMBIGUOUS_BASES = set("RYSWKMBDHV")
SPLICE_SITE_CHANNELS = {"acceptor", "donor", "splice_site"}
DEFAULT_SEQUENCE = (
"GCTGACCTGCTGCTGACCCAGGTGAGTCTGCACTCCTGGGCTCAGGTTTCTCTCTCTCTCTCTCTCTCTCTCTCCAG"
"GATGATGCTGATGAGGAGGAGGAGCTGACTGATGCTGAGGCTGACCTGA"
)
def _device() -> int:
return 0 if torch.cuda.is_available() else -1
@lru_cache(maxsize=6)
def load_predictor(model_id: str):
return pipeline("splice-site", model=model_id, device=_device())
def clean_sequence(sequence: str) -> str:
sequence = "".join(str(sequence or "").split()).upper().replace("U", "T")
if not sequence:
raise gr.Error("Sequence is empty.")
invalid = sorted(set(sequence) - VALID_BASES - AMBIGUOUS_BASES)
if invalid:
raise gr.Error(f"DNA sequence contains unsupported symbols: {', '.join(invalid)}.")
return "".join(base if base in VALID_BASES else "N" for base in sequence)
def load_input_file(input_file: Any):
if input_file is None:
return gr.update()
path = Path(getattr(input_file, "name", input_file))
try:
records = mmio.read_fasta_records(path)
except mmio.InvalidStructureFile as error:
raise gr.Error("Could not parse uploaded file as FASTA.") from error
if not records:
raise gr.Error("Could not parse uploaded file as FASTA.")
if len(records) > 1:
raise gr.Error(f"This demo supports one sequence at a time. Uploaded FASTA contains {len(records)} records.")
return clean_sequence(records[0].sequence)
def normalize_prediction_result(result: Any, sequence: str) -> dict[str, Any]:
if isinstance(result, list):
if len(result) != 1:
raise gr.Error(f"Expected one prediction result, got {len(result)}.")
result = result[0]
if not isinstance(result, dict):
raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.")
channels = [str(channel) for channel in result.get("channels", [])]
scores = _list_of_dicts(result.get("scores", []))
splice_sites = _list_of_dicts(result.get("splice_sites", []))
if "score" in result and not scores:
channels = channels or ["score"]
score = _safe_float(result["score"])
scores = [{"position": None, "nucleotide": None, channels[0]: score}]
splice_sites = [{"position": None, "nucleotide": None, "type": channels[0], "score": score}]
return {
"splice_sites": splice_sites,
"scores": scores,
"channels": channels,
"position_index_base": int(result.get("position_index_base", 0)),
"sequence": clean_sequence(str(result.get("sequence", sequence))),
}
def _list_of_dicts(value: Any) -> list[dict[str, Any]]:
if value is None:
return []
if not isinstance(value, list):
raise gr.Error(f"Expected a list output, got {type(value).__name__}.")
return [dict(item) for item in value if isinstance(item, dict)]
def _safe_float(value: Any) -> float:
if isinstance(value, (list, tuple)):
if len(value) != 1:
raise gr.Error("Expected a scalar score.")
value = value[0]
return float(value)
def predict_maxentscan(sequence: str, threshold: float, top_k: int) -> dict[str, Any]:
scores_by_position: list[dict[str, Any]] = [
{"position": position, "nucleotide": nucleotide, "acceptor": None, "donor": None}
for position, nucleotide in enumerate(sequence)
]
splice_sites: list[dict[str, Any]] = []
windows_scored = {"acceptor": 0, "donor": 0}
for site_type, config in MAXENTSCAN_MODELS.items():
model_id = config["model_id"]
predictor = load_predictor(model_id)
window = int(config["window"])
site_offset = int(config["site_offset"])
if len(sequence) < window:
continue
windows = [sequence[start : start + window] for start in range(len(sequence) - window + 1)]
windows_scored[site_type] = len(windows)
results = predictor(windows, output_scores=True)
if isinstance(results, dict):
results = [results]
for start, result in enumerate(results):
score = _safe_float(result.get("score"))
position = start + site_offset
scores_by_position[position][site_type] = score
if score >= threshold:
splice_sites.append(
{
"position": position,
"nucleotide": sequence[position],
"type": site_type,
"score": score,
}
)
splice_sites.sort(key=lambda item: float(item["score"]), reverse=True)
return {
"splice_sites": splice_sites[:top_k],
"scores": scores_by_position,
"channels": ["acceptor", "donor"],
"sequence": sequence,
"windows_scored": windows_scored,
}
def predict(
model_label: str,
sequence: str,
threshold: float,
top_k: int,
):
sequence = clean_sequence(sequence)
top_k = int(top_k)
model_id = MODEL_OPTIONS[model_label]
if model_label == "MaxEntScan":
normalized = predict_maxentscan(sequence, threshold, top_k)
model_ids: str | list[str] = [config["model_id"] for config in MAXENTSCAN_MODELS.values()]
else:
predictor = load_predictor(model_id)
result = predictor(sequence, threshold=threshold, output_scores=True, top_k=top_k)
normalized = normalize_prediction_result(result, sequence)
model_ids = model_id
top_sites = top_sites_dataframe(normalized, threshold=threshold, top_k=top_k)
scores = scores_dataframe(normalized)
figure = plot_score_track(normalized, threshold=threshold)
metadata = {
"model": model_ids,
"model_label": model_label,
"device": "cuda" if torch.cuda.is_available() else "cpu",
"length": len(normalized["sequence"]),
"position_index_base": normalized["position_index_base"],
"threshold": threshold,
"top_k": top_k,
"channels": normalized["channels"],
"num_splice_sites": len(normalized["splice_sites"]),
}
if "windows_scored" in normalized:
metadata["windows_scored"] = normalized["windows_scored"]
csv_path, json_path = write_result_files(normalized, metadata, scores)
return top_sites, scores, metadata, figure, csv_path, json_path
def top_sites_dataframe(normalized: dict[str, Any], *, threshold: float, top_k: int) -> pd.DataFrame:
sites = normalized["splice_sites"]
if not sites:
sites = rank_sites_from_scores(normalized, threshold=threshold, top_k=top_k)
rows = []
for site in sorted(sites, key=lambda item: float(item.get("score", 0.0)), reverse=True)[:top_k]:
position = site.get("position")
rows.append(
{
"position": position,
"nucleotide": site.get("nucleotide"),
"type": site.get("type"),
"score": _safe_float(site.get("score", 0.0)),
"above_threshold": _safe_float(site.get("score", 0.0)) >= threshold,
}
)
return pd.DataFrame(
rows,
columns=["position", "nucleotide", "type", "score", "above_threshold"],
)
def rank_sites_from_scores(normalized: dict[str, Any], *, threshold: float, top_k: int) -> list[dict[str, Any]]:
channels = site_channels(normalized["channels"])
rows = []
for score_row in normalized["scores"]:
for channel in channels:
value = score_row.get(channel)
if value is None:
continue
rows.append(
{
"position": score_row.get("position"),
"nucleotide": score_row.get("nucleotide"),
"type": channel,
"score": _safe_float(value),
"above_threshold": _safe_float(value) >= threshold,
}
)
rows.sort(key=lambda item: float(item["score"]), reverse=True)
return rows[:top_k]
def scores_dataframe(normalized: dict[str, Any]) -> pd.DataFrame:
rows = []
for score_row in normalized["scores"]:
position = score_row.get("position")
row = {
"position": position,
"nucleotide": score_row.get("nucleotide"),
}
for channel in normalized["channels"]:
row[channel] = score_row.get(channel)
rows.append(row)
return pd.DataFrame(rows, columns=["position", "nucleotide", *normalized["channels"]])
def site_channels(channels: list[str]) -> list[str]:
candidates = [
channel
for channel in channels
if channel in SPLICE_SITE_CHANNELS or channel.endswith("_splice_site") or channel in {"acceptor", "donor"}
]
if candidates:
return candidates
return [channel for channel in channels if channel != "no_splice"][:6]
def plot_score_track(normalized: dict[str, Any], *, threshold: float):
channels = site_channels(normalized["channels"])
fig, ax = plt.subplots(figsize=(10, 3.2))
if not normalized["scores"] or not channels:
ax.text(0.5, 0.5, "No per-position scores returned", ha="center", va="center", transform=ax.transAxes)
ax.set_axis_off()
return fig
x = [
row["position"] if isinstance(row.get("position"), int) else index
for index, row in enumerate(normalized["scores"])
]
plotted = 0
for channel in channels[:6]:
y = [row.get(channel) for row in normalized["scores"]]
if all(value is None for value in y):
continue
ax.plot(x, y, linewidth=1.4, label=channel)
plotted += 1
if plotted == 0:
ax.text(0.5, 0.5, "No plottable score channels", ha="center", va="center", transform=ax.transAxes)
else:
ax.axhline(threshold, color="0.3", linestyle="--", linewidth=0.9, label="threshold")
ax.legend(loc="upper right", ncols=min(plotted + 1, 3), fontsize=8)
ax.set_xlabel("Position (0-based)")
ax.set_ylabel("Score")
ax.set_ylim(bottom=0)
ax.margins(x=0.01)
fig.tight_layout()
return fig
def write_result_files(normalized: dict[str, Any], metadata: dict[str, Any], scores: pd.DataFrame):
csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="")
scores.to_csv(csv_file.name, index=False)
csv_file.close()
payload = {
**normalized,
"metadata": metadata,
}
json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False)
json.dump(payload, json_file, indent=2)
json_file.close()
return csv_file.name, json_file.name
def initial_model(request: gr.Request):
if request is None:
return "OpenSpliceAI"
query_params = getattr(request, "query_params", None)
model_id = None
if query_params is not None:
model_id = query_params.get("model")
if not model_id and getattr(request, "url", None):
parsed = parse_qs(urlparse(str(request.url)).query)
model_values = parsed.get("model")
model_id = model_values[0] if model_values else None
return MODEL_LABELS.get(model_id, "OpenSpliceAI")
with gr.Blocks(title="Splice Site") as demo:
gr.Markdown(
"# Splice Site\n"
"Run MultiMolecule splice-site checkpoints on a DNA sequence and inspect ranked site calls, "
"per-position scores, score tracks, and normalized JSON output."
)
with gr.Row():
model = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value="OpenSpliceAI",
label="Checkpoint",
)
threshold = gr.Slider(0.05, 0.95, value=0.5, step=0.05, label="Site threshold")
top_k = gr.Slider(1, 100, value=25, step=1, label="Top sites")
sequence = gr.Textbox(
label="DNA sequence",
value=DEFAULT_SEQUENCE,
lines=5,
)
input_file = gr.File(
label="Upload FASTA",
file_types=[".fa", ".fas", ".fasta", ".ffn", ".fna"],
)
run = gr.Button("Run prediction", variant="primary")
with gr.Row():
top_sites = gr.Dataframe(label="Top predicted splice sites (0-based positions)")
metadata = gr.JSON(label="Run metadata")
score_track = gr.Plot(label="Per-position score track")
scores = gr.Dataframe(label="Per-position scores (0-based positions)")
with gr.Row():
csv_download = gr.File(label="Download scores CSV")
json_download = gr.File(label="Download JSON")
run.click(
predict,
inputs=[model, sequence, threshold, top_k],
outputs=[top_sites, scores, metadata, score_track, csv_download, json_download],
)
input_file.change(load_input_file, inputs=input_file, outputs=sequence)
demo.load(initial_model, outputs=model)
if __name__ == "__main__":
demo.launch()