| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| from __future__ import annotations |
|
|
| import csv |
| import json |
| import re |
| import tempfile |
| import time |
| from functools import lru_cache |
| from typing import Any, Mapping |
| from urllib.parse import parse_qs, urlparse |
|
|
| import gradio as gr |
| import matplotlib |
| import numpy as np |
| import torch |
| from transformers import pipeline |
|
|
| matplotlib.use("Agg") |
|
|
| import matplotlib.pyplot as plt |
| import multimolecule |
|
|
| DEFAULT_REFERENCE_SEQUENCE = "ACGT" * 250 |
| DEFAULT_ALTERNATIVE_SEQUENCE = "ACGT" * 125 + "TCGA" + "ACGT" * 124 |
| DEFAULT_MODEL_LABEL = "DeepSEA" |
|
|
| MODEL_OPTIONS = { |
| "A2Z Chromatin": "multimolecule/a2zchromatin", |
| "Basset": "multimolecule/basset", |
| "DeepMEL": "multimolecule/deepmel", |
| "DeepSEA": "multimolecule/deepsea", |
| "DeepSTARR": "multimolecule/deepstarr", |
| "Malinois": "multimolecule/malinois", |
| "MPRA-DragoNN": "multimolecule/mpradragonn", |
| "scBasset": "multimolecule/scbasset", |
| "Xpresso": "multimolecule/xpresso", |
| } |
| MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()} |
|
|
| TABLE_HEADERS = ["position", "nucleotide", "channel", "delta_score", "reference_score", "alternative_score"] |
| DNA_ALPHABET = set("ACGTN") |
| FLOAT_PATTERN = re.compile(r"[-+]?(?:(?:\d*\.\d+)|(?:\d+\.?))(?:[eE][-+]?\d+)?") |
|
|
|
|
| def _device() -> int: |
| return 0 if torch.cuda.is_available() else -1 |
|
|
|
|
| @lru_cache(maxsize=2) |
| def load_predictor(model_id: str): |
| return pipeline("regulatory-variant-effect", model=model_id, device=_device()) |
|
|
|
|
| def clean_sequence(sequence: str, label: str) -> str: |
| sequence = "".join(str(sequence or "").split()).upper().replace("U", "T") |
| if not sequence: |
| raise gr.Error(f"{label} sequence is empty.") |
| invalid = sorted(set(sequence) - DNA_ALPHABET) |
| if invalid: |
| invalid_text = ", ".join(invalid) |
| raise gr.Error(f"{label} sequence contains unsupported symbols: {invalid_text}. Use A, C, G, T, or N.") |
| return sequence |
|
|
|
|
| def parse_features(features_text: str) -> Any | None: |
| text = str(features_text or "").strip() |
| if not text: |
| return None |
|
|
| try: |
| parsed = json.loads(text) |
| except json.JSONDecodeError: |
| values = FLOAT_PATTERN.findall(text) |
| if not values: |
| raise gr.Error("Features must be JSON or comma/space-separated numbers.") |
| return [float(value) for value in values] |
|
|
| if isinstance(parsed, Mapping): |
| for key in ("features", "values", "reference_features", "alternative_features"): |
| if key in parsed: |
| return parsed[key] |
| if all(isinstance(value, int | float) for value in parsed.values()): |
| return list(parsed.values()) |
| raise gr.Error("Feature JSON objects must contain a features/values list or only numeric values.") |
| if isinstance(parsed, str): |
| return parse_features(parsed) |
| return parsed |
|
|
|
|
| def feature_summary(features: Any | None) -> dict[str, Any]: |
| if features is None: |
| return {"provided": False} |
| try: |
| array = np.asarray(features, dtype=float) |
| except (TypeError, ValueError): |
| return {"provided": True, "shape": None} |
| return {"provided": True, "shape": list(array.shape)} |
|
|
|
|
| def unpack_prediction_result(result: Any) -> dict[str, Any]: |
| if isinstance(result, list): |
| if len(result) != 1: |
| raise gr.Error(f"Expected one prediction result, got {len(result)}.") |
| result = result[0] |
| if not isinstance(result, dict): |
| raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.") |
| return result |
|
|
|
|
| def build_delta_rows(result: Mapping[str, Any]) -> list[dict[str, Any]]: |
| if "delta_score" in result: |
| return [ |
| { |
| "position": "", |
| "nucleotide": "", |
| "channel": "score", |
| "delta_score": result.get("delta_score"), |
| "reference_score": result.get("reference_score", ""), |
| "alternative_score": result.get("alternative_score", ""), |
| } |
| ] |
|
|
| delta_scores = result.get("delta_scores") |
| if isinstance(delta_scores, Mapping): |
| reference_scores = result.get("reference_scores") if isinstance(result.get("reference_scores"), Mapping) else {} |
| alternative_scores = ( |
| result.get("alternative_scores") if isinstance(result.get("alternative_scores"), Mapping) else {} |
| ) |
| return [ |
| { |
| "position": "", |
| "nucleotide": "", |
| "channel": str(channel), |
| "delta_score": value, |
| "reference_score": reference_scores.get(channel, ""), |
| "alternative_score": alternative_scores.get(channel, ""), |
| } |
| for channel, value in delta_scores.items() |
| ] |
|
|
| if isinstance(delta_scores, list): |
| return build_axis_delta_rows(result, delta_scores) |
|
|
| raise gr.Error("The selected model did not return delta scores.") |
|
|
|
|
| def build_axis_delta_rows(result: Mapping[str, Any], delta_scores: list[Any]) -> list[dict[str, Any]]: |
| channels = [str(channel) for channel in result.get("channels", [])] |
| reference_scores = _index_axis_rows(result.get("reference_scores")) |
| alternative_scores = _index_axis_rows(result.get("alternative_scores")) |
| output_rows: list[dict[str, Any]] = [] |
|
|
| for row_index, row in enumerate(delta_scores): |
| if not isinstance(row, Mapping): |
| continue |
| position = row.get("position", row.get("bin", row_index)) |
| channel_names = channels or [ |
| str(key) for key in row if key not in {"position", "bin", "nucleotide"} and _is_number(row[key]) |
| ] |
| ref_row = reference_scores.get(position, {}) |
| alt_row = alternative_scores.get(position, {}) |
| for channel in channel_names: |
| if channel not in row: |
| continue |
| output_rows.append( |
| { |
| "position": position, |
| "nucleotide": row.get("nucleotide", ""), |
| "channel": channel, |
| "delta_score": row[channel], |
| "reference_score": ref_row.get(channel, ""), |
| "alternative_score": alt_row.get(channel, ""), |
| } |
| ) |
| return output_rows |
|
|
|
|
| def _index_axis_rows(rows: Any) -> dict[Any, Mapping[str, Any]]: |
| if not isinstance(rows, list): |
| return {} |
| indexed = {} |
| for row_index, row in enumerate(rows): |
| if isinstance(row, Mapping): |
| indexed[row.get("position", row.get("bin", row_index))] = row |
| return indexed |
|
|
|
|
| def _is_number(value: Any) -> bool: |
| return isinstance(value, int | float | np.number) |
|
|
|
|
| def table_values(rows: list[Mapping[str, Any]]) -> list[list[Any]]: |
| return [[row.get(header, "") for header in TABLE_HEADERS] for row in rows] |
|
|
|
|
| def plot_delta_rows(rows: list[Mapping[str, Any]], max_bars: int = 24): |
| numeric_rows = [row for row in rows if _is_number(row.get("delta_score"))] |
| fig, ax = plt.subplots(figsize=(7.0, 2.4)) |
| if not numeric_rows: |
| ax.text(0.5, 0.5, "No numeric delta scores", ha="center", va="center", transform=ax.transAxes) |
| ax.set_axis_off() |
| fig.tight_layout() |
| return fig |
|
|
| top_rows = sorted(numeric_rows, key=lambda row: abs(float(row["delta_score"])), reverse=True)[:max_bars] |
| labels = [_row_label(row) for row in top_rows] |
| values = [float(row["delta_score"]) for row in top_rows] |
| colors = ["#1b9e77" if value >= 0 else "#d95f02" for value in values] |
|
|
| height = min(7.0, max(2.4, 0.28 * len(top_rows) + 1.2)) |
| fig.set_size_inches(7.0, height, forward=True) |
| ax.barh(range(len(top_rows)), values, color=colors) |
| ax.axvline(0, color="#333333", linewidth=0.8) |
| ax.set_yticks(range(len(top_rows)), labels) |
| ax.invert_yaxis() |
| ax.set_xlabel("Alternative - reference") |
| ax.set_title("Largest absolute delta scores") |
| ax.tick_params(axis="y", labelsize=8) |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def _row_label(row: Mapping[str, Any]) -> str: |
| channel = str(row.get("channel", "score")) |
| position = row.get("position") |
| if position not in ("", None): |
| nucleotide = row.get("nucleotide") |
| suffix = f" {nucleotide}" if nucleotide not in ("", None) else "" |
| return f"{position}{suffix} {channel}" |
| return channel |
|
|
|
|
| def write_result_files( |
| model_id: str, |
| result: Mapping[str, Any], |
| rows: list[Mapping[str, Any]], |
| metadata: Mapping[str, Any], |
| ) -> tuple[str, str]: |
| csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", newline="", delete=False) |
| writer = csv.DictWriter(csv_file, fieldnames=TABLE_HEADERS) |
| writer.writeheader() |
| writer.writerows({header: row.get(header, "") for header in TABLE_HEADERS} for row in rows) |
| csv_file.close() |
|
|
| json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) |
| json.dump( |
| { |
| "metadata": dict(metadata), |
| "model": model_id, |
| "result": result, |
| "delta_table": [{header: row.get(header, "") for header in TABLE_HEADERS} for row in rows], |
| }, |
| json_file, |
| indent=2, |
| default=_json_default, |
| ) |
| json_file.close() |
| return csv_file.name, json_file.name |
|
|
|
|
| def _json_default(value: Any): |
| if isinstance(value, np.generic): |
| return value.item() |
| if isinstance(value, np.ndarray): |
| return value.tolist() |
| raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable") |
|
|
|
|
| def predict( |
| model_label: str, |
| reference_sequence: str, |
| alternative_sequence: str, |
| reference_features_text: str, |
| alternative_features_text: str, |
| ): |
| model_id = MODEL_OPTIONS[model_label] |
| reference_sequence = clean_sequence(reference_sequence, "Reference") |
| alternative_sequence = clean_sequence(alternative_sequence, "Alternative") |
| if len(reference_sequence) != len(alternative_sequence): |
| raise gr.Error( |
| f"Reference and alternative sequences must have the same length. " |
| f"Got {len(reference_sequence)} and {len(alternative_sequence)}." |
| ) |
|
|
| reference_features = parse_features(reference_features_text) |
| alternative_features = parse_features(alternative_features_text) |
| started = time.perf_counter() |
|
|
| predictor = load_predictor(model_id) |
| try: |
| result = predictor( |
| reference_sequence, |
| alternative=alternative_sequence, |
| features=reference_features, |
| alternative_features=alternative_features, |
| ) |
| except Exception as error: |
| raise gr.Error(f"Prediction failed for {model_id}: {error}") from error |
|
|
| result = unpack_prediction_result(result) |
| rows = build_delta_rows(result) |
| if not rows: |
| raise gr.Error("The selected model returned no tabular delta scores.") |
|
|
| metadata = { |
| "task": "regulatory-variant-effect", |
| "model": model_id, |
| "device": "cuda" if torch.cuda.is_available() else "cpu", |
| "reference_length": len(reference_sequence), |
| "alternative_length": len(alternative_sequence), |
| "reference_features": feature_summary(reference_features), |
| "alternative_features": feature_summary(alternative_features), |
| "alternative_features_inherit_reference": alternative_features is None and reference_features is not None, |
| "score_definition": "alternative_minus_reference", |
| "num_delta_rows": len(rows), |
| "has_reference_scores": any(row.get("reference_score") not in ("", None) for row in rows), |
| "has_alternative_scores": any(row.get("alternative_score") not in ("", None) for row in rows), |
| "elapsed_seconds": round(time.perf_counter() - started, 3), |
| } |
| csv_path, json_path = write_result_files(model_id, result, rows, metadata) |
|
|
| return ( |
| table_values(rows), |
| metadata, |
| plot_delta_rows(rows), |
| csv_path, |
| json_path, |
| ) |
|
|
|
|
| def initial_model(request: gr.Request): |
| if request is None: |
| return DEFAULT_MODEL_LABEL |
|
|
| query_params = getattr(request, "query_params", None) |
| model_id = None |
| if query_params is not None: |
| model_id = query_params.get("model") |
| if not model_id and getattr(request, "url", None): |
| parsed = parse_qs(urlparse(str(request.url)).query) |
| model_values = parsed.get("model") |
| model_id = model_values[0] if model_values else None |
|
|
| return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL) |
|
|
|
|
| with gr.Blocks(title="Regulatory Variant Effect") as demo: |
| gr.Markdown( |
| "# Regulatory Variant Effect\n" |
| "Score matched reference and alternative DNA windows with MultiMolecule regulatory variant-effect models." |
| ) |
|
|
| model = gr.Dropdown( |
| choices=list(MODEL_OPTIONS.keys()), |
| value=DEFAULT_MODEL_LABEL, |
| label="Checkpoint", |
| ) |
|
|
| with gr.Row(): |
| reference_sequence = gr.Textbox(label="Reference DNA sequence", value=DEFAULT_REFERENCE_SEQUENCE, lines=5) |
| alternative_sequence = gr.Textbox(label="Alternative DNA sequence", value=DEFAULT_ALTERNATIVE_SEQUENCE, lines=5) |
|
|
| with gr.Accordion("Optional numeric features", open=False), gr.Row(): |
| reference_features = gr.Textbox( |
| label="Reference features JSON/text", |
| placeholder='[0.1, 0.2, 0.3] or {"features": [0.1, 0.2, 0.3]}', |
| lines=3, |
| ) |
| alternative_features = gr.Textbox( |
| label="Alternative features JSON/text", |
| placeholder="Leave blank to reuse reference features when provided.", |
| lines=3, |
| ) |
|
|
| run = gr.Button("Run prediction", variant="primary") |
|
|
| delta_table = gr.Dataframe(headers=TABLE_HEADERS, label="Delta scores", interactive=False, wrap=True) |
| with gr.Row(): |
| metadata = gr.JSON(label="Run metadata") |
| delta_plot = gr.Plot(label="Delta plot") |
|
|
| with gr.Row(): |
| csv_download = gr.File(label="Download CSV") |
| json_download = gr.File(label="Download JSON") |
|
|
| run.click( |
| predict, |
| inputs=[model, reference_sequence, alternative_sequence, reference_features, alternative_features], |
| outputs=[delta_table, metadata, delta_plot, csv_download, json_download], |
| ) |
| demo.load(initial_model, outputs=model) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|