import csv
import os
import tempfile
import time
import uuid
from pathlib import Path
import gradio as gr
import matplotlib
import matplotlib.colors as mcolors
import numpy as np
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
from bigwig_export import _softmax_last, create_bigwig_zip
from ntv3_tracks_pipeline import (
ASSEMBLY_TO_SPECIES,
BED_ELEMENT_COLORS,
SPECIES_WITH_COORDINATE_SUPPORT,
load_ntv3_tracks_pipeline,
)
matplotlib.use("Agg")
# -----------------------------
# Env / auth
# -----------------------------
MODEL_ID = os.environ.get("MODEL_ID", "InstaDeepAI/NTv3_650M_post")
DEFAULT_SPECIES = os.environ.get("DEFAULT_SPECIES", "human")
HF_TOKEN = (
os.environ.get("NTV3_HF_TOKEN")
or os.environ.get("HF_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
if HF_TOKEN is None:
raise RuntimeError(
"Missing Hugging Face token. Set NTV3_HF_TOKEN as a Space Secret."
)
PLOT_TARGET_POINTS = int(os.environ.get("PLOT_TARGET_POINTS", "1500"))
SEARCH_MAX_RESULTS = int(os.environ.get("SEARCH_MAX_RESULTS", "20"))
MAX_SEQUENCE_SIZE = 1_048_576 # 1MB in bytes - maximum allowed sequence input size
# -----------------------------
# Load pipeline (reloadable)
# -----------------------------
pipe = None
current_model_id = MODEL_ID
def load_pipeline(model_id: str, species: str = DEFAULT_SPECIES):
"""Load or reload the pipeline with a new model."""
global pipe, current_model_id
pipe = load_ntv3_tracks_pipeline(
model=model_id,
token=HF_TOKEN,
device="cpu", # Prevents model.to(cuda) during import
default_species=species,
verbose=False,
)
current_model_id = model_id
return pipe
# Load initial pipeline
load_pipeline(MODEL_ID, DEFAULT_SPECIES)
# -----------------------------
# Helpers
# -----------------------------
_t0 = None
_tlast = None
def tprint(msg: str):
"Function to print timing information"
global _t0, _tlast
if _t0 is None:
_t0 = _tlast = time.perf_counter()
# CUDA ops are async → synchronize to get real timings
if torch.cuda.is_available():
torch.cuda.synchronize()
now = time.perf_counter()
print(f"[timing] {msg}: {now - _tlast:.3f}s (total {now - _t0:.3f}s)")
_tlast = now
# GPU decorator
try:
import spaces
gpu = spaces.GPU
except Exception:
def gpu(*args, **kwargs):
"""GPU decorator placeholder when spaces module is not available."""
def wrap(fn):
return fn
return wrap
def _global_stride(length: int, target: int) -> int:
if target <= 0 or length <= target:
return 1
return int(np.ceil(length / target))
def _make_tracks_figure(
x: np.ndarray, series: list[tuple[str, np.ndarray]], region: str = ""
):
"""Create an interactive plotly figure with multiple tracks."""
if not series:
raise gr.Error("Nothing to plot (no tracks/elements selected).")
n = len(series)
# Adjust vertical spacing based on number of tracks
# More spacing when fewer tracks to prevent title overlap
if n <= 2:
vertical_spacing = 0.15 # More space for 1-2 tracks
elif n <= 4:
vertical_spacing = 0.08 # Moderate space for 3-4 tracks
else:
vertical_spacing = 0.04 # Tighter spacing for many tracks
# Create subplots with shared x-axis
fig = make_subplots(
rows=n,
cols=1,
shared_xaxes=True,
vertical_spacing=vertical_spacing,
subplot_titles=[title for title, _ in series],
)
# Define color schemes
bigwig_color = "#4A90E2" # Blue
for i, (title, y) in enumerate(series, 1):
# Determine color based on track type
if title in BED_ELEMENT_COLORS:
color = BED_ELEMENT_COLORS[title]
else:
color = bigwig_color
# Convert color to rgba for fill
rgba = mcolors.to_rgba(color)
rgba_str = (
f"rgba({int(rgba[0]*255)}, {int(rgba[1]*255)}, {int(rgba[2]*255)}, 0.3)"
)
# Add filled area (fill_between equivalent)
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
name=title,
line={"color": color, "width": 1.5},
fill="tozeroy",
fillcolor=rgba_str,
hovertemplate=f"{title}
"
+ "Position: %{x}
"
+ "Value: %{y:.4f}",
showlegend=False,
),
row=i,
col=1,
)
# Adjust height and margins based on number of tracks
# More height per track when fewer tracks to accommodate titles
if n <= 2:
height_per_track = 200 # More height for 1-2 tracks
top_margin = 60 # More top margin for titles
elif n <= 4:
height_per_track = 170 # Moderate height for 3-4 tracks
top_margin = 50
else:
height_per_track = 150 # Standard height for many tracks
top_margin = 40
# Update layout for better appearance
fig.update_layout(
height=height_per_track * n, # Adjust height based on number of tracks
width=1200,
margin={"l": 80, "r": 20, "t": top_margin, "b": 60},
hovermode="x unified", # Show all values at same x position
template="plotly_white",
modebar={
"activecolor": "#7dd3fc", # Blue color for active/hovered buttons
"bgcolor": "rgba(255, 255, 255, 0.9)",
"color": "#7dd3fc", # Blue color for buttons
"orientation": "v",
},
)
# Update y-axes to remove ticks and improve appearance
for i in range(1, n + 1):
fig.update_yaxes(
showticklabels=False,
showgrid=True,
gridcolor="rgba(0,0,0,0.1)",
row=i,
col=1,
)
# Update x-axis on the last subplot with region label
xaxis_title = region if region else "Genomic position / index"
fig.update_xaxes(
title_text=xaxis_title,
showgrid=True,
gridcolor="rgba(0,0,0,0.1)",
row=n,
col=1,
)
return fig
# Cache track lists per species so search is instant after first load
_BIGWIG_CACHE: dict[str, list[str]] = {}
# Cache for track metadata (track_id -> display_name)
_TRACK_METADATA_CACHE: dict[str, str] = {}
def _load_track_metadata() -> dict[str, str]:
"""Load track metadata from CSV and create display name mapping."""
if _TRACK_METADATA_CACHE:
return _TRACK_METADATA_CACHE
csv_path = Path(__file__).parent / "data" / "functional_tracks_metadata.csv"
if not csv_path.exists():
return {}
metadata = {}
try:
with open(csv_path, encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
track_id = row["file_id"]
tissue = row.get("tissue", "").strip()
assay = row.get("assay", "").strip()
experiment_target = row.get("experiment_target", "").strip()
biosample_type = row.get("biosample_type", "").strip()
strand = row.get("strand", "").strip()
# Build display name from available fields
parts = []
# if biosample_type and biosample_type != "tissue":
# parts.append(biosample_type)
if tissue:
parts.append(tissue)
if assay:
# For RNA-seq, include strand information if available
if strand:
if strand == "plus":
strand = "+"
elif strand == "minus":
strand = "-"
parts.append(f"{assay} {strand}")
else:
parts.append(assay)
if experiment_target and experiment_target not in ("none", "RNA-seq"):
parts.append(experiment_target)
if parts:
display_name = " ".join(parts)
else:
display_name = track_id # Fallback to ID if no metadata
metadata[track_id] = display_name
except Exception as e:
print(f"Warning: Could not load track metadata: {e}")
return {}
_TRACK_METADATA_CACHE.update(metadata)
return metadata
def _get_track_display_name(track_id: str) -> str:
"""Get display name for a track ID, or return the ID if not found."""
metadata = _load_track_metadata()
return metadata.get(track_id, track_id)
def _format_track_for_display(track_id: str) -> str:
"""Format track ID for display: 'display_name (track_id)'."""
display_name = _get_track_display_name(track_id)
if display_name == track_id:
return track_id # No metadata available, just show ID
return f"{display_name} ({track_id})"
def _extract_track_id(display_value: str) -> str:
"""Extract track ID from display format or return as-is."""
if " (" in display_value and display_value.endswith(")"):
# Extract track_id from format "display_name (track_id)"
return display_value.rsplit(" (", 1)[1][:-1]
return display_value # No parentheses, assume it's already just the ID
def _get_bigwig_names(species: str) -> list[str]:
if species not in _BIGWIG_CACHE:
_BIGWIG_CACHE[species] = pipe.available_bigwig_track_names(species)
return _BIGWIG_CACHE[species]
def _get_bed_element_names(species: str) -> list[str]:
"""Get BED element names available for a given species (filtered by training data)."""
if pipe is None:
return []
try:
return pipe.available_bed_element_names(species)
except (ValueError, AttributeError):
return []
def _format_bed_element_for_display(element_name: str) -> str:
"""Format BED element name for display: replace underscores with spaces and capitalize."""
return element_name.replace("_", " ").title()
def _has_bigwigs(species: str) -> bool:
"""Check if a species has BigWig tracks available in the current model."""
try:
tracks = _get_bigwig_names(species)
return len(tracks) > 0
except (ValueError, AttributeError):
# Species not in config or pipeline not loaded
return False
def _get_species_with_bigwigs() -> set[str]:
"""Get set of species that have BigWig tracks available in the current model."""
if pipe is None:
return set()
species_with_bigwigs = set()
for species in ASSEMBLY_TO_SPECIES.values():
if _has_bigwigs(species):
species_with_bigwigs.add(species)
return species_with_bigwigs
def _rank_search(query: str, names: list[str], limit: int) -> list[str]:
"""
Return up to `limit` candidate track IDs matching `query` using a fast,
low-overhead ranking suitable for very large `names` lists.
Matching & ranking rules:
1) Case-insensitive match.
2) Items whose ID *starts with* the query are ranked first.
3) Remaining items that merely *contain* the query are ranked after.
4) Results preserve the original relative order within each group
(stable w.r.t. the input `names` order).
5) If `query` is empty/whitespace, returns an empty list to avoid
flooding the UI with a huge default list.
Notes:
- `limit` only caps the number of returned results; it does not prevent
short queries (e.g. "E") from producing many matches—if you want that,
add a minimum query length check (e.g. `if len(q) < 2: return []`).
- Time complexity is O(len(names)) per call.
"""
q = (query or "").strip().lower()
if not q:
return [] # don’t spam a giant default list
starts = []
contains = []
for n in names:
nl = n.lower()
if nl.startswith(q):
starts.append(n)
elif q in nl:
contains.append(n)
out = starts + contains
return out[:limit]
def search_bigwigs(species: str, query: str, current_selected: list[str]):
"""Search BigWig tracks and return formatted display names."""
# Handle None or empty query
if query is None:
query = ""
query_stripped = query.strip()
# If query is empty, return empty results immediately (don't show all tracks)
if not query_stripped:
displayed_selected = current_selected or []
show_selected = bool(displayed_selected)
return (
gr.update(
choices=[], value=[], interactive=True
), # empty results, explicitly clear checked state
gr.update(
visible=show_selected,
choices=displayed_selected,
value=displayed_selected,
), # show ALL selected tracks
)
names = _get_bigwig_names(species)
# Search in both track IDs and display names
metadata = _load_track_metadata()
query_lower = query_stripped.lower()
# Show selected tracks section if user is typing or has selections
show_selected = bool(query_stripped) or bool(current_selected)
# Show ALL selected tracks (not limited to 20)
displayed_selected = current_selected or []
# Extract track IDs from already selected tracks (to exclude them from results)
selected_track_ids = set()
if current_selected:
selected_track_ids = {_extract_track_id(x) for x in current_selected}
# Build list of (display_format, track_id) tuples for searching
track_display_pairs = []
for track_id in names:
# Skip tracks that are already selected
if track_id in selected_track_ids:
continue
display_name = metadata.get(track_id, track_id)
display_format = _format_track_for_display(track_id)
track_display_pairs.append((display_format, track_id, display_name))
# Filter by query (search in display name, display format, and track_id)
matching = []
for display_format, track_id, display_name in track_display_pairs:
if (
query_lower in track_id.lower()
or query_lower in display_name.lower()
or query_lower in display_format.lower()
):
matching.append(display_format)
# Limit search results
results = matching[:SEARCH_MAX_RESULTS]
return (
gr.update(
choices=results, value=[], interactive=True
), # results - limited to SEARCH_MAX_RESULTS, explicitly clear checked state
gr.update(
visible=show_selected, choices=displayed_selected, value=displayed_selected
), # show ALL selected tracks
)
def add_selected(current_selected: list[str], to_add: list[str]):
"""Add tracks to selected list, converting display format to track IDs if needed."""
# Extract track IDs from current selection (in case they're in display format)
cur_ids = [_extract_track_id(x) for x in (current_selected or [])]
cur_display = [_format_track_for_display(tid) for tid in cur_ids]
# Extract track IDs from items to add
to_add_ids = [_extract_track_id(x) for x in (to_add or [])]
# Add new track IDs
for tid in to_add_ids:
if tid not in cur_ids:
cur_ids.append(tid)
cur_display.append(_format_track_for_display(tid))
# Show ALL selected tracks (no limit)
return gr.update(choices=cur_display, value=cur_display) # show all selected tracks
def remove_selected(current_selected: list[str], to_remove: list[str]):
"""Remove tracks from selected list."""
cur = [x for x in (current_selected or []) if x not in set(to_remove or [])]
# Show ALL remaining selected tracks (no limit)
show_selected = bool(cur)
return gr.update(choices=cur, value=cur, visible=show_selected)
def reset_on_species_change(species: str):
"""Reset search and selected tracks when species changes."""
# Clear results + selected when species changes (avoids mismatched IDs)
try:
track_ids = _get_bigwig_names(species) # warms cache if available
# Format available tracks for display
formatted_tracks = [_format_track_for_display(tid) for tid in track_ids]
# Get default tracks for this species (filter to what's available)
default_track_ids = [tid for tid in DEFAULT_BIGWIG_TRACKS if tid in track_ids]
default_formatted = [
_format_track_for_display(tid) for tid in default_track_ids
]
# Show selected tracks section if there are default tracks
show_selected = bool(default_formatted)
return (
gr.update(value=""), # query textbox
gr.update(choices=[], value=[]), # results list
gr.update(
choices=formatted_tracks, value=default_formatted, visible=show_selected
), # selected list with defaults
)
except (ValueError, AttributeError):
# Species doesn't have bigwigs, that's okay
return (
gr.update(value=""), # query textbox
gr.update(choices=[], value=[]), # results list
gr.update(choices=[], value=[], visible=False), # selected list (hidden)
)
# -----------------------------
# Predict
# -----------------------------
@gpu
def predict(
seq: str,
species: str,
chrom: str,
start: int,
end: int,
input_type: str,
bigwig_selected: list[str],
bed_elements: list[str],
):
"""Run prediction and return figure with tracks."""
tprint("start")
# Debug: verify species is being passed
if not species:
raise gr.Error("Species parameter is missing. Please select a species.")
# Extract track IDs from display format if needed
bigwig_selected = [_extract_track_id(tid) for tid in bigwig_selected]
# Determine if using coordinates based on input_type radio button
use_coords = input_type == "Use genomic coordinates"
if use_coords:
# Check if this species supports coordinate-based fetching
if species not in SPECIES_WITH_COORDINATE_SUPPORT:
supported = ", ".join(sorted(SPECIES_WITH_COORDINATE_SUPPORT))
raise gr.Error(
f"Species '{species}' does not support coordinate-based sequence "
f"fetching. Please provide a DNA sequence directly or use one of "
f"the supported species: {supported}"
)
if not chrom:
raise gr.Error("chrom is required when use_coords=True")
if start is None or end is None or int(end) <= int(start):
raise gr.Error("start/end must be set and end > start when use_coords=True")
# Check sequence size before fetching from API: max 1MB
# Each base pair is typically 1 byte, so check region length
region_length = int(end) - int(start)
if region_length > MAX_SEQUENCE_SIZE:
raise gr.Error(
f"Requested genomic region is too large ({region_length:,} base pairs). "
f"Maximum allowed size is {MAX_SEQUENCE_SIZE:,} base pairs (1MB). "
f"Please select a smaller region."
)
inputs = {
"chrom": chrom,
"start": int(start),
"end": int(end),
"species": species,
}
else:
if not seq or not seq.strip():
raise gr.Error("seq is required when use_coords=False")
seq_stripped = seq.strip()
# Check sequence size: max 1MB
# Each character is typically 1 byte, so check length
if len(seq_stripped) > MAX_SEQUENCE_SIZE:
raise gr.Error(
f"Sequence input is too large ({len(seq_stripped):,} characters). "
f"Maximum allowed size is {MAX_SEQUENCE_SIZE:,} characters (1MB). "
f"Please use a shorter sequence or use genomic coordinates instead."
)
inputs = {"seq": seq_stripped, "species": species}
# Verify species is in inputs before calling pipeline
if "species" not in inputs:
input_keys = list(inputs.keys())
raise gr.Error(
f"Internal error: species not found in inputs dict. "
f"Inputs: {input_keys}"
)
tprint("inputs prepared")
# move to GPU only once the ZeroGPU context is active
device = "cuda" if torch.cuda.is_available() else "cpu"
# check where the model currently lives
current = next(pipe.model.parameters()).device.type # "cpu" or "cuda"
# only move if needed
if current != device:
pipe.model.to(device)
tprint(f"model moved to {device}")
pipe.model.eval()
print(f"Running on {next(pipe.model.parameters()).device}")
tprint("model ready to run inference")
# run inference
out = pipe(inputs)
tprint("inference completed")
# optional: move back to CPU so you don’t rely on any persistent CUDA context
# if device == "cuda":
# pipe.model.to("cpu")
bw_names = out.bigwig_track_names or []
bw = out.bigwig_tracks_logits
bed_names = out.bed_element_names or []
bed_logits = out.bed_tracks_logits
# Check if we have any tracks/elements to plot
has_bigwigs = bw is not None and len(bw_names) > 0
has_bed = bed_logits is not None and len(bed_names) > 0
if not has_bigwigs and not has_bed:
raise gr.Error(
"No BigWig tracks or BED elements available for this species "
"in the current model."
)
if not has_bigwigs and bigwig_selected:
raise gr.Error(
"No BigWig tracks available for this species, but BigWig tracks "
"were selected. Please deselect BigWig tracks or choose a "
"different species."
)
# Defaults if user picked none
if has_bigwigs and not bigwig_selected:
# Filter to only include tracks that are available for this species/assembly
bigwig_selected = [tid for tid in DEFAULT_BIGWIG_TRACKS if tid in bw_names]
if (not bed_elements) and bed_names:
default_bed_elements = ["protein_coding_gene", "exon", "intron"]
# Filter to only include elements that are available
bed_elements = [elem for elem in default_bed_elements if elem in bed_names]
# Validate (important for API usage)
if has_bigwigs and bigwig_selected:
missing_tracks = [t for t in bigwig_selected if t not in bw_names]
if missing_tracks:
raise gr.Error(f"Unknown BigWig track id(s): {missing_tracks}")
if bed_elements:
missing_elems = [e for e in bed_elements if e not in bed_names]
if missing_elems:
raise gr.Error(f"Unknown BED element(s): {missing_elems}")
# Determine sequence length from available data
if has_bigwigs:
seq_length = bw.shape[0]
elif has_bed:
seq_length = bed_logits.shape[0]
else:
raise gr.Error("No data available for plotting.")
stride = _global_stride(seq_length, PLOT_TARGET_POINTS)
x0 = int(out.pred_start or 0)
x1 = int(out.pred_end or (x0 + seq_length))
x = np.linspace(x0, x1, num=seq_length, endpoint=False)[::stride]
series: list[tuple[str, np.ndarray]] = []
# Add BigWig tracks if available and selected
if has_bigwigs and bigwig_selected:
for tid in bigwig_selected:
idx = bw_names.index(tid)
# Use clean display name instead of track ID
display_name = _get_track_display_name(tid)
series.append((display_name, bw[:, idx][::stride].astype(float)))
# Add BED elements if available and selected
if bed_logits is not None and bed_elements:
probs = _softmax_last(bed_logits)
for ename in bed_elements:
display_name = ename.replace("_", " ").lower()
eidx = bed_names.index(ename)
series.append((display_name, probs[:, eidx, 1][::stride].astype(float)))
tprint("figure data processed created")
# Build region string for x-axis label
region = (
f"{out.chrom}:{out.pred_start}-{out.pred_end}" if out.chrom else f"{x0}-{x1}"
)
if out.assembly:
region += f" ({out.assembly})"
fig = _make_tracks_figure(x, series, region=region)
tprint("figure created")
meta = {
"model_id": current_model_id,
"species": out.species,
"assembly": out.assembly,
"chrom": out.chrom,
"pred_start": out.pred_start,
"pred_end": out.pred_end,
"bigwig_selected": bigwig_selected,
"bed_selected": bed_elements,
"plot_stride": stride,
"plot_target_points": PLOT_TARGET_POINTS,
}
return (
gr.update(visible=True), # predictions_heading
gr.update(visible=True), # predictions_note
gr.update(value=fig, visible=True), # plot
gr.update(visible=True), # download_bigwig_btn
# meta,
out,
bigwig_selected,
bed_elements,
)
# -----------------------------
# UI (keep your download icon setup)
# -----------------------------
# Load CSS from external file
CSS_PATH = Path(__file__).parent / "style.css"
CSS = CSS_PATH.read_text() if CSS_PATH.exists() else ""
JS = """
// Remove blue backgrounds from footer buttons after page loads
function removeFooterButtonBackgrounds() {
// Target all buttons and links in footer
const footer = document.querySelector('footer');
if (footer) {
const buttons = footer.querySelectorAll('button, a[role="button"], a[class*="button"]');
buttons.forEach(btn => {
btn.style.setProperty('background', 'transparent', 'important');
btn.style.setProperty('background-color', 'transparent', 'important');
btn.style.setProperty('border', '1px solid rgba(125, 211, 252, 0.2)', 'important');
});
}
}
// Run on page load
if (document.readyState === 'loading') {
document.addEventListener('DOMContentLoaded', removeFooterButtonBackgrounds);
} else {
removeFooterButtonBackgrounds();
}
// Also run after a delay to catch dynamically loaded content
setTimeout(removeFooterButtonBackgrounds, 100);
setTimeout(removeFooterButtonBackgrounds, 500);
setTimeout(removeFooterButtonBackgrounds, 1000);
// Use MutationObserver to watch for dynamically added footer buttons
const observer = new MutationObserver(() => {
removeFooterButtonBackgrounds();
});
// Observe changes to the footer
const footer = document.querySelector('footer');
if (footer) {
observer.observe(footer, {
childList: true,
subtree: true,
attributes: true,
attributeFilter: ['style', 'class']
});
}
// Also observe the entire document for footer additions
observer.observe(document.body, {
childList: true,
subtree: true
});
"""
# BED list is small enough to keep as dropdown
# Filter by default species to show only elements available for training
_init_bed = pipe.available_bed_element_names(DEFAULT_SPECIES)
# Default BigWig tracks
DEFAULT_BIGWIG_TRACKS = [
"ENCSR056HPM", # K562 RNA-seq
"ENCSR921NMD", # K562 DNAse
"ENCSR000DWD", # K562 H3k4me3
"ENCSR000AKO", # K562 CTCF
"ENCSR561FEE_P", # HepG2 RNA-seq
"ENCSR000EJV", # HepG2 DNAse
"ENCSR000AMP", # HepG2 H3k4me3
"ENCSR000BIE", # HepG2 CTCF
]
# Default BED elements
DEFAULT_BED_ELEMENTS = ["protein_coding_gene", "exon", "intron"]
# Get available BigWig tracks for default species and filter defaults
_init_bigwig = _get_bigwig_names(DEFAULT_SPECIES)
_init_bigwig_selected_ids = [
tid for tid in DEFAULT_BIGWIG_TRACKS if tid in _init_bigwig
]
# Format for display
_init_bigwig_selected = [
_format_track_for_display(tid) for tid in _init_bigwig_selected_ids
]
# Filter default BED elements to only those available
_init_bed_selected = [elem for elem in DEFAULT_BED_ELEMENTS if elem in _init_bed]
# Format BED elements for display: use tuples (display_name, value) for dropdown
_init_bed_choices = [
(_format_bed_element_for_display(elem), elem) for elem in _init_bed
]
_init_bed_selected_values = _init_bed_selected # Keep original values for selection
# Default coordinates per species
DEFAULT_COORDS = {
"human": {"chrom": "chr19", "start": 6_700_000, "end": 6_831_072},
"mouse": {"chrom": "chr1", "start": 9_880_168, "end": 10_142_312},
"drosophila_melanogaster": {"chrom": "chr2L", "start": 6_700_000, "end": 6_831_072},
"arabidopsis_thaliana": {"chrom": "chr1", "start": 13_135_095, "end": 13_397_239},
}
# Get default coordinates for default species
_default_coords = DEFAULT_COORDS.get(DEFAULT_SPECIES, DEFAULT_COORDS["human"])
# Format species names for display (replace underscores with spaces, capitalize first letter)
def _format_species_name(species: str) -> str:
"""Format species name for display."""
return species.replace("_", " ").capitalize()
# Get all available species and format them
_all_species = sorted(ASSEMBLY_TO_SPECIES.values())
_all_species_formatted = [_format_species_name(s) for s in _all_species]
_all_species_list = ", ".join(_all_species_formatted)
# Get species with BigWig tracks
_species_with_bigwigs = _get_species_with_bigwigs()
_bigwig_species_formatted = sorted(
[_format_species_name(s) for s in _species_with_bigwigs]
)
_bigwig_species_list = (
", ".join(_bigwig_species_formatted)
if _bigwig_species_formatted
else "None (BED elements only)"
)
with gr.Blocks(title="NTv3 Tracks Demo") as demo:
gr.Markdown(
f"""
🧬 NTv3 Tracks Demo
Predict and visualize functional genomics signals directly from DNA using
Nucleotide Transformer v3.
🧬 Functional genomics
📊 BigWig tracks
🎯 Genome annotation
🌍 Multi-species
⚡ Interactive visualization
1) Provide input
- Select a model and species
- Use genomic coordinates (chrom, start, end), or
- Paste a DNA sequence
2) Choose signals
- Search & select BigWig functional tracks
(RNA-seq, ChIP-seq, DNase…)
- Select BED genome annotation elements
(exons, introns, promoters…)
3) Explore
- View stacked tracks across the region
- Compare multiple tracks side-by-side
- Download plot and BigWig files
💡
Tip: The demo includes default settings that you can use
to get started, taking ~ 15 seconds to run for the example on human.
Available species: {_all_species_list}
Species with functional tracks: {_bigwig_species_list}
""",
elem_id="intro_markdown",
)
gr.Markdown("# Select NTv3 post-trained model")
# Model display names (without InstaDeepAI/ prefix) and their full IDs
MODEL_OPTIONS = {
"NTv3 650M (post)": "InstaDeepAI/NTv3_650M_post",
"NTv3 100M (post)": "InstaDeepAI/NTv3_100M_post",
}
# Reverse mapping: full ID -> display name
MODEL_ID_TO_DISPLAY = {v: k for k, v in MODEL_OPTIONS.items()}
# Get display name for current model
current_display_name = MODEL_ID_TO_DISPLAY.get(current_model_id, "NTv3 100M (pos)")
model_selector = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value=current_display_name,
label="Model",
)
model_status = gr.Markdown("", visible=False)
gr.Markdown("# Input DNA sequence")
# Get all available species from the pipeline and format for display
all_species = sorted(ASSEMBLY_TO_SPECIES.values())
# Format choices as (display_name, value) tuples so dropdown shows formatted names
# but returns actual species values
species_choices = [(_format_species_name(s), s) for s in all_species]
species = gr.Dropdown(
choices=species_choices,
value=DEFAULT_SPECIES,
label="Species",
)
# Radio buttons for input type selection
is_supported_default = DEFAULT_SPECIES in SPECIES_WITH_COORDINATE_SUPPORT
initial_input_type = (
"Use genomic coordinates" if is_supported_default else "Enter DNA sequence"
)
input_type = gr.Radio(
choices=["Use genomic coordinates", "Enter DNA sequence"],
value=initial_input_type,
label="Input method",
visible=is_supported_default, # Only show if species supports coordinates
)
# Coordinates section - visible only when "Use genomic coordinates" is selected
with gr.Group(
visible=is_supported_default
and initial_input_type == "Use genomic coordinates",
elem_id="coords_group",
) as coords_group:
gr.Markdown(
"**Genomic coordinates** (supported species: "
+ ", ".join(sorted(SPECIES_WITH_COORDINATE_SUPPORT))
+ ")"
)
with gr.Row():
# chrom = gr.Textbox(
# label="Chromosome", value=_default_coords["chrom"], elem_id="chromosome_input"
# )
chrom = gr.Dropdown(
label="Chromosome",
choices=[], # no predefined list
value=_default_coords["chrom"],
allow_custom_value=True, # user can type anything (e.g. chr19, scaffold_123)
filterable=True, # enables typing/search UI
elem_id="chromosome_input",
)
start = gr.Number(
label="Start", value=_default_coords["start"], precision=0, elem_id="start_input"
)
end = gr.Number(
label="End", value=_default_coords["end"], precision=0, elem_id="end_input"
)
# DNA sequence section - visible only when "Enter DNA sequence" is selected
# Using Textbox directly (not wrapped in Group) to avoid visual border/line
seq = gr.Textbox(
lines=4,
label="Input DNA sequence",
placeholder="ACGT...",
visible=initial_input_type == "Enter DNA sequence",
elem_id="dna_sequence_input",
)
def change_model(display_name: str, species: str):
"""Reload pipeline with new model."""
try:
# Convert display name to full model ID
if display_name in MODEL_OPTIONS:
model_id = MODEL_OPTIONS[display_name]
else:
# Fallback: assume it's already a model ID or custom value
model_id = display_name
load_pipeline(model_id, species)
# Update available tracks/elements
_get_bigwig_names(species) # warm cache
return gr.update(value="✅ Model loaded successfully"), gr.update(
visible=True
)
except Exception as e:
return gr.update(value=f"❌ Error loading model: {str(e)}"), gr.update(
visible=True
)
model_selector.change(
fn=change_model,
inputs=[model_selector, species],
outputs=[model_status, model_status],
)
gr.Markdown("# Select functional tracks")
# Button to download tracks metadata
def get_metadata_file_path():
"""Return path to metadata CSV file for download."""
csv_path = Path(__file__).parent / "data" / "functional_tracks_metadata.csv"
if csv_path.exists():
return str(csv_path)
return None
metadata_file_path = get_metadata_file_path()
download_metadata_btn = gr.Button(
"📋 Download metadata for all functional tracks",
variant="secondary",
visible=metadata_file_path is not None,
)
metadata_download_file = gr.File(
label="Tracks metadata",
visible=False,
)
def download_metadata():
"""Return metadata file for download."""
if metadata_file_path and Path(metadata_file_path).exists():
return gr.update(value=metadata_file_path, visible=True)
return gr.update(visible=False)
download_metadata_btn.click(
fn=download_metadata,
inputs=[],
outputs=[metadata_download_file],
)
bigwig_no_tracks_msg = gr.Markdown(
"⚠️ No functional genomic tracks available for this species "
"in the current model.",
visible=False,
)
bigwig_query = gr.Textbox(
label="Search functional tracks (auto-search while typing)",
placeholder="Type to search… (e.g. heart DNAse-seq)",
)
bigwig_results = gr.CheckboxGroup(
choices=[],
label="Results (click to add to Selected)",
)
bigwig_selected = gr.CheckboxGroup(
choices=_init_bigwig_selected,
value=_init_bigwig_selected,
label="Selected functional tracks (used for prediction)",
visible=bool(
_init_bigwig_selected
), # Show if there are default tracks, otherwise hidden
)
with gr.Row(visible=True) as bigwig_buttons_row:
bigwig_clear_btn = gr.Button("Clear search results")
bigwig_remove_btn = gr.Button("Remove all selected")
gr.Markdown("# Select genome annotation elements")
bed_elements = gr.Dropdown(
choices=_init_bed_choices,
value=_init_bed_selected_values if _init_bed_selected_values else [],
multiselect=True,
label="Genome annotation elements (search + select)",
elem_id="bed_elements_dropdown",
)
btn = gr.Button("Predict", elem_id="predict_btn")
predictions_heading = gr.Markdown(
"# NTv3 predictions for selected tracks and elements\n\n", visible=False
)
predictions_note = gr.Markdown(
"Note: NTv3 predictions are for the 37.5% center of the input sequence.",
visible=False,
)
plot = gr.Plot(label="", elem_id="tracks_plot", visible=False)
# State to store prediction output and selections for BigWig export
prediction_state = gr.State(value=None)
bigwig_selected_state = gr.State(value=[])
bed_elements_state = gr.State(value=[])
download_bigwig_btn = gr.Button(
"📥 Download tracks as BigWig files (ZIP)",
variant="secondary",
visible=False,
)
export_bigwig = gr.File(label="Download BigWig files", visible=False)
# --- wiring (live search + auto-add) ---
# Live search on every keystroke and when text changes (including deletion)
bigwig_query.input(
fn=search_bigwigs,
inputs=[species, bigwig_query, bigwig_selected],
outputs=[bigwig_results, bigwig_selected],
)
# Also trigger on change to catch deletions
bigwig_query.change(
fn=search_bigwigs,
inputs=[species, bigwig_query, bigwig_selected],
outputs=[bigwig_results, bigwig_selected],
)
# Helper function to get search results choices directly (without gr.update wrapper)
def _get_search_results_choices(
species: str, query: str, current_selected: list[str]
) -> list[str]:
"""Get search results choices as a list, excluding selected tracks."""
if query is None:
query = ""
query_stripped = query.strip()
if not query_stripped:
return []
names = _get_bigwig_names(species)
metadata = _load_track_metadata()
query_lower = query_stripped.lower()
# Extract track IDs from already selected tracks
selected_track_ids = set()
if current_selected:
selected_track_ids = {_extract_track_id(x) for x in current_selected}
# Build and filter results
matching = []
for track_id in names:
if track_id in selected_track_ids:
continue
display_name = metadata.get(track_id, track_id)
display_format = _format_track_for_display(track_id)
if (
query_lower in track_id.lower()
or query_lower in display_name.lower()
or query_lower in display_format.lower()
):
matching.append(display_format)
return matching[:SEARCH_MAX_RESULTS]
# Auto-add: whenever user checks items in results, add them to Selected,
# then clear results selection (so it feels like "click to add")
def _auto_add(
selected_now: list[str],
results_checked: list[str],
current_query: str,
current_results: list[str],
current_species: str,
):
"""Add selected tracks and refresh search results."""
upd = add_selected(selected_now, results_checked)
show_selected = bool(upd["value"])
# Get updated search results (excluding newly selected tracks)
new_choices = _get_search_results_choices(
current_species, current_query, upd["value"]
)
# Clear checked state by setting empty value
fresh_update = gr.update(choices=new_choices, value=[])
return gr.update(**upd, visible=show_selected), fresh_update
# Use a wrapper that ensures results are cleared before updating
def _auto_add_wrapper(
selected_now: list[str],
results_checked: list[str],
current_query: str,
current_results: list[str],
current_species: str,
):
"""Wrapper to ensure results are cleared after adding tracks."""
return _auto_add(
selected_now,
results_checked,
current_query,
current_results,
current_species,
)
bigwig_results.change(
fn=_auto_add_wrapper,
inputs=[bigwig_selected, bigwig_results, bigwig_query, bigwig_results, species],
outputs=[bigwig_selected, bigwig_results],
)
# Update selected tracks immediately when user unchecks items
def _update_selected_tracks(
selected_value: list[str], current_query: str, current_species: str
):
"""Update selected tracks when user checks/unchecks items directly."""
# selected_value contains only the currently checked items
# Update choices to match current selections
# (unchecked items are removed)
show_selected = bool(selected_value)
# Also update search results to reflect new selection
# (unchecked tracks can now appear in results)
search_updates = search_bigwigs(current_species, current_query, selected_value)
return (
gr.update(
choices=selected_value, value=selected_value, visible=show_selected
), # Update selected tracks
search_updates[0], # Update search results
)
bigwig_selected.change(
fn=_update_selected_tracks,
inputs=[bigwig_selected, bigwig_query, species],
outputs=[bigwig_selected, bigwig_results],
)
# Clear results list (handy when query is short)
def _clear_results():
return gr.update(choices=[], value=[]), gr.update(value="")
bigwig_clear_btn.click(
fn=_clear_results,
inputs=[],
outputs=[bigwig_results, bigwig_query],
)
# Remove: check items in Selected, then click Remove
bigwig_remove_btn.click(
fn=remove_selected,
inputs=[bigwig_selected, bigwig_selected],
outputs=[bigwig_selected],
)
species.change(
fn=reset_on_species_change,
inputs=[species],
outputs=[bigwig_query, bigwig_results, bigwig_selected],
)
# Update coordinates visibility and values when species changes
def update_on_species_change(species: str, input_type_val: str):
"""Update coordinates visibility and values when species changes."""
is_supported = species in SPECIES_WITH_COORDINATE_SUPPORT
has_bigwigs = _has_bigwigs(species)
coords = DEFAULT_COORDS.get(species, DEFAULT_COORDS["human"])
# Show coordinates only if species is supported AND input type is coordinates
use_coords = input_type_val == "Use genomic coordinates"
show_coords = is_supported and use_coords
show_seq = not show_coords
# Format available tracks for display if species has bigwigs
formatted_tracks = []
default_formatted = []
show_selected_tracks = False
if has_bigwigs:
try:
track_ids = _get_bigwig_names(species)
formatted_tracks = [_format_track_for_display(tid) for tid in track_ids]
default_track_ids = [tid for tid in DEFAULT_BIGWIG_TRACKS if tid in track_ids]
default_formatted = [_format_track_for_display(tid) for tid in default_track_ids]
show_selected_tracks = bool(default_formatted)
except Exception:
pass
# Get BED elements available for this species
bed_element_names = _get_bed_element_names(species)
# Filter default BED elements to only those available for this species
default_bed_selected = [
elem for elem in DEFAULT_BED_ELEMENTS if elem in bed_element_names
]
# Format BED elements for display: use tuples (display_name, value)
bed_element_choices = [
(_format_bed_element_for_display(elem), elem) for elem in bed_element_names
]
return (
gr.update(visible=show_coords, value=coords["chrom"]),
gr.update(visible=show_coords, value=coords["start"]),
gr.update(visible=show_coords, value=coords["end"]),
gr.update(
visible=is_supported,
value="Use genomic coordinates"
if is_supported
else "Enter DNA sequence",
), # Update input_type radio
gr.update(visible=show_coords), # Show/hide coords_group
gr.update(visible=show_seq), # Show/hide seq
gr.update(
visible=not has_bigwigs
), # Show "no tracks" message if no bigwigs
gr.update(
visible=show_selected_tracks,
choices=formatted_tracks,
value=default_formatted,
), # Show bigwig selection with defaults if available
gr.update(visible=has_bigwigs), # Show bigwig query if available
gr.update(visible=has_bigwigs), # Show bigwig results if available
gr.update(visible=has_bigwigs), # Show bigwig buttons if available
gr.update(
choices=bed_element_choices,
value=default_bed_selected,
), # Update BED elements dropdown with species-specific elements
)
# Update input type radio visibility and value when species changes
def update_input_type_on_species_change(species: str):
"""Update input type radio when species changes."""
is_supported = species in SPECIES_WITH_COORDINATE_SUPPORT
# If species doesn't support coordinates, default to sequence input
default_value = (
"Use genomic coordinates" if is_supported else "Enter DNA sequence"
)
return gr.update(visible=is_supported, value=default_value)
# Update input visibility when radio button changes
def update_input_visibility(input_type_val: str, species: str):
"""Update input visibility when radio button changes."""
is_supported = species in SPECIES_WITH_COORDINATE_SUPPORT
use_coords = input_type_val == "Use genomic coordinates"
show_coords = is_supported and use_coords
return (
gr.update(visible=show_coords), # coords_group
gr.update(visible=not show_coords), # seq
gr.update(visible=show_coords), # chrom
gr.update(visible=show_coords), # start
gr.update(visible=show_coords), # end
)
species.change(
fn=update_input_type_on_species_change,
inputs=[species],
outputs=[input_type],
)
species.change(
fn=update_on_species_change,
inputs=[species, input_type],
outputs=[
chrom,
start,
end,
input_type,
coords_group,
seq,
bigwig_no_tracks_msg,
bigwig_selected,
bigwig_query,
bigwig_results,
bigwig_buttons_row,
bed_elements,
],
)
input_type.change(
fn=update_input_visibility,
inputs=[input_type, species],
outputs=[coords_group, seq, chrom, start, end],
)
def show_prediction_ui():
"""Show prediction UI elements immediately when Predict is clicked."""
return (
gr.update(visible=True), # predictions_heading
gr.update(visible=True), # predictions_note
gr.update(visible=True), # plot (shows progress bar)
gr.update(visible=False), # download_bigwig_btn (will be shown after prediction)
)
# Show UI elements immediately when button is clicked
btn.click(
fn=show_prediction_ui,
inputs=[],
outputs=[
predictions_heading,
predictions_note,
plot,
download_bigwig_btn,
],
).then(
fn=predict,
inputs=[
seq,
species,
chrom,
start,
end,
input_type,
bigwig_selected,
bed_elements,
],
outputs=[
predictions_heading,
predictions_note,
plot,
download_bigwig_btn,
# meta,
prediction_state,
bigwig_selected_state,
bed_elements_state,
],
api_name="predict",
)
def download_bigwig_zip(out, bw_selected, bed_selected):
"""Create and return BigWig zip file."""
try:
zip_path = create_bigwig_zip(out, bw_selected, bed_selected)
return gr.update(value=zip_path, visible=True)
except ImportError:
raise gr.Error(
"pyBigWig is required for BigWig export. "
"Install with: pip install pyBigWig"
)
except Exception as exc:
raise gr.Error(f"Error creating BigWig files: {str(exc)}")
download_bigwig_btn.click(
fn=download_bigwig_zip,
inputs=[prediction_state, bigwig_selected_state, bed_elements_state],
outputs=[export_bigwig],
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False,
show_error=True,
allowed_paths=[tempfile.gettempdir()],
css=CSS,
js=JS,
)