ZhiyuanChen commited on
Commit
2c6f65c
·
unverified ·
1 Parent(s): 7fc266e

implement methylation app

Browse files

Signed-off-by: Zhiyuan Chen <this@zyc.ai>

Files changed (4) hide show
  1. .pre-commit-config.yaml +50 -0
  2. README.md +18 -5
  3. app.py +225 -0
  4. requirements.txt +5 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+ repos:
4
+ - repo: https://github.com/PSF/black
5
+ rev: 25.12.0
6
+ hooks:
7
+ - id: black
8
+ args: [--safe, --quiet, --line-length=120]
9
+ - repo: https://github.com/PyCQA/isort
10
+ rev: 7.0.0
11
+ hooks:
12
+ - id: isort
13
+ name: isort
14
+ args: [--profile=black, --line-length=120]
15
+ - repo: https://github.com/PyCQA/flake8
16
+ rev: 7.3.0
17
+ hooks:
18
+ - id: flake8
19
+ args: [--max-line-length=120]
20
+ additional_dependencies:
21
+ - flake8-bugbear
22
+ - flake8-comprehensions
23
+ - flake8-simplify
24
+ - repo: https://github.com/asottile/pyupgrade
25
+ rev: v3.21.2
26
+ hooks:
27
+ - id: pyupgrade
28
+ args: [--keep-runtime-typing]
29
+ - repo: https://github.com/codespell-project/codespell
30
+ rev: v2.4.1
31
+ hooks:
32
+ - id: codespell
33
+ - repo: https://github.com/pre-commit/pre-commit-hooks
34
+ rev: v6.0.0
35
+ hooks:
36
+ - id: check-added-large-files
37
+ - id: check-ast
38
+ - id: check-builtin-literals
39
+ - id: check-case-conflict
40
+ - id: check-docstring-first
41
+ - id: check-json
42
+ - id: check-toml
43
+ - id: check-yaml
44
+ - id: debug-statements
45
+ - id: end-of-file-fixer
46
+ - id: fix-byte-order-marker
47
+ - id: mixed-line-ending
48
+ args: ["--fix=lf"]
49
+ - id: requirements-txt-fixer
50
+ - id: trailing-whitespace
README.md CHANGED
@@ -1,15 +1,28 @@
1
  ---
2
  title: Methylation
3
- emoji: 🌖
4
  colorFrom: purple
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
11
  license: agpl-3.0
12
- short_description: Methylation
 
 
 
 
 
 
 
 
 
 
 
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
  title: Methylation
3
+ emoji: 🧬
4
  colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
+ python_version: "3.13"
9
  app_file: app.py
10
  pinned: false
11
  license: agpl-3.0
12
+ suggested_hardware: t4-small
13
+ models:
14
+ - multimolecule/deepcpgdna-smallwood2014-serum
15
+ - multimolecule/deepcpgdna-smallwood2014-2i
16
+ - multimolecule/deepcpgdna-hou2016-hcc
17
+ - multimolecule/deepcpgdna-hou2016-hepg2
18
+ - multimolecule/deepcpgdna-hou2016-mesc
19
+ tags:
20
+ - biology
21
+ - dna
22
+ - methylation
23
+ - multimolecule
24
  ---
25
 
