# MultiMolecule # Copyright (C) 2024-Present MultiMolecule # This file is part of MultiMolecule. # MultiMolecule is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # any later version. # MultiMolecule is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # For additional terms and clarifications, please refer to our License FAQ at: # . from __future__ import annotations import json import math import tempfile import time from collections.abc import Mapping from datetime import datetime, timezone from functools import lru_cache from pathlib import Path from typing import Any from urllib.parse import parse_qs, urlparse import gradio as gr import matplotlib import numpy as np import pandas as pd import torch from Bio import SeqIO from transformers import pipeline matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa: E402 import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers MODEL_OPTIONS = { "MMSplice": "multimolecule/mmsplice", "MTSplice": "multimolecule/mtsplice", "HAL": "multimolecule/hal", "MaxEntScan score5": "multimolecule/maxentscan-score5", "MaxEntScan score3": "multimolecule/maxentscan-score3", "Pangolin": "multimolecule/pangolin", "SpTransformer": "multimolecule/sptransformer", } MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()} FASTA_SUFFIXES = {".fa", ".fasta", ".fna"} VALID_DNA = set("ACGTNRYSWKMBDHVX") META_COLUMNS = {"scope", "position", "nucleotide", "sequence", "label", "type"} DEFAULT_REFERENCE = "ACGT" * 25 + "CCCCCCCCCCCCCCCCCCCC" + "TGCA" * 25 DEFAULT_ALTERNATIVE = "ACGT" * 25 + "CCCCCCCCCCCTCCCCCCCC" + "TGCA" * 25 def _device() -> int: return 0 if torch.cuda.is_available() else -1 @lru_cache(maxsize=len(MODEL_OPTIONS)) def load_predictor(model_id: str): return pipeline("splice-variant-effect", model=model_id, device=_device()) def clean_sequence(sequence: str, label: str) -> str: sequence = "".join(str(sequence or "").split()).upper().replace("U", "T") if not sequence: raise gr.Error(f"{label} sequence is empty.") invalid = sorted(set(sequence) - VALID_DNA) if invalid: raise gr.Error(f"{label} sequence contains unsupported DNA symbols: {', '.join(invalid)}.") return sequence def validate_pair(reference: str, alternative: str) -> tuple[str, str]: reference = clean_sequence(reference, "Reference") alternative = clean_sequence(alternative, "Alternative") if len(reference) != len(alternative): raise gr.Error( "Reference and alternative sequences must have the same length. " "This app does not perform genome-coordinate lookup or sequence reconstruction." ) return reference, alternative def load_fasta_pair(input_file: Any): if input_file is None: return gr.update(), gr.update() path = Path(getattr(input_file, "name", input_file)) if path.suffix.lower() not in FASTA_SUFFIXES: raise gr.Error("Upload a FASTA file with two records: reference first, alternative second.") records = list(SeqIO.parse(path, "fasta")) if len(records) != 2: raise gr.Error(f"Expected exactly two FASTA records, found {len(records)}.") reference, alternative = validate_pair(str(records[0].seq), str(records[1].seq)) return reference, alternative def _json_safe(value: Any) -> Any: if isinstance(value, torch.Tensor): return _json_safe(value.detach().cpu().tolist()) if isinstance(value, np.ndarray): return _json_safe(value.tolist()) if isinstance(value, np.generic): return value.item() if isinstance(value, Mapping): return {str(key): _json_safe(item) for key, item in value.items()} if isinstance(value, (list, tuple)): return [_json_safe(item) for item in value] return value def _is_scalar(value: Any) -> bool: if isinstance(value, (str, bytes)) or value is None: return False try: float(value) except (TypeError, ValueError): return False return True def _number(value: Any) -> float | Any: if not _is_scalar(value): return value number = float(value) if math.isfinite(number): return number return value def _position_key(key: Any) -> bool: try: int(str(key)) except ValueError: return False return True def _vector_row(values: list[Any], channels: list[str], scalar_column: str, scope: str = "sequence") -> dict[str, Any]: row: dict[str, Any] = {"scope": scope} if channels and len(values) == len(channels): row.update({channel: _number(value) for channel, value in zip(channels, values)}) elif len(values) == 1: row[scalar_column] = _number(values[0]) else: row.update({f"{scalar_column}_{index}": _number(value) for index, value in enumerate(values)}) return row def _flatten_mapping( mapping: Mapping[str, Any], channels: list[str], scalar_column: str, prefix: str | None = None, ) -> dict[str, Any]: row: dict[str, Any] = {} for key, value in mapping.items(): key = str(key) column = f"{prefix}_{key}" if prefix else key value = _json_safe(value) if _is_scalar(value) or value is None or isinstance(value, str): row[column] = _number(value) elif isinstance(value, Mapping): row.update(_flatten_mapping(value, channels, scalar_column, prefix=column)) elif isinstance(value, list) and all(_is_scalar(item) for item in value): if key in META_COLUMNS: row[column] = value elif channels and len(value) == len(channels): row.update({channel: _number(item) for channel, item in zip(channels, value)}) else: row.update({f"{column}_{index}": _number(item) for index, item in enumerate(value)}) else: row[column] = value return row def normalize_score_rows(score_value: Any, channels: list[str], scalar_column: str) -> list[dict[str, Any]]: score_value = _json_safe(score_value) if score_value is None: return [] if _is_scalar(score_value): return [{"scope": "sequence", scalar_column: _number(score_value)}] if isinstance(score_value, Mapping): if score_value and not all(_position_key(key) for key in score_value): series_lengths = { len(value) for value in score_value.values() if isinstance(value, list) and all(_is_scalar(item) for item in value) } if len(series_lengths) == 1: length = series_lengths.pop() if length > 1 and all(isinstance(value, list) for value in score_value.values()): return [ { "position": position, **{str(key): _number(value[position]) for key, value in score_value.items()}, } for position in range(length) ] if score_value and all(_position_key(key) for key in score_value): rows = [] for key, value in score_value.items(): row = {"position": int(str(key))} if isinstance(value, Mapping): row.update(_flatten_mapping(value, channels, scalar_column)) elif isinstance(value, list): row.update(_vector_row(value, channels, scalar_column, scope="position")) row.pop("scope", None) else: row[scalar_column] = _number(value) rows.append(row) return rows return [_flatten_mapping(score_value, channels, scalar_column)] if isinstance(score_value, list): if not score_value: return [] if all(_is_scalar(item) for item in score_value): return [_vector_row(score_value, channels, scalar_column)] rows = [] for index, item in enumerate(score_value): item = _json_safe(item) if isinstance(item, Mapping): rows.append(_flatten_mapping(item, channels, scalar_column)) elif isinstance(item, list): row = {"position": index} row.update(_vector_row(item, channels, scalar_column, scope="position")) row.pop("scope", None) rows.append(row) elif _is_scalar(item): rows.append({"position": index, scalar_column: _number(item)}) return rows return [{"scope": "sequence", scalar_column: score_value}] def result_table(result: Mapping[str, Any], score_key: str, scores_key: str, scalar_column: str) -> pd.DataFrame: channels = [str(channel) for channel in result.get("channels", [])] score_value = result.get(scores_key, result.get(score_key)) rows = normalize_score_rows(score_value, channels, scalar_column) if not rows: return pd.DataFrame() table = pd.DataFrame(rows) ordered = [column for column in ("scope", "position", "nucleotide", "sequence", "label", "type") if column in table] remaining = [column for column in table.columns if column not in ordered] return table[ordered + remaining] def dataframe_records(table: pd.DataFrame) -> list[dict[str, Any]]: if table.empty: return [] return json.loads(table.to_json(orient="records")) def _list_of_dicts(value: Any) -> list[dict[str, Any]]: value = _json_safe(value) if value is None: return [] if not isinstance(value, list): raise gr.Error(f"Expected a list output, got {type(value).__name__}.") return [dict(item) for item in value if isinstance(item, Mapping)] def difference_summary(reference: str, alternative: str) -> dict[str, Any]: differences = [ { "position": index, "reference": ref_base, "alternative": alt_base, } for index, (ref_base, alt_base) in enumerate(zip(reference, alternative)) if ref_base != alt_base ] return { "count": len(differences), "positions": differences[:25], "positions_truncated": len(differences) > 25, } def make_delta_plot(effects_table: pd.DataFrame, delta_table: pd.DataFrame, model_label: str): fig, ax = plt.subplots(figsize=(7, 2.8)) values: list[tuple[str, float]] = [] if not effects_table.empty and "delta_score" in effects_table.columns: for _, row in effects_table.iterrows(): value = row.get("delta_score") if pd.isna(value): continue channel = row.get("channel", "delta") position = row.get("position") suffix = f"@{int(position)}" if position is not None and pd.notna(position) else "" values.append((f"{channel}{suffix}", float(value))) if not values and not delta_table.empty: numeric_columns = [ column for column in delta_table.columns if column not in META_COLUMNS and pd.api.types.is_numeric_dtype(delta_table[column]) ] for _, row in delta_table.iterrows(): position = row.get("position") for column in numeric_columns: value = row.get(column) if pd.notna(value): suffix = f"@{int(position)}" if position is not None and pd.notna(position) else "" values.append((f"{column}{suffix}", float(value))) values = sorted(values, key=lambda item: abs(item[1]), reverse=True)[:20] values.reverse() if not values: ax.text(0.5, 0.5, "No numeric delta scores", ha="center", va="center") ax.set_axis_off() fig.tight_layout() return fig labels, scores = zip(*values) colors = ["#2563eb" if score >= 0 else "#dc2626" for score in scores] ax.barh(labels, scores, color=colors) ax.axvline(0, color="#111827", linewidth=0.8) ax.set_title(f"{model_label} top delta scores") ax.set_xlabel("alternative - reference") ax.tick_params(axis="y", labelsize=8) fig.tight_layout() return fig def write_result_files( metadata: dict[str, Any], result: Mapping[str, Any], effects_table: pd.DataFrame, delta_table: pd.DataFrame, reference_table: pd.DataFrame, alternative_table: pd.DataFrame, ) -> tuple[str, str]: csv_tables = [] if not effects_table.empty: csv_table = effects_table.copy() csv_table.insert(0, "score_set", "variant_effect") csv_tables.append(csv_table) for score_set, table in ( ("delta", delta_table), ("reference", reference_table), ("alternative", alternative_table), ): if not table.empty: csv_table = table.copy() csv_table.insert(0, "score_set", score_set) csv_tables.append(csv_table) csv_payload = pd.concat(csv_tables, ignore_index=True, sort=False) if csv_tables else pd.DataFrame() csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="") csv_path = csv_file.name csv_file.close() csv_payload.to_csv(csv_path, index=False) json_payload = { "metadata": metadata, "result": _json_safe(result), "tables": { "variant_effects": dataframe_records(effects_table), "delta": dataframe_records(delta_table), "reference": dataframe_records(reference_table), "alternative": dataframe_records(alternative_table), }, } json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) json_path = json_file.name json_file.close() with open(json_path, "w") as handle: json.dump(json_payload, handle, indent=2) return csv_path, json_path def unpack_prediction_result(result: Any) -> Mapping[str, Any]: result = _json_safe(result) if isinstance(result, list): if len(result) != 1: raise gr.Error(f"Expected one prediction result, got {len(result)}.") result = result[0] if not isinstance(result, Mapping): raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.") return result def predict(model_label: str, reference: str, alternative: str, top_k: int): started = time.perf_counter() model_id = MODEL_OPTIONS[model_label] reference, alternative = validate_pair(reference, alternative) top_k = int(top_k) try: predictor = load_predictor(model_id) result = unpack_prediction_result(predictor(reference, alternative=alternative, top_k=top_k)) except gr.Error: raise except Exception as exc: raise gr.Error(f"Prediction failed for {model_label}: {exc}") from exc delta_table = result_table(result, "delta_score", "delta_scores", "delta_score") reference_table = result_table(result, "reference_score", "reference_scores", "reference_score") alternative_table = result_table(result, "alternative_score", "alternative_scores", "alternative_score") effects_table = pd.DataFrame(_list_of_dicts(result.get("variant_effects", []))) metadata = { "task": "splice-variant-effect", "model": model_id, "model_label": model_label, "device": "cuda" if torch.cuda.is_available() else "cpu", "position_index_base": result.get("position_index_base", 0), "reference_length": len(reference), "alternative_length": len(alternative), "differences": difference_summary(reference, alternative), "channels": result.get("channels", []), "top_k": top_k, "output_fields": sorted(result.keys()), "runtime_seconds": round(time.perf_counter() - started, 3), "timestamp_utc": datetime.now(timezone.utc).isoformat(), } csv_path, json_path = write_result_files( metadata, result, effects_table, delta_table, reference_table, alternative_table, ) delta_plot = make_delta_plot(effects_table, delta_table, model_label) return effects_table, delta_table, reference_table, alternative_table, metadata, delta_plot, csv_path, json_path def initial_model(request: gr.Request): if request is None: return "MMSplice" query_params = getattr(request, "query_params", None) model_id = None if query_params is not None: model_id = query_params.get("model") if not model_id and getattr(request, "url", None): parsed = parse_qs(urlparse(str(request.url)).query) model_values = parsed.get("model") model_id = model_values[0] if model_values else None return MODEL_LABELS.get(model_id, "MMSplice") with gr.Blocks(title="Splice Variant Effect") as demo: gr.Markdown( "# Splice Variant Effect\n" "Score paired reference and alternative DNA windows with MultiMolecule splice variant-effect models." ) with gr.Row(): model = gr.Dropdown( choices=list(MODEL_OPTIONS.keys()), value="MMSplice", label="Checkpoint", ) top_k = gr.Slider(1, 100, value=20, step=1, label="Top effects") with gr.Row(): reference = gr.Textbox( label="Reference DNA sequence", value=DEFAULT_REFERENCE, lines=5, ) alternative = gr.Textbox( label="Alternative DNA sequence", value=DEFAULT_ALTERNATIVE, lines=5, ) input_file = gr.File( label="Upload paired FASTA (reference record first, alternative record second)", file_types=[".fa", ".fasta", ".fna"], ) run = gr.Button("Run variant effect", variant="primary") with gr.Row(): variant_effects = gr.Dataframe(label="Top delta effects (0-based sequence positions)") run_metadata = gr.JSON(label="Run metadata") delta_scores = gr.Dataframe(label="Delta scores (0-based sequence positions)") with gr.Row(): reference_scores = gr.Dataframe(label="Reference scores") alternative_scores = gr.Dataframe(label="Alternative scores") delta_plot = gr.Plot(label="Top delta scores") with gr.Row(): csv_download = gr.File(label="Download CSV") json_download = gr.File(label="Download JSON") run.click( predict, inputs=[model, reference, alternative, top_k], outputs=[ variant_effects, delta_scores, reference_scores, alternative_scores, run_metadata, delta_plot, csv_download, json_download, ], ) input_file.change(load_fasta_pair, inputs=input_file, outputs=[reference, alternative]) demo.load(initial_model, outputs=model) if __name__ == "__main__": demo.launch()