ntv3_benchmark / src /streamlit_app.py
MidAtBest's picture
fix: fix naming
8dd6c47
from typing import List
import os
import pandas as pd
import streamlit as st
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np
# ---------------------------------------------------------------------
# Page config (must be the first Streamlit command)
# ---------------------------------------------------------------------
st.set_page_config(
page_title="NTv3 Benchmark",
layout="wide",
)
# ---------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------
COLORS = {
# Primary colors 1 (our models)
'blue_0': '#004697', # Darkest allowable blue
'blue_1': '#3973fc', # Main blue
'blue_2': '#7ea4fc', # Medium blue
'blue_3': '#c3d5fc', # Light blue (lightest allowable blue)
# Secondary colors 1
'red_1': '#ff554d', # Medium red
'red_2': '#ffe0de', # Light red
# Primary colors 2
'green_1': '#00b050', # Darkest green
'green_2': '#92d050', # Medium green
'green_3': '#c6e0b4', # Light green (lightest allowable green)
# Secondary colors 2
'gold_1': '#fdb932',
# Tertiary colors
'orange_1': '#ff975e',
'purple_1': '#9a6ce4',
'purple_2': '#bb9aef', # Medium purple
'purple_3': '#ceb5f5', # Light purple (lightest allowable purple)
# Grays (other models)
'gray_1': '#808080', # Darkest gray (use as a last resort)
'gray_2': '#b3b3b3', # Medium gray (start with this as the darkest when possible)
'gray_3': '#e6e6e6', # Lightest gray
'gray_4': '#ffffff', # It's actually just white (use as a last resort)
# If all other options are exhausted
'cyan_1': '#0096b4', # Darkest teal
'cyan_2': '#28bed2', # Medium cyan
'cyan_3': '#8cdceb', # Lightest cyan
'magenta_1': '#b428a0', # Darkest magenta
'magenta_2': '#dc50be', # Medium pink
'magenta_3': '#f5a0dc', # Lightest pink
'yellow_1': '#c8aa00', # Darkest yellow
'yellow_2': '#ffd200', # Medium yellow
'yellow_3': '#fff08c', # Lightest yellow
}
ASSAY_TYPE_MAPPING = {
'ATAC-seq': 'chromatin accessibility',
'DNase-seq': 'chromatin accessibility',
'Histone ChIP-seq': 'histone modifications',
'TF ChIP-seq': 'chromatin accessibility',
'PRO-cap': 'transcription initiation',
'eCLIP': 'RNA binding sites',
'RNA-seq': 'gene expression',
'ribo-seq': 'mRNA translation',
'Annotation': 'genome annotation',
"Exon": "exon",
"Intron": "intron",
"Splice acceptor": "splice acceptor",
"Start codon": "start codon",
}
ASSAY_COLORS = {
'chromatin accessibility': '#004697',
'histone modifications': '#cc0000',
'transcription initiation': '#ff9900',
'RNA binding sites': '#9933cc',
'gene expression': '#009900',
'mRNA translation': '#ff6699',
'genome annotation': '#ffcc00',
"intron": '#004697',
"exon": '#cc0000',
"splice acceptor": '#ff9900',
"start codon": '#9933cc',
}
ASSAY_COLORS["other"] = "#808080"
MODEL_COLORS = {
"NTv3 650M (pos)": COLORS['blue_0'],
'NTv3 650M (pre)': COLORS['blue_1'], # #3973fc (Darkest blue)
'NTv3 100M (pre)': COLORS['blue_2'], # #7ea4fc (Medium blue)
'NTv3 8M (pre)': COLORS['blue_3'], # #c3d5fc (Light blue)
'Evo2 1B': COLORS['green_3'], # #b3b3b3 (Medium gray)
"NTv2 500M": COLORS['gray_1'],
"BPNet arch. 6M": COLORS['cyan_1'],
"Residual CNN 44M": COLORS['magenta_1'],
"PlantCAD2 88M": COLORS["purple_1"],
"Caduceus 7M": COLORS["purple_2"],
"HyenaDNA 7M": COLORS["yellow_2"]
}
MODEL_TRAINING_STATUS = {
"NTv3 650M (pos)": "POS",
"NTv3 650M (pre)": "PRE",
"NTv3 100M (pre)": "PRE",
"NTv3 8M (pre)": "PRE",
"Residual CNN 44M": "SCRATCH",
"Caduceus 7M": "PRE",
"Evo2 1B": "PRE",
"NTv2 500M": "PRE",
"BPNet arch. 6M": "SCRATCH",
"PlantCAD2 88M": "PRE",
"HyenaDNA 7M": "PRE"
}
MODEL_GPU_MULTIPLIER = {
"Evo2 1B": 8, # trained on 8 GPUs
}
MODEL_NAMES = list(MODEL_COLORS.keys())
PLANT_SPECIES = ["tomato", "rice", "maize", "arabidopsis"]
ANIMAL_SPECIES = ["human", "chicken", "cattle"]
SPECIES_GROUPS = {
"Plants": PLANT_SPECIES,
"Animals": ANIMAL_SPECIES, # (your code calls these HUMAN_SPECIES, but they’re the “animal” set)
}
_LAST_UPDATED = "Dec 10, 2025"
_INTRO = """
The **NTv3 Benchmark** is a curated benchmark of 106 long-range genomic datasets
designed to evaluate models under realistic 32 kb input, single-base-pair output settings.
The dataset spans two complementary task families: genome annotation (exon, intron, splice acceptor, start codon)
and functional-regulatory prediction, which includes diverse experimental tracks such as chromatin accessibility,
histone modifications, transcription initiation (PRO-cap), RNA binding (eCLIP), gene expression (RNA-seq),
and translation (Ribo-seq).
Data are drawn from a phylogenetically diverse set of species, including organisms seen during post-training
(human, chicken, arabidopsis, rice, maize) and entirely unseen species (cattle, tomato), with careful curation
to avoid data leakage. This design allows the dataset to probe long-range sequence-to-function mapping,
cross-species generalization, and transfer across heterogeneous regulatory modalities,
including assays not present in prior multispecies training corpora. By standardizing sequence length,
resolution, and evaluation metrics across all tracks, the NTv3 Benchmark provides a controlled dataset
for comparing representation quality across genomic foundation models.
The metrics used are:
- **Pearson correlations (multi-assay)**: per-dataset scores across species and models for functional tracks.
- **MCC (bed tracks)**: per-track MCC values across species and models for gene annotation tracks.
"""
HERE = os.path.dirname(os.path.abspath(__file__)) # /app/src
PROJECT_ROOT = os.path.dirname(HERE) # /app
DATA_DIR = os.path.join(PROJECT_ROOT, "data")
SINGLE_TABLE_PATH = os.path.join(DATA_DIR, "ntv3_benchmark_results.csv")
# ---------------------------------------------------------------------
# Data loading & preprocessing
# ---------------------------------------------------------------------
@st.cache_data
def load_raw_data():
df = pd.read_csv(SINGLE_TABLE_PATH)
df.columns = [c.strip() for c in df.columns]
return df
def _normalize_training_time_to_gpu_hours(df: pd.DataFrame) -> pd.DataFrame:
"""
Your new column is `running_time`. In your sample it looks like seconds
(e.g. 317034 ~= 88 hours). We'll convert to hours if values look like seconds.
"""
if "running_time" not in df.columns:
return df
rt = pd.to_numeric(df["running_time"], errors="coerce")
# Heuristic: if median is huge, it's probably seconds -> convert to hours
# (88 hours = 316800 seconds is a typical-looking value in your sample)
if rt.dropna().median() > 10_000:
df["GPU hours"] = rt / 3600.0
else:
df["GPU hours"] = rt.astype(float)
return df
def _best_step_time_to_hours(s: pd.Series) -> pd.Series:
"""
Converts strings like '3 days 04:26:26.467000' to hours (float).
Works with pandas Timedelta parsing.
"""
td = pd.to_timedelta(s, errors="coerce")
return td.dt.total_seconds() / 3600.0
@st.cache_data
def load_expanded_data():
df = load_raw_data().copy()
df = df.rename(columns={"Metric": "Score", "model_name": "Model"})
df["Score"] = pd.to_numeric(df["Score"], errors="coerce")
if "best_step" in df.columns:
df["best_step"] = pd.to_numeric(df["best_step"], errors="coerce")
if "best_step_time" in df.columns:
df["best_step_time_hours"] = _best_step_time_to_hours(df["best_step_time"])
else:
df["best_step_time_hours"] = np.nan
is_annot = df.get("assay_type", "").astype(str).eq("Annotation")
pearson_raw = df[~is_annot].copy()
mcc_raw = df[is_annot].copy()
# -------------------------
# Functional Tracks (Pearson)
# -------------------------
pearson_group_cols = ["species", "datasets", "Model"]
if "assay_type" in pearson_raw.columns:
pearson_group_cols.append("assay_type")
pearson_df = (
pearson_raw
.groupby(pearson_group_cols, as_index=False, dropna=False)
.agg({
"Score": "mean",
"best_step": "mean",
"best_step_time_hours": "mean",
})
)
# ✅ merge track_name_clean WHILE assay_type is still raw
if "track_name_clean" in pearson_raw.columns:
map_keys = ["species", "datasets"]
if "assay_type" in pearson_raw.columns:
map_keys.append("assay_type")
track_map = (
pearson_raw[map_keys + ["track_name_clean"]]
.dropna(subset=["track_name_clean"])
.drop_duplicates()
)
pearson_df = pearson_df.merge(track_map, on=map_keys, how="left")
# ✅ now it’s safe to map assay_type to categories
if "assay_type" in pearson_df.columns:
pearson_df["assay_type"] = (
pearson_df["assay_type"].astype(str).map(ASSAY_TYPE_MAPPING).fillna("Other")
)
# -------------------------
# Genome Annotation (MCC)
# -------------------------
mcc_df = (
mcc_raw
.groupby(["species", "datasets", "Model"], as_index=False, dropna=False)
.agg({
"Score": "mean",
"best_step": "mean",
"best_step_time_hours": "mean",
})
)
return pearson_df, mcc_df
_PEARSON_DF, _MCC_DF = load_expanded_data()
# Global sets (we'll further filter per-benchmark below)
_ALL_SPECIES = sorted(
set(_PEARSON_DF["species"].unique()).union(_MCC_DF["species"].unique())
)
_ALL_ASSAYS = (
sorted(_PEARSON_DF["assay_type"].dropna().unique())
if "assay_type" in _PEARSON_DF.columns
else []
)
_ALL_MODELS = MODEL_NAMES[:]
_BENCHMARKS = {
"Functional Tracks": {
"df": _PEARSON_DF,
"metric_label": "Pearson correlation",
"has_assay_type": True,
},
"Genome Annotation": {
"df": _MCC_DF,
"metric_label": "MCC",
"has_assay_type": False,
},
}
# ---------------------------------------------------------------------
# Computation helpers
# ---------------------------------------------------------------------
def filter_base_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
) -> pd.DataFrame:
cfg = _BENCHMARKS[benchmark_name]
df = cfg["df"].copy()
# Species filter
if selected_species:
df = df[df["species"].isin(selected_species)]
# Assay type filter (Pearson only)
if cfg.get("has_assay_type", False) and selected_assays and "assay_type" in df.columns:
df = df[df["assay_type"].isin(selected_assays)]
# Dataset / bed track filter (for MCC, but safe to apply generally)
if selected_datasets and "datasets" in df.columns:
df = df[df["datasets"].isin(selected_datasets)]
# Model filter
if selected_models:
df = df[df["Model"].isin(selected_models)]
return df
def build_leaderboard(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
) -> pd.DataFrame:
df = filter_base_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
)
if df.empty:
return pd.DataFrame(columns=["Model", "Model Type", "Num entries", "Mean score"])
agg = (
df.groupby("Model")["Score"]
.mean()
.reset_index()
.rename(columns={"Score": "Mean score"})
)
agg["Mean score"] = agg["Mean score"].round(3)
agg["Num entries"] = (
df.groupby("Model")["Score"].count().reindex(agg["Model"]).values
)
# 👇 Add training regime column
agg["Training"] = agg["Model"].map(MODEL_TRAINING_STATUS).fillna("UNKNOWN")
# Sort by performance
agg = agg.sort_values("Mean score", ascending=False).reset_index(drop=True)
# Column order
agg = agg[["Model", "Training", "Num entries", "Mean score"]]
# Ensure the index starts with 1
agg.index += 1
return agg
def build_bar_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
) -> pd.DataFrame:
"""For now, just one bar per model (same as leaderboard)."""
return build_leaderboard(
benchmark_name, selected_species, selected_assays, selected_models, selected_datasets
)
def build_category_model_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
) -> pd.DataFrame:
"""
Mean score per (category, Model) after applying the same filters.
Category = assay_type (Functional Tracks) or datasets (Genome Annotation).
"""
cfg = _BENCHMARKS[benchmark_name]
df = filter_base_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
)
if df.empty:
return pd.DataFrame(columns=["Category", "Model", "Mean score"])
# Pick the right breakdown column
if cfg.get("has_assay_type", False) and "assay_type" in df.columns:
category_col = "assay_type"
category_label = "Assay type"
else:
category_col = "datasets"
category_label = "Dataset"
if category_col not in df.columns:
return pd.DataFrame(columns=["Category", "Model", "Mean score"])
out = (
df.groupby([category_col, "Model"], as_index=False)["Score"]
.mean()
.rename(columns={category_col: "Category", "Score": "Mean score"})
)
out["Mean score"] = out["Mean score"].round(3)
out.attrs["category_label"] = category_label # for nicer axis title
return out
def plot_breakdown_facets_sorted_models(
breakdown_df: pd.DataFrame,
metric_label: str,
height: int = 420,
):
categories = list(breakdown_df["Category"].dropna().unique())
categories = sorted(categories)
n = len(categories)
if n == 0:
return None
rows = 1
cols = n # 👈 everything in one row
# Global y-range (consistent scale)
y_min = breakdown_df["Mean score"].min()
y_max = breakdown_df["Mean score"].max()
pad = 0.05 * (y_max - y_min if y_max > y_min else 1.0)
y_range = [y_min - pad, y_max + pad]
fig = make_subplots(
rows=rows,
cols=cols,
subplot_titles=categories,
shared_yaxes=True,
horizontal_spacing=0.04, # tighter spacing
)
for i, cat in enumerate(categories):
r = (i // cols) + 1
c = (i % cols) + 1
sub = (
breakdown_df[breakdown_df["Category"] == cat]
.sort_values("Mean score", ascending=True)
)
fig.add_trace(
go.Bar(
x=sub["Model"],
y=sub["Mean score"],
marker_color=[MODEL_COLORS.get(m, "#808080") for m in sub["Model"]],
showlegend=False,
),
row=r,
col=c,
)
fig.update_xaxes(showticklabels=False, title_text="", row=r, col=c)
fig.update_yaxes(range=y_range, title_text="", row=r, col=c) # 👈 apply range
fig.update_layout(
height=height,
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
margin=dict(t=60, l=10, r=10, b=10),
)
# Single y-axis label on the leftmost panel
fig.update_yaxes(title_text=metric_label, row=1, col=1)
return fig
def build_pairwise_scatter_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
model_a: str,
model_b: str,
) -> pd.DataFrame:
cfg = _BENCHMARKS[benchmark_name]
models_for_filter = (
list(set(selected_models + [model_a, model_b]))
if selected_models else [model_a, model_b]
)
df = filter_base_df(
benchmark_name,
selected_species,
selected_assays,
models_for_filter,
selected_datasets,
)
if df.empty:
return pd.DataFrame()
# ---- define "track identity" for head-to-head ----
# Always use datasets for the identity (x/y points)
track_cols = ["datasets"]
if cfg.get("has_assay_type", False) and "assay_type" in df.columns:
track_cols = ["assay_type", "datasets"]
keep_species = "species" in df.columns and (selected_species is None or len(selected_species) != 1)
id_cols = (["species"] if keep_species else []) + track_cols
wide = (
df[df["Model"].isin([model_a, model_b])]
.pivot_table(index=id_cols, columns="Model", values="Score", aggfunc="mean")
.reset_index()
)
if model_a not in wide.columns or model_b not in wide.columns:
return pd.DataFrame()
wide = wide.dropna(subset=[model_a, model_b])
# Nice display label: use datasets (not track_name_clean)
if "assay_type" in wide.columns:
wide["Track"] = wide["assay_type"].astype(str) + " / " + wide["datasets"].astype(str)
else:
wide["Track"] = wide["datasets"].astype(str)
wide = wide.rename(columns={model_a: "Model A", model_b: "Model B"})
# ---- Pearson-only: merge track_name_clean for hover ----
if benchmark_name == "Functional Tracks" and "track_name_clean" in df.columns:
merge_keys = id_cols.copy() # species? + assay_type? + datasets
track_map = (
df[merge_keys + ["track_name_clean"]]
.dropna(subset=["track_name_clean"])
.drop_duplicates()
)
wide = wide.merge(track_map, on=merge_keys, how="left")
return wide
def build_violin_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
) -> pd.DataFrame:
# Use the same base filtering, but keep all per-track rows
df = filter_base_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
)
# Keep only needed columns
keep = ["Model", "Score"]
for col in ["species", "assay_type", "datasets"]:
if col in df.columns:
keep.append(col)
return df[keep].copy()
def build_convergence_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
x_mode: str = "best_step", # "best_step" | "best_step_time"
) -> pd.DataFrame:
df = filter_base_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
)
if df.empty:
return pd.DataFrame(columns=["Model", "X", "Performance"])
# Mean performance per model
out = (
df.groupby("Model", as_index=False)
.agg({"Score": "mean"})
.rename(columns={"Score": "Performance"})
)
# -------------------------
# X axis selection
# -------------------------
if x_mode == "Steps (billions)":
if "best_step" not in df.columns:
return pd.DataFrame(columns=["Model", "X", "Performance"])
x = (
df.groupby("Model", as_index=False)["best_step"]
.mean()
.rename(columns={"best_step": "X"})
)
else: # best_step_time (GPU hours)
if "best_step_time_hours" not in df.columns:
return pd.DataFrame(columns=["Model", "X", "Performance"])
x = (
df.groupby("Model", as_index=False)["best_step_time_hours"]
.mean()
.rename(columns={"best_step_time_hours": "X"})
)
# 👇 Apply GPU multiplier (Evo2 uses 8 GPUs)
gpu_multiplier = {
"Evo2 1B": 8,
}
x["X"] = x.apply(
lambda r: r["X"] * gpu_multiplier.get(r["Model"], 1),
axis=1,
)
# Merge + clean
out = out.merge(x, on="Model", how="left")
out = out.dropna(subset=["X", "Performance"])
out["Performance"] = out["Performance"].round(3)
return out
# ---------------------------------------------------------------------
# UI helpers
# ---------------------------------------------------------------------
def sidebar_toggle(label: str, value: bool = False, key: str | None = None) -> bool:
"""
Wrapper to use st.sidebar.toggle when available, otherwise fall back to checkbox.
This makes the app compatible with older Streamlit versions on Hugging Face.
"""
toggle_fn = getattr(st.sidebar, "toggle", None)
if toggle_fn is not None:
return toggle_fn(label, value=value, key=key)
# Fallback for older Streamlit versions
return st.sidebar.checkbox(label, value=value, key=key)
# ---------------------------------------------------------------------
# Streamlit UI
# ---------------------------------------------------------------------
def main():
st.title("🧬 NTv3 Benchmark")
st.markdown(_INTRO)
st.markdown(f"_Last updated: **{_LAST_UPDATED}**_")
# --- Sidebar filters ---
st.sidebar.header("Filters")
# Benchmark
benchmark_name = st.sidebar.selectbox(
"Benchmark",
options=list(_BENCHMARKS.keys()),
index=0,
)
cfg = _BENCHMARKS[benchmark_name]
df_bench = cfg["df"]
# Species toggles, but only for species present in this benchmark
st.sidebar.subheader("Species")
# Toggle: Plants vs Animals
species_group = st.sidebar.radio(
"Group",
options=["Animals", "Plants"],
index=0,
horizontal=True,
key=f"species_group_{benchmark_name}",
)
available_species_all = sorted(df_bench["species"].unique())
allowed_species = set(SPECIES_GROUPS[species_group]).intersection(available_species_all)
available_species = sorted(allowed_species)
selected_species: List[str] = []
for sp in available_species:
if sidebar_toggle(sp, value=True, key=f"species_{benchmark_name}_{species_group}_{sp}"):
selected_species.append(sp)
# (Optional) If no species exist for that group in this benchmark
if not available_species:
st.sidebar.info(f"No {species_group.lower()} species available for this benchmark.")
# Assay toggles (Pearson only), based on filtered species
if cfg.get("has_assay_type", False):
st.sidebar.subheader("Assay types")
if selected_species:
df_for_assays = df_bench[df_bench["species"].isin(selected_species)]
else:
df_for_assays = df_bench
available_assays = (
sorted(df_for_assays["assay_type"].dropna().unique())
if "assay_type" in df_for_assays.columns
else []
)
selected_assays: List[str] = []
for assay in available_assays:
if sidebar_toggle(assay, value=True, key=f"assay_{benchmark_name}_{assay}"):
selected_assays.append(assay)
else:
selected_assays = []
# Bed track / dataset toggles (MCC only), based on species selection
selected_datasets: List[str] = []
if benchmark_name == "Genome Annotation":
st.sidebar.subheader("Genome annotations")
if selected_species:
df_for_tracks = df_bench[df_bench["species"].isin(selected_species)]
else:
df_for_tracks = df_bench
available_datasets = sorted(df_for_tracks["datasets"].unique())
for ds in available_datasets:
if sidebar_toggle(ds, value=True, key=f"dataset_{benchmark_name}_{ds}"):
selected_datasets.append(ds)
else:
selected_datasets = []
# Model toggles (we keep all models in MODEL_NAMES; filters + data will prune)
st.sidebar.subheader("Models")
selected_models: List[str] = []
for model in _ALL_MODELS:
if sidebar_toggle(model, value=True, key=f"model_{model}"):
selected_models.append(model)
# -------------------------
# ✅ Validation: require ≥1 selection per relevant category
# -------------------------
missing = []
# Always required
if not selected_species:
missing.append("Species")
if not selected_models:
missing.append("Models")
# Required depending on benchmark
if cfg.get("has_assay_type", False) and not selected_assays:
missing.append("Assay types")
if benchmark_name == "Genome Annotation" and not selected_datasets:
missing.append("Genome annotations")
if missing:
# Show a single message and prevent *any* further display
st.error(
"Please select at least one item in each category. Currently missing: "
+ ", ".join(missing)
+ "."
)
st.stop()
# --- Main content ---
leaderboard_df = build_leaderboard(
benchmark_name, selected_species, selected_assays, selected_models, selected_datasets
)
bar_df = build_bar_df(
benchmark_name, selected_species, selected_assays, selected_models, selected_datasets
)
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("🏅 Leaderboard")
st.write("\n") # spacer to match plotly padding
st.write("\n")
st.write("\n")
if leaderboard_df.empty:
st.info("No data for the selected filters.")
else:
st.dataframe(leaderboard_df, use_container_width=True)
with col2:
st.subheader("📈 Mean score per model")
if bar_df.empty:
st.info("No data for the selected filters.")
else:
# Order models by performance (least -> most)
bar_df = bar_df.sort_values("Mean score", ascending=True)
model_order = bar_df["Model"].tolist()
fig = px.bar(
bar_df,
x="Model",
y="Mean score",
color="Model",
color_discrete_map=MODEL_COLORS,
category_orders={"Model": model_order},
)
fig.update_layout(
barmode="group",
height=500,
xaxis_title="",
yaxis_title=cfg["metric_label"],
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
bargap=0.08,
)
fig.update_xaxes(showticklabels=False)
st.plotly_chart(fig, use_container_width=True)
# --- Breakdown plot: assay_type (Functional Tracks) OR datasets (Genome Annotation) ---
breakdown_df = build_category_model_df(
benchmark_name, selected_species, selected_assays, selected_models, selected_datasets
)
type_of_data = "assay type" if benchmark_name == "Functional Tracks" else "gene annotation"
st.subheader(f"🧪 Mean score by {type_of_data}")
if breakdown_df.empty:
st.info("No data for the selected filters.")
else:
fig_breakdown = plot_breakdown_facets_sorted_models(
breakdown_df,
metric_label=cfg["metric_label"],
height=300,
)
st.plotly_chart(fig_breakdown, use_container_width=True)
# ------------------------------------------------------------------
# Model comparison: Head-to-head (left) + Convergence (right)
# ------------------------------------------------------------------
left, right = st.columns([1, 1], gap="large")
with left:
st.markdown("#### ⚖️ Head-to-head (per track)")
model_picker_options = selected_models if selected_models else _ALL_MODELS
default_a = model_picker_options[0] if model_picker_options else _ALL_MODELS[0]
default_b = model_picker_options[1] if len(model_picker_options) > 1 else (
_ALL_MODELS[1] if len(_ALL_MODELS) > 1 else default_a
)
cA, cB = st.columns([1, 1])
with cA:
model_a = st.selectbox(
"Model A (y-axis)",
options=model_picker_options,
index=model_picker_options.index(default_a) if default_a in model_picker_options else 0,
key=f"pair_model_a_{benchmark_name}",
)
with cB:
b_options = [m for m in model_picker_options if m != model_a] or model_picker_options
model_b = st.selectbox(
"Model B (x-axis)",
options=b_options,
index=0,
key=f"pair_model_b_{benchmark_name}",
)
scatter_df = build_pairwise_scatter_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
model_a,
model_b,
)
if scatter_df.empty:
st.info("No overlapping tracks for the selected filters (or one model is missing values).")
else:
min_v = float(min(scatter_df["Model A"].min(), scatter_df["Model B"].min()))
max_v = float(max(scatter_df["Model A"].max(), scatter_df["Model B"].max()))
pad = 0.05 * (max_v - min_v if max_v > min_v else 1.0)
axis_range = [min_v - pad, max_v + pad]
tick_step = (axis_range[1] - axis_range[0]) / 5
hover_cols = ["datasets"]
if benchmark_name == "Functional Tracks" and "track_name_clean" in scatter_df.columns:
hover_cols.append("track_name_clean")
else:
hover_cols.append("datasets")
color_col = "assay_type" if "assay_type" in scatter_df.columns else "datasets"
fig_scatter = px.scatter(
scatter_df,
x="Model B",
y="Model A",
color=color_col,
color_discrete_map=ASSAY_COLORS,
hover_name="Track",
hover_data=hover_cols,
)
fig_scatter.add_shape(
type="line",
x0=axis_range[0], y0=axis_range[0],
x1=axis_range[1], y1=axis_range[1],
xref="x", yref="y",
line=dict(color="red", dash="dot", width=2),
)
fig_scatter.update_layout(
height=550,
margin=dict(l=60, r=20, t=20, b=60),
xaxis=dict(
title=f"{model_b}{cfg['metric_label']}",
range=axis_range,
dtick=tick_step,
constrain="domain",
),
yaxis=dict(
title=f"{model_a}{cfg['metric_label']}",
range=axis_range,
dtick=tick_step,
scaleanchor="x",
scaleratio=1,
constrain="domain",
),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)
fig_scatter.update_layout(
legend=dict(
title="Assay type" if benchmark_name == "Functional Tracks" else "Genome Annotation",
x=0.98,
y=0.1,
xanchor="right",
yanchor="bottom",
bgcolor="rgba(255,255,255,0.2)", # semi-transparent white
bordercolor="rgba(0,0,0,0.2)",
borderwidth=1,
)
)
st.plotly_chart(fig_scatter, use_container_width=True)
with right:
st.markdown("#### ⏱️ Time to convergence")
x_mode = st.selectbox(
"X-axis",
options=["GPU (hours)", "Steps (billions)"],
index=0,
key=f"conv_x_mode_{benchmark_name}",
)
conv_df = build_convergence_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
x_mode=x_mode,
)
if conv_df.empty:
st.info("No convergence data found for the selected filters / x-axis mode.")
else:
fig_conv = px.scatter(
conv_df,
x="X",
y="Performance",
text="Model",
color="Model",
color_discrete_map=MODEL_COLORS,
hover_data=["Model", "X", "Performance"],
)
fig_conv.update_layout(
height=550,
xaxis_title=("GPU (hours)" if x_mode == "GPU (hours)" else x_mode),
yaxis_title=cfg["metric_label"],
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
showlegend=False, # ✅ no legend
)
fig_conv.update_traces(
marker=dict(size=14), # 👈 bigger dots
textposition="top center",
)
# Log scale only makes sense for hours (and sometimes best_step)
if x_mode in ["GPU (hours)"]:
fig_conv.update_xaxes(
type="log",
dtick=1,
minor=dict(ticks="", showgrid=False),
)
st.plotly_chart(fig_conv, use_container_width=True)
# ------------------------------------------------------------------
# Violin (full width, below)
# ------------------------------------------------------------------
st.subheader("🎻 Performance comparaison across tracks")
violin_df = build_violin_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
)
if violin_df.empty:
st.info("No data for the selected filters.")
else:
model_order = (
violin_df
.groupby("Model")["Score"]
.median()
.sort_values(ascending=True)
.index
.tolist()
)
fig_violin = px.violin(
violin_df,
x="Model",
y="Score",
color="Model",
color_discrete_map=MODEL_COLORS,
box=True,
points=False,
category_orders={"Model": model_order},
)
fig_violin.update_layout(
height=650,
xaxis_title="",
yaxis_title=cfg["metric_label"],
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
showlegend=False,
)
fig_violin.update_traces(
box_visible=True,
meanline_visible=False,
)
st.plotly_chart(fig_violin, use_container_width=True)
if __name__ == "__main__":
main()