Spaces:
Sleeping
Sleeping
| # MultiMolecule | |
| # Copyright (C) 2024-Present MultiMolecule | |
| # This file is part of MultiMolecule. | |
| # MultiMolecule is free software: you can redistribute it and/or modify | |
| # it under the terms of the GNU Affero General Public License as published by | |
| # the Free Software Foundation, either version 3 of the License, or | |
| # any later version. | |
| # MultiMolecule is distributed in the hope that it will be useful, | |
| # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
| # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
| # GNU Affero General Public License for more details. | |
| # You should have received a copy of the GNU Affero General Public License | |
| # along with this program. If not, see <http://www.gnu.org/licenses/>. | |
| # For additional terms and clarifications, please refer to our License FAQ at: | |
| # <https://multimolecule.danling.org/about/license-faq>. | |
| from __future__ import annotations | |
| import csv | |
| import json | |
| import re | |
| import tempfile | |
| from datetime import datetime, timezone | |
| from functools import lru_cache | |
| from typing import Any, Mapping | |
| from urllib.parse import parse_qs, urlparse | |
| import gradio as gr | |
| import matplotlib | |
| import numpy as np | |
| import torch | |
| from transformers import pipeline | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt # noqa: E402 | |
| import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers | |
| MODEL_OPTIONS = { | |
| "A2Z Chromatin": "multimolecule/a2zchromatin", | |
| "Basset": "multimolecule/basset", | |
| "DeepMEL": "multimolecule/deepmel", | |
| "DeepSEA": "multimolecule/deepsea", | |
| "DeepSTARR": "multimolecule/deepstarr", | |
| "Malinois": "multimolecule/malinois", | |
| "MPRA-DragoNN": "multimolecule/mpradragonn", | |
| "scBasset": "multimolecule/scbasset", | |
| "Xpresso": "multimolecule/xpresso", | |
| } | |
| MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()} | |
| DEFAULT_MODEL_LABEL = "DeepSEA" | |
| DEFAULT_SEQUENCE = "ACGT" * 150 | |
| DNA_ALPHABET = set("ACGTN") | |
| def _device() -> int: | |
| return 0 if torch.cuda.is_available() else -1 | |
| def _device_label() -> str: | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_predictor(model_id: str): | |
| return pipeline("regulatory-activity", model=model_id, device=_device()) | |
| def clean_sequence(sequence: str) -> str: | |
| lines = [] | |
| for line in (sequence or "").splitlines(): | |
| line = line.strip() | |
| if line and not line.startswith(">"): | |
| lines.append(line) | |
| sequence = re.sub(r"\s+", "", "".join(lines)).upper().replace("U", "T") | |
| if not sequence: | |
| raise gr.Error("Sequence is empty.") | |
| invalid = sorted(set(sequence) - DNA_ALPHABET) | |
| if invalid: | |
| raise gr.Error(f"DNA sequence contains unsupported characters: {', '.join(invalid)}.") | |
| return sequence | |
| def parse_features(features_text: str | None) -> tuple[Any | None, list[int] | None]: | |
| text = (features_text or "").strip() | |
| if not text: | |
| return None, None | |
| try: | |
| features = json.loads(text) | |
| except json.JSONDecodeError: | |
| tokens = [token for token in re.split(r"[\s,;]+", text) if token] | |
| try: | |
| features = [float(token) for token in tokens] | |
| except ValueError as error: | |
| raise gr.Error( | |
| "Auxiliary features must be a JSON numeric value/list or comma-separated numbers." | |
| ) from error | |
| else: | |
| if isinstance(features, Mapping): | |
| if "features" not in features: | |
| raise gr.Error('JSON object features must use a "features" key, for example {"features": [0, 0]}.') | |
| features = features["features"] | |
| try: | |
| array = np.asarray(features, dtype=np.float32) | |
| except (TypeError, ValueError) as error: | |
| raise gr.Error("Auxiliary features must contain only numeric values.") from error | |
| if array.size == 0: | |
| raise gr.Error("Auxiliary features are empty.") | |
| if array.ndim > 2: | |
| raise gr.Error("Auxiliary features must be a number, a 1-D list, or a 2-D batch-sized list.") | |
| if not np.isfinite(array).all(): | |
| raise gr.Error("Auxiliary features must be finite numbers.") | |
| return array.tolist(), list(array.shape) | |
| def unpack_prediction_result(result: Any) -> dict[str, Any]: | |
| if isinstance(result, list): | |
| if len(result) != 1: | |
| raise gr.Error(f"Expected one prediction result, got {len(result)}.") | |
| result = result[0] | |
| if not isinstance(result, dict): | |
| raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.") | |
| return result | |
| def score_rows_from_result(result: Mapping[str, Any]) -> list[list[Any]]: | |
| channels = [str(channel) for channel in result.get("channels", [])] | |
| if "score" in result: | |
| return rows_from_values(result["score"], channels or ["score"]) | |
| if "scores" in result: | |
| scores = result["scores"] | |
| if isinstance(scores, Mapping): | |
| return [[str(channel), number_value(score)] for channel, score in scores.items()] | |
| if isinstance(scores, list): | |
| return rows_from_score_list(scores, channels) | |
| raise gr.Error("The selected model did not return sequence-level score output.") | |
| def rows_from_values(values: Any, channels: list[str]) -> list[list[Any]]: | |
| if isinstance(values, (list, tuple)): | |
| if len(channels) != len(values): | |
| channels = [f"score_{index}" for index in range(len(values))] | |
| return [[channel, number_value(value)] for channel, value in zip(channels, values)] | |
| return [[channels[0] if channels else "score", number_value(values)]] | |
| def rows_from_score_list(scores: list[Any], channels: list[str]) -> list[list[Any]]: | |
| if scores and all(isinstance(score, (int, float)) for score in scores): | |
| return rows_from_values(scores, channels) | |
| rows: list[list[Any]] = [] | |
| for index, item in enumerate(scores): | |
| if not isinstance(item, Mapping): | |
| rows.append([f"score_{index}", number_value(item)]) | |
| continue | |
| prefix_parts = [] | |
| for key in ("position", "bin", "nucleotide"): | |
| if key in item: | |
| prefix_parts.append(f"{key}={item[key]}") | |
| prefix = " ".join(prefix_parts) | |
| for key, value in item.items(): | |
| if key in {"position", "bin", "nucleotide"}: | |
| continue | |
| label = str(key) if not prefix else f"{prefix} {key}" | |
| rows.append([label, number_value(value)]) | |
| return rows | |
| def number_value(value: Any) -> float: | |
| try: | |
| number = float(value) | |
| except (TypeError, ValueError) as error: | |
| raise gr.Error(f"Score value {value!r} is not numeric.") from error | |
| if not np.isfinite(number): | |
| raise gr.Error(f"Score value {value!r} is not finite.") | |
| return number | |
| def plot_scores(rows: list[list[Any]], top_n: int | float) -> Any: | |
| top_n = max(1, int(top_n or 20)) | |
| values = [(str(channel), float(score)) for channel, score in rows] | |
| values = sorted(values, key=lambda item: abs(item[1]), reverse=True)[:top_n] | |
| height = max(3.0, min(12.0, 1.2 + 0.36 * len(values))) | |
| fig, ax = plt.subplots(figsize=(9, height)) | |
| if not values: | |
| ax.set_axis_off() | |
| return fig | |
| labels = [label if len(label) <= 54 else f"{label[:51]}..." for label, _ in values] | |
| scores = [score for _, score in values] | |
| y_positions = np.arange(len(values)) | |
| colors = ["#2f6f9f" if score >= 0 else "#c75146" for score in scores] | |
| ax.barh(y_positions, scores, color=colors) | |
| ax.set_yticks(y_positions, labels) | |
| ax.invert_yaxis() | |
| ax.axvline(0, color="#555555", linewidth=0.8) | |
| ax.set_xlabel("Score") | |
| ax.set_title(f"Top {len(values)} channels by absolute score") | |
| ax.grid(axis="x", alpha=0.2) | |
| fig.tight_layout() | |
| return fig | |
| def write_result_files( | |
| metadata: Mapping[str, Any], | |
| result: Mapping[str, Any], | |
| rows: list[list[Any]], | |
| ) -> tuple[str, str]: | |
| csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="") | |
| writer = csv.writer(csv_file) | |
| writer.writerow(["channel", "score"]) | |
| writer.writerows(rows) | |
| csv_file.close() | |
| json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) | |
| json.dump( | |
| { | |
| "metadata": metadata, | |
| "scores": [{"channel": channel, "score": score} for channel, score in rows], | |
| "raw_result": result, | |
| }, | |
| json_file, | |
| indent=2, | |
| ) | |
| json_file.close() | |
| return csv_file.name, json_file.name | |
| def predict( | |
| model_label: str, | |
| sequence: str, | |
| features_text: str, | |
| top_n: int | float, | |
| ): | |
| model_id = MODEL_OPTIONS[model_label] | |
| sequence = clean_sequence(sequence) | |
| features, features_shape = parse_features(features_text) | |
| try: | |
| predictor = load_predictor(model_id) | |
| if features is None: | |
| result = predictor(sequence) | |
| else: | |
| result = predictor(sequence, features=features) | |
| except gr.Error: | |
| raise | |
| except Exception as error: | |
| raise gr.Error(str(error)) from error | |
| result = unpack_prediction_result(result) | |
| rows = score_rows_from_result(result) | |
| metadata = { | |
| "task": "regulatory-activity", | |
| "model": model_id, | |
| "model_label": model_label, | |
| "device": _device_label(), | |
| "sequence_length": len(sequence), | |
| "features_provided": features is not None, | |
| "features_shape": features_shape, | |
| "score_count": len(rows), | |
| "channels": result.get("channels", []), | |
| "created_at": datetime.now(timezone.utc).isoformat(), | |
| } | |
| figure = plot_scores(rows, top_n) | |
| csv_path, json_path = write_result_files(metadata, result, rows) | |
| return rows, metadata, figure, csv_path, json_path | |
| def initial_model(request: gr.Request): | |
| if request is None: | |
| return DEFAULT_MODEL_LABEL | |
| query_params = getattr(request, "query_params", None) | |
| model_id = None | |
| if query_params is not None: | |
| model_id = query_params.get("model") | |
| if not model_id and getattr(request, "url", None): | |
| parsed = parse_qs(urlparse(str(request.url)).query) | |
| model_values = parsed.get("model") | |
| model_id = model_values[0] if model_values else None | |
| if model_id in MODEL_OPTIONS: | |
| return model_id | |
| return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL) | |
| with gr.Blocks(title="Regulatory Activity") as demo: | |
| gr.Markdown( | |
| "# Regulatory Activity\n" | |
| "Run MultiMolecule sequence-level DNA regulatory checkpoints and inspect the returned activity scores." | |
| ) | |
| with gr.Row(): | |
| model = gr.Dropdown( | |
| choices=list(MODEL_OPTIONS.keys()), | |
| value=DEFAULT_MODEL_LABEL, | |
| label="Checkpoint", | |
| ) | |
| top_n = gr.Slider(1, 50, value=20, step=1, label="Bar count") | |
| sequence = gr.Textbox( | |
| label="DNA sequence", | |
| value=DEFAULT_SEQUENCE, | |
| lines=5, | |
| ) | |
| features = gr.Textbox( | |
| label="Auxiliary numeric features (optional)", | |
| placeholder='JSON list, {"features": [...]}, or comma-separated numbers', | |
| lines=2, | |
| ) | |
| run = gr.Button("Run prediction", variant="primary") | |
| with gr.Row(): | |
| scores = gr.Dataframe( | |
| headers=["channel", "score"], | |
| datatype=["str", "number"], | |
| label="Score table", | |
| ) | |
| metadata = gr.JSON(label="Run metadata") | |
| score_plot = gr.Plot(label="Score bar plot") | |
| with gr.Row(): | |
| csv_download = gr.File(label="Download CSV") | |
| json_download = gr.File(label="Download JSON") | |
| run.click( | |
| predict, | |
| inputs=[model, sequence, features, top_n], | |
| outputs=[scores, metadata, score_plot, csv_download, json_download], | |
| ) | |
| demo.load(initial_model, outputs=model) | |
| if __name__ == "__main__": | |
| demo.launch() | |