# MultiMolecule # Copyright (C) 2024-Present MultiMolecule from __future__ import annotations import csv import json import re import tempfile import time from functools import lru_cache from typing import Any, Mapping from urllib.parse import parse_qs, urlparse import gradio as gr import matplotlib import numpy as np import torch from transformers import pipeline matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa: E402 import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers MODEL_OPTIONS = { "APARENT2": "multimolecule/aparent2", "APARENT": "multimolecule/aparent", } MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()} DEFAULT_MODEL_LABEL = "APARENT2" DEFAULT_SEQUENCE = "A" * 70 + "AATAAA" + "A" * 129 DNA_ALPHABET = set("ACGTN") TABLE_HEADERS = ["event", "position", "probability"] def _device() -> int: return 0 if torch.cuda.is_available() else -1 def _device_label() -> str: return "cuda" if torch.cuda.is_available() else "cpu" @lru_cache(maxsize=2) def load_predictor(model_id: str): return pipeline("polyadenylation", model=model_id, device=_device()) def clean_sequence(sequence: str) -> str: lines = [] for line in str(sequence or "").splitlines(): line = line.strip() if line and not line.startswith(">"): lines.append(line) sequence = re.sub(r"\s+", "", "".join(lines)).upper().replace("U", "T") if not sequence: raise gr.Error("Sequence is empty.") invalid = sorted(set(sequence) - DNA_ALPHABET) if invalid: raise gr.Error(f"DNA sequence contains unsupported characters: {', '.join(invalid)}.") return sequence def unpack_prediction_result(result: Any) -> dict[str, Any]: if isinstance(result, list): if len(result) != 1: raise gr.Error(f"Expected one prediction result, got {len(result)}.") result = result[0] if not isinstance(result, dict): raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.") return result def rows_from_result(result: Mapping[str, Any]) -> list[dict[str, Any]]: if isinstance(result.get("cleavage_distribution"), list): return [_cleavage_row(row) for row in result["cleavage_distribution"]] if "score" in result: return [ { "event": str(result.get("channel", "polyadenylation")), "position": "", "probability": number_value(result["score"]), } ] if isinstance(result.get("scores"), Mapping): return [ {"event": str(channel), "position": "", "probability": number_value(score)} for channel, score in result["scores"].items() ] raise gr.Error("The selected model did not return polyadenylation scores.") def _cleavage_row(row: Any) -> dict[str, Any]: if not isinstance(row, Mapping): raise gr.Error("Cleavage distribution rows must be dictionaries.") if "event" in row: return {"event": str(row["event"]), "position": "", "probability": number_value(row.get("probability"))} return { "event": "cleavage", "position": row.get("position", ""), "probability": number_value(row.get("probability")), } def number_value(value: Any) -> float: try: number = float(value) except (TypeError, ValueError) as error: raise gr.Error(f"Score value {value!r} is not numeric.") from error if not np.isfinite(number): raise gr.Error(f"Score value {value!r} is not finite.") return number def table_values(rows: list[Mapping[str, Any]]) -> list[list[Any]]: return [[row.get(header, "") for header in TABLE_HEADERS] for row in rows] def plot_polyadenylation(rows: list[Mapping[str, Any]]): position_rows = [ (int(row["position"]), float(row["probability"])) for row in rows if row.get("position") not in ("", None) and _is_number(row.get("probability")) ] no_cleavage = next((float(row["probability"]) for row in rows if row.get("event") == "no_cleavage"), None) fig, ax = plt.subplots(figsize=(8.0, 3.2)) if position_rows: position_rows.sort() positions = [position for position, _ in position_rows] probabilities = [probability for _, probability in position_rows] ax.plot(positions, probabilities, color="#2f6f9f", linewidth=1.8) ax.fill_between(positions, probabilities, color="#9dcbec", alpha=0.35) ax.set_xlabel("Position") ax.set_ylabel("Cleavage probability") if no_cleavage is not None: ax.text( 0.99, 0.95, f"no cleavage: {no_cleavage:.3f}", ha="right", va="top", transform=ax.transAxes, ) else: labels = [str(row.get("event", "score")) for row in rows] values = [float(row.get("probability", 0.0)) for row in rows] ax.barh(np.arange(len(values)), values, color="#2f6f9f") ax.set_yticks(np.arange(len(values)), labels) ax.invert_yaxis() ax.set_xlabel("Score") ax.grid(axis="y", alpha=0.2) fig.tight_layout() return fig def _is_number(value: Any) -> bool: return isinstance(value, int | float | np.number) def write_result_files( metadata: Mapping[str, Any], result: Mapping[str, Any], rows: list[Mapping[str, Any]], ) -> tuple[str, str]: csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="") writer = csv.DictWriter(csv_file, fieldnames=TABLE_HEADERS) writer.writeheader() writer.writerows({header: row.get(header, "") for header in TABLE_HEADERS} for row in rows) csv_file.close() json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) json.dump( { "metadata": dict(metadata), "rows": [{header: row.get(header, "") for header in TABLE_HEADERS} for row in rows], "raw_result": result, }, json_file, indent=2, ) json_file.close() return csv_file.name, json_file.name def predict(model_label: str, sequence: str): model_id = MODEL_OPTIONS[model_label] sequence = clean_sequence(sequence) started = time.perf_counter() try: result = load_predictor(model_id)(sequence) except gr.Error: raise except Exception as error: raise gr.Error(f"Prediction failed for {model_id}: {error}") from error result = unpack_prediction_result(result) rows = rows_from_result(result) metadata = { "task": "polyadenylation", "model": model_id, "model_label": model_label, "device": _device_label(), "sequence_length": len(sequence), "row_count": len(rows), "elapsed_seconds": round(time.perf_counter() - started, 3), } csv_path, json_path = write_result_files(metadata, result, rows) return table_values(rows), metadata, plot_polyadenylation(rows), csv_path, json_path def initial_model(request: gr.Request): if request is None: return DEFAULT_MODEL_LABEL query_params = getattr(request, "query_params", None) model_id = query_params.get("model") if query_params is not None else None if not model_id and getattr(request, "url", None): parsed = parse_qs(urlparse(str(request.url)).query) model_values = parsed.get("model") model_id = model_values[0] if model_values else None return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL) with gr.Blocks(title="Polyadenylation") as demo: gr.Markdown( "# Polyadenylation\n" "Run MultiMolecule polyadenylation checkpoints and inspect APA isoform or cleavage-position scores." ) model = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value=DEFAULT_MODEL_LABEL, label="Checkpoint") sequence = gr.Textbox(label="DNA sequence", value=DEFAULT_SEQUENCE, lines=5) run = gr.Button("Run prediction", variant="primary") with gr.Row(): table = gr.Dataframe(headers=TABLE_HEADERS, label="Polyadenylation scores", interactive=False) metadata = gr.JSON(label="Run metadata") plot = gr.Plot(label="Polyadenylation plot") with gr.Row(): csv_download = gr.File(label="Download CSV") json_download = gr.File(label="Download JSON") run.click(predict, inputs=[model, sequence], outputs=[table, metadata, plot, csv_download, json_download]) demo.load(initial_model, outputs=model) if __name__ == "__main__": demo.launch()