# MultiMolecule # Copyright (C) 2024-Present MultiMolecule # This file is part of MultiMolecule. # MultiMolecule is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # any later version. # MultiMolecule is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # For additional terms and clarifications, please refer to our License FAQ at: # . from __future__ import annotations import json import tempfile from functools import lru_cache from pathlib import Path from typing import Any from urllib.parse import parse_qs, urlparse import gradio as gr import matplotlib import pandas as pd import torch from transformers import pipeline matplotlib.use("Agg") import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers import multimolecule.io as mmio # noqa: E402 from matplotlib import pyplot as plt # noqa: E402 MODEL_OPTIONS = { "OpenSpliceAI": "multimolecule/openspliceai-mane-400nt", "Pangolin": "multimolecule/pangolin", "SpTransformer": "multimolecule/sptransformer", "MaxEntScan": "multimolecule/maxentscan-score5", } MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()} MODEL_LABELS["multimolecule/maxentscan-score3"] = "MaxEntScan" MAXENTSCAN_MODELS = { "donor": { "model_id": "multimolecule/maxentscan-score5", "window": 9, "site_offset": 3, }, "acceptor": { "model_id": "multimolecule/maxentscan-score3", "window": 23, "site_offset": 18, }, } FASTA_SUFFIXES = {f".{suffix}" for suffix in mmio.FASTA} VALID_BASES = set("ACGTN") AMBIGUOUS_BASES = set("RYSWKMBDHV") SPLICE_SITE_CHANNELS = {"acceptor", "donor", "splice_site"} DEFAULT_SEQUENCE = ( "GCTGACCTGCTGCTGACCCAGGTGAGTCTGCACTCCTGGGCTCAGGTTTCTCTCTCTCTCTCTCTCTCTCTCTCCAG" "GATGATGCTGATGAGGAGGAGGAGCTGACTGATGCTGAGGCTGACCTGA" ) def _device() -> int: return 0 if torch.cuda.is_available() else -1 @lru_cache(maxsize=6) def load_predictor(model_id: str): return pipeline("splice-site", model=model_id, device=_device()) def clean_sequence(sequence: str) -> str: sequence = "".join(str(sequence or "").split()).upper().replace("U", "T") if not sequence: raise gr.Error("Sequence is empty.") invalid = sorted(set(sequence) - VALID_BASES - AMBIGUOUS_BASES) if invalid: raise gr.Error(f"DNA sequence contains unsupported symbols: {', '.join(invalid)}.") return "".join(base if base in VALID_BASES else "N" for base in sequence) def load_input_file(input_file: Any): if input_file is None: return gr.update() path = Path(getattr(input_file, "name", input_file)) try: records = mmio.read_fasta_records(path) except mmio.InvalidStructureFile as error: raise gr.Error("Could not parse uploaded file as FASTA.") from error if not records: raise gr.Error("Could not parse uploaded file as FASTA.") if len(records) > 1: raise gr.Error(f"This demo supports one sequence at a time. Uploaded FASTA contains {len(records)} records.") return clean_sequence(records[0].sequence) def normalize_prediction_result(result: Any, sequence: str) -> dict[str, Any]: if isinstance(result, list): if len(result) != 1: raise gr.Error(f"Expected one prediction result, got {len(result)}.") result = result[0] if not isinstance(result, dict): raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.") channels = [str(channel) for channel in result.get("channels", [])] scores = _list_of_dicts(result.get("scores", [])) splice_sites = _list_of_dicts(result.get("splice_sites", [])) if "score" in result and not scores: channels = channels or ["score"] score = _safe_float(result["score"]) scores = [{"position": None, "nucleotide": None, channels[0]: score}] splice_sites = [{"position": None, "nucleotide": None, "type": channels[0], "score": score}] return { "splice_sites": splice_sites, "scores": scores, "channels": channels, "position_index_base": int(result.get("position_index_base", 0)), "sequence": clean_sequence(str(result.get("sequence", sequence))), } def _list_of_dicts(value: Any) -> list[dict[str, Any]]: if value is None: return [] if not isinstance(value, list): raise gr.Error(f"Expected a list output, got {type(value).__name__}.") return [dict(item) for item in value if isinstance(item, dict)] def _safe_float(value: Any) -> float: if isinstance(value, (list, tuple)): if len(value) != 1: raise gr.Error("Expected a scalar score.") value = value[0] return float(value) def predict_maxentscan(sequence: str, threshold: float, top_k: int) -> dict[str, Any]: scores_by_position: list[dict[str, Any]] = [ {"position": position, "nucleotide": nucleotide, "acceptor": None, "donor": None} for position, nucleotide in enumerate(sequence) ] splice_sites: list[dict[str, Any]] = [] windows_scored = {"acceptor": 0, "donor": 0} for site_type, config in MAXENTSCAN_MODELS.items(): model_id = config["model_id"] predictor = load_predictor(model_id) window = int(config["window"]) site_offset = int(config["site_offset"]) if len(sequence) < window: continue windows = [sequence[start : start + window] for start in range(len(sequence) - window + 1)] windows_scored[site_type] = len(windows) results = predictor(windows, output_scores=True) if isinstance(results, dict): results = [results] for start, result in enumerate(results): score = _safe_float(result.get("score")) position = start + site_offset scores_by_position[position][site_type] = score if score >= threshold: splice_sites.append( { "position": position, "nucleotide": sequence[position], "type": site_type, "score": score, } ) splice_sites.sort(key=lambda item: float(item["score"]), reverse=True) return { "splice_sites": splice_sites[:top_k], "scores": scores_by_position, "channels": ["acceptor", "donor"], "sequence": sequence, "windows_scored": windows_scored, } def predict( model_label: str, sequence: str, threshold: float, top_k: int, ): sequence = clean_sequence(sequence) top_k = int(top_k) model_id = MODEL_OPTIONS[model_label] if model_label == "MaxEntScan": normalized = predict_maxentscan(sequence, threshold, top_k) model_ids: str | list[str] = [config["model_id"] for config in MAXENTSCAN_MODELS.values()] else: predictor = load_predictor(model_id) result = predictor(sequence, threshold=threshold, output_scores=True, top_k=top_k) normalized = normalize_prediction_result(result, sequence) model_ids = model_id top_sites = top_sites_dataframe(normalized, threshold=threshold, top_k=top_k) scores = scores_dataframe(normalized) figure = plot_score_track(normalized, threshold=threshold) metadata = { "model": model_ids, "model_label": model_label, "device": "cuda" if torch.cuda.is_available() else "cpu", "length": len(normalized["sequence"]), "position_index_base": normalized["position_index_base"], "threshold": threshold, "top_k": top_k, "channels": normalized["channels"], "num_splice_sites": len(normalized["splice_sites"]), } if "windows_scored" in normalized: metadata["windows_scored"] = normalized["windows_scored"] csv_path, json_path = write_result_files(normalized, metadata, scores) return top_sites, scores, metadata, figure, csv_path, json_path def top_sites_dataframe(normalized: dict[str, Any], *, threshold: float, top_k: int) -> pd.DataFrame: sites = normalized["splice_sites"] if not sites: sites = rank_sites_from_scores(normalized, threshold=threshold, top_k=top_k) rows = [] for site in sorted(sites, key=lambda item: float(item.get("score", 0.0)), reverse=True)[:top_k]: position = site.get("position") rows.append( { "position": position, "nucleotide": site.get("nucleotide"), "type": site.get("type"), "score": _safe_float(site.get("score", 0.0)), "above_threshold": _safe_float(site.get("score", 0.0)) >= threshold, } ) return pd.DataFrame( rows, columns=["position", "nucleotide", "type", "score", "above_threshold"], ) def rank_sites_from_scores(normalized: dict[str, Any], *, threshold: float, top_k: int) -> list[dict[str, Any]]: channels = site_channels(normalized["channels"]) rows = [] for score_row in normalized["scores"]: for channel in channels: value = score_row.get(channel) if value is None: continue rows.append( { "position": score_row.get("position"), "nucleotide": score_row.get("nucleotide"), "type": channel, "score": _safe_float(value), "above_threshold": _safe_float(value) >= threshold, } ) rows.sort(key=lambda item: float(item["score"]), reverse=True) return rows[:top_k] def scores_dataframe(normalized: dict[str, Any]) -> pd.DataFrame: rows = [] for score_row in normalized["scores"]: position = score_row.get("position") row = { "position": position, "nucleotide": score_row.get("nucleotide"), } for channel in normalized["channels"]: row[channel] = score_row.get(channel) rows.append(row) return pd.DataFrame(rows, columns=["position", "nucleotide", *normalized["channels"]]) def site_channels(channels: list[str]) -> list[str]: candidates = [ channel for channel in channels if channel in SPLICE_SITE_CHANNELS or channel.endswith("_splice_site") or channel in {"acceptor", "donor"} ] if candidates: return candidates return [channel for channel in channels if channel != "no_splice"][:6] def plot_score_track(normalized: dict[str, Any], *, threshold: float): channels = site_channels(normalized["channels"]) fig, ax = plt.subplots(figsize=(10, 3.2)) if not normalized["scores"] or not channels: ax.text(0.5, 0.5, "No per-position scores returned", ha="center", va="center", transform=ax.transAxes) ax.set_axis_off() return fig x = [ row["position"] if isinstance(row.get("position"), int) else index for index, row in enumerate(normalized["scores"]) ] plotted = 0 for channel in channels[:6]: y = [row.get(channel) for row in normalized["scores"]] if all(value is None for value in y): continue ax.plot(x, y, linewidth=1.4, label=channel) plotted += 1 if plotted == 0: ax.text(0.5, 0.5, "No plottable score channels", ha="center", va="center", transform=ax.transAxes) else: ax.axhline(threshold, color="0.3", linestyle="--", linewidth=0.9, label="threshold") ax.legend(loc="upper right", ncols=min(plotted + 1, 3), fontsize=8) ax.set_xlabel("Position (0-based)") ax.set_ylabel("Score") ax.set_ylim(bottom=0) ax.margins(x=0.01) fig.tight_layout() return fig def write_result_files(normalized: dict[str, Any], metadata: dict[str, Any], scores: pd.DataFrame): csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="") scores.to_csv(csv_file.name, index=False) csv_file.close() payload = { **normalized, "metadata": metadata, } json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) json.dump(payload, json_file, indent=2) json_file.close() return csv_file.name, json_file.name def initial_model(request: gr.Request): if request is None: return "OpenSpliceAI" query_params = getattr(request, "query_params", None) model_id = None if query_params is not None: model_id = query_params.get("model") if not model_id and getattr(request, "url", None): parsed = parse_qs(urlparse(str(request.url)).query) model_values = parsed.get("model") model_id = model_values[0] if model_values else None return MODEL_LABELS.get(model_id, "OpenSpliceAI") with gr.Blocks(title="Splice Site") as demo: gr.Markdown( "# Splice Site\n" "Run MultiMolecule splice-site checkpoints on a DNA sequence and inspect ranked site calls, " "per-position scores, score tracks, and normalized JSON output." ) with gr.Row(): model = gr.Dropdown( choices=list(MODEL_OPTIONS.keys()), value="OpenSpliceAI", label="Checkpoint", ) threshold = gr.Slider(0.05, 0.95, value=0.5, step=0.05, label="Site threshold") top_k = gr.Slider(1, 100, value=25, step=1, label="Top sites") sequence = gr.Textbox( label="DNA sequence", value=DEFAULT_SEQUENCE, lines=5, ) input_file = gr.File( label="Upload FASTA", file_types=[".fa", ".fas", ".fasta", ".ffn", ".fna"], ) run = gr.Button("Run prediction", variant="primary") with gr.Row(): top_sites = gr.Dataframe(label="Top predicted splice sites (0-based positions)") metadata = gr.JSON(label="Run metadata") score_track = gr.Plot(label="Per-position score track") scores = gr.Dataframe(label="Per-position scores (0-based positions)") with gr.Row(): csv_download = gr.File(label="Download scores CSV") json_download = gr.File(label="Download JSON") run.click( predict, inputs=[model, sequence, threshold, top_k], outputs=[top_sites, scores, metadata, score_track, csv_download, json_download], ) input_file.change(load_input_file, inputs=input_file, outputs=sequence) demo.load(initial_model, outputs=model) if __name__ == "__main__": demo.launch()