ZhiyuanChen commited on
Commit
ae141c8
·
unverified ·
1 Parent(s): dc7e249

implement splice-variant-effect app

Browse files

Signed-off-by: Zhiyuan Chen <this@zyc.ai>

Files changed (4) hide show
  1. .pre-commit-config.yaml +50 -0
  2. README.md +24 -6
  3. app.py +491 -0
  4. requirements.txt +7 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+ repos:
4
+ - repo: https://github.com/PSF/black
5
+ rev: 25.12.0
6
+ hooks:
7
+ - id: black
8
+ args: [--safe, --quiet, --line-length=120]
9
+ - repo: https://github.com/PyCQA/isort
10
+ rev: 7.0.0
11
+ hooks:
12
+ - id: isort
13
+ name: isort
14
+ args: [--profile=black, --line-length=120]
15
+ - repo: https://github.com/PyCQA/flake8
16
+ rev: 7.3.0
17
+ hooks:
18
+ - id: flake8
19
+ args: [--max-line-length=120]
20
+ additional_dependencies:
21
+ - flake8-bugbear
22
+ - flake8-comprehensions
23
+ - flake8-simplify
24
+ - repo: https://github.com/asottile/pyupgrade
25
+ rev: v3.21.2
26
+ hooks:
27
+ - id: pyupgrade
28
+ args: [--keep-runtime-typing]
29
+ - repo: https://github.com/codespell-project/codespell
30
+ rev: v2.4.1
31
+ hooks:
32
+ - id: codespell
33
+ - repo: https://github.com/pre-commit/pre-commit-hooks
34
+ rev: v6.0.0
35
+ hooks:
36
+ - id: check-added-large-files
37
+ - id: check-ast
38
+ - id: check-builtin-literals
39
+ - id: check-case-conflict
40
+ - id: check-docstring-first
41
+ - id: check-json
42
+ - id: check-toml
43
+ - id: check-yaml
44
+ - id: debug-statements
45
+ - id: end-of-file-fixer
46
+ - id: fix-byte-order-marker
47
+ - id: mixed-line-ending
48
+ args: ["--fix=lf"]
49
+ - id: requirements-txt-fixer
50
+ - id: trailing-whitespace
README.md CHANGED
@@ -1,15 +1,33 @@
1
  ---
2
  title: Splice Variant Effect
3
- emoji: 🐠
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
11
  license: agpl-3.0
12
- short_description: Splice Variant Effect
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
  title: Splice Variant Effect
3
+ emoji: 🧬
4
+ colorFrom: purple
5
+ colorTo: pink
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
+ python_version: "3.13"
9
  app_file: app.py
10
  pinned: false
11
  license: agpl-3.0
12
+ suggested_hardware: t4-small
13
+ models:
14
+ - multimolecule/mmsplice
15
+ - multimolecule/mtsplice
16
+ - multimolecule/hal
17
+ - multimolecule/maxentscan-score5
18
+ - multimolecule/maxentscan-score3
19
+ - multimolecule/pangolin
20
+ - multimolecule/sptransformer
21
+ tags:
22
+ - biology
23
+ - dna
24
+ - splicing
25
+ - variant-effect
26
+ - multimolecule
27
  ---
28
 