26
+ Interactive DNA methylation scoring with MultiMolecule.
27
+
28
+ Enter one DNA sequence, choose a DeepCpG-DNA checkpoint, and inspect the returned per-cell methylation score table, run metadata, and bar plot. Results can be downloaded as CSV or JSON.
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MultiMolecule
2
+ # Copyright (C) 2024-Present MultiMolecule
3
+
4
+ from __future__ import annotations
5
+
6
+ import csv
7
+ import json
8
+ import re
9
+ import tempfile
10
+ import time
11
+ from functools import lru_cache
12
+ from typing import Any, Mapping
13
+ from urllib.parse import parse_qs, urlparse
14
+
15
+ import gradio as gr
16
+ import matplotlib
17
+ import numpy as np
18
+ import torch
19
+ from transformers import pipeline
20
+
21
+ matplotlib.use("Agg")
22
+ import matplotlib.pyplot as plt # noqa: E402
23
+ import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers
24
+
25
+ MODEL_OPTIONS = {
26
+ "DeepCpG-DNA Smallwood 2014 serum mESC": "multimolecule/deepcpgdna-smallwood2014-serum",
27
+ "DeepCpG-DNA Smallwood 2014 2i mESC": "multimolecule/deepcpgdna-smallwood2014-2i",
28
+ "DeepCpG-DNA Hou 2016 HCC": "multimolecule/deepcpgdna-hou2016-hcc",
29
+ "DeepCpG-DNA Hou 2016 HepG2": "multimolecule/deepcpgdna-hou2016-hepg2",
30
+ "DeepCpG-DNA Hou 2016 mESC": "multimolecule/deepcpgdna-hou2016-mesc",
31
+ }
32
+ MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()}
33
+ DEFAULT_MODEL_LABEL = "DeepCpG-DNA Smallwood 2014 serum mESC"
34
+ DEFAULT_SEQUENCE = ("ACGT" * 125)[:499] + "CG" + ("TGCA" * 125)[:500]
35
+ DNA_ALPHABET = set("ACGTN")
36
+
37
+
38
+ def _device() -> int:
39
+ return 0 if torch.cuda.is_available() else -1
40
+
41
+
42
+ def _device_label() -> str:
43
+ return "cuda" if torch.cuda.is_available() else "cpu"
44
+
45
+
46
+ @lru_cache(maxsize=2)
47
+ def load_predictor(model_id: str):
48
+ return pipeline("methylation", model=model_id, device=_device())
49
+
50
+
51
+ def clean_sequence(sequence: str) -> str:
52
+ lines = []
53
+ for line in str(sequence or "").splitlines():
54
+ line = line.strip()
55
+ if line and not line.startswith(">"):
56
+ lines.append(line)
57
+ sequence = re.sub(r"\s+", "", "".join(lines)).upper().replace("U", "T")
58
+ if not sequence:
59
+ raise gr.Error("Sequence is empty.")
60
+ invalid = sorted(set(sequence) - DNA_ALPHABET)
61
+ if invalid:
62
+ raise gr.Error(f"DNA sequence contains unsupported characters: {', '.join(invalid)}.")
63
+ return sequence
64
+
65
+
66
+ def unpack_prediction_result(result: Any) -> dict[str, Any]:
67
+ if isinstance(result, list):
68
+ if len(result) != 1:
69
+ raise gr.Error(f"Expected one prediction result, got {len(result)}.")
70
+ result = result[0]
71
+ if not isinstance(result, dict):
72
+ raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.")
73
+ return result
74
+
75
+
76
+ def score_rows_from_result(result: Mapping[str, Any]) -> list[list[Any]]:
77
+ channels = [str(channel) for channel in result.get("channels", [])]
78
+ if "score" in result:
79
+ return rows_from_values(result["score"], channels or ["methylation"])
80
+ if "scores" in result:
81
+ scores = result["scores"]
82
+ if isinstance(scores, Mapping):
83
+ return [[str(channel), number_value(score)] for channel, score in scores.items()]
84
+ if isinstance(scores, list):
85
+ return rows_from_values(scores, channels)
86
+ raise gr.Error("The selected model did not return methylation scores.")
87
+
88
+
89
+ def rows_from_values(values: Any, channels: list[str]) -> list[list[Any]]:
90
+ if isinstance(values, (list, tuple)):
91
+ if len(channels) != len(values):
92
+ channels = [f"methylation_{index}" for index in range(len(values))]
93
+ return [[channel, number_value(value)] for channel, value in zip(channels, values)]
94
+ return [[channels[0] if channels else "methylation", number_value(values)]]
95
+
96
+
97
+ def number_value(value: Any) -> float:
98
+ try:
99
+ number = float(value)
100
+ except (TypeError, ValueError) as error:
101
+ raise gr.Error(f"Score value {value!r} is not numeric.") from error
102
+ if not np.isfinite(number):
103
+ raise gr.Error(f"Score value {value!r} is not finite.")
104
+ return number
105
+
106
+
107
+ def plot_scores(rows: list[list[Any]], top_n: int | float):
108
+ top_n = max(1, int(top_n or 25))
109
+ values = [(str(channel), float(score)) for channel, score in rows]
110
+ values = sorted(values, key=lambda item: item[1], reverse=True)[:top_n]
111
+
112
+ height = max(3.0, min(12.0, 1.2 + 0.34 * len(values)))
113
+ fig, ax = plt.subplots(figsize=(8.0, height))
114
+ if not values:
115
+ ax.set_axis_off()
116
+ return fig
117
+
118
+ labels = [label if len(label) <= 58 else f"{label[:55]}..." for label, _ in values]
119
+ scores = [score for _, score in values]
120
+ y_positions = np.arange(len(values))
121
+
122
+ ax.barh(y_positions, scores, color="#2f6f9f")
123
+ ax.set_yticks(y_positions, labels)
124
+ ax.invert_yaxis()
125
+ if all(0.0 <= score <= 1.0 for score in scores):
126
+ ax.set_xlim(0.0, 1.0)
127
+ ax.set_xlabel("Methylation score")
128
+ ax.grid(axis="x", alpha=0.2)
129
+ fig.tight_layout()
130
+ return fig
131
+
132
+
133
+ def write_result_files(
134
+ metadata: Mapping[str, Any], result: Mapping[str, Any], rows: list[list[Any]]
135
+ ) -> tuple[str, str]:
136
+ csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="")
137
+ writer = csv.writer(csv_file)
138
+ writer.writerow(["channel", "score"])
139
+ writer.writerows(rows)
140
+ csv_file.close()
141
+
142
+ json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False)
143
+ json.dump(
144
+ {
145
+ "metadata": dict(metadata),
146
+ "scores": [{"channel": channel, "score": score} for channel, score in rows],
147
+ "raw_result": result,
148
+ },
149
+ json_file,
150
+ indent=2,
151
+ )
152
+ json_file.close()
153
+ return csv_file.name, json_file.name
154
+
155
+
156
+ def predict(model_label: str, sequence: str, top_n: int | float):
157
+ model_id = MODEL_OPTIONS[model_label]
158
+ sequence = clean_sequence(sequence)
159
+ started = time.perf_counter()
160
+
161
+ try:
162
+ result = load_predictor(model_id)(sequence)
163
+ except gr.Error:
164
+ raise
165
+ except Exception as error:
166
+ raise gr.Error(f"Prediction failed for {model_id}: {error}") from error
167
+
168
+ result = unpack_prediction_result(result)
169
+ rows = score_rows_from_result(result)
170
+ metadata = {
171
+ "task": "methylation",
172
+ "model": model_id,
173
+ "model_label": model_label,
174
+ "device": _device_label(),
175
+ "sequence_length": len(sequence),
176
+ "score_count": len(rows),
177
+ "channels": result.get("channels", []),
178
+ "elapsed_seconds": round(time.perf_counter() - started, 3),
179
+ }
180
+ csv_path, json_path = write_result_files(metadata, result, rows)
181
+ return rows, metadata, plot_scores(rows, top_n), csv_path, json_path
182
+
183
+
184
+ def initial_model(request: gr.Request):
185
+ if request is None:
186
+ return DEFAULT_MODEL_LABEL
187
+ query_params = getattr(request, "query_params", None)
188
+ model_id = query_params.get("model") if query_params is not None else None
189
+ if not model_id and getattr(request, "url", None):
190
+ parsed = parse_qs(urlparse(str(request.url)).query)
191
+ model_values = parsed.get("model")
192
+ model_id = model_values[0] if model_values else None
193
+ return MODEL_LABELS.get(model_id, DEFAULT_MODEL_LABEL)
194
+
195
+
196
+ with gr.Blocks(title="Methylation") as demo:
197
+ gr.Markdown(
198
+ "# Methylation\n" "Run MultiMolecule DNA methylation checkpoints and inspect per-cell methylation scores."
199
+ )
200
+
201
+ with gr.Row():
202
+ model = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value=DEFAULT_MODEL_LABEL, label="Checkpoint")
203
+ top_n = gr.Slider(1, 50, value=25, step=1, label="Bar count")
204
+
205
+ sequence = gr.Textbox(label="DNA sequence", value=DEFAULT_SEQUENCE, lines=7)
206
+ run = gr.Button("Run prediction", variant="primary")
207
+
208
+ with gr.Row():
209
+ scores = gr.Dataframe(headers=["channel", "score"], datatype=["str", "number"], label="Score table")
210
+ metadata = gr.JSON(label="Run metadata")
211
+
212
+ score_plot = gr.Plot(label="Score bar plot")
213
+
214
+ with gr.Row():
215
+ csv_download = gr.File(label="Download CSV")
216
+ json_download = gr.File(label="Download JSON")
217
+
218
+ run.click(
219
+ predict, inputs=[model, sequence, top_n], outputs=[scores, metadata, score_plot, csv_download, json_download]
220
+ )
221
+ demo.load(initial_model, outputs=model)
222
+
223
+
224
+ if __name__ == "__main__":
225
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ matplotlib
2
+ multimolecule @ git+https://github.com/DLS5-Omics/multimolecule.git@master
3
+ numpy
4
+ torch
5
+ transformers