Spaces:
Running
Running
| # MultiMolecule | |
| # Copyright (C) 2024-Present MultiMolecule | |
| # This file is part of MultiMolecule. | |
| # MultiMolecule is free software: you can redistribute it and/or modify | |
| # it under the terms of the GNU Affero General Public License as published by | |
| # the Free Software Foundation, either version 3 of the License, or | |
| # any later version. | |
| # MultiMolecule is distributed in the hope that it will be useful, | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| # GNU Affero General Public License for more details. | |
| # You should have received a copy of the GNU Affero General Public License | |
| # along with this program. If not, see <http://www.gnu.org/licenses/>. | |
| # For additional terms and clarifications, please refer to our License FAQ at: | |
| # <https://multimolecule.danling.org/about/license-faq>. | |
| from __future__ import annotations | |
| import json | |
| import tempfile | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any | |
| from urllib.parse import parse_qs, urlparse | |
| import gradio as gr | |
| import matplotlib | |
| import pandas as pd | |
| import torch | |
| from transformers import pipeline | |
| matplotlib.use("Agg") | |
| import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers | |
| import multimolecule.io as mmio # noqa: E402 | |
| from matplotlib import pyplot as plt # noqa: E402 | |
| MODEL_OPTIONS = { | |
| "OpenSpliceAI": "multimolecule/openspliceai-mane-400nt", | |
| "Pangolin": "multimolecule/pangolin", | |
| "SpTransformer": "multimolecule/sptransformer", | |
| "MaxEntScan": "multimolecule/maxentscan-score5", | |
| } | |
| MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()} | |
| MODEL_LABELS["multimolecule/maxentscan-score3"] = "MaxEntScan" | |
| MAXENTSCAN_MODELS = { | |
| "donor": { | |
| "model_id": "multimolecule/maxentscan-score5", | |
| "window": 9, | |
| "site_offset": 3, | |
| }, | |
| "acceptor": { | |
| "model_id": "multimolecule/maxentscan-score3", | |
| "window": 23, | |
| "site_offset": 18, | |
| }, | |
| } | |
| FASTA_SUFFIXES = {f".{suffix}" for suffix in mmio.FASTA} | |
| VALID_BASES = set("ACGTN") | |
| AMBIGUOUS_BASES = set("RYSWKMBDHV") | |
| SPLICE_SITE_CHANNELS = {"acceptor", "donor", "splice_site"} | |
| DEFAULT_SEQUENCE = ( | |
| "GCTGACCTGCTGCTGACCCAGGTGAGTCTGCACTCCTGGGCTCAGGTTTCTCTCTCTCTCTCTCTCTCTCTCTCCAG" | |
| "GATGATGCTGATGAGGAGGAGGAGCTGACTGATGCTGAGGCTGACCTGA" | |
| ) | |
| def _device() -> int: | |
| return 0 if torch.cuda.is_available() else -1 | |
| def load_predictor(model_id: str): | |
| return pipeline("splice-site", model=model_id, device=_device()) | |
| def clean_sequence(sequence: str) -> str: | |
| sequence = "".join(str(sequence or "").split()).upper().replace("U", "T") | |
| if not sequence: | |
| raise gr.Error("Sequence is empty.") | |
| invalid = sorted(set(sequence) - VALID_BASES - AMBIGUOUS_BASES) | |
| if invalid: | |
| raise gr.Error(f"DNA sequence contains unsupported symbols: {', '.join(invalid)}.") | |
| return "".join(base if base in VALID_BASES else "N" for base in sequence) | |
| def load_input_file(input_file: Any): | |
| if input_file is None: | |
| return gr.update() | |
| path = Path(getattr(input_file, "name", input_file)) | |
| try: | |
| records = mmio.read_fasta_records(path) | |
| except mmio.InvalidStructureFile as error: | |
| raise gr.Error("Could not parse uploaded file as FASTA.") from error | |
| if not records: | |
| raise gr.Error("Could not parse uploaded file as FASTA.") | |
| if len(records) > 1: | |
| raise gr.Error(f"This demo supports one sequence at a time. Uploaded FASTA contains {len(records)} records.") | |
| return clean_sequence(records[0].sequence) | |
| def normalize_prediction_result(result: Any, sequence: str) -> dict[str, Any]: | |
| if isinstance(result, list): | |
| if len(result) != 1: | |
| raise gr.Error(f"Expected one prediction result, got {len(result)}.") | |
| result = result[0] | |
| if not isinstance(result, dict): | |
| raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.") | |
| channels = [str(channel) for channel in result.get("channels", [])] | |
| scores = _list_of_dicts(result.get("scores", [])) | |
| splice_sites = _list_of_dicts(result.get("splice_sites", [])) | |
| if "score" in result and not scores: | |
| channels = channels or ["score"] | |
| score = _safe_float(result["score"]) | |
| scores = [{"position": None, "nucleotide": None, channels[0]: score}] | |
| splice_sites = [{"position": None, "nucleotide": None, "type": channels[0], "score": score}] | |
| return { | |
| "splice_sites": splice_sites, | |
| "scores": scores, | |
| "channels": channels, | |
| "position_index_base": int(result.get("position_index_base", 0)), | |
| "sequence": clean_sequence(str(result.get("sequence", sequence))), | |
| } | |
| def _list_of_dicts(value: Any) -> list[dict[str, Any]]: | |
| if value is None: | |
| return [] | |
| if not isinstance(value, list): | |
| raise gr.Error(f"Expected a list output, got {type(value).__name__}.") | |
| return [dict(item) for item in value if isinstance(item, dict)] | |
| def _safe_float(value: Any) -> float: | |
| if isinstance(value, (list, tuple)): | |
| if len(value) != 1: | |
| raise gr.Error("Expected a scalar score.") | |
| value = value[0] | |
| return float(value) | |
| def predict_maxentscan(sequence: str, threshold: float, top_k: int) -> dict[str, Any]: | |
| scores_by_position: list[dict[str, Any]] = [ | |
| {"position": position, "nucleotide": nucleotide, "acceptor": None, "donor": None} | |
| for position, nucleotide in enumerate(sequence) | |
| ] | |
| splice_sites: list[dict[str, Any]] = [] | |
| windows_scored = {"acceptor": 0, "donor": 0} | |
| for site_type, config in MAXENTSCAN_MODELS.items(): | |
| model_id = config["model_id"] | |
| predictor = load_predictor(model_id) | |
| window = int(config["window"]) | |
| site_offset = int(config["site_offset"]) | |
| if len(sequence) < window: | |
| continue | |
| windows = [sequence[start : start + window] for start in range(len(sequence) - window + 1)] | |
| windows_scored[site_type] = len(windows) | |
| results = predictor(windows, output_scores=True) | |
| if isinstance(results, dict): | |
| results = [results] | |
| for start, result in enumerate(results): | |
| score = _safe_float(result.get("score")) | |
| position = start + site_offset | |
| scores_by_position[position][site_type] = score | |
| if score >= threshold: | |
| splice_sites.append( | |
| { | |
| "position": position, | |
| "nucleotide": sequence[position], | |
| "type": site_type, | |
| "score": score, | |
| } | |
| ) | |
| splice_sites.sort(key=lambda item: float(item["score"]), reverse=True) | |
| return { | |
| "splice_sites": splice_sites[:top_k], | |
| "scores": scores_by_position, | |
| "channels": ["acceptor", "donor"], | |
| "sequence": sequence, | |
| "windows_scored": windows_scored, | |
| } | |
| def predict( | |
| model_label: str, | |
| sequence: str, | |
| threshold: float, | |
| top_k: int, | |
| ): | |
| sequence = clean_sequence(sequence) | |
| top_k = int(top_k) | |
| model_id = MODEL_OPTIONS[model_label] | |
| if model_label == "MaxEntScan": | |
| normalized = predict_maxentscan(sequence, threshold, top_k) | |
| model_ids: str | list[str] = [config["model_id"] for config in MAXENTSCAN_MODELS.values()] | |
| else: | |
| predictor = load_predictor(model_id) | |
| result = predictor(sequence, threshold=threshold, output_scores=True, top_k=top_k) | |
| normalized = normalize_prediction_result(result, sequence) | |
| model_ids = model_id | |
| top_sites = top_sites_dataframe(normalized, threshold=threshold, top_k=top_k) | |
| scores = scores_dataframe(normalized) | |
| figure = plot_score_track(normalized, threshold=threshold) | |
| metadata = { | |
| "model": model_ids, | |
| "model_label": model_label, | |
| "device": "cuda" if torch.cuda.is_available() else "cpu", | |
| "length": len(normalized["sequence"]), | |
| "position_index_base": normalized["position_index_base"], | |
| "threshold": threshold, | |
| "top_k": top_k, | |
| "channels": normalized["channels"], | |
| "num_splice_sites": len(normalized["splice_sites"]), | |
| } | |
| if "windows_scored" in normalized: | |
| metadata["windows_scored"] = normalized["windows_scored"] | |
| csv_path, json_path = write_result_files(normalized, metadata, scores) | |
| return top_sites, scores, metadata, figure, csv_path, json_path | |
| def top_sites_dataframe(normalized: dict[str, Any], *, threshold: float, top_k: int) -> pd.DataFrame: | |
| sites = normalized["splice_sites"] | |
| if not sites: | |
| sites = rank_sites_from_scores(normalized, threshold=threshold, top_k=top_k) | |
| rows = [] | |
| for site in sorted(sites, key=lambda item: float(item.get("score", 0.0)), reverse=True)[:top_k]: | |
| position = site.get("position") | |
| rows.append( | |
| { | |
| "position": position, | |
| "nucleotide": site.get("nucleotide"), | |
| "type": site.get("type"), | |
| "score": _safe_float(site.get("score", 0.0)), | |
| "above_threshold": _safe_float(site.get("score", 0.0)) >= threshold, | |
| } | |
| ) | |
| return pd.DataFrame( | |
| rows, | |
| columns=["position", "nucleotide", "type", "score", "above_threshold"], | |
| ) | |
| def rank_sites_from_scores(normalized: dict[str, Any], *, threshold: float, top_k: int) -> list[dict[str, Any]]: | |
| channels = site_channels(normalized["channels"]) | |
| rows = [] | |
| for score_row in normalized["scores"]: | |
| for channel in channels: | |
| value = score_row.get(channel) | |
| if value is None: | |
| continue | |
| rows.append( | |
| { | |
| "position": score_row.get("position"), | |
| "nucleotide": score_row.get("nucleotide"), | |
| "type": channel, | |
| "score": _safe_float(value), | |
| "above_threshold": _safe_float(value) >= threshold, | |
| } | |
| ) | |
| rows.sort(key=lambda item: float(item["score"]), reverse=True) | |
| return rows[:top_k] | |
| def scores_dataframe(normalized: dict[str, Any]) -> pd.DataFrame: | |
| rows = [] | |
| for score_row in normalized["scores"]: | |
| position = score_row.get("position") | |
| row = { | |
| "position": position, | |
| "nucleotide": score_row.get("nucleotide"), | |
| } | |
| for channel in normalized["channels"]: | |
| row[channel] = score_row.get(channel) | |
| rows.append(row) | |
| return pd.DataFrame(rows, columns=["position", "nucleotide", *normalized["channels"]]) | |
| def site_channels(channels: list[str]) -> list[str]: | |
| candidates = [ | |
| channel | |
| for channel in channels | |
| if channel in SPLICE_SITE_CHANNELS or channel.endswith("_splice_site") or channel in {"acceptor", "donor"} | |
| ] | |
| if candidates: | |
| return candidates | |
| return [channel for channel in channels if channel != "no_splice"][:6] | |
| def plot_score_track(normalized: dict[str, Any], *, threshold: float): | |
| channels = site_channels(normalized["channels"]) | |
| fig, ax = plt.subplots(figsize=(10, 3.2)) | |
| if not normalized["scores"] or not channels: | |
| ax.text(0.5, 0.5, "No per-position scores returned", ha="center", va="center", transform=ax.transAxes) | |
| ax.set_axis_off() | |
| return fig | |
| x = [ | |
| row["position"] if isinstance(row.get("position"), int) else index | |
| for index, row in enumerate(normalized["scores"]) | |
| ] | |
| plotted = 0 | |
| for channel in channels[:6]: | |
| y = [row.get(channel) for row in normalized["scores"]] | |
| if all(value is None for value in y): | |
| continue | |
| ax.plot(x, y, linewidth=1.4, label=channel) | |
| plotted += 1 | |
| if plotted == 0: | |
| ax.text(0.5, 0.5, "No plottable score channels", ha="center", va="center", transform=ax.transAxes) | |
| else: | |
| ax.axhline(threshold, color="0.3", linestyle="--", linewidth=0.9, label="threshold") | |
| ax.legend(loc="upper right", ncols=min(plotted + 1, 3), fontsize=8) | |
| ax.set_xlabel("Position (0-based)") | |
| ax.set_ylabel("Score") | |
| ax.set_ylim(bottom=0) | |
| ax.margins(x=0.01) | |
| fig.tight_layout() | |
| return fig | |
| def write_result_files(normalized: dict[str, Any], metadata: dict[str, Any], scores: pd.DataFrame): | |
| csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="") | |
| scores.to_csv(csv_file.name, index=False) | |
| csv_file.close() | |
| payload = { | |
| **normalized, | |
| "metadata": metadata, | |
| } | |
| json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) | |
| json.dump(payload, json_file, indent=2) | |
| json_file.close() | |
| return csv_file.name, json_file.name | |
| def initial_model(request: gr.Request): | |
| if request is None: | |
| return "OpenSpliceAI" | |
| query_params = getattr(request, "query_params", None) | |
| model_id = None | |
| if query_params is not None: | |
| model_id = query_params.get("model") | |
| if not model_id and getattr(request, "url", None): | |
| parsed = parse_qs(urlparse(str(request.url)).query) | |
| model_values = parsed.get("model") | |
| model_id = model_values[0] if model_values else None | |
| return MODEL_LABELS.get(model_id, "OpenSpliceAI") | |
| with gr.Blocks(title="Splice Site") as demo: | |
| gr.Markdown( | |
| "# Splice Site\n" | |
| "Run MultiMolecule splice-site checkpoints on a DNA sequence and inspect ranked site calls, " | |
| "per-position scores, score tracks, and normalized JSON output." | |
| ) | |
| with gr.Row(): | |
| model = gr.Dropdown( | |
| choices=list(MODEL_OPTIONS.keys()), | |
| value="OpenSpliceAI", | |
| label="Checkpoint", | |
| ) | |
| threshold = gr.Slider(0.05, 0.95, value=0.5, step=0.05, label="Site threshold") | |
| top_k = gr.Slider(1, 100, value=25, step=1, label="Top sites") | |
| sequence = gr.Textbox( | |
| label="DNA sequence", | |
| value=DEFAULT_SEQUENCE, | |
| lines=5, | |
| ) | |
| input_file = gr.File( | |
| label="Upload FASTA", | |
| file_types=[".fa", ".fas", ".fasta", ".ffn", ".fna"], | |
| ) | |
| run = gr.Button("Run prediction", variant="primary") | |
| with gr.Row(): | |
| top_sites = gr.Dataframe(label="Top predicted splice sites (0-based positions)") | |
| metadata = gr.JSON(label="Run metadata") | |
| score_track = gr.Plot(label="Per-position score track") | |
| scores = gr.Dataframe(label="Per-position scores (0-based positions)") | |
| with gr.Row(): | |
| csv_download = gr.File(label="Download scores CSV") | |
| json_download = gr.File(label="Download JSON") | |
| run.click( | |
| predict, | |
| inputs=[model, sequence, threshold, top_k], | |
| outputs=[top_sites, scores, metadata, score_track, csv_download, json_download], | |
| ) | |
| input_file.change(load_input_file, inputs=input_file, outputs=sequence) | |
| demo.load(initial_model, outputs=model) | |
| if __name__ == "__main__": | |
| demo.launch() | |