# MultiMolecule # Copyright (C) 2024-Present MultiMolecule # This file is part of MultiMolecule. # MultiMolecule is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # any later version. # MultiMolecule is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # For additional terms and clarifications, please refer to our License FAQ at: # . from __future__ import annotations import csv import json import re import tempfile from datetime import datetime, timezone from functools import lru_cache from typing import Any, Mapping from urllib.parse import parse_qs, urlparse import gradio as gr import matplotlib import numpy as np import torch from transformers import pipeline matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa: E402 import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers MODEL_OPTIONS = { "A2Z Chromatin": "multimolecule/a2zchromatin", "Basset": "multimolecule/basset", "DeepMEL": "multimolecule/deepmel", "DeepSEA": "multimolecule/deepsea", "DeepSTARR": "multimolecule/deepstarr", "Malinois": "multimolecule/malinois", "MPRA-DragoNN": "multimolecule/mpradragonn", "scBasset": "multimolecule/scbasset", "Xpresso": "multimolecule/xpresso", } MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()} DEFAULT_MODEL_LABEL = "DeepSEA" DEFAULT_SEQUENCE = "ACGT" * 150 DNA_ALPHABET = set("ACGTN") def _device() -> int: return 0 if torch.cuda.is_available() else -1 def _device_label() -> str: return "cuda" if torch.cuda.is_available() else "cpu" @lru_cache(maxsize=2) def load_predictor(model_id: str): return pipeline("regulatory-activity", model=model_id, device=_device()) def clean_sequence(sequence: str) -> str: lines = [] for line in (sequence or "").splitlines(): line = line.strip() if line and not line.startswith(">"): lines.append(line) sequence = re.sub(r"\s+", "", "".join(lines)).upper().replace("U", "T") if not sequence: raise gr.Error("Sequence is empty.") invalid = sorted(set(sequence) - DNA_ALPHABET) if invalid: raise gr.Error(f"DNA sequence contains unsupported characters: {', '.join(invalid)}.") return sequence def parse_features(features_text: str | None) -> tuple[Any | None, list[int] | None]: text = (features_text or "").strip() if not text: return None, None try: features = json.loads(text) except json.JSONDecodeError: tokens = [token for token in re.split(r"[\s,;]+", text) if token] try: features = [float(token) for token in tokens] except ValueError as error: raise gr.Error( "Auxiliary features must be a JSON numeric value/list or comma-separated numbers." ) from error else: if isinstance(features, Mapping): if "features" not in features: raise gr.Error('JSON object features must use a "features" key, for example {"features": [0, 0]}.') features = features["features"] try: array = np.asarray(features, dtype=np.float32) except (TypeError, ValueError) as error: raise gr.Error("Auxiliary features must contain only numeric values.") from error if array.size == 0: raise gr.Error("Auxiliary features are empty.") if array.ndim > 2: raise gr.Error("Auxiliary features must be a number, a 1-D list, or a 2-D batch-sized list.") if not np.isfinite(array).all(): raise gr.Error("Auxiliary features must be finite numbers.") return array.tolist(), list(array.shape) def unpack_prediction_result(result: Any) -> dict[str, Any]: if isinstance(result, list): if len(result) != 1: raise gr.Error(f"Expected one prediction result, got {len(result)}.") result = result[0] if not isinstance(result, dict): raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.") return result def score_rows_from_result(result: Mapping[str, Any]) -> list[list[Any]]: channels = [str(channel) for channel in result.get("channels", [])] if "score" in result: return rows_from_values(result["score"], channels or ["score"]) if "scores" in result: scores = result["scores"] if isinstance(scores, Mapping): return [[str(channel), number_value(score)] for channel, score in scores.items()] if isinstance(scores, list): return rows_from_score_list(scores, channels) raise gr.Error("The selected model did not return sequence-level score output.") def rows_from_values(values: Any, channels: list[str]) -> list[list[Any]]: if isinstance(values, (list, tuple)): if len(channels) != len(values): channels = [f"score_{index}" for index in range(len(values))] return [[channel, number_value(value)] for channel, value in zip(channels, values)] return [[channels[0] if channels else "score", number_value(values)]] def rows_from_score_list(scores: list[Any], channels: list[str]) -> list[list[Any]]: if scores and all(isinstance(score, (int, float)) for score in scores): return rows_from_values(scores, channels) rows: list[list[Any]] = [] for index, item in enumerate(scores): if not isinstance(item, Mapping): rows.append([f"score_{index}", number_value(item)]) continue prefix_parts = [] for key in ("position", "bin", "nucleotide"): if key in item: prefix_parts.append(f"{key}={item[key]}") prefix = " ".join(prefix_parts) for key, value in item.items(): if key in {"position", "bin", "nucleotide"}: continue label = str(key) if not prefix else f"{prefix} {key}" rows.append([label, number_value(value)]) return rows def number_value(value: Any) -> float: try: number = float(value) except (TypeError, ValueError) as error: raise gr.Error(f"Score value {value!r} is not numeric.") from error if not np.isfinite(number): raise gr.Error(f"Score value {value!r} is not finite.") return number def plot_scores(rows: list[list[Any]], top_n: int | float) -> Any: top_n = max(1, int(top_n or 20)) values = [(str(channel), float(score)) for channel, score in rows] values = sorted(values, key=lambda item: abs(item[1]), reverse=True)[:top_n] height = max(3.0, min(12.0, 1.2 + 0.36 * len(values))) fig, ax = plt.subplots(figsize=(9, height)) if not values: ax.set_axis_off() return fig labels = [label if len(label) <= 54 else f"{label[:51]}..." for label, _ in values] scores = [score for _, score in values] y_positions = np.arange(len(values)) colors = ["#2f6f9f" if score >= 0 else "#c75146" for score in scores] ax.barh(y_positions, scores, color=colors) ax.set_yticks(y_positions, labels) ax.invert_yaxis() ax.axvline(0, color="#555555", linewidth=0.8) ax.set_xlabel("Score") ax.set_title(f"Top {len(values)} channels by absolute score") ax.grid(axis="x", alpha=0.2) fig.tight_layout() return fig def write_result_files( metadata: Mapping[str, Any], result: Mapping[str, Any], rows: list[list[Any]], ) -> tuple[str, str]: csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="") writer = csv.writer(csv_file) writer.writerow(["channel", "score"]) writer.writerows(rows) csv_file.close() json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) json.dump( { "metadata": metadata, "scores": [{"channel": channel, "score": score} for channel, score in rows], "raw_result": result, }, json_file, indent=2, ) json_file.close() return csv_file.name, json_file.name def predict( model_label: str, sequence: str, features_text: str, top_n: int | float, ): model_id = MODEL_OPTIONS[model_label] sequence = clean_sequence(sequence) features, features_shape = parse_features(features_text) try: predictor = load_predictor(model_id) if features is None: result = predictor(sequence) else: result = predictor(sequence, features=features) except gr.Error: raise except Exception as error: raise gr.Error(str(error)) from error result = unpack_prediction_result(result) rows = score_rows_from_result(result) metadata = { "task": "regulatory-activity", "model": model_id, "model_label": model_label, "device": _device_label(), "sequence_length": len(sequence), "features_provided": features is not None, "features_shape": features_shape, "score_count": len(rows), "channels": result.get("channels", []), "created_at": datetime.now(timezone.utc).isoformat(), } figure = plot_scores(rows, top_n) csv_path, json_path = write_result_files(metadata, result, rows) return rows, metadata, figure, csv_path, json_path def initial_model(request: gr.Request): if request is None: return DEFAULT_MODEL_LABEL query_params = getattr(request, "query_params", None) model_id = None if query_params is not None: model_id = query_params.get("model") if not model_id and getattr(request, "url", None): parsed = parse_qs(urlparse(str(request.url)).query) model_values = parsed.get("model") model_id = model_values[0] if model_values else None if model_id in MODEL_OPTIONS: return model_id return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL) with gr.Blocks(title="Regulatory Activity") as demo: gr.Markdown( "# Regulatory Activity\n" "Run MultiMolecule sequence-level DNA regulatory checkpoints and inspect the returned activity scores." ) with gr.Row(): model = gr.Dropdown( choices=list(MODEL_OPTIONS.keys()), value=DEFAULT_MODEL_LABEL, label="Checkpoint", ) top_n = gr.Slider(1, 50, value=20, step=1, label="Bar count") sequence = gr.Textbox( label="DNA sequence", value=DEFAULT_SEQUENCE, lines=5, ) features = gr.Textbox( label="Auxiliary numeric features (optional)", placeholder='JSON list, {"features": [...]}, or comma-separated numbers', lines=2, ) run = gr.Button("Run prediction", variant="primary") with gr.Row(): scores = gr.Dataframe( headers=["channel", "score"], datatype=["str", "number"], label="Score table", ) metadata = gr.JSON(label="Run metadata") score_plot = gr.Plot(label="Score bar plot") with gr.Row(): csv_download = gr.File(label="Download CSV") json_download = gr.File(label="Download JSON") run.click( predict, inputs=[model, sequence, features, top_n], outputs=[scores, metadata, score_plot, csv_download, json_download], ) demo.load(initial_model, outputs=model) if __name__ == "__main__": demo.launch()