Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
b65f002
1
Parent(s):
beb6a82
refactor: clean code
Browse files- app.py +81 -59
- bigwig_export.py +11 -7
- ntv3_tracks_pipeline.py +72 -61
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -8,11 +8,10 @@ from pathlib import Path
|
|
| 8 |
import gradio as gr
|
| 9 |
import matplotlib
|
| 10 |
import matplotlib.colors as mcolors
|
| 11 |
-
import matplotlib.pyplot as plt
|
| 12 |
import numpy as np
|
| 13 |
import plotly.graph_objects as go
|
| 14 |
-
from plotly.subplots import make_subplots
|
| 15 |
import torch
|
|
|
|
| 16 |
|
| 17 |
from bigwig_export import _softmax_last, create_bigwig_zip
|
| 18 |
from ntv3_tracks_pipeline import (
|
|
@@ -57,7 +56,7 @@ def load_pipeline(model_id: str, species: str = DEFAULT_SPECIES):
|
|
| 57 |
pipe = load_ntv3_tracks_pipeline(
|
| 58 |
model=model_id,
|
| 59 |
token=HF_TOKEN,
|
| 60 |
-
device="cpu", #
|
| 61 |
default_species=species,
|
| 62 |
verbose=False,
|
| 63 |
)
|
|
@@ -100,25 +99,29 @@ try:
|
|
| 100 |
except Exception:
|
| 101 |
|
| 102 |
def gpu(*args, **kwargs):
|
|
|
|
|
|
|
| 103 |
def wrap(fn):
|
| 104 |
return fn
|
| 105 |
|
| 106 |
return wrap
|
| 107 |
|
| 108 |
|
| 109 |
-
def _global_stride(
|
| 110 |
-
if target <= 0 or
|
| 111 |
return 1
|
| 112 |
-
return int(np.ceil(
|
| 113 |
|
| 114 |
|
| 115 |
-
def _make_tracks_figure(
|
|
|
|
|
|
|
| 116 |
"""Create an interactive plotly figure with multiple tracks."""
|
| 117 |
if not series:
|
| 118 |
raise gr.Error("Nothing to plot (no tracks/elements selected).")
|
| 119 |
|
| 120 |
n = len(series)
|
| 121 |
-
|
| 122 |
# Create subplots with shared x-axis
|
| 123 |
fig = make_subplots(
|
| 124 |
rows=n,
|
|
@@ -140,8 +143,10 @@ def _make_tracks_figure(x: np.ndarray, series: list[tuple[str, np.ndarray]], reg
|
|
| 140 |
|
| 141 |
# Convert color to rgba for fill
|
| 142 |
rgba = mcolors.to_rgba(color)
|
| 143 |
-
rgba_str =
|
| 144 |
-
|
|
|
|
|
|
|
| 145 |
# Add filled area (fill_between equivalent)
|
| 146 |
fig.add_trace(
|
| 147 |
go.Scatter(
|
|
@@ -149,12 +154,12 @@ def _make_tracks_figure(x: np.ndarray, series: list[tuple[str, np.ndarray]], reg
|
|
| 149 |
y=y,
|
| 150 |
mode="lines",
|
| 151 |
name=title,
|
| 152 |
-
line=
|
| 153 |
fill="tozeroy",
|
| 154 |
fillcolor=rgba_str,
|
| 155 |
-
hovertemplate=f"<b>{title}</b><br>"
|
| 156 |
-
|
| 157 |
-
|
| 158 |
showlegend=False,
|
| 159 |
),
|
| 160 |
row=i,
|
|
@@ -165,7 +170,7 @@ def _make_tracks_figure(x: np.ndarray, series: list[tuple[str, np.ndarray]], reg
|
|
| 165 |
fig.update_layout(
|
| 166 |
height=150 * n, # Adjust height based on number of tracks
|
| 167 |
width=1200,
|
| 168 |
-
margin=
|
| 169 |
hovermode="x unified", # Show all values at same x position
|
| 170 |
template="plotly_white",
|
| 171 |
)
|
|
@@ -278,7 +283,7 @@ def _format_track_for_display(track_id: str) -> str:
|
|
| 278 |
|
| 279 |
|
| 280 |
def _extract_track_id(display_value: str) -> str:
|
| 281 |
-
"""Extract track ID from display format
|
| 282 |
if " (" in display_value and display_value.endswith(")"):
|
| 283 |
# Extract track_id from format "display_name (track_id)"
|
| 284 |
return display_value.rsplit(" (", 1)[1][:-1]
|
|
@@ -455,6 +460,7 @@ def update_coords_on_species_change(species: str):
|
|
| 455 |
|
| 456 |
|
| 457 |
def reset_on_species_change(species: str):
|
|
|
|
| 458 |
# Clear results + selected when species changes (avoids mismatched IDs)
|
| 459 |
try:
|
| 460 |
track_ids = _get_bigwig_names(species) # warms cache if available
|
|
@@ -500,6 +506,7 @@ def predict(
|
|
| 500 |
bigwig_selected: list[str],
|
| 501 |
bed_elements: list[str],
|
| 502 |
):
|
|
|
|
| 503 |
tprint("start")
|
| 504 |
|
| 505 |
# Debug: verify species is being passed
|
|
@@ -515,10 +522,11 @@ def predict(
|
|
| 515 |
if use_coords:
|
| 516 |
# Check if this species supports coordinate-based fetching
|
| 517 |
if species not in SPECIES_WITH_COORDINATE_SUPPORT:
|
|
|
|
| 518 |
raise gr.Error(
|
| 519 |
-
f"Species '{species}' does not support coordinate-based sequence
|
| 520 |
-
f"Please provide a DNA sequence directly or use one of
|
| 521 |
-
f"{
|
| 522 |
)
|
| 523 |
if not chrom:
|
| 524 |
raise gr.Error("chrom is required when use_coords=True")
|
|
@@ -537,8 +545,10 @@ def predict(
|
|
| 537 |
|
| 538 |
# Verify species is in inputs before calling pipeline
|
| 539 |
if "species" not in inputs:
|
|
|
|
| 540 |
raise gr.Error(
|
| 541 |
-
f"Internal error: species not found in inputs dict.
|
|
|
|
| 542 |
)
|
| 543 |
|
| 544 |
tprint("inputs prepared")
|
|
@@ -576,12 +586,15 @@ def predict(
|
|
| 576 |
|
| 577 |
if not has_bigwigs and not has_bed:
|
| 578 |
raise gr.Error(
|
| 579 |
-
"No BigWig tracks or BED elements available for this species
|
|
|
|
| 580 |
)
|
| 581 |
|
| 582 |
if not has_bigwigs and bigwig_selected:
|
| 583 |
raise gr.Error(
|
| 584 |
-
"No BigWig tracks available for this species, but BigWig tracks
|
|
|
|
|
|
|
| 585 |
)
|
| 586 |
|
| 587 |
# Defaults if user picked none
|
|
@@ -617,17 +630,17 @@ def predict(
|
|
| 617 |
|
| 618 |
# Determine sequence length from available data
|
| 619 |
if has_bigwigs:
|
| 620 |
-
|
| 621 |
elif has_bed:
|
| 622 |
-
|
| 623 |
else:
|
| 624 |
raise gr.Error("No data available for plotting.")
|
| 625 |
|
| 626 |
-
stride = _global_stride(
|
| 627 |
|
| 628 |
x0 = int(out.pred_start or 0)
|
| 629 |
-
x1 = int(out.pred_end or (x0 +
|
| 630 |
-
x = np.linspace(x0, x1, num=
|
| 631 |
|
| 632 |
series: list[tuple[str, np.ndarray]] = []
|
| 633 |
|
|
@@ -645,14 +658,14 @@ def predict(
|
|
| 645 |
series.append((ename, probs[:, eidx, 1][::stride].astype(float)))
|
| 646 |
|
| 647 |
tprint("figure data processed created")
|
| 648 |
-
|
| 649 |
# Build region string for x-axis label
|
| 650 |
region = (
|
| 651 |
f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
|
| 652 |
)
|
| 653 |
if out.assembly:
|
| 654 |
region += f" ({out.assembly})"
|
| 655 |
-
|
| 656 |
fig = _make_tracks_figure(x, series, region=region)
|
| 657 |
tprint("figure created")
|
| 658 |
|
|
@@ -680,7 +693,10 @@ def predict(
|
|
| 680 |
# -----------------------------
|
| 681 |
CSS = """
|
| 682 |
#tracks_plot { position: relative; width: 100% !important; max-width: 100% !important; }
|
| 683 |
-
#tracks_plot .wrap, #tracks_plot .plot-container {
|
|
|
|
|
|
|
|
|
|
| 684 |
|
| 685 |
#tracks_plot_download {
|
| 686 |
position: absolute;
|
|
@@ -916,7 +932,8 @@ function addDownloadIcon() {
|
|
| 916 |
btn.title = "Download PNG";
|
| 917 |
btn.innerHTML = `
|
| 918 |
<svg viewBox="0 0 24 24" aria-hidden="true">
|
| 919 |
-
<path d="M5 20h14v-2H5v2zm7-18v10.17l3.59-3.58L17 10l-5 5-5-5
|
|
|
|
| 920 |
</svg>
|
| 921 |
`;
|
| 922 |
btn.onclick = () => {
|
|
@@ -1024,8 +1041,10 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
|
|
| 1024 |
<div class="intro-card">
|
| 1025 |
<h3>2) Choose signals</h3>
|
| 1026 |
<ul>
|
| 1027 |
-
<li>Search & select <strong>BigWig functional tracks</strong>
|
| 1028 |
-
|
|
|
|
|
|
|
| 1029 |
</ul>
|
| 1030 |
</div>
|
| 1031 |
|
|
@@ -1041,10 +1060,12 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
|
|
| 1041 |
|
| 1042 |
<div class="intro-tip">
|
| 1043 |
<span class="intro-tip-icon">💡</span>
|
| 1044 |
-
<span><strong>Tip:</strong> The demo includes default settings that you can use
|
|
|
|
| 1045 |
</div>
|
| 1046 |
|
| 1047 |
-
<div style="margin-top: 16px; padding: 12px; background: rgba(0,0,0,0.03);
|
|
|
|
| 1048 |
<strong>Available species:</strong> {_all_species_list}<br>
|
| 1049 |
<br>
|
| 1050 |
<strong>Species with functional tracks:</strong> {_bigwig_species_list}
|
|
@@ -1059,8 +1080,8 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
|
|
| 1059 |
|
| 1060 |
# Model display names (without InstaDeepAI/ prefix) and their full IDs
|
| 1061 |
MODEL_OPTIONS = {
|
| 1062 |
-
"NTv3 650M (
|
| 1063 |
-
"NTv3 100M (
|
| 1064 |
}
|
| 1065 |
|
| 1066 |
# Reverse mapping: full ID -> display name
|
|
@@ -1112,11 +1133,9 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
|
|
| 1112 |
+ ")"
|
| 1113 |
)
|
| 1114 |
with gr.Row():
|
| 1115 |
-
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
)
|
| 1119 |
-
end = gr.Number(label="End", value=_default_coords["end"], precision=0)
|
| 1120 |
|
| 1121 |
# DNA sequence section - visible only when "Enter DNA sequence" is selected
|
| 1122 |
# Using Textbox directly (not wrapped in Group) to avoid visual border/line
|
|
@@ -1189,7 +1208,8 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
|
|
| 1189 |
)
|
| 1190 |
|
| 1191 |
bigwig_no_tracks_msg = gr.Markdown(
|
| 1192 |
-
"⚠️ No functional genomic tracks available for this species
|
|
|
|
| 1193 |
visible=False,
|
| 1194 |
)
|
| 1195 |
|
|
@@ -1318,19 +1338,18 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
|
|
| 1318 |
current_species, current_query, upd["value"]
|
| 1319 |
)
|
| 1320 |
|
| 1321 |
-
# Create a completely fresh update with explicit empty value
|
| 1322 |
-
#
|
| 1323 |
-
#
|
| 1324 |
-
#
|
| 1325 |
-
#
|
| 1326 |
-
|
| 1327 |
-
# Ensure no items from results_checked are in new_choices (they should already be filtered, but double-check)
|
| 1328 |
checked_track_ids = {_extract_track_id(x) for x in results_checked}
|
| 1329 |
new_choices_filtered = [
|
| 1330 |
c for c in new_choices if _extract_track_id(c) not in checked_track_ids
|
| 1331 |
]
|
| 1332 |
|
| 1333 |
-
# Create update with explicit empty value
|
|
|
|
| 1334 |
fresh_update = gr.update(
|
| 1335 |
choices=new_choices_filtered,
|
| 1336 |
value=[], # CRITICAL: Explicitly empty list to clear all checked state
|
|
@@ -1366,7 +1385,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
|
|
| 1366 |
results_choices = (
|
| 1367 |
results_update.choices if hasattr(results_update, "choices") else []
|
| 1368 |
)
|
| 1369 |
-
except:
|
| 1370 |
# Fallback: get choices from the search function directly
|
| 1371 |
results_choices = _get_search_results_choices(
|
| 1372 |
current_species,
|
|
@@ -1395,10 +1414,12 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
|
|
| 1395 |
):
|
| 1396 |
"""Update selected tracks when user checks/unchecks items directly."""
|
| 1397 |
# selected_value contains only the currently checked items
|
| 1398 |
-
# Update choices to match
|
|
|
|
| 1399 |
show_selected = bool(selected_value)
|
| 1400 |
|
| 1401 |
-
# Also update search results to reflect
|
|
|
|
| 1402 |
search_updates = search_bigwigs(current_species, current_query, selected_value)
|
| 1403 |
|
| 1404 |
return (
|
|
@@ -1462,7 +1483,7 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
|
|
| 1462 |
]
|
| 1463 |
# Show selected tracks section if there are default tracks
|
| 1464 |
show_selected_tracks = bool(default_formatted)
|
| 1465 |
-
except:
|
| 1466 |
formatted_tracks = []
|
| 1467 |
default_formatted = []
|
| 1468 |
show_selected_tracks = False
|
|
@@ -1594,12 +1615,13 @@ with gr.Blocks(title="NTv3 Tracks Demo") as demo:
|
|
| 1594 |
try:
|
| 1595 |
zip_path = create_bigwig_zip(out, bw_selected, bed_selected)
|
| 1596 |
return gr.update(value=zip_path, visible=True)
|
| 1597 |
-
except ImportError
|
| 1598 |
raise gr.Error(
|
| 1599 |
-
"pyBigWig is required for BigWig export.
|
|
|
|
| 1600 |
)
|
| 1601 |
-
except Exception as
|
| 1602 |
-
raise gr.Error(f"Error creating BigWig files: {str(
|
| 1603 |
|
| 1604 |
download_bigwig_btn.click(
|
| 1605 |
fn=download_bigwig_zip,
|
|
|
|
| 8 |
import gradio as gr
|
| 9 |
import matplotlib
|
| 10 |
import matplotlib.colors as mcolors
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import plotly.graph_objects as go
|
|
|
|
| 13 |
import torch
|
| 14 |
+
from plotly.subplots import make_subplots
|
| 15 |
|
| 16 |
from bigwig_export import _softmax_last, create_bigwig_zip
|
| 17 |
from ntv3_tracks_pipeline import (
|
|
|
|
| 56 |
pipe = load_ntv3_tracks_pipeline(
|
| 57 |
model=model_id,
|
| 58 |
token=HF_TOKEN,
|
| 59 |
+
device="cpu", # Prevents model.to(cuda) during import
|
| 60 |
default_species=species,
|
| 61 |
verbose=False,
|
| 62 |
)
|
|
|
|
| 99 |
except Exception:
|
| 100 |
|
| 101 |
def gpu(*args, **kwargs):
|
| 102 |
+
"""GPU decorator placeholder when spaces module is not available."""
|
| 103 |
+
|
| 104 |
def wrap(fn):
|
| 105 |
return fn
|
| 106 |
|
| 107 |
return wrap
|
| 108 |
|
| 109 |
|
| 110 |
+
def _global_stride(length: int, target: int) -> int:
|
| 111 |
+
if target <= 0 or length <= target:
|
| 112 |
return 1
|
| 113 |
+
return int(np.ceil(length / target))
|
| 114 |
|
| 115 |
|
| 116 |
+
def _make_tracks_figure(
|
| 117 |
+
x: np.ndarray, series: list[tuple[str, np.ndarray]], region: str = ""
|
| 118 |
+
):
|
| 119 |
"""Create an interactive plotly figure with multiple tracks."""
|
| 120 |
if not series:
|
| 121 |
raise gr.Error("Nothing to plot (no tracks/elements selected).")
|
| 122 |
|
| 123 |
n = len(series)
|
| 124 |
+
|
| 125 |
# Create subplots with shared x-axis
|
| 126 |
fig = make_subplots(
|
| 127 |
rows=n,
|
|
|
|
| 143 |
|
| 144 |
# Convert color to rgba for fill
|
| 145 |
rgba = mcolors.to_rgba(color)
|
| 146 |
+
rgba_str = (
|
| 147 |
+
f"rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, 0.3)"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
# Add filled area (fill_between equivalent)
|
| 151 |
fig.add_trace(
|
| 152 |
go.Scatter(
|
|
|
|
| 154 |
y=y,
|
| 155 |
mode="lines",
|
| 156 |
name=title,
|
| 157 |
+
line={"color": color, "width": 1.5},
|
| 158 |
fill="tozeroy",
|
| 159 |
fillcolor=rgba_str,
|
| 160 |
+
hovertemplate=f"<b>{title}</b><br>"
|
| 161 |
+
+ "Position: %{x}<br>"
|
| 162 |
+
+ "Value: %{y:.4f}<extra></extra>",
|
| 163 |
showlegend=False,
|
| 164 |
),
|
| 165 |
row=i,
|
|
|
|
| 170 |
fig.update_layout(
|
| 171 |
height=150 * n, # Adjust height based on number of tracks
|
| 172 |
width=1200,
|
| 173 |
+
margin={"l": 80, "r": 20, "t": 40, "b": 60},
|
| 174 |
hovermode="x unified", # Show all values at same x position
|
| 175 |
template="plotly_white",
|
| 176 |
)
|
|
|
|
| 283 |
|
| 284 |
|
| 285 |
def _extract_track_id(display_value: str) -> str:
|
| 286 |
+
"""Extract track ID from display format or return as-is."""
|
| 287 |
if " (" in display_value and display_value.endswith(")"):
|
| 288 |
# Extract track_id from format "display_name (track_id)"
|
| 289 |
return display_value.rsplit(" (", 1)[1][:-1]
|
|
|
|
| 460 |
|
| 461 |
|
| 462 |
def reset_on_species_change(species: str):
|
| 463 |
+
"""Reset search and selected tracks when species changes."""
|
| 464 |
# Clear results + selected when species changes (avoids mismatched IDs)
|
| 465 |
try:
|
| 466 |
track_ids = _get_bigwig_names(species) # warms cache if available
|
|
|
|
| 506 |
bigwig_selected: list[str],
|
| 507 |
bed_elements: list[str],
|
| 508 |
):
|
| 509 |
+
"""Run prediction and return figure with tracks."""
|
| 510 |
tprint("start")
|
| 511 |
|
| 512 |
# Debug: verify species is being passed
|
|
|
|
| 522 |
if use_coords:
|
| 523 |
# Check if this species supports coordinate-based fetching
|
| 524 |
if species not in SPECIES_WITH_COORDINATE_SUPPORT:
|
| 525 |
+
supported = ", ".join(sorted(SPECIES_WITH_COORDINATE_SUPPORT))
|
| 526 |
raise gr.Error(
|
| 527 |
+
f"Species '{species}' does not support coordinate-based sequence "
|
| 528 |
+
f"fetching. Please provide a DNA sequence directly or use one of "
|
| 529 |
+
f"the supported species: {supported}"
|
| 530 |
)
|
| 531 |
if not chrom:
|
| 532 |
raise gr.Error("chrom is required when use_coords=True")
|
|
|
|
| 545 |
|
| 546 |
# Verify species is in inputs before calling pipeline
|
| 547 |
if "species" not in inputs:
|
| 548 |
+
input_keys = list(inputs.keys())
|
| 549 |
raise gr.Error(
|
| 550 |
+
f"Internal error: species not found in inputs dict. "
|
| 551 |
+
f"Inputs: {input_keys}"
|
| 552 |
)
|
| 553 |
|
| 554 |
tprint("inputs prepared")
|
|
|
|
| 586 |
|
| 587 |
if not has_bigwigs and not has_bed:
|
| 588 |
raise gr.Error(
|
| 589 |
+
"No BigWig tracks or BED elements available for this species "
|
| 590 |
+
"in the current model."
|
| 591 |
)
|
| 592 |
|
| 593 |
if not has_bigwigs and bigwig_selected:
|
| 594 |
raise gr.Error(
|
| 595 |
+
"No BigWig tracks available for this species, but BigWig tracks "
|
| 596 |
+
"were selected. Please deselect BigWig tracks or choose a "
|
| 597 |
+
"different species."
|
| 598 |
)
|
| 599 |
|
| 600 |
# Defaults if user picked none
|
|
|
|
| 630 |
|
| 631 |
# Determine sequence length from available data
|
| 632 |
if has_bigwigs:
|
| 633 |
+
seq_length = bw.shape[0]
|
| 634 |
elif has_bed:
|
| 635 |
+
seq_length = bed_logits.shape[0]
|
| 636 |
else:
|
| 637 |
raise gr.Error("No data available for plotting.")
|
| 638 |
|
| 639 |
+
stride = _global_stride(seq_length, PLOT_TARGET_POINTS)
|
| 640 |
|
| 641 |
x0 = int(out.pred_start or 0)
|
| 642 |
+
x1 = int(out.pred_end or (x0 + seq_length))
|
| 643 |
+
x = np.linspace(x0, x1, num=seq_length, endpoint=False)[::stride]
|
| 644 |
|
| 645 |
series: list[tuple[str, np.ndarray]] = []
|
| 646 |
|
|
|
|
| 658 |
series.append((ename, probs[:, eidx, 1][::stride].astype(float)))
|
| 659 |
|
| 660 |
tprint("figure data processed created")
|
| 661 |
+
|
| 662 |
# Build region string for x-axis label
|
| 663 |
region = (
|
| 664 |
f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
|
| 665 |
)
|
| 666 |
if out.assembly:
|
| 667 |
region += f" ({out.assembly})"
|
| 668 |
+
|
| 669 |
fig = _make_tracks_figure(x, series, region=region)
|
| 670 |
tprint("figure created")
|
| 671 |
|
|
|
|
| 693 |
# -----------------------------
|
| 694 |
CSS = """
|
| 695 |
#tracks_plot { position: relative; width: 100% !important; max-width: 100% !important; }
|
| 696 |
+
#tracks_plot .wrap, #tracks_plot .plot-container {
|
| 697 |
+
width: 100% !important;
|
| 698 |
+
max-width: 100% !important;
|
| 699 |
+
}
|
| 700 |
|
| 701 |
#tracks_plot_download {
|
| 702 |
position: absolute;
|
|
|
|
| 932 |
btn.title = "Download PNG";
|
| 933 |
btn.innerHTML = `
|
| 934 |
<svg viewBox="0 0 24 24" aria-hidden="true">
|
| 935 |
+
<path d="M5 20h14v-2H5v2zm7-18v10.17l3.59-3.58L17 10l-5 5-5-5
|
| 936 |
+
1.41-1.41L11 12.17V2h1z"/>
|
| 937 |
</svg>
|
| 938 |
`;
|
| 939 |
btn.onclick = () => {
|
|
|
|
| 1041 |
<div class="intro-card">
|
| 1042 |
<h3>2) Choose signals</h3>
|
| 1043 |
<ul>
|
| 1044 |
+
<li>Search & select <strong>BigWig functional tracks</strong>
|
| 1045 |
+
(RNA-seq, ChIP-seq, DNase…)</li>
|
| 1046 |
+
<li>Select <strong>BED genome annotation elements</strong>
|
| 1047 |
+
(exons, introns, promoters…)</li>
|
| 1048 |
</ul>
|
| 1049 |
</div>
|
| 1050 |
|
|
|
|
| 1060 |
|
| 1061 |
<div class="intro-tip">
|
| 1062 |
<span class="intro-tip-icon">💡</span>
|
| 1063 |
+
<span><strong>Tip:</strong> The demo includes default settings that you can use
|
| 1064 |
+
to get started, taking ~ 15 seconds to run for the example on human.</span>
|
| 1065 |
</div>
|
| 1066 |
|
| 1067 |
+
<div style="margin-top: 16px; padding: 12px; background: rgba(0,0,0,0.03);
|
| 1068 |
+
border-radius: 12px; font-size: 0.95rem;">
|
| 1069 |
<strong>Available species:</strong> {_all_species_list}<br>
|
| 1070 |
<br>
|
| 1071 |
<strong>Species with functional tracks:</strong> {_bigwig_species_list}
|
|
|
|
| 1080 |
|
| 1081 |
# Model display names (without InstaDeepAI/ prefix) and their full IDs
|
| 1082 |
MODEL_OPTIONS = {
|
| 1083 |
+
"NTv3 650M (post)": "InstaDeepAI/NTv3_650M_pos",
|
| 1084 |
+
"NTv3 100M (post)": "InstaDeepAI/NTv3_100M_pos",
|
| 1085 |
}
|
| 1086 |
|
| 1087 |
# Reverse mapping: full ID -> display name
|
|
|
|
| 1133 |
+ ")"
|
| 1134 |
)
|
| 1135 |
with gr.Row():
|
| 1136 |
+
chrom = gr.Textbox(label="Chromosome", value=_default_coords["chrom"])
|
| 1137 |
+
start = gr.Number(label="Start", value=_default_coords["start"], precision=0)
|
| 1138 |
+
end = gr.Number(label="End", value=_default_coords["end"], precision=0)
|
|
|
|
|
|
|
| 1139 |
|
| 1140 |
# DNA sequence section - visible only when "Enter DNA sequence" is selected
|
| 1141 |
# Using Textbox directly (not wrapped in Group) to avoid visual border/line
|
|
|
|
| 1208 |
)
|
| 1209 |
|
| 1210 |
bigwig_no_tracks_msg = gr.Markdown(
|
| 1211 |
+
"⚠️ No functional genomic tracks available for this species "
|
| 1212 |
+
"in the current model.",
|
| 1213 |
visible=False,
|
| 1214 |
)
|
| 1215 |
|
|
|
|
| 1338 |
current_species, current_query, upd["value"]
|
| 1339 |
)
|
| 1340 |
|
| 1341 |
+
# Create a completely fresh update with explicit empty value
|
| 1342 |
+
# to prevent any checked state. Force Gradio to clear checked state
|
| 1343 |
+
# by explicitly setting value to empty list.
|
| 1344 |
+
# Ensure no items from results_checked are in new_choices
|
| 1345 |
+
# (they should already be filtered, but double-check)
|
|
|
|
|
|
|
| 1346 |
checked_track_ids = {_extract_track_id(x) for x in results_checked}
|
| 1347 |
new_choices_filtered = [
|
| 1348 |
c for c in new_choices if _extract_track_id(c) not in checked_track_ids
|
| 1349 |
]
|
| 1350 |
|
| 1351 |
+
# Create update with explicit empty value
|
| 1352 |
+
# This should force Gradio to clear all checked items
|
| 1353 |
fresh_update = gr.update(
|
| 1354 |
choices=new_choices_filtered,
|
| 1355 |
value=[], # CRITICAL: Explicitly empty list to clear all checked state
|
|
|
|
| 1385 |
results_choices = (
|
| 1386 |
results_update.choices if hasattr(results_update, "choices") else []
|
| 1387 |
)
|
| 1388 |
+
except Exception:
|
| 1389 |
# Fallback: get choices from the search function directly
|
| 1390 |
results_choices = _get_search_results_choices(
|
| 1391 |
current_species,
|
|
|
|
| 1414 |
):
|
| 1415 |
"""Update selected tracks when user checks/unchecks items directly."""
|
| 1416 |
# selected_value contains only the currently checked items
|
| 1417 |
+
# Update choices to match current selections
|
| 1418 |
+
# (unchecked items are removed)
|
| 1419 |
show_selected = bool(selected_value)
|
| 1420 |
|
| 1421 |
+
# Also update search results to reflect new selection
|
| 1422 |
+
# (unchecked tracks can now appear in results)
|
| 1423 |
search_updates = search_bigwigs(current_species, current_query, selected_value)
|
| 1424 |
|
| 1425 |
return (
|
|
|
|
| 1483 |
]
|
| 1484 |
# Show selected tracks section if there are default tracks
|
| 1485 |
show_selected_tracks = bool(default_formatted)
|
| 1486 |
+
except Exception:
|
| 1487 |
formatted_tracks = []
|
| 1488 |
default_formatted = []
|
| 1489 |
show_selected_tracks = False
|
|
|
|
| 1615 |
try:
|
| 1616 |
zip_path = create_bigwig_zip(out, bw_selected, bed_selected)
|
| 1617 |
return gr.update(value=zip_path, visible=True)
|
| 1618 |
+
except ImportError:
|
| 1619 |
raise gr.Error(
|
| 1620 |
+
"pyBigWig is required for BigWig export. "
|
| 1621 |
+
"Install with: pip install pyBigWig"
|
| 1622 |
)
|
| 1623 |
+
except Exception as exc:
|
| 1624 |
+
raise gr.Error(f"Error creating BigWig files: {str(exc)}")
|
| 1625 |
|
| 1626 |
download_bigwig_btn.click(
|
| 1627 |
fn=download_bigwig_zip,
|
bigwig_export.py
CHANGED
|
@@ -11,9 +11,9 @@ from typing import TYPE_CHECKING
|
|
| 11 |
import numpy as np
|
| 12 |
|
| 13 |
try:
|
| 14 |
-
import pyBigWig
|
| 15 |
except ImportError:
|
| 16 |
-
pyBigWig = None
|
| 17 |
|
| 18 |
if TYPE_CHECKING:
|
| 19 |
from ntv3_tracks_pipeline import NTv3TracksOutput
|
|
@@ -75,16 +75,20 @@ def create_bigwig_zip(
|
|
| 75 |
chrom = out.chrom
|
| 76 |
if chrom is None:
|
| 77 |
raise ValueError(
|
| 78 |
-
"Chromosome information not available. Use genomic coordinates
|
| 79 |
)
|
| 80 |
|
| 81 |
start = out.start
|
| 82 |
end = out.end
|
|
|
|
|
|
|
| 83 |
window_len = out.window_len or (end - start)
|
| 84 |
|
| 85 |
# Calculate prediction region (center 37.5%)
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# Create temporary directory for BigWig files
|
| 90 |
tmpdir = tempfile.gettempdir()
|
|
@@ -160,11 +164,11 @@ def create_bigwig_zip(
|
|
| 160 |
for bw_file in created_files:
|
| 161 |
try:
|
| 162 |
os.remove(bw_file)
|
| 163 |
-
except:
|
| 164 |
pass
|
| 165 |
try:
|
| 166 |
os.rmdir(output_dir)
|
| 167 |
-
except:
|
| 168 |
pass
|
| 169 |
|
| 170 |
return zip_path
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
|
| 13 |
try:
|
| 14 |
+
import pyBigWig # noqa: N816
|
| 15 |
except ImportError:
|
| 16 |
+
pyBigWig = None # noqa: N816
|
| 17 |
|
| 18 |
if TYPE_CHECKING:
|
| 19 |
from ntv3_tracks_pipeline import NTv3TracksOutput
|
|
|
|
| 75 |
chrom = out.chrom
|
| 76 |
if chrom is None:
|
| 77 |
raise ValueError(
|
| 78 |
+
"Chromosome information not available. Use genomic coordinates."
|
| 79 |
)
|
| 80 |
|
| 81 |
start = out.start
|
| 82 |
end = out.end
|
| 83 |
+
if start is None or end is None:
|
| 84 |
+
raise ValueError("Start and end coordinates are required for BigWig export.")
|
| 85 |
window_len = out.window_len or (end - start)
|
| 86 |
|
| 87 |
# Calculate prediction region (center 37.5%)
|
| 88 |
+
if out.pred_start is not None:
|
| 89 |
+
pred_start = out.pred_start
|
| 90 |
+
else:
|
| 91 |
+
pred_start = start + int(window_len * 0.3125)
|
| 92 |
|
| 93 |
# Create temporary directory for BigWig files
|
| 94 |
tmpdir = tempfile.gettempdir()
|
|
|
|
| 164 |
for bw_file in created_files:
|
| 165 |
try:
|
| 166 |
os.remove(bw_file)
|
| 167 |
+
except Exception:
|
| 168 |
pass
|
| 169 |
try:
|
| 170 |
os.rmdir(output_dir)
|
| 171 |
+
except Exception:
|
| 172 |
pass
|
| 173 |
|
| 174 |
return zip_path
|
ntv3_tracks_pipeline.py
CHANGED
|
@@ -74,13 +74,13 @@ SPECIES_WITH_COORDINATE_SUPPORT = {
|
|
| 74 |
# Assembly -> API URL template mapping
|
| 75 |
# ---------------------------------------------------------------------
|
| 76 |
# Default API URL template (UCSC format) that works for most species
|
| 77 |
-
DEFAULT_API_URL_TEMPLATE = "https://api.genome.ucsc.edu/getData/sequence?genome={assembly};chrom={chrom};start={start};end={end}"
|
| 78 |
|
| 79 |
# for species with different format, add the assembly name to the mapping
|
| 80 |
# The template should use {chrom}, {start}, and {end} as placeholders.
|
| 81 |
ASSEMBLY_TO_API_URL_TEMPLATE = {
|
| 82 |
# Arabidopsis thaliana (TAIR10) - uses hub URL format
|
| 83 |
-
"TAIR10": "https://api.genome.ucsc.edu/getData/sequence?hubUrl=http://genome.ucsc.edu/goldenPath/help/examples/hubExamples/hubAssembly/plantAraTha1/hub.txt;genome=araTha1;chrom={chrom};start={start};end={end}",
|
| 84 |
}
|
| 85 |
|
| 86 |
|
|
@@ -124,7 +124,8 @@ def _get_dna_sequence(assembly: str, chrom: str, start: int, end: int) -> str:
|
|
| 124 |
"""
|
| 125 |
if requests is None:
|
| 126 |
raise ImportError(
|
| 127 |
-
"requests is required for genome download.
|
|
|
|
| 128 |
)
|
| 129 |
|
| 130 |
# Get API URL template for this assembly, or use default
|
|
@@ -151,12 +152,11 @@ def _ensure_fasta_for_assembly(assembly: str, cache_dir: str | Path) -> Path:
|
|
| 151 |
if fa_path.exists():
|
| 152 |
return fa_path
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
)
|
| 160 |
|
| 161 |
import gzip
|
| 162 |
|
|
@@ -340,7 +340,8 @@ class NTv3TracksPipeline(Pipeline):
|
|
| 340 |
else:
|
| 341 |
self.tokenizer = tokenizer
|
| 342 |
|
| 343 |
-
# Extract model_id from config if not already set
|
|
|
|
| 344 |
if self.model_id is None and self.config is not None:
|
| 345 |
self.model_id = getattr(self.config, "_name_or_path", None) or getattr(
|
| 346 |
self.config, "name_or_path", None
|
|
@@ -374,29 +375,57 @@ class NTv3TracksPipeline(Pipeline):
|
|
| 374 |
model=self.model, tokenizer=self.tokenizer, device=-1, **kwargs
|
| 375 |
)
|
| 376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
def _sanitize_parameters(self, **kwargs):
|
| 378 |
return {}, {}, {}
|
| 379 |
|
| 380 |
-
def _get_model_device(self) -> torch.device:
|
| 381 |
return next(self.model.parameters()).device
|
| 382 |
|
| 383 |
def _resolve_species_and_assembly(self, inputs: dict[str, Any]) -> tuple[str, str]:
|
| 384 |
species = inputs.get("species", self.default_species)
|
| 385 |
if species not in SPECIES_TO_ASSEMBLY:
|
|
|
|
| 386 |
raise ValueError(
|
| 387 |
-
f"Unsupported species='{species}'. Supported species: {
|
| 388 |
)
|
| 389 |
assembly = SPECIES_TO_ASSEMBLY[species]
|
| 390 |
|
| 391 |
cfg_assemblies = list(self.config.bigwigs_per_file_assembly.keys())
|
| 392 |
if assembly not in cfg_assemblies:
|
| 393 |
raise ValueError(
|
| 394 |
-
f"Species '{species}' maps to assembly '{assembly}',
|
|
|
|
| 395 |
f"Available assemblies: {cfg_assemblies}"
|
| 396 |
)
|
| 397 |
return species, assembly
|
| 398 |
|
| 399 |
-
def _maybe_force_cpu_for_mps_long(
|
| 400 |
self, input_ids_cpu: torch.Tensor
|
| 401 |
) -> torch.device:
|
| 402 |
dev = self._get_model_device()
|
|
@@ -405,40 +434,15 @@ class NTv3TracksPipeline(Pipeline):
|
|
| 405 |
if seq_len >= self.mps_force_cpu_length:
|
| 406 |
if self.verbose:
|
| 407 |
print(
|
| 408 |
-
f"[NTv3TracksPipeline] MPS detected and input is long
|
| 409 |
-
"Switching model + inputs to CPU
|
|
|
|
| 410 |
)
|
| 411 |
self.model.to("cpu")
|
| 412 |
self.model.eval()
|
| 413 |
return torch.device("cpu")
|
| 414 |
return dev
|
| 415 |
|
| 416 |
-
def available_bigwig_track_names(self, species: str | None = None) -> list[str]:
|
| 417 |
-
"""
|
| 418 |
-
Return BigWig track IDs for the assembly corresponding to `species`.
|
| 419 |
-
No model forward pass.
|
| 420 |
-
"""
|
| 421 |
-
sp = species or self.default_species
|
| 422 |
-
assembly = SPECIES_TO_ASSEMBLY.get(sp)
|
| 423 |
-
if assembly is None:
|
| 424 |
-
raise ValueError(
|
| 425 |
-
f"Unknown species={sp}. Supported: {sorted(SPECIES_TO_ASSEMBLY.keys())}"
|
| 426 |
-
)
|
| 427 |
-
|
| 428 |
-
if assembly not in self.config.bigwigs_per_file_assembly:
|
| 429 |
-
raise ValueError(
|
| 430 |
-
f"Assembly {assembly} not found in checkpoint config. "
|
| 431 |
-
f"Available: {list(self.config.bigwigs_per_file_assembly.keys())}"
|
| 432 |
-
)
|
| 433 |
-
|
| 434 |
-
return list(self.config.bigwigs_per_file_assembly[assembly])
|
| 435 |
-
|
| 436 |
-
def available_bed_element_names(self) -> list[str]:
|
| 437 |
-
"""
|
| 438 |
-
Return BED element names available in this checkpoint (no forward pass).
|
| 439 |
-
"""
|
| 440 |
-
return list(self.bed_element_names or [])
|
| 441 |
-
|
| 442 |
def preprocess(self, inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
|
| 443 |
species, assembly = self._resolve_species_and_assembly(inputs)
|
| 444 |
|
|
@@ -506,19 +510,6 @@ class NTv3TracksPipeline(Pipeline):
|
|
| 506 |
def forward(self, model_inputs, **forward_params):
|
| 507 |
return self._forward(model_inputs, **forward_params)
|
| 508 |
|
| 509 |
-
def _forward(self, model_inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
|
| 510 |
-
meta = model_inputs.pop("meta")
|
| 511 |
-
if self.verbose:
|
| 512 |
-
print(f"Running on device: {self._get_model_device()}")
|
| 513 |
-
with torch.no_grad():
|
| 514 |
-
out = self.model(
|
| 515 |
-
input_ids=model_inputs["input_ids"],
|
| 516 |
-
species_ids=model_inputs["species_ids"],
|
| 517 |
-
return_dict=True,
|
| 518 |
-
)
|
| 519 |
-
out["meta"] = meta
|
| 520 |
-
return out
|
| 521 |
-
|
| 522 |
def postprocess(
|
| 523 |
self, model_outputs: dict[str, Any], **kwargs: Any
|
| 524 |
) -> NTv3TracksOutput:
|
|
@@ -565,6 +556,19 @@ class NTv3TracksPipeline(Pipeline):
|
|
| 565 |
pred_end=meta.get("pred_end"),
|
| 566 |
)
|
| 567 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
def __call__(
|
| 569 |
self,
|
| 570 |
inputs,
|
|
@@ -584,7 +588,8 @@ class NTv3TracksPipeline(Pipeline):
|
|
| 584 |
if plot:
|
| 585 |
if out.bigwig_track_names is None:
|
| 586 |
raise ValueError(
|
| 587 |
-
"bigwig_track_names missing; expected
|
|
|
|
| 588 |
)
|
| 589 |
if out.bed_element_names is None:
|
| 590 |
raise ValueError("bed element names missing from config.")
|
|
@@ -600,17 +605,22 @@ class NTv3TracksPipeline(Pipeline):
|
|
| 600 |
]
|
| 601 |
if missing_tracks:
|
| 602 |
raise ValueError(
|
| 603 |
-
f"The following tracks are not available in
|
| 604 |
-
f"
|
|
|
|
|
|
|
| 605 |
)
|
| 606 |
|
| 607 |
missing_elements = [
|
| 608 |
e for e in elements_to_plot if e not in bed_element_names
|
| 609 |
]
|
| 610 |
if missing_elements:
|
|
|
|
|
|
|
| 611 |
raise ValueError(
|
| 612 |
-
f"The following elements are not available in
|
| 613 |
-
f"
|
|
|
|
| 614 |
)
|
| 615 |
|
| 616 |
# Build bigwig tracks dict (title -> y)
|
|
@@ -662,7 +672,8 @@ def load_ntv3_tracks_pipeline(
|
|
| 662 |
device:
|
| 663 |
"auto", "cpu", "cuda", "mps"
|
| 664 |
pipeline_kwargs:
|
| 665 |
-
Extra kwargs passed to NTv3TracksPipeline
|
|
|
|
| 666 |
"""
|
| 667 |
pipe = NTv3TracksPipeline(
|
| 668 |
model=model,
|
|
|
|
| 74 |
# Assembly -> API URL template mapping
|
| 75 |
# ---------------------------------------------------------------------
|
| 76 |
# Default API URL template (UCSC format) that works for most species
|
| 77 |
+
DEFAULT_API_URL_TEMPLATE = "https://api.genome.ucsc.edu/getData/sequence?genome={assembly};chrom={chrom};start={start};end={end}" # noqa: E501
|
| 78 |
|
| 79 |
# for species with different format, add the assembly name to the mapping
|
| 80 |
# The template should use {chrom}, {start}, and {end} as placeholders.
|
| 81 |
ASSEMBLY_TO_API_URL_TEMPLATE = {
|
| 82 |
# Arabidopsis thaliana (TAIR10) - uses hub URL format
|
| 83 |
+
"TAIR10": "https://api.genome.ucsc.edu/getData/sequence?hubUrl=http://genome.ucsc.edu/goldenPath/help/examples/hubExamples/hubAssembly/plantAraTha1/hub.txt;genome=araTha1;chrom={chrom};start={start};end={end}", # noqa: E501
|
| 84 |
}
|
| 85 |
|
| 86 |
|
|
|
|
| 124 |
"""
|
| 125 |
if requests is None:
|
| 126 |
raise ImportError(
|
| 127 |
+
"requests is required for genome download. "
|
| 128 |
+
"Install with: pip install requests"
|
| 129 |
)
|
| 130 |
|
| 131 |
# Get API URL template for this assembly, or use default
|
|
|
|
| 152 |
if fa_path.exists():
|
| 153 |
return fa_path
|
| 154 |
|
| 155 |
+
# This function is deprecated - use _get_dna_sequence with API instead
|
| 156 |
+
raise ValueError(
|
| 157 |
+
f"FASTA file download is no longer supported for assembly='{assembly}'. "
|
| 158 |
+
f"Please use _get_dna_sequence() with API-based sequence fetching instead."
|
| 159 |
+
)
|
|
|
|
| 160 |
|
| 161 |
import gzip
|
| 162 |
|
|
|
|
| 340 |
else:
|
| 341 |
self.tokenizer = tokenizer
|
| 342 |
|
| 343 |
+
# Extract model_id from config if not already set
|
| 344 |
+
# (following ntv3_gff_pipeline.py pattern)
|
| 345 |
if self.model_id is None and self.config is not None:
|
| 346 |
self.model_id = getattr(self.config, "_name_or_path", None) or getattr(
|
| 347 |
self.config, "name_or_path", None
|
|
|
|
| 375 |
model=self.model, tokenizer=self.tokenizer, device=-1, **kwargs
|
| 376 |
)
|
| 377 |
|
| 378 |
+
def available_bigwig_track_names(self, species: str | None = None) -> list[str]:
|
| 379 |
+
"""
|
| 380 |
+
Return BigWig track IDs for the assembly corresponding to `species`.
|
| 381 |
+
No model forward pass.
|
| 382 |
+
"""
|
| 383 |
+
sp = species or self.default_species
|
| 384 |
+
assembly = SPECIES_TO_ASSEMBLY.get(sp)
|
| 385 |
+
if assembly is None:
|
| 386 |
+
raise ValueError(
|
| 387 |
+
f"Unknown species={sp}. Supported: {sorted(SPECIES_TO_ASSEMBLY.keys())}"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
if assembly not in self.config.bigwigs_per_file_assembly:
|
| 391 |
+
raise ValueError(
|
| 392 |
+
f"Assembly {assembly} not found in checkpoint config. "
|
| 393 |
+
f"Available: {list(self.config.bigwigs_per_file_assembly.keys())}"
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
return list(self.config.bigwigs_per_file_assembly[assembly])
|
| 397 |
+
|
| 398 |
+
def available_bed_element_names(self) -> list[str]:
|
| 399 |
+
"""
|
| 400 |
+
Return BED element names available in this checkpoint (no forward pass).
|
| 401 |
+
"""
|
| 402 |
+
return list(self.bed_element_names or [])
|
| 403 |
+
|
| 404 |
def _sanitize_parameters(self, **kwargs):
|
| 405 |
return {}, {}, {}
|
| 406 |
|
| 407 |
+
def _get_model_device(self) -> torch.device: # noqa: CCE001
|
| 408 |
return next(self.model.parameters()).device
|
| 409 |
|
| 410 |
def _resolve_species_and_assembly(self, inputs: dict[str, Any]) -> tuple[str, str]:
|
| 411 |
species = inputs.get("species", self.default_species)
|
| 412 |
if species not in SPECIES_TO_ASSEMBLY:
|
| 413 |
+
supported = sorted(SPECIES_TO_ASSEMBLY.keys())
|
| 414 |
raise ValueError(
|
| 415 |
+
f"Unsupported species='{species}'. " f"Supported species: {supported}"
|
| 416 |
)
|
| 417 |
assembly = SPECIES_TO_ASSEMBLY[species]
|
| 418 |
|
| 419 |
cfg_assemblies = list(self.config.bigwigs_per_file_assembly.keys())
|
| 420 |
if assembly not in cfg_assemblies:
|
| 421 |
raise ValueError(
|
| 422 |
+
f"Species '{species}' maps to assembly '{assembly}', "
|
| 423 |
+
f"but that assembly is not available in this checkpoint. "
|
| 424 |
f"Available assemblies: {cfg_assemblies}"
|
| 425 |
)
|
| 426 |
return species, assembly
|
| 427 |
|
| 428 |
+
def _maybe_force_cpu_for_mps_long( # noqa: CCE001
|
| 429 |
self, input_ids_cpu: torch.Tensor
|
| 430 |
) -> torch.device:
|
| 431 |
dev = self._get_model_device()
|
|
|
|
| 434 |
if seq_len >= self.mps_force_cpu_length:
|
| 435 |
if self.verbose:
|
| 436 |
print(
|
| 437 |
+
f"[NTv3TracksPipeline] MPS detected and input is long "
|
| 438 |
+
f"(tokens={seq_len}). Switching model + inputs to CPU "
|
| 439 |
+
"for this run."
|
| 440 |
)
|
| 441 |
self.model.to("cpu")
|
| 442 |
self.model.eval()
|
| 443 |
return torch.device("cpu")
|
| 444 |
return dev
|
| 445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
def preprocess(self, inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
|
| 447 |
species, assembly = self._resolve_species_and_assembly(inputs)
|
| 448 |
|
|
|
|
| 510 |
def forward(self, model_inputs, **forward_params):
|
| 511 |
return self._forward(model_inputs, **forward_params)
|
| 512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
def postprocess(
|
| 514 |
self, model_outputs: dict[str, Any], **kwargs: Any
|
| 515 |
) -> NTv3TracksOutput:
|
|
|
|
| 556 |
pred_end=meta.get("pred_end"),
|
| 557 |
)
|
| 558 |
|
| 559 |
+
def _forward(self, model_inputs: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
|
| 560 |
+
meta = model_inputs.pop("meta")
|
| 561 |
+
if self.verbose:
|
| 562 |
+
print(f"Running on device: {self._get_model_device()}")
|
| 563 |
+
with torch.no_grad():
|
| 564 |
+
out = self.model(
|
| 565 |
+
input_ids=model_inputs["input_ids"],
|
| 566 |
+
species_ids=model_inputs["species_ids"],
|
| 567 |
+
return_dict=True,
|
| 568 |
+
)
|
| 569 |
+
out["meta"] = meta
|
| 570 |
+
return out
|
| 571 |
+
|
| 572 |
def __call__(
|
| 573 |
self,
|
| 574 |
inputs,
|
|
|
|
| 588 |
if plot:
|
| 589 |
if out.bigwig_track_names is None:
|
| 590 |
raise ValueError(
|
| 591 |
+
"bigwig_track_names missing; expected "
|
| 592 |
+
"cfg.bigwigs_per_file_assembly[assembly]."
|
| 593 |
)
|
| 594 |
if out.bed_element_names is None:
|
| 595 |
raise ValueError("bed element names missing from config.")
|
|
|
|
| 605 |
]
|
| 606 |
if missing_tracks:
|
| 607 |
raise ValueError(
|
| 608 |
+
f"The following tracks are not available in "
|
| 609 |
+
f"bigwig_names: {missing_tracks}\n"
|
| 610 |
+
f"First 50 available: {bigwig_names[:50]}"
|
| 611 |
+
f"{'...' if len(bigwig_names) > 50 else ''}"
|
| 612 |
)
|
| 613 |
|
| 614 |
missing_elements = [
|
| 615 |
e for e in elements_to_plot if e not in bed_element_names
|
| 616 |
]
|
| 617 |
if missing_elements:
|
| 618 |
+
first_50 = bed_element_names[:50]
|
| 619 |
+
ellipsis = "..." if len(bed_element_names) > 50 else ""
|
| 620 |
raise ValueError(
|
| 621 |
+
f"The following elements are not available in "
|
| 622 |
+
f"bed_element_names: {missing_elements}\n"
|
| 623 |
+
f"First 50 available: {first_50}{ellipsis}"
|
| 624 |
)
|
| 625 |
|
| 626 |
# Build bigwig tracks dict (title -> y)
|
|
|
|
| 672 |
device:
|
| 673 |
"auto", "cpu", "cuda", "mps"
|
| 674 |
pipeline_kwargs:
|
| 675 |
+
Extra kwargs passed to NTv3TracksPipeline
|
| 676 |
+
(default_species, genome_cache_dir, etc.).
|
| 677 |
"""
|
| 678 |
pipe = NTv3TracksPipeline(
|
| 679 |
model=model,
|
requirements.txt
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
gradio>=4.0.0
|
|
|
|
| 2 |
matplotlib
|
| 3 |
numpy
|
| 4 |
plotly
|
| 5 |
-
kaleido
|
| 6 |
pyBigWig
|
| 7 |
pyfaidx
|
| 8 |
requests
|
|
|
|
| 1 |
gradio>=4.0.0
|
| 2 |
+
kaleido
|
| 3 |
matplotlib
|
| 4 |
numpy
|
| 5 |
plotly
|
|
|
|
| 6 |
pyBigWig
|
| 7 |
pyfaidx
|
| 8 |
requests
|