| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| from __future__ import annotations |
|
|
| import json |
| import math |
| import tempfile |
| import time |
| from collections.abc import Mapping |
| from datetime import datetime, timezone |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import Any |
| from urllib.parse import parse_qs, urlparse |
|
|
| import gradio as gr |
| import matplotlib |
| import numpy as np |
| import pandas as pd |
| import torch |
| from Bio import SeqIO |
| from transformers import pipeline |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import multimolecule |
|
|
| MODEL_OPTIONS = { |
| "MMSplice": "multimolecule/mmsplice", |
| "MTSplice": "multimolecule/mtsplice", |
| "HAL": "multimolecule/hal", |
| "MaxEntScan score5": "multimolecule/maxentscan-score5", |
| "MaxEntScan score3": "multimolecule/maxentscan-score3", |
| "Pangolin": "multimolecule/pangolin", |
| "SpTransformer": "multimolecule/sptransformer", |
| } |
| MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()} |
| FASTA_SUFFIXES = {".fa", ".fasta", ".fna"} |
| VALID_DNA = set("ACGTNRYSWKMBDHVX") |
| META_COLUMNS = {"scope", "position", "nucleotide", "sequence", "label", "type"} |
|
|
| DEFAULT_REFERENCE = "ACGT" * 25 + "CCCCCCCCCCCCCCCCCCCC" + "TGCA" * 25 |
| DEFAULT_ALTERNATIVE = "ACGT" * 25 + "CCCCCCCCCCCTCCCCCCCC" + "TGCA" * 25 |
|
|
|
|
| def _device() -> int: |
| return 0 if torch.cuda.is_available() else -1 |
|
|
|
|
| @lru_cache(maxsize=len(MODEL_OPTIONS)) |
| def load_predictor(model_id: str): |
| return pipeline("splice-variant-effect", model=model_id, device=_device()) |
|
|
|
|
| def clean_sequence(sequence: str, label: str) -> str: |
| sequence = "".join(str(sequence or "").split()).upper().replace("U", "T") |
| if not sequence: |
| raise gr.Error(f"{label} sequence is empty.") |
|
|
| invalid = sorted(set(sequence) - VALID_DNA) |
| if invalid: |
| raise gr.Error(f"{label} sequence contains unsupported DNA symbols: {', '.join(invalid)}.") |
| return sequence |
|
|
|
|
| def validate_pair(reference: str, alternative: str) -> tuple[str, str]: |
| reference = clean_sequence(reference, "Reference") |
| alternative = clean_sequence(alternative, "Alternative") |
| if len(reference) != len(alternative): |
| raise gr.Error( |
| "Reference and alternative sequences must have the same length. " |
| "This app does not perform genome-coordinate lookup or sequence reconstruction." |
| ) |
| return reference, alternative |
|
|
|
|
| def load_fasta_pair(input_file: Any): |
| if input_file is None: |
| return gr.update(), gr.update() |
|
|
| path = Path(getattr(input_file, "name", input_file)) |
| if path.suffix.lower() not in FASTA_SUFFIXES: |
| raise gr.Error("Upload a FASTA file with two records: reference first, alternative second.") |
|
|
| records = list(SeqIO.parse(path, "fasta")) |
| if len(records) != 2: |
| raise gr.Error(f"Expected exactly two FASTA records, found {len(records)}.") |
|
|
| reference, alternative = validate_pair(str(records[0].seq), str(records[1].seq)) |
| return reference, alternative |
|
|
|
|
| def _json_safe(value: Any) -> Any: |
| if isinstance(value, torch.Tensor): |
| return _json_safe(value.detach().cpu().tolist()) |
| if isinstance(value, np.ndarray): |
| return _json_safe(value.tolist()) |
| if isinstance(value, np.generic): |
| return value.item() |
| if isinstance(value, Mapping): |
| return {str(key): _json_safe(item) for key, item in value.items()} |
| if isinstance(value, (list, tuple)): |
| return [_json_safe(item) for item in value] |
| return value |
|
|
|
|
| def _is_scalar(value: Any) -> bool: |
| if isinstance(value, (str, bytes)) or value is None: |
| return False |
| try: |
| float(value) |
| except (TypeError, ValueError): |
| return False |
| return True |
|
|
|
|
| def _number(value: Any) -> float | Any: |
| if not _is_scalar(value): |
| return value |
| number = float(value) |
| if math.isfinite(number): |
| return number |
| return value |
|
|
|
|
| def _position_key(key: Any) -> bool: |
| try: |
| int(str(key)) |
| except ValueError: |
| return False |
| return True |
|
|
|
|
| def _vector_row(values: list[Any], channels: list[str], scalar_column: str, scope: str = "sequence") -> dict[str, Any]: |
| row: dict[str, Any] = {"scope": scope} |
| if channels and len(values) == len(channels): |
| row.update({channel: _number(value) for channel, value in zip(channels, values)}) |
| elif len(values) == 1: |
| row[scalar_column] = _number(values[0]) |
| else: |
| row.update({f"{scalar_column}_{index}": _number(value) for index, value in enumerate(values)}) |
| return row |
|
|
|
|
| def _flatten_mapping( |
| mapping: Mapping[str, Any], |
| channels: list[str], |
| scalar_column: str, |
| prefix: str | None = None, |
| ) -> dict[str, Any]: |
| row: dict[str, Any] = {} |
| for key, value in mapping.items(): |
| key = str(key) |
| column = f"{prefix}_{key}" if prefix else key |
| value = _json_safe(value) |
| if _is_scalar(value) or value is None or isinstance(value, str): |
| row[column] = _number(value) |
| elif isinstance(value, Mapping): |
| row.update(_flatten_mapping(value, channels, scalar_column, prefix=column)) |
| elif isinstance(value, list) and all(_is_scalar(item) for item in value): |
| if key in META_COLUMNS: |
| row[column] = value |
| elif channels and len(value) == len(channels): |
| row.update({channel: _number(item) for channel, item in zip(channels, value)}) |
| else: |
| row.update({f"{column}_{index}": _number(item) for index, item in enumerate(value)}) |
| else: |
| row[column] = value |
| return row |
|
|
|
|
| def normalize_score_rows(score_value: Any, channels: list[str], scalar_column: str) -> list[dict[str, Any]]: |
| score_value = _json_safe(score_value) |
| if score_value is None: |
| return [] |
|
|
| if _is_scalar(score_value): |
| return [{"scope": "sequence", scalar_column: _number(score_value)}] |
|
|
| if isinstance(score_value, Mapping): |
| if score_value and not all(_position_key(key) for key in score_value): |
| series_lengths = { |
| len(value) |
| for value in score_value.values() |
| if isinstance(value, list) and all(_is_scalar(item) for item in value) |
| } |
| if len(series_lengths) == 1: |
| length = series_lengths.pop() |
| if length > 1 and all(isinstance(value, list) for value in score_value.values()): |
| return [ |
| { |
| "position": position, |
| **{str(key): _number(value[position]) for key, value in score_value.items()}, |
| } |
| for position in range(length) |
| ] |
| if score_value and all(_position_key(key) for key in score_value): |
| rows = [] |
| for key, value in score_value.items(): |
| row = {"position": int(str(key))} |
| if isinstance(value, Mapping): |
| row.update(_flatten_mapping(value, channels, scalar_column)) |
| elif isinstance(value, list): |
| row.update(_vector_row(value, channels, scalar_column, scope="position")) |
| row.pop("scope", None) |
| else: |
| row[scalar_column] = _number(value) |
| rows.append(row) |
| return rows |
| return [_flatten_mapping(score_value, channels, scalar_column)] |
|
|
| if isinstance(score_value, list): |
| if not score_value: |
| return [] |
| if all(_is_scalar(item) for item in score_value): |
| return [_vector_row(score_value, channels, scalar_column)] |
| rows = [] |
| for index, item in enumerate(score_value): |
| item = _json_safe(item) |
| if isinstance(item, Mapping): |
| rows.append(_flatten_mapping(item, channels, scalar_column)) |
| elif isinstance(item, list): |
| row = {"position": index} |
| row.update(_vector_row(item, channels, scalar_column, scope="position")) |
| row.pop("scope", None) |
| rows.append(row) |
| elif _is_scalar(item): |
| rows.append({"position": index, scalar_column: _number(item)}) |
| return rows |
|
|
| return [{"scope": "sequence", scalar_column: score_value}] |
|
|
|
|
| def result_table(result: Mapping[str, Any], score_key: str, scores_key: str, scalar_column: str) -> pd.DataFrame: |
| channels = [str(channel) for channel in result.get("channels", [])] |
| score_value = result.get(scores_key, result.get(score_key)) |
| rows = normalize_score_rows(score_value, channels, scalar_column) |
| if not rows: |
| return pd.DataFrame() |
|
|
| table = pd.DataFrame(rows) |
| ordered = [column for column in ("scope", "position", "nucleotide", "sequence", "label", "type") if column in table] |
| remaining = [column for column in table.columns if column not in ordered] |
| return table[ordered + remaining] |
|
|
|
|
| def dataframe_records(table: pd.DataFrame) -> list[dict[str, Any]]: |
| if table.empty: |
| return [] |
| return json.loads(table.to_json(orient="records")) |
|
|
|
|
| def difference_summary(reference: str, alternative: str) -> dict[str, Any]: |
| differences = [ |
| { |
| "position": index, |
| "reference": ref_base, |
| "alternative": alt_base, |
| } |
| for index, (ref_base, alt_base) in enumerate(zip(reference, alternative)) |
| if ref_base != alt_base |
| ] |
| return { |
| "count": len(differences), |
| "positions": differences[:25], |
| "positions_truncated": len(differences) > 25, |
| } |
|
|
|
|
| def make_delta_plot(delta_table: pd.DataFrame, model_label: str): |
| fig, ax = plt.subplots(figsize=(7, 2.8)) |
| values: list[tuple[str, float]] = [] |
|
|
| if not delta_table.empty: |
| numeric_columns = [ |
| column |
| for column in delta_table.columns |
| if column not in META_COLUMNS and pd.api.types.is_numeric_dtype(delta_table[column]) |
| ] |
| for _, row in delta_table.iterrows(): |
| position = row.get("position") |
| for column in numeric_columns: |
| value = row.get(column) |
| if pd.notna(value): |
| suffix = f"@{int(position)}" if position is not None and pd.notna(position) else "" |
| values.append((f"{column}{suffix}", float(value))) |
|
|
| values = sorted(values, key=lambda item: abs(item[1]), reverse=True)[:20] |
| values.reverse() |
| if not values: |
| ax.text(0.5, 0.5, "No numeric delta scores", ha="center", va="center") |
| ax.set_axis_off() |
| fig.tight_layout() |
| return fig |
|
|
| labels, scores = zip(*values) |
| colors = ["#2563eb" if score >= 0 else "#dc2626" for score in scores] |
| ax.barh(labels, scores, color=colors) |
| ax.axvline(0, color="#111827", linewidth=0.8) |
| ax.set_title(f"{model_label} top delta scores") |
| ax.set_xlabel("alternative - reference") |
| ax.tick_params(axis="y", labelsize=8) |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def write_result_files( |
| metadata: dict[str, Any], |
| result: Mapping[str, Any], |
| delta_table: pd.DataFrame, |
| reference_table: pd.DataFrame, |
| alternative_table: pd.DataFrame, |
| ) -> tuple[str, str]: |
| csv_tables = [] |
| for score_set, table in ( |
| ("delta", delta_table), |
| ("reference", reference_table), |
| ("alternative", alternative_table), |
| ): |
| if not table.empty: |
| csv_table = table.copy() |
| csv_table.insert(0, "score_set", score_set) |
| csv_tables.append(csv_table) |
| csv_payload = pd.concat(csv_tables, ignore_index=True, sort=False) if csv_tables else pd.DataFrame() |
|
|
| csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="") |
| csv_path = csv_file.name |
| csv_file.close() |
| csv_payload.to_csv(csv_path, index=False) |
|
|
| json_payload = { |
| "metadata": metadata, |
| "result": _json_safe(result), |
| "tables": { |
| "delta": dataframe_records(delta_table), |
| "reference": dataframe_records(reference_table), |
| "alternative": dataframe_records(alternative_table), |
| }, |
| } |
| json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) |
| json_path = json_file.name |
| json_file.close() |
| with open(json_path, "w") as handle: |
| json.dump(json_payload, handle, indent=2) |
|
|
| return csv_path, json_path |
|
|
|
|
| def unpack_prediction_result(result: Any) -> Mapping[str, Any]: |
| result = _json_safe(result) |
| if isinstance(result, list): |
| if len(result) != 1: |
| raise gr.Error(f"Expected one prediction result, got {len(result)}.") |
| result = result[0] |
| if not isinstance(result, Mapping): |
| raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.") |
| return result |
|
|
|
|
| def predict(model_label: str, reference: str, alternative: str): |
| started = time.perf_counter() |
| model_id = MODEL_OPTIONS[model_label] |
| reference, alternative = validate_pair(reference, alternative) |
|
|
| try: |
| predictor = load_predictor(model_id) |
| result = unpack_prediction_result(predictor(reference, alternative=alternative)) |
| except gr.Error: |
| raise |
| except Exception as exc: |
| raise gr.Error(f"Prediction failed for {model_label}: {exc}") from exc |
|
|
| delta_table = result_table(result, "delta_score", "delta_scores", "delta_score") |
| reference_table = result_table(result, "reference_score", "reference_scores", "reference_score") |
| alternative_table = result_table(result, "alternative_score", "alternative_scores", "alternative_score") |
| metadata = { |
| "task": "splice-variant-effect", |
| "model": model_id, |
| "model_label": model_label, |
| "device": "cuda" if torch.cuda.is_available() else "cpu", |
| "reference_length": len(reference), |
| "alternative_length": len(alternative), |
| "differences": difference_summary(reference, alternative), |
| "channels": result.get("channels", []), |
| "output_fields": sorted(result.keys()), |
| "runtime_seconds": round(time.perf_counter() - started, 3), |
| "timestamp_utc": datetime.now(timezone.utc).isoformat(), |
| } |
| csv_path, json_path = write_result_files(metadata, result, delta_table, reference_table, alternative_table) |
| delta_plot = make_delta_plot(delta_table, model_label) |
| return delta_table, reference_table, alternative_table, metadata, delta_plot, csv_path, json_path |
|
|
|
|
| def initial_model(request: gr.Request): |
| if request is None: |
| return "MMSplice" |
|
|
| query_params = getattr(request, "query_params", None) |
| model_id = None |
| if query_params is not None: |
| model_id = query_params.get("model") |
| if not model_id and getattr(request, "url", None): |
| parsed = parse_qs(urlparse(str(request.url)).query) |
| model_values = parsed.get("model") |
| model_id = model_values[0] if model_values else None |
|
|
| return MODEL_LABELS.get(model_id, "MMSplice") |
|
|
|
|
| with gr.Blocks(title="Splice Variant Effect") as demo: |
| gr.Markdown( |
| "# Splice Variant Effect\n" |
| "Score paired reference and alternative DNA windows with MultiMolecule splice variant-effect models." |
| ) |
|
|
| model = gr.Dropdown( |
| choices=list(MODEL_OPTIONS.keys()), |
| value="MMSplice", |
| label="Checkpoint", |
| ) |
|
|
| with gr.Row(): |
| reference = gr.Textbox( |
| label="Reference DNA sequence", |
| value=DEFAULT_REFERENCE, |
| lines=5, |
| ) |
| alternative = gr.Textbox( |
| label="Alternative DNA sequence", |
| value=DEFAULT_ALTERNATIVE, |
| lines=5, |
| ) |
|
|
| input_file = gr.File( |
| label="Upload paired FASTA (reference record first, alternative record second)", |
| file_types=[".fa", ".fasta", ".fna"], |
| ) |
| run = gr.Button("Run variant effect", variant="primary") |
|
|
| with gr.Row(): |
| delta_scores = gr.Dataframe(label="Delta scores") |
| run_metadata = gr.JSON(label="Run metadata") |
|
|
| with gr.Row(): |
| reference_scores = gr.Dataframe(label="Reference scores") |
| alternative_scores = gr.Dataframe(label="Alternative scores") |
|
|
| delta_plot = gr.Plot(label="Top delta scores") |
|
|
| with gr.Row(): |
| csv_download = gr.File(label="Download CSV") |
| json_download = gr.File(label="Download JSON") |
|
|
| run.click( |
| predict, |
| inputs=[model, reference, alternative], |
| outputs=[ |
| delta_scores, |
| reference_scores, |
| alternative_scores, |
| run_metadata, |
| delta_plot, |
| csv_download, |
| json_download, |
| ], |
| ) |
| input_file.change(load_fasta_pair, inputs=input_file, outputs=[reference, alternative]) |
| demo.load(initial_model, outputs=model) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|