ZhiyuanChen commited on
Commit
7b1dcd7
·
unverified ·
1 Parent(s): 56ac439

implement polyadenylation app

Browse files

Signed-off-by: Zhiyuan Chen <this@zyc.ai>

Files changed (4) hide show
  1. .pre-commit-config.yaml +51 -0
  2. README.md +15 -5
  3. app.py +250 -0
  4. requirements.txt +5 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+ repos:
4
+ - repo: https://github.com/PSF/black
5
+ rev: 25.12.0
6
+ hooks:
7
+ - id: black
8
+ args: [--safe, --quiet, --line-length=120]
9
+ - repo: https://github.com/PyCQA/isort
10
+ rev: 7.0.0
11
+ hooks:
12
+ - id: isort
13
+ name: isort
14
+ args: [--profile=black, --line-length=120]
15
+ - repo: https://github.com/PyCQA/flake8
16
+ rev: 7.3.0
17
+ hooks:
18
+ - id: flake8
19
+ args: [--max-line-length=120]
20
+ additional_dependencies:
21
+ - flake8-bugbear
22
+ - flake8-comprehensions
23
+ - flake8-simplify
24
+ - repo: https://github.com/asottile/pyupgrade
25
+ rev: v3.21.2
26
+ hooks:
27
+ - id: pyupgrade
28
+ args: [--keep-runtime-typing]
29
+ - repo: https://github.com/codespell-project/codespell
30
+ rev: v2.4.1
31
+ hooks:
32
+ - id: codespell
33
+ args: [--ignore-regex=(?i)aparent]
34
+ - repo: https://github.com/pre-commit/pre-commit-hooks
35
+ rev: v6.0.0
36
+ hooks:
37
+ - id: check-added-large-files
38
+ - id: check-ast
39
+ - id: check-builtin-literals
40
+ - id: check-case-conflict
41
+ - id: check-docstring-first
42
+ - id: check-json
43
+ - id: check-toml
44
+ - id: check-yaml
45
+ - id: debug-statements
46
+ - id: end-of-file-fixer
47
+ - id: fix-byte-order-marker
48
+ - id: mixed-line-ending
49
+ args: ["--fix=lf"]
50
+ - id: requirements-txt-fixer
51
+ - id: trailing-whitespace
README.md CHANGED
@@ -1,15 +1,25 @@
1
  ---
2
  title: Polyadenylation
3
- emoji: 🚀
4
  colorFrom: purple
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
11
  license: agpl-3.0
12
- short_description: Polyadenylation
 
 
 
 
 
 
 
 
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
  title: Polyadenylation
3
+ emoji: 🧬
4
  colorFrom: purple
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
+ python_version: "3.13"
9
  app_file: app.py
10
  pinned: false
11
  license: agpl-3.0
12
+ suggested_hardware: t4-small
13
+ models:
14
+ - multimolecule/aparent2
15
+ - multimolecule/aparent
16
+ tags:
17
+ - biology
18
+ - dna
19
+ - polyadenylation
20
+ - multimolecule
21
  ---
22
 