29
+ Interactive splice variant-effect scoring with MultiMolecule.
30
+
31
+ Enter same-length reference and alternative DNA sequence windows, or upload a two-record FASTA file with the reference first and the alternative second. The app runs the Transformers `splice-variant-effect` pipeline registered by MultiMolecule and reports delta scores, optional reference and alternative scores, run metadata, and downloadable CSV/JSON outputs.
32
+
33
+ This Space intentionally works on supplied sequence windows only. It does not perform genome-coordinate lookup, transcript annotation, or reference sequence reconstruction.
app.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MultiMolecule
2
+ # Copyright (C) 2024-Present MultiMolecule
3
+
4
+ # This file is part of MultiMolecule.
5
+
6
+ # MultiMolecule is free software: you can redistribute it and/or modify
7
+ # it under the terms of the GNU Affero General Public License as published by
8
+ # the Free Software Foundation, either version 3 of the License, or
9
+ # any later version.
10
+
11
+ # MultiMolecule is distributed in the hope that it will be useful,
12
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
13
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14
+ # GNU Affero General Public License for more details.
15
+
16
+ # You should have received a copy of the GNU Affero General Public License
17
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
18
+
19
+ # For additional terms and clarifications, please refer to our License FAQ at:
20
+ # <https://multimolecule.danling.org/about/license-faq>.
21
+
22
+ from __future__ import annotations
23
+
24
+ import json
25
+ import math
26
+ import tempfile
27
+ import time
28
+ from collections.abc import Mapping
29
+ from datetime import datetime, timezone
30
+ from functools import lru_cache
31
+ from pathlib import Path
32
+ from typing import Any
33
+ from urllib.parse import parse_qs, urlparse
34
+
35
+ import gradio as gr
36
+ import matplotlib
37
+ import numpy as np
38
+ import pandas as pd
39
+ import torch
40
+ from Bio import SeqIO
41
+ from transformers import pipeline
42
+
43
+ matplotlib.use("Agg")
44
+ import matplotlib.pyplot as plt # noqa: E402
45
+ import multimolecule # noqa: E402, F401 - registers MultiMolecule models and pipelines with Transformers
46
+
47
+ MODEL_OPTIONS = {
48
+ "MMSplice": "multimolecule/mmsplice",
49
+ "MTSplice": "multimolecule/mtsplice",
50
+ "HAL": "multimolecule/hal",
51
+ "MaxEntScan score5": "multimolecule/maxentscan-score5",
52
+ "MaxEntScan score3": "multimolecule/maxentscan-score3",
53
+ "Pangolin": "multimolecule/pangolin",
54
+ "SpTransformer": "multimolecule/sptransformer",
55
+ }
56
+ MODEL_LABELS = {model_id: label for label, model_id in MODEL_OPTIONS.items()}
57
+ FASTA_SUFFIXES = {".fa", ".fasta", ".fna"}
58
+ VALID_DNA = set("ACGTNRYSWKMBDHVX")
59
+ META_COLUMNS = {"scope", "position", "nucleotide", "sequence", "label", "type"}
60
+
61
+ DEFAULT_REFERENCE = "ACGT" * 25 + "CCCCCCCCCCCCCCCCCCCC" + "TGCA" * 25
62
+ DEFAULT_ALTERNATIVE = "ACGT" * 25 + "CCCCCCCCCCCTCCCCCCCC" + "TGCA" * 25
63
+
64
+
65
+ def _device() -> int:
66
+ return 0 if torch.cuda.is_available() else -1
67
+
68
+
69
+ @lru_cache(maxsize=len(MODEL_OPTIONS))
70
+ def load_predictor(model_id: str):
71
+ return pipeline("splice-variant-effect", model=model_id, device=_device())
72
+
73
+
74
+ def clean_sequence(sequence: str, label: str) -> str:
75
+ sequence = "".join(str(sequence or "").split()).upper().replace("U", "T")
76
+ if not sequence:
77
+ raise gr.Error(f"{label} sequence is empty.")
78
+
79
+ invalid = sorted(set(sequence) - VALID_DNA)
80
+ if invalid:
81
+ raise gr.Error(f"{label} sequence contains unsupported DNA symbols: {', '.join(invalid)}.")
82
+ return sequence
83
+
84
+
85
+ def validate_pair(reference: str, alternative: str) -> tuple[str, str]:
86
+ reference = clean_sequence(reference, "Reference")
87
+ alternative = clean_sequence(alternative, "Alternative")
88
+ if len(reference) != len(alternative):
89
+ raise gr.Error(
90
+ "Reference and alternative sequences must have the same length. "
91
+ "This app does not perform genome-coordinate lookup or sequence reconstruction."
92
+ )
93
+ return reference, alternative
94
+
95
+
96
+ def load_fasta_pair(input_file: Any):
97
+ if input_file is None:
98
+ return gr.update(), gr.update()
99
+
100
+ path = Path(getattr(input_file, "name", input_file))
101
+ if path.suffix.lower() not in FASTA_SUFFIXES:
102
+ raise gr.Error("Upload a FASTA file with two records: reference first, alternative second.")
103
+
104
+ records = list(SeqIO.parse(path, "fasta"))
105
+ if len(records) != 2:
106
+ raise gr.Error(f"Expected exactly two FASTA records, found {len(records)}.")
107
+
108
+ reference, alternative = validate_pair(str(records[0].seq), str(records[1].seq))
109
+ return reference, alternative
110
+
111
+
112
+ def _json_safe(value: Any) -> Any:
113
+ if isinstance(value, torch.Tensor):
114
+ return _json_safe(value.detach().cpu().tolist())
115
+ if isinstance(value, np.ndarray):
116
+ return _json_safe(value.tolist())
117
+ if isinstance(value, np.generic):
118
+ return value.item()
119
+ if isinstance(value, Mapping):
120
+ return {str(key): _json_safe(item) for key, item in value.items()}
121
+ if isinstance(value, (list, tuple)):
122
+ return [_json_safe(item) for item in value]
123
+ return value
124
+
125
+
126
+ def _is_scalar(value: Any) -> bool:
127
+ if isinstance(value, (str, bytes)) or value is None:
128
+ return False
129
+ try:
130
+ float(value)
131
+ except (TypeError, ValueError):
132
+ return False
133
+ return True
134
+
135
+
136
+ def _number(value: Any) -> float | Any:
137
+ if not _is_scalar(value):
138
+ return value
139
+ number = float(value)
140
+ if math.isfinite(number):
141
+ return number
142
+ return value
143
+
144
+
145
+ def _position_key(key: Any) -> bool:
146
+ try:
147
+ int(str(key))
148
+ except ValueError:
149
+ return False
150
+ return True
151
+
152
+
153
+ def _vector_row(values: list[Any], channels: list[str], scalar_column: str, scope: str = "sequence") -> dict[str, Any]:
154
+ row: dict[str, Any] = {"scope": scope}
155
+ if channels and len(values) == len(channels):
156
+ row.update({channel: _number(value) for channel, value in zip(channels, values)})
157
+ elif len(values) == 1:
158
+ row[scalar_column] = _number(values[0])
159
+ else:
160
+ row.update({f"{scalar_column}_{index}": _number(value) for index, value in enumerate(values)})
161
+ return row
162
+
163
+
164
+ def _flatten_mapping(
165
+ mapping: Mapping[str, Any],
166
+ channels: list[str],
167
+ scalar_column: str,
168
+ prefix: str | None = None,
169
+ ) -> dict[str, Any]:
170
+ row: dict[str, Any] = {}
171
+ for key, value in mapping.items():
172
+ key = str(key)
173
+ column = f"{prefix}_{key}" if prefix else key
174
+ value = _json_safe(value)
175
+ if _is_scalar(value) or value is None or isinstance(value, str):
176
+ row[column] = _number(value)
177
+ elif isinstance(value, Mapping):
178
+ row.update(_flatten_mapping(value, channels, scalar_column, prefix=column))
179
+ elif isinstance(value, list) and all(_is_scalar(item) for item in value):
180
+ if key in META_COLUMNS:
181
+ row[column] = value
182
+ elif channels and len(value) == len(channels):
183
+ row.update({channel: _number(item) for channel, item in zip(channels, value)})
184
+ else:
185
+ row.update({f"{column}_{index}": _number(item) for index, item in enumerate(value)})
186
+ else:
187
+ row[column] = value
188
+ return row
189
+
190
+
191
+ def normalize_score_rows(score_value: Any, channels: list[str], scalar_column: str) -> list[dict[str, Any]]:
192
+ score_value = _json_safe(score_value)
193
+ if score_value is None:
194
+ return []
195
+
196
+ if _is_scalar(score_value):
197
+ return [{"scope": "sequence", scalar_column: _number(score_value)}]
198
+
199
+ if isinstance(score_value, Mapping):
200
+ if score_value and not all(_position_key(key) for key in score_value):
201
+ series_lengths = {
202
+ len(value)
203
+ for value in score_value.values()
204
+ if isinstance(value, list) and all(_is_scalar(item) for item in value)
205
+ }
206
+ if len(series_lengths) == 1:
207
+ length = series_lengths.pop()
208
+ if length > 1 and all(isinstance(value, list) for value in score_value.values()):
209
+ return [
210
+ {
211
+ "position": position,
212
+ **{str(key): _number(value[position]) for key, value in score_value.items()},
213
+ }
214
+ for position in range(length)
215
+ ]
216
+ if score_value and all(_position_key(key) for key in score_value):
217
+ rows = []
218
+ for key, value in score_value.items():
219
+ row = {"position": int(str(key))}
220
+ if isinstance(value, Mapping):
221
+ row.update(_flatten_mapping(value, channels, scalar_column))
222
+ elif isinstance(value, list):
223
+ row.update(_vector_row(value, channels, scalar_column, scope="position"))
224
+ row.pop("scope", None)
225
+ else:
226
+ row[scalar_column] = _number(value)
227
+ rows.append(row)
228
+ return rows
229
+ return [_flatten_mapping(score_value, channels, scalar_column)]
230
+
231
+ if isinstance(score_value, list):
232
+ if not score_value:
233
+ return []
234
+ if all(_is_scalar(item) for item in score_value):
235
+ return [_vector_row(score_value, channels, scalar_column)]
236
+ rows = []
237
+ for index, item in enumerate(score_value):
238
+ item = _json_safe(item)
239
+ if isinstance(item, Mapping):
240
+ rows.append(_flatten_mapping(item, channels, scalar_column))
241
+ elif isinstance(item, list):
242
+ row = {"position": index}
243
+ row.update(_vector_row(item, channels, scalar_column, scope="position"))
244
+ row.pop("scope", None)
245
+ rows.append(row)
246
+ elif _is_scalar(item):
247
+ rows.append({"position": index, scalar_column: _number(item)})
248
+ return rows
249
+
250
+ return [{"scope": "sequence", scalar_column: score_value}]
251
+
252
+
253
+ def result_table(result: Mapping[str, Any], score_key: str, scores_key: str, scalar_column: str) -> pd.DataFrame:
254
+ channels = [str(channel) for channel in result.get("channels", [])]
255
+ score_value = result.get(scores_key, result.get(score_key))
256
+ rows = normalize_score_rows(score_value, channels, scalar_column)
257
+ if not rows:
258
+ return pd.DataFrame()
259
+
260
+ table = pd.DataFrame(rows)
261
+ ordered = [column for column in ("scope", "position", "nucleotide", "sequence", "label", "type") if column in table]
262
+ remaining = [column for column in table.columns if column not in ordered]
263
+ return table[ordered + remaining]
264
+
265
+
266
+ def dataframe_records(table: pd.DataFrame) -> list[dict[str, Any]]:
267
+ if table.empty:
268
+ return []
269
+ return json.loads(table.to_json(orient="records"))
270
+
271
+
272
+ def difference_summary(reference: str, alternative: str) -> dict[str, Any]:
273
+ differences = [
274
+ {
275
+ "position": index,
276
+ "reference": ref_base,
277
+ "alternative": alt_base,
278
+ }
279
+ for index, (ref_base, alt_base) in enumerate(zip(reference, alternative))
280
+ if ref_base != alt_base
281
+ ]
282
+ return {
283
+ "count": len(differences),
284
+ "positions": differences[:25],
285
+ "positions_truncated": len(differences) > 25,
286
+ }
287
+
288
+
289
+ def make_delta_plot(delta_table: pd.DataFrame, model_label: str):
290
+ fig, ax = plt.subplots(figsize=(7, 2.8))
291
+ values: list[tuple[str, float]] = []
292
+
293
+ if not delta_table.empty:
294
+ numeric_columns = [
295
+ column
296
+ for column in delta_table.columns
297
+ if column not in META_COLUMNS and pd.api.types.is_numeric_dtype(delta_table[column])
298
+ ]
299
+ for _, row in delta_table.iterrows():
300
+ position = row.get("position")
301
+ for column in numeric_columns:
302
+ value = row.get(column)
303
+ if pd.notna(value):
304
+ suffix = f"@{int(position)}" if position is not None and pd.notna(position) else ""
305
+ values.append((f"{column}{suffix}", float(value)))
306
+
307
+ values = sorted(values, key=lambda item: abs(item[1]), reverse=True)[:20]
308
+ values.reverse()
309
+ if not values:
310
+ ax.text(0.5, 0.5, "No numeric delta scores", ha="center", va="center")
311
+ ax.set_axis_off()
312
+ fig.tight_layout()
313
+ return fig
314
+
315
+ labels, scores = zip(*values)
316
+ colors = ["#2563eb" if score >= 0 else "#dc2626" for score in scores]
317
+ ax.barh(labels, scores, color=colors)
318
+ ax.axvline(0, color="#111827", linewidth=0.8)
319
+ ax.set_title(f"{model_label} top delta scores")
320
+ ax.set_xlabel("alternative - reference")
321
+ ax.tick_params(axis="y", labelsize=8)
322
+ fig.tight_layout()
323
+ return fig
324
+
325
+
326
+ def write_result_files(
327
+ metadata: dict[str, Any],
328
+ result: Mapping[str, Any],
329
+ delta_table: pd.DataFrame,
330
+ reference_table: pd.DataFrame,
331
+ alternative_table: pd.DataFrame,
332
+ ) -> tuple[str, str]:
333
+ csv_tables = []
334
+ for score_set, table in (
335
+ ("delta", delta_table),
336
+ ("reference", reference_table),
337
+ ("alternative", alternative_table),
338
+ ):
339
+ if not table.empty:
340
+ csv_table = table.copy()
341
+ csv_table.insert(0, "score_set", score_set)
342
+ csv_tables.append(csv_table)
343
+ csv_payload = pd.concat(csv_tables, ignore_index=True, sort=False) if csv_tables else pd.DataFrame()
344
+
345
+ csv_file = tempfile.NamedTemporaryFile("w", suffix=".csv", delete=False, newline="")
346
+ csv_path = csv_file.name
347
+ csv_file.close()
348
+ csv_payload.to_csv(csv_path, index=False)
349
+
350
+ json_payload = {
351
+ "metadata": metadata,
352
+ "result": _json_safe(result),
353
+ "tables": {
354
+ "delta": dataframe_records(delta_table),
355
+ "reference": dataframe_records(reference_table),
356
+ "alternative": dataframe_records(alternative_table),
357
+ },
358
+ }
359
+ json_file = tempfile.NamedTemporaryFile("w", suffix=".json", delete=False)
360
+ json_path = json_file.name
361
+ json_file.close()
362
+ with open(json_path, "w") as handle:
363
+ json.dump(json_payload, handle, indent=2)
364
+
365
+ return csv_path, json_path
366
+
367
+
368
+ def unpack_prediction_result(result: Any) -> Mapping[str, Any]:
369
+ result = _json_safe(result)
370
+ if isinstance(result, list):
371
+ if len(result) != 1:
372
+ raise gr.Error(f"Expected one prediction result, got {len(result)}.")
373
+ result = result[0]
374
+ if not isinstance(result, Mapping):
375
+ raise gr.Error(f"Expected a prediction dictionary, got {type(result).__name__}.")
376
+ return result
377
+
378
+
379
+ def predict(model_label: str, reference: str, alternative: str):
380
+ started = time.perf_counter()
381
+ model_id = MODEL_OPTIONS[model_label]
382
+ reference, alternative = validate_pair(reference, alternative)
383
+
384
+ try:
385
+ predictor = load_predictor(model_id)
386
+ result = unpack_prediction_result(predictor(reference, alternative=alternative))
387
+ except gr.Error:
388
+ raise
389
+ except Exception as exc:
390
+ raise gr.Error(f"Prediction failed for {model_label}: {exc}") from exc
391
+
392
+ delta_table = result_table(result, "delta_score", "delta_scores", "delta_score")
393
+ reference_table = result_table(result, "reference_score", "reference_scores", "reference_score")
394
+ alternative_table = result_table(result, "alternative_score", "alternative_scores", "alternative_score")
395
+ metadata = {
396
+ "task": "splice-variant-effect",
397
+ "model": model_id,
398
+ "model_label": model_label,
399
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
400
+ "reference_length": len(reference),
401
+ "alternative_length": len(alternative),
402
+ "differences": difference_summary(reference, alternative),
403
+ "channels": result.get("channels", []),
404
+ "output_fields": sorted(result.keys()),
405
+ "runtime_seconds": round(time.perf_counter() - started, 3),
406
+ "timestamp_utc": datetime.now(timezone.utc).isoformat(),
407
+ }
408
+ csv_path, json_path = write_result_files(metadata, result, delta_table, reference_table, alternative_table)
409
+ delta_plot = make_delta_plot(delta_table, model_label)
410
+ return delta_table, reference_table, alternative_table, metadata, delta_plot, csv_path, json_path
411
+
412
+
413
+ def initial_model(request: gr.Request):
414
+ if request is None:
415
+ return "MMSplice"
416
+
417
+ query_params = getattr(request, "query_params", None)
418
+ model_id = None
419
+ if query_params is not None:
420
+ model_id = query_params.get("model")
421
+ if not model_id and getattr(request, "url", None):
422
+ parsed = parse_qs(urlparse(str(request.url)).query)
423
+ model_values = parsed.get("model")
424
+ model_id = model_values[0] if model_values else None
425
+
426
+ return MODEL_LABELS.get(model_id, "MMSplice")
427
+
428
+
429
+ with gr.Blocks(title="Splice Variant Effect") as demo:
430
+ gr.Markdown(
431
+ "# Splice Variant Effect\n"
432
+ "Score paired reference and alternative DNA windows with MultiMolecule splice variant-effect models."
433
+ )
434
+
435
+ model = gr.Dropdown(
436
+ choices=list(MODEL_OPTIONS.keys()),
437
+ value="MMSplice",
438
+ label="Checkpoint",
439
+ )
440
+
441
+ with gr.Row():
442
+ reference = gr.Textbox(
443
+ label="Reference DNA sequence",
444
+ value=DEFAULT_REFERENCE,
445
+ lines=5,
446
+ )
447
+ alternative = gr.Textbox(
448
+ label="Alternative DNA sequence",
449
+ value=DEFAULT_ALTERNATIVE,
450
+ lines=5,
451
+ )
452
+
453
+ input_file = gr.File(
454
+ label="Upload paired FASTA (reference record first, alternative record second)",
455
+ file_types=[".fa", ".fasta", ".fna"],
456
+ )
457
+ run = gr.Button("Run variant effect", variant="primary")
458
+
459
+ with gr.Row():
460
+ delta_scores = gr.Dataframe(label="Delta scores")
461
+ run_metadata = gr.JSON(label="Run metadata")
462
+
463
+ with gr.Row():
464
+ reference_scores = gr.Dataframe(label="Reference scores")
465
+ alternative_scores = gr.Dataframe(label="Alternative scores")
466
+
467
+ delta_plot = gr.Plot(label="Top delta scores")
468
+
469
+ with gr.Row():
470
+ csv_download = gr.File(label="Download CSV")
471
+ json_download = gr.File(label="Download JSON")
472
+
473
+ run.click(
474
+ predict,
475
+ inputs=[model, reference, alternative],
476
+ outputs=[
477
+ delta_scores,
478
+ reference_scores,
479
+ alternative_scores,
480
+ run_metadata,
481
+ delta_plot,
482
+ csv_download,
483
+ json_download,
484
+ ],
485
+ )
486
+ input_file.change(load_fasta_pair, inputs=input_file, outputs=[reference, alternative])
487
+ demo.load(initial_model, outputs=model)
488
+
489
+
490
+ if __name__ == "__main__":
491
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ biopython
2
+ matplotlib
3
+ multimolecule @ git+https://github.com/DLS5-Omics/multimolecule.git@master
4
+ numpy
5
+ pandas
6
+ torch
7
+ transformers