# MultiMolecule # Copyright (C) 2024-Present MultiMolecule # This file is part of MultiMolecule. # MultiMolecule is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # any later version. # MultiMolecule is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # For additional terms and clarifications, please refer to our License FAQ at: # . from __future__ import annotations import csv import json import re import tempfile import time from functools import lru_cache from typing import Any, Mapping from urllib.parse import parse_qs, urlparse import gradio as gr import matplotlib import numpy as np import torch from transformers import pipeline matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa: E402 import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers DEFAULT_REFERENCE_SEQUENCE = "ACGT" * 250 DEFAULT_ALTERNATIVE_SEQUENCE = "ACGT" * 125 + "TCGA" + "ACGT" * 124 DEFAULT_MODEL_LABEL = "DeepSEA" MODEL_OPTIONS = { "A2Z Chromatin": "multimolecule/a2zchromatin", "Basset": "multimolecule/basset", "DeepMEL": "multimolecule/deepmel", "DeepSEA": "multimolecule/deepsea", "DeepSTARR": "multimolecule/deepstarr", "Malinois": "multimolecule/malinois", "MPRA-DragoNN": "multimolecule/mpradragonn", "scBasset": "multimolecule/scbasset", "Xpresso": "multimolecule/xpresso", } MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()} TABLE_HEADERS = ["position", "nucleotide", "channel", "delta_score", "reference_score", "alternative_score"] DNA_ALPHABET = set("ACGTN") FLOAT_PATTERN = re.compile(r"[-+]?(?:(?:\d*\.\d+)|(?:\d+\.?))(?:[eE][-+]?\d+)?") def _device() -> int: return 0 if torch.cuda.is_available() else -1 @lru_cache(maxsize=2) def load_predictor(model_id: str): return pipeline("regulatory-variant-effect", model=model_id, device=_device()) def clean_sequence(sequence: str, label: str) -> str: sequence = "".join(str(sequence or "").split()).upper().replace("U", "T") if not sequence: raise gr.Error(f"{label} sequence is empty.") invalid = sorted(set(sequence) - DNA_ALPHABET) if invalid: invalid_text = ", ".join(invalid) raise gr.Error(f"{label} sequence contains unsupported symbols: {invalid_text}. Use A, C, G, T, or N.") return sequence def parse_features(features_text: str) -> Any | None: text = str(features_text or "").strip() if not text: return None try: parsed = json.loads(text) except json.JSONDecodeError: values = FLOAT_PATTERN.findall(text) if not values: raise gr.Error("Features must be JSON or comma/space-separated numbers.") return [float(value) for value in values] if isinstance(parsed, Mapping): for key in ("features", "values", "reference_features", "alternative_features"): if key in parsed: return parsed[key] if all(isinstance(value, int | float) for value in parsed.values()): return list(parsed.values()) raise gr.Error("Feature JSON objects must contain a features/values list or only numeric values.") if isinstance(parsed, str): return parse_features(parsed) return parsed def feature_summary(features: Any | None) -> dict[str, Any]: if features is None: return {"provided": False} try: array = np.asarray(features, dtype=float) except (TypeError, ValueError): return {"provided": True, "shape": None} return {"provided": True, "shape": list(array.shape)} def unpack_prediction_result(result: Any) -> dict[str, Any]: if isinstance(result, list): if len(result) != 1: raise gr.Error(f"Expected one prediction result, got {len(result)}.") result = result[0] if not isinstance(result, dict): raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.") return result def build_delta_rows(result: Mapping[str, Any]) -> list[dict[str, Any]]: if "delta_score" in result: return [ { "position": "", "nucleotide": "", "channel": "score", "delta_score": result.get("delta_score"), "reference_score": result.get("reference_score", ""), "alternative_score": result.get("alternative_score", ""), } ] delta_scores = result.get("delta_scores") if isinstance(delta_scores, Mapping): reference_scores = result.get("reference_scores") if isinstance(result.get("reference_scores"), Mapping) else {} alternative_scores = ( result.get("alternative_scores") if isinstance(result.get("alternative_scores"), Mapping) else {} ) return [ { "position": "", "nucleotide": "", "channel": str(channel), "delta_score": value, "reference_score": reference_scores.get(channel, ""), "alternative_score": alternative_scores.get(channel, ""), } for channel, value in delta_scores.items() ] if isinstance(delta_scores, list): return build_axis_delta_rows(result, delta_scores) raise gr.Error("The selected model did not return delta scores.") def build_axis_delta_rows(result: Mapping[str, Any], delta_scores: list[Any]) -> list[dict[str, Any]]: channels = [str(channel) for channel in result.get("channels", [])] reference_scores = _index_axis_rows(result.get("reference_scores")) alternative_scores = _index_axis_rows(result.get("alternative_scores")) output_rows: list[dict[str, Any]] = [] for row_index, row in enumerate(delta_scores): if not isinstance(row, Mapping): continue position = row.get("position", row.get("bin", row_index)) channel_names = channels or [ str(key) for key in row if key not in {"position", "bin", "nucleotide"} and _is_number(row[key]) ] ref_row = reference_scores.get(position, {}) alt_row = alternative_scores.get(position, {}) for channel in channel_names: if channel not in row: continue output_rows.append( { "position": position, "nucleotide": row.get("nucleotide", ""), "channel": channel, "delta_score": row[channel], "reference_score": ref_row.get(channel, ""), "alternative_score": alt_row.get(channel, ""), } ) return output_rows def _index_axis_rows(rows: Any) -> dict[Any, Mapping[str, Any]]: if not isinstance(rows, list): return {} indexed = {} for row_index, row in enumerate(rows): if isinstance(row, Mapping): indexed[row.get("position", row.get("bin", row_index))] = row return indexed def _is_number(value: Any) -> bool: return isinstance(value, int | float | np.number) def table_values(rows: list[Mapping[str, Any]]) -> list[list[Any]]: return [[row.get(header, "") for header in TABLE_HEADERS] for row in rows] def plot_delta_rows(rows: list[Mapping[str, Any]], max_bars: int = 24): numeric_rows = [row for row in rows if _is_number(row.get("delta_score"))] fig, ax = plt.subplots(figsize=(7.0, 2.4)) if not numeric_rows: ax.text(0.5, 0.5, "No numeric delta scores", ha="center", va="center", transform=ax.transAxes) ax.set_axis_off() fig.tight_layout() return fig top_rows = sorted(numeric_rows, key=lambda row: abs(float(row["delta_score"])), reverse=True)[:max_bars] labels = [_row_label(row) for row in top_rows] values = [float(row["delta_score"]) for row in top_rows] colors = ["#1b9e77" if value >= 0 else "#d95f02" for value in values] height = min(7.0, max(2.4, 0.28 * len(top_rows) + 1.2)) fig.set_size_inches(7.0, height, forward=True) ax.barh(range(len(top_rows)), values, color=colors) ax.axvline(0, color="#333333", linewidth=0.8) ax.set_yticks(range(len(top_rows)), labels) ax.invert_yaxis() ax.set_xlabel("Alternative - reference") ax.set_title("Largest absolute delta scores") ax.tick_params(axis="y", labelsize=8) fig.tight_layout() return fig def _row_label(row: Mapping[str, Any]) -> str: channel = str(row.get("channel", "score")) position = row.get("position") if position not in ("", None): nucleotide = row.get("nucleotide") suffix = f" {nucleotide}" if nucleotide not in ("", None) else "" return f"{position}{suffix} {channel}" return channel def write_result_files( model_id: str, result: Mapping[str, Any], rows: list[Mapping[str, Any]], metadata: Mapping[str, Any], ) -> tuple[str, str]: csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", newline="", delete=False) writer = csv.DictWriter(csv_file, fieldnames=TABLE_HEADERS) writer.writeheader() writer.writerows({header: row.get(header, "") for header in TABLE_HEADERS} for row in rows) csv_file.close() json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) json.dump( { "metadata": dict(metadata), "model": model_id, "result": result, "delta_table": [{header: row.get(header, "") for header in TABLE_HEADERS} for row in rows], }, json_file, indent=2, default=_json_default, ) json_file.close() return csv_file.name, json_file.name def _json_default(value: Any): if isinstance(value, np.generic): return value.item() if isinstance(value, np.ndarray): return value.tolist() raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable") def predict( model_label: str, reference_sequence: str, alternative_sequence: str, reference_features_text: str, alternative_features_text: str, ): model_id = MODEL_OPTIONS[model_label] reference_sequence = clean_sequence(reference_sequence, "Reference") alternative_sequence = clean_sequence(alternative_sequence, "Alternative") if len(reference_sequence) != len(alternative_sequence): raise gr.Error( f"Reference and alternative sequences must have the same length. " f"Got {len(reference_sequence)} and {len(alternative_sequence)}." ) reference_features = parse_features(reference_features_text) alternative_features = parse_features(alternative_features_text) started = time.perf_counter() predictor = load_predictor(model_id) try: result = predictor( reference_sequence, alternative=alternative_sequence, features=reference_features, alternative_features=alternative_features, ) except Exception as error: raise gr.Error(f"Prediction failed for {model_id}: {error}") from error result = unpack_prediction_result(result) rows = build_delta_rows(result) if not rows: raise gr.Error("The selected model returned no tabular delta scores.") metadata = { "task": "regulatory-variant-effect", "model": model_id, "device": "cuda" if torch.cuda.is_available() else "cpu", "reference_length": len(reference_sequence), "alternative_length": len(alternative_sequence), "reference_features": feature_summary(reference_features), "alternative_features": feature_summary(alternative_features), "alternative_features_inherit_reference": alternative_features is None and reference_features is not None, "score_definition": "alternative_minus_reference", "num_delta_rows": len(rows), "has_reference_scores": any(row.get("reference_score") not in ("", None) for row in rows), "has_alternative_scores": any(row.get("alternative_score") not in ("", None) for row in rows), "elapsed_seconds": round(time.perf_counter() - started, 3), } csv_path, json_path = write_result_files(model_id, result, rows, metadata) return ( table_values(rows), metadata, plot_delta_rows(rows), csv_path, json_path, ) def initial_model(request: gr.Request): if request is None: return DEFAULT_MODEL_LABEL query_params = getattr(request, "query_params", None) model_id = None if query_params is not None: model_id = query_params.get("model") if not model_id and getattr(request, "url", None): parsed = parse_qs(urlparse(str(request.url)).query) model_values = parsed.get("model") model_id = model_values[0] if model_values else None return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL) with gr.Blocks(title="Regulatory Variant Effect") as demo: gr.Markdown( "# Regulatory Variant Effect\n" "Score matched reference and alternative DNA windows with MultiMolecule regulatory variant-effect models." ) model = gr.Dropdown( choices=list(MODEL_OPTIONS.keys()), value=DEFAULT_MODEL_LABEL, label="Checkpoint", ) with gr.Row(): reference_sequence = gr.Textbox(label="Reference DNA sequence", value=DEFAULT_REFERENCE_SEQUENCE, lines=5) alternative_sequence = gr.Textbox(label="Alternative DNA sequence", value=DEFAULT_ALTERNATIVE_SEQUENCE, lines=5) with gr.Accordion("Optional numeric features", open=False), gr.Row(): reference_features = gr.Textbox( label="Reference features JSON/text", placeholder='[0.1, 0.2, 0.3] or {"features": [0.1, 0.2, 0.3]}', lines=3, ) alternative_features = gr.Textbox( label="Alternative features JSON/text", placeholder="Leave blank to reuse reference features when provided.", lines=3, ) run = gr.Button("Run prediction", variant="primary") delta_table = gr.Dataframe(headers=TABLE_HEADERS, label="Delta scores", interactive=False, wrap=True) with gr.Row(): metadata = gr.JSON(label="Run metadata") delta_plot = gr.Plot(label="Delta plot") with gr.Row(): csv_download = gr.File(label="Download CSV") json_download = gr.File(label="Download JSON") run.click( predict, inputs=[model, reference_sequence, alternative_sequence, reference_features, alternative_features], outputs=[delta_table, metadata, delta_plot, csv_download, json_download], ) demo.load(initial_model, outputs=model) if __name__ == "__main__": demo.launch()