23
+ Interactive polyadenylation scoring with MultiMolecule.
24
+
25
+ Enter a DNA sequence, choose an APARENT-family checkpoint, and inspect APA isoform or base-resolution cleavage scores. Results can be downloaded as CSV or JSON.
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MultiMolecule
2
+ # Copyright (C) 2024-Present MultiMolecule
3
+
4
+ from __future__ import annotations
5
+
6
+ import csv
7
+ import json
8
+ import re
9
+ import tempfile
10
+ import time
11
+ from functools import lru_cache
12
+ from typing import Any, Mapping
13
+ from urllib.parse import parse_qs, urlparse
14
+
15
+ import gradio as gr
16
+ import matplotlib
17
+ import numpy as np
18
+ import torch
19
+ from transformers import pipeline
20
+
21
+ matplotlib.use("Agg")
22
+ import matplotlib.pyplot as plt # noqa: E402
23
+ import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers
24
+
25
+ MODEL_OPTIONS = {
26
+ "APARENT2": "multimolecule/aparent2",
27
+ "APARENT": "multimolecule/aparent",
28
+ }
29
+ MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()}
30
+ DEFAULT_MODEL_LABEL = "APARENT2"
31
+ DEFAULT_SEQUENCE = "A" * 70 + "AATAAA" + "A" * 129
32
+ DNA_ALPHABET = set("ACGTN")
33
+ TABLE_HEADERS = ["event", "position", "probability"]
34
+
35
+
36
+ def _device() -> int:
37
+ return 0 if torch.cuda.is_available() else -1
38
+
39
+
40
+ def _device_label() -> str:
41
+ return "cuda" if torch.cuda.is_available() else "cpu"
42
+
43
+
44
+ @lru_cache(maxsize=2)
45
+ def load_predictor(model_id: str):
46
+ return pipeline("polyadenylation", model=model_id, device=_device())
47
+
48
+
49
+ def clean_sequence(sequence: str) -> str:
50
+ lines = []
51
+ for line in str(sequence or "").splitlines():
52
+ line = line.strip()
53
+ if line and not line.startswith(">"):
54
+ lines.append(line)
55
+ sequence = re.sub(r"\s+", "", "".join(lines)).upper().replace("U", "T")
56
+ if not sequence:
57
+ raise gr.Error("Sequence is empty.")
58
+ invalid = sorted(set(sequence) - DNA_ALPHABET)
59
+ if invalid:
60
+ raise gr.Error(f"DNA sequence contains unsupported characters: {', '.join(invalid)}.")
61
+ return sequence
62
+
63
+
64
+ def unpack_prediction_result(result: Any) -> dict[str, Any]:
65
+ if isinstance(result, list):
66
+ if len(result) != 1:
67
+ raise gr.Error(f"Expected one prediction result, got {len(result)}.")
68
+ result = result[0]
69
+ if not isinstance(result, dict):
70
+ raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.")
71
+ return result
72
+
73
+
74
+ def rows_from_result(result: Mapping[str, Any]) -> list[dict[str, Any]]:
75
+ if isinstance(result.get("cleavage_distribution"), list):
76
+ return [_cleavage_row(row) for row in result["cleavage_distribution"]]
77
+ if "score" in result:
78
+ return [
79
+ {
80
+ "event": str(result.get("channel", "polyadenylation")),
81
+ "position": "",
82
+ "probability": number_value(result["score"]),
83
+ }
84
+ ]
85
+ if isinstance(result.get("scores"), Mapping):
86
+ return [
87
+ {"event": str(channel), "position": "", "probability": number_value(score)}
88
+ for channel, score in result["scores"].items()
89
+ ]
90
+ raise gr.Error("The selected model did not return polyadenylation scores.")
91
+
92
+
93
+ def _cleavage_row(row: Any) -> dict[str, Any]:
94
+ if not isinstance(row, Mapping):
95
+ raise gr.Error("Cleavage distribution rows must be dictionaries.")
96
+ if "event" in row:
97
+ return {"event": str(row["event"]), "position": "", "probability": number_value(row.get("probability"))}
98
+ return {
99
+ "event": "cleavage",
100
+ "position": row.get("position", ""),
101
+ "probability": number_value(row.get("probability")),
102
+ }
103
+
104
+
105
+ def number_value(value: Any) -> float:
106
+ try:
107
+ number = float(value)
108
+ except (TypeError, ValueError) as error:
109
+ raise gr.Error(f"Score value {value!r} is not numeric.") from error
110
+ if not np.isfinite(number):
111
+ raise gr.Error(f"Score value {value!r} is not finite.")
112
+ return number
113
+
114
+
115
+ def table_values(rows: list[Mapping[str, Any]]) -> list[list[Any]]:
116
+ return [[row.get(header, "") for header in TABLE_HEADERS] for row in rows]
117
+
118
+
119
+ def plot_polyadenylation(rows: list[Mapping[str, Any]]):
120
+ position_rows = [
121
+ (int(row["position"]), float(row["probability"]))
122
+ for row in rows
123
+ if row.get("position") not in ("", None) and _is_number(row.get("probability"))
124
+ ]
125
+ no_cleavage = next((float(row["probability"]) for row in rows if row.get("event") == "no_cleavage"), None)
126
+
127
+ fig, ax = plt.subplots(figsize=(8.0, 3.2))
128
+ if position_rows:
129
+ position_rows.sort()
130
+ positions = [position for position, _ in position_rows]
131
+ probabilities = [probability for _, probability in position_rows]
132
+ ax.plot(positions, probabilities, color="#2f6f9f", linewidth=1.8)
133
+ ax.fill_between(positions, probabilities, color="#9dcbec", alpha=0.35)
134
+ ax.set_xlabel("Position")
135
+ ax.set_ylabel("Cleavage probability")
136
+ if no_cleavage is not None:
137
+ ax.text(
138
+ 0.99,
139
+ 0.95,
140
+ f"no cleavage: {no_cleavage:.3f}",
141
+ ha="right",
142
+ va="top",
143
+ transform=ax.transAxes,
144
+ )
145
+ else:
146
+ labels = [str(row.get("event", "score")) for row in rows]
147
+ values = [float(row.get("probability", 0.0)) for row in rows]
148
+ ax.barh(np.arange(len(values)), values, color="#2f6f9f")
149
+ ax.set_yticks(np.arange(len(values)), labels)
150
+ ax.invert_yaxis()
151
+ ax.set_xlabel("Score")
152
+ ax.grid(axis="y", alpha=0.2)
153
+ fig.tight_layout()
154
+ return fig
155
+
156
+
157
+ def _is_number(value: Any) -> bool:
158
+ return isinstance(value, int | float | np.number)
159
+
160
+
161
+ def write_result_files(
162
+ metadata: Mapping[str, Any],
163
+ result: Mapping[str, Any],
164
+ rows: list[Mapping[str, Any]],
165
+ ) -> tuple[str, str]:
166
+ csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="")
167
+ writer = csv.DictWriter(csv_file, fieldnames=TABLE_HEADERS)
168
+ writer.writeheader()
169
+ writer.writerows({header: row.get(header, "") for header in TABLE_HEADERS} for row in rows)
170
+ csv_file.close()
171
+
172
+ json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False)
173
+ json.dump(
174
+ {
175
+ "metadata": dict(metadata),
176
+ "rows": [{header: row.get(header, "") for header in TABLE_HEADERS} for row in rows],
177
+ "raw_result": result,
178
+ },
179
+ json_file,
180
+ indent=2,
181
+ )
182
+ json_file.close()
183
+ return csv_file.name, json_file.name
184
+
185
+
186
+ def predict(model_label: str, sequence: str):
187
+ model_id = MODEL_OPTIONS[model_label]
188
+ sequence = clean_sequence(sequence)
189
+ started = time.perf_counter()
190
+
191
+ try:
192
+ result = load_predictor(model_id)(sequence)
193
+ except gr.Error:
194
+ raise
195
+ except Exception as error:
196
+ raise gr.Error(f"Prediction failed for {model_id}: {error}") from error
197
+
198
+ result = unpack_prediction_result(result)
199
+ rows = rows_from_result(result)
200
+ metadata = {
201
+ "task": "polyadenylation",
202
+ "model": model_id,
203
+ "model_label": model_label,
204
+ "device": _device_label(),
205
+ "sequence_length": len(sequence),
206
+ "row_count": len(rows),
207
+ "elapsed_seconds": round(time.perf_counter() - started, 3),
208
+ }
209
+ csv_path, json_path = write_result_files(metadata, result, rows)
210
+ return table_values(rows), metadata, plot_polyadenylation(rows), csv_path, json_path
211
+
212
+
213
+ def initial_model(request: gr.Request):
214
+ if request is None:
215
+ return DEFAULT_MODEL_LABEL
216
+ query_params = getattr(request, "query_params", None)
217
+ model_id = query_params.get("model") if query_params is not None else None
218
+ if not model_id and getattr(request, "url", None):
219
+ parsed = parse_qs(urlparse(str(request.url)).query)
220
+ model_values = parsed.get("model")
221
+ model_id = model_values[0] if model_values else None
222
+ return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL)
223
+
224
+
225
+ with gr.Blocks(title="Polyadenylation") as demo:
226
+ gr.Markdown(
227
+ "# Polyadenylation\n"
228
+ "Run MultiMolecule polyadenylation checkpoints and inspect APA isoform or cleavage-position scores."
229
+ )
230
+
231
+ model = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value=DEFAULT_MODEL_LABEL, label="Checkpoint")
232
+ sequence = gr.Textbox(label="DNA sequence", value=DEFAULT_SEQUENCE, lines=5)
233
+ run = gr.Button("Run prediction", variant="primary")
234
+
235
+ with gr.Row():
236
+ table = gr.Dataframe(headers=TABLE_HEADERS, label="Polyadenylation scores", interactive=False)
237
+ metadata = gr.JSON(label="Run metadata")
238
+
239
+ plot = gr.Plot(label="Polyadenylation plot")
240
+
241
+ with gr.Row():
242
+ csv_download = gr.File(label="Download CSV")
243
+ json_download = gr.File(label="Download JSON")
244
+
245
+ run.click(predict, inputs=[model, sequence], outputs=[table, metadata, plot, csv_download, json_download])
246
+ demo.load(initial_model, outputs=model)
247
+
248
+
249
+ if __name__ == "__main__":
250
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ matplotlib
2
+ multimolecule @ git+https://github.com/DLS5-Omics/multimolecule.git@master
3
+ numpy
4
+ torch
5
+ transformers