| |
| |
|
|
| from __future__ import annotations |
|
|
| import csv |
| import json |
| import re |
| import tempfile |
| import time |
| from functools import lru_cache |
| from typing import Any, Mapping |
| from urllib.parse import parse_qs, urlparse |
|
|
| import gradio as gr |
| import matplotlib |
| import numpy as np |
| import torch |
| from transformers import pipeline |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import multimolecule |
|
|
| MODEL_OPTIONS = { |
| "APARENT2": "multimolecule/aparent2", |
| "APARENT": "multimolecule/aparent", |
| } |
| MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()} |
| DEFAULT_MODEL_LABEL = "APARENT2" |
| DEFAULT_SEQUENCE = "A" * 70 + "AATAAA" + "A" * 129 |
| DNA_ALPHABET = set("ACGTN") |
| TABLE_HEADERS = ["event", "position", "probability"] |
|
|
|
|
| def _device() -> int: |
| return 0 if torch.cuda.is_available() else -1 |
|
|
|
|
| def _device_label() -> str: |
| return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| @lru_cache(maxsize=2) |
| def load_predictor(model_id: str): |
| return pipeline("polyadenylation", model=model_id, device=_device()) |
|
|
|
|
| def clean_sequence(sequence: str) -> str: |
| lines = [] |
| for line in str(sequence or "").splitlines(): |
| line = line.strip() |
| if line and not line.startswith(">"): |
| lines.append(line) |
| sequence = re.sub(r"\s+", "", "".join(lines)).upper().replace("U", "T") |
| if not sequence: |
| raise gr.Error("Sequence is empty.") |
| invalid = sorted(set(sequence) - DNA_ALPHABET) |
| if invalid: |
| raise gr.Error(f"DNA sequence contains unsupported characters: {', '.join(invalid)}.") |
| return sequence |
|
|
|
|
| def unpack_prediction_result(result: Any) -> dict[str, Any]: |
| if isinstance(result, list): |
| if len(result) != 1: |
| raise gr.Error(f"Expected one prediction result, got {len(result)}.") |
| result = result[0] |
| if not isinstance(result, dict): |
| raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.") |
| return result |
|
|
|
|
| def rows_from_result(result: Mapping[str, Any]) -> list[dict[str, Any]]: |
| if isinstance(result.get("cleavage_distribution"), list): |
| return [_cleavage_row(row) for row in result["cleavage_distribution"]] |
| if "score" in result: |
| return [ |
| { |
| "event": str(result.get("channel", "polyadenylation")), |
| "position": "", |
| "probability": number_value(result["score"]), |
| } |
| ] |
| if isinstance(result.get("scores"), Mapping): |
| return [ |
| {"event": str(channel), "position": "", "probability": number_value(score)} |
| for channel, score in result["scores"].items() |
| ] |
| raise gr.Error("The selected model did not return polyadenylation scores.") |
|
|
|
|
| def _cleavage_row(row: Any) -> dict[str, Any]: |
| if not isinstance(row, Mapping): |
| raise gr.Error("Cleavage distribution rows must be dictionaries.") |
| if "event" in row: |
| return {"event": str(row["event"]), "position": "", "probability": number_value(row.get("probability"))} |
| return { |
| "event": "cleavage", |
| "position": row.get("position", ""), |
| "probability": number_value(row.get("probability")), |
| } |
|
|
|
|
| def number_value(value: Any) -> float: |
| try: |
| number = float(value) |
| except (TypeError, ValueError) as error: |
| raise gr.Error(f"Score value {value!r} is not numeric.") from error |
| if not np.isfinite(number): |
| raise gr.Error(f"Score value {value!r} is not finite.") |
| return number |
|
|
|
|
| def table_values(rows: list[Mapping[str, Any]]) -> list[list[Any]]: |
| return [[row.get(header, "") for header in TABLE_HEADERS] for row in rows] |
|
|
|
|
| def plot_polyadenylation(rows: list[Mapping[str, Any]]): |
| position_rows = [ |
| (int(row["position"]), float(row["probability"])) |
| for row in rows |
| if row.get("position") not in ("", None) and _is_number(row.get("probability")) |
| ] |
| no_cleavage = next((float(row["probability"]) for row in rows if row.get("event") == "no_cleavage"), None) |
|
|
| fig, ax = plt.subplots(figsize=(8.0, 3.2)) |
| if position_rows: |
| position_rows.sort() |
| positions = [position for position, _ in position_rows] |
| probabilities = [probability for _, probability in position_rows] |
| ax.plot(positions, probabilities, color="#2f6f9f", linewidth=1.8) |
| ax.fill_between(positions, probabilities, color="#9dcbec", alpha=0.35) |
| ax.set_xlabel("Position") |
| ax.set_ylabel("Cleavage probability") |
| if no_cleavage is not None: |
| ax.text( |
| 0.99, |
| 0.95, |
| f"no cleavage: {no_cleavage:.3f}", |
| ha="right", |
| va="top", |
| transform=ax.transAxes, |
| ) |
| else: |
| labels = [str(row.get("event", "score")) for row in rows] |
| values = [float(row.get("probability", 0.0)) for row in rows] |
| ax.barh(np.arange(len(values)), values, color="#2f6f9f") |
| ax.set_yticks(np.arange(len(values)), labels) |
| ax.invert_yaxis() |
| ax.set_xlabel("Score") |
| ax.grid(axis="y", alpha=0.2) |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def _is_number(value: Any) -> bool: |
| return isinstance(value, int | float | np.number) |
|
|
|
|
| def write_result_files( |
| metadata: Mapping[str, Any], |
| result: Mapping[str, Any], |
| rows: list[Mapping[str, Any]], |
| ) -> tuple[str, str]: |
| csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="") |
| writer = csv.DictWriter(csv_file, fieldnames=TABLE_HEADERS) |
| writer.writeheader() |
| writer.writerows({header: row.get(header, "") for header in TABLE_HEADERS} for row in rows) |
| csv_file.close() |
|
|
| json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False) |
| json.dump( |
| { |
| "metadata": dict(metadata), |
| "rows": [{header: row.get(header, "") for header in TABLE_HEADERS} for row in rows], |
| "raw_result": result, |
| }, |
| json_file, |
| indent=2, |
| ) |
| json_file.close() |
| return csv_file.name, json_file.name |
|
|
|
|
| def predict(model_label: str, sequence: str): |
| model_id = MODEL_OPTIONS[model_label] |
| sequence = clean_sequence(sequence) |
| started = time.perf_counter() |
|
|
| try: |
| result = load_predictor(model_id)(sequence) |
| except gr.Error: |
| raise |
| except Exception as error: |
| raise gr.Error(f"Prediction failed for {model_id}: {error}") from error |
|
|
| result = unpack_prediction_result(result) |
| rows = rows_from_result(result) |
| metadata = { |
| "task": "polyadenylation", |
| "model": model_id, |
| "model_label": model_label, |
| "device": _device_label(), |
| "sequence_length": len(sequence), |
| "row_count": len(rows), |
| "elapsed_seconds": round(time.perf_counter() - started, 3), |
| } |
| csv_path, json_path = write_result_files(metadata, result, rows) |
| return table_values(rows), metadata, plot_polyadenylation(rows), csv_path, json_path |
|
|
|
|
| def initial_model(request: gr.Request): |
| if request is None: |
| return DEFAULT_MODEL_LABEL |
| query_params = getattr(request, "query_params", None) |
| model_id = query_params.get("model") if query_params is not None else None |
| if not model_id and getattr(request, "url", None): |
| parsed = parse_qs(urlparse(str(request.url)).query) |
| model_values = parsed.get("model") |
| model_id = model_values[0] if model_values else None |
| return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL) |
|
|
|
|
| with gr.Blocks(title="Polyadenylation") as demo: |
| gr.Markdown( |
| "# Polyadenylation\n" |
| "Run MultiMolecule polyadenylation checkpoints and inspect APA isoform or cleavage-position scores." |
| ) |
|
|
| model = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value=DEFAULT_MODEL_LABEL, label="Checkpoint") |
| sequence = gr.Textbox(label="DNA sequence", value=DEFAULT_SEQUENCE, lines=5) |
| run = gr.Button("Run prediction", variant="primary") |
|
|
| with gr.Row(): |
| table = gr.Dataframe(headers=TABLE_HEADERS, label="Polyadenylation scores", interactive=False) |
| metadata = gr.JSON(label="Run metadata") |
|
|
| plot = gr.Plot(label="Polyadenylation plot") |
|
|
| with gr.Row(): |
| csv_download = gr.File(label="Download CSV") |
| json_download = gr.File(label="Download JSON") |
|
|
| run.click(predict, inputs=[model, sequence], outputs=[table, metadata, plot, csv_download, json_download]) |
| demo.load(initial_model, outputs=model) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|