ntv3_benchmark / src /streamlit_app.py
MidAtBest's picture
feat: add multiple example plots
e9035bc
raw
history blame
29.5 kB
from typing import List
import os
import pandas as pd
import streamlit as st
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
# ---------------------------------------------------------------------
# Page config (must be the first Streamlit command)
# ---------------------------------------------------------------------
st.set_page_config(
page_title="NTv3 Benchmark",
layout="wide",
)
# ---------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------
COLORS = {
# Primary colors 1 (our models)
'blue_0': '#004697', # Darkest allowable blue
'blue_1': '#3973fc', # Main blue
'blue_2': '#7ea4fc', # Medium blue
'blue_3': '#c3d5fc', # Light blue (lightest allowable blue)
# Secondary colors 1
'red_1': '#ff554d', # Medium red
'red_2': '#ffe0de', # Light red
# Primary colors 2
'green_1': '#00b050', # Darkest green
'green_2': '#92d050', # Medium green
'green_3': '#c6e0b4', # Light green (lightest allowable green)
# Secondary colors 2
'gold_1': '#fdb932',
# Tertiary colors
'orange_1': '#ff975e',
'purple_1': '#9a6ce4',
'purple_2': '#bb9aef', # Medium purple
'purple_3': '#ceb5f5', # Light purple (lightest allowable purple)
# Grays (other models)
'gray_1': '#808080', # Darkest gray (use as a last resort)
'gray_2': '#b3b3b3', # Medium gray (start with this as the darkest when possible)
'gray_3': '#e6e6e6', # Lightest gray
'gray_4': '#ffffff', # It's actually just white (use as a last resort)
# If all other options are exhausted
'cyan_1': '#0096b4', # Darkest teal
'cyan_2': '#28bed2', # Medium cyan
'cyan_3': '#8cdceb', # Lightest cyan
'magenta_1': '#b428a0', # Darkest magenta
'magenta_2': '#dc50be', # Medium pink
'magenta_3': '#f5a0dc', # Lightest pink
'yellow_1': '#c8aa00', # Darkest yellow
'yellow_2': '#ffd200', # Medium yellow
'yellow_3': '#fff08c', # Lightest yellow
}
MODEL_COLORS = {
"NTv3 650M (post)": COLORS['blue_0'],
'NTv3 650M (pre)': COLORS['blue_1'], # #3973fc (Darkest blue)
'NTv3 100M (pre)': COLORS['blue_2'], # #7ea4fc (Medium blue)
'NTv3 8M (pre)': COLORS['blue_3'], # #c3d5fc (Light blue)
'Evo2 1B': COLORS['green_3'], # #b3b3b3 (Medium gray)
"NTv2 500M": COLORS['gray_1'],
"BPNet arch. 6M": COLORS['cyan_1'],
"Residual CNN 44M": COLORS['magenta_1'],
"PlantCAD2 88M": COLORS["purple_1"],
"Caduceus 7M": COLORS["purple_2"]
}
MODEL_NAMES = list(MODEL_COLORS.keys())
PLANT_SPECIES = ["tomato", "rice", "maize", "arabidopsis"]
ANIMAL_SPECIES = ["human", "chicken", "cattle"]
SPECIES_GROUPS = {
"Plants": PLANT_SPECIES,
"Animals": ANIMAL_SPECIES, # (your code calls these HUMAN_SPECIES, but they’re the “animal” set)
}
_LAST_UPDATED = "Dec 10, 2025"
_INTRO = """
Benchmark across gene annotation and functionnal tracks.
- **Pearson correlations (multi-assay)**: per-dataset scores across species and models.
- **MCC (bed tracks)**: per-track MCC values across species and models.
These tasks measure the model's ability the generalize to unseen tracks, species and assay types.
"""
HERE = os.path.dirname(os.path.abspath(__file__)) # /app/src
PROJECT_ROOT = os.path.dirname(HERE) # /app
DATA_DIR = os.path.join(PROJECT_ROOT, "data")
PEARSON_PATH = os.path.join(DATA_DIR, "bigwig_dataset.csv")
MCC_PATH = os.path.join(DATA_DIR, "bed_dataset.csv")
# ---------------------------------------------------------------------
# Data loading & preprocessing
# ---------------------------------------------------------------------
@st.cache_data
def load_raw_data():
pearson_df = pd.read_csv(PEARSON_PATH)
mcc_df = pd.read_csv(MCC_PATH)
pearson_df.columns = [c.strip() for c in pearson_df.columns]
mcc_df.columns = [c.strip() for c in mcc_df.columns]
return pearson_df, mcc_df
@st.cache_data
def load_expanded_data():
"""
Load data in the new format where each row is already:
(species, [assay_type], datasets, model_name, metric)
and convert into a unified schema:
species, assay_type?, datasets, Model, Score
For Pearson:
If multiple rows share (species, assay_type, datasets, Model),
we average their Score.
For MCC:
If multiple rows share (species, datasets, Model),
we average their Score.
"""
pearson_df, mcc_df = load_raw_data()
# --- Pearson correlations ---
# Expect columns: species, assay_type, datasets, model_name, pearson correlation
pearson_df = pearson_df.rename(
columns={
"model_name": "Model",
"pearson correlation": "Score",
}
)
pearson_group_cols = ["species", "datasets", "Model"]
if "assay_type" in pearson_df.columns:
pearson_group_cols.append("assay_type")
pearson_df = (
pearson_df
.groupby(pearson_group_cols, as_index=False, dropna=False)["Score"]
.mean()
)
# --- MCC (bed tracks) ---
# Expect columns: species, datasets, model_name, MCC
mcc_df = mcc_df.rename(
columns={
"model_name": "Model",
"MCC": "Score",
}
)
# Collapse duplicates with same (species, datasets, Model)
mcc_group_cols = ["species", "datasets", "Model"]
mcc_df = (
mcc_df
.groupby(mcc_group_cols, as_index=False, dropna=False)["Score"]
.mean()
)
# Optional sanity checks
for df_name, df in [("pearson", pearson_df), ("mcc", mcc_df)]:
required = {"species", "datasets", "Model", "Score"}
missing = required - set(df.columns)
if missing:
st.error(f"{df_name} dataframe missing columns: {missing}")
return pearson_df, mcc_df
_PEARSON_DF, _MCC_DF = load_expanded_data()
# Global sets (we'll further filter per-benchmark below)
_ALL_SPECIES = sorted(
set(_PEARSON_DF["species"].unique()).union(_MCC_DF["species"].unique())
)
_ALL_ASSAYS = (
sorted(_PEARSON_DF["assay_type"].dropna().unique())
if "assay_type" in _PEARSON_DF.columns
else []
)
_ALL_MODELS = MODEL_NAMES[:]
_BENCHMARKS = {
"Functional Tracks": {
"df": _PEARSON_DF,
"metric_label": "Pearson correlation",
"has_assay_type": True,
},
"Genome Annotation": {
"df": _MCC_DF,
"metric_label": "MCC",
"has_assay_type": False,
},
}
# ---------------------------------------------------------------------
# Computation helpers
# ---------------------------------------------------------------------
def filter_base_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
) -> pd.DataFrame:
cfg = _BENCHMARKS[benchmark_name]
df = cfg["df"].copy()
# Species filter
if selected_species:
df = df[df["species"].isin(selected_species)]
# Assay type filter (Pearson only)
if cfg.get("has_assay_type", False) and selected_assays and "assay_type" in df.columns:
df = df[df["assay_type"].isin(selected_assays)]
# Dataset / bed track filter (for MCC, but safe to apply generally)
if selected_datasets and "datasets" in df.columns:
df = df[df["datasets"].isin(selected_datasets)]
# Model filter
if selected_models:
df = df[df["Model"].isin(selected_models)]
return df
def build_leaderboard(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
) -> pd.DataFrame:
df = filter_base_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
)
if df.empty:
return pd.DataFrame(columns=["Model", "Num entries", "Mean score"])
agg = (
df.groupby("Model")["Score"]
.mean()
.reset_index()
.rename(columns={"Score": "Mean score"})
)
agg["Mean score"] = agg["Mean score"].round(3)
agg["Num entries"] = (
df.groupby("Model")["Score"].count().reindex(agg["Model"]).values
)
agg = agg.sort_values("Mean score", ascending=False).reset_index(drop=True)
agg = agg[["Model", "Num entries", "Mean score"]]
# Ensure the index starts with 1
agg.index += 1
return agg
def build_bar_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
) -> pd.DataFrame:
"""For now, just one bar per model (same as leaderboard)."""
return build_leaderboard(
benchmark_name, selected_species, selected_assays, selected_models, selected_datasets
)
def build_category_model_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
) -> pd.DataFrame:
"""
Mean score per (category, Model) after applying the same filters.
Category = assay_type (Functional Tracks) or datasets (Genome Annotation).
"""
cfg = _BENCHMARKS[benchmark_name]
df = filter_base_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
)
if df.empty:
return pd.DataFrame(columns=["Category", "Model", "Mean score"])
# Pick the right breakdown column
if cfg.get("has_assay_type", False) and "assay_type" in df.columns:
category_col = "assay_type"
category_label = "Assay type"
else:
category_col = "datasets"
category_label = "Dataset"
if category_col not in df.columns:
return pd.DataFrame(columns=["Category", "Model", "Mean score"])
out = (
df.groupby([category_col, "Model"], as_index=False)["Score"]
.mean()
.rename(columns={category_col: "Category", "Score": "Mean score"})
)
out["Mean score"] = out["Mean score"].round(3)
out.attrs["category_label"] = category_label # for nicer axis title
return out
def plot_breakdown_facets_sorted_models(
breakdown_df: pd.DataFrame,
metric_label: str,
height: int = 420,
):
categories = list(breakdown_df["Category"].dropna().unique())
categories = sorted(categories)
n = len(categories)
if n == 0:
return None
rows = 1
cols = n # 👈 everything in one row
# Global y-range (consistent scale)
y_min = breakdown_df["Mean score"].min()
y_max = breakdown_df["Mean score"].max()
pad = 0.05 * (y_max - y_min if y_max > y_min else 1.0)
y_range = [y_min - pad, y_max + pad]
fig = make_subplots(
rows=rows,
cols=cols,
subplot_titles=categories,
shared_yaxes=True,
horizontal_spacing=0.04, # tighter spacing
)
for i, cat in enumerate(categories):
r = (i // cols) + 1
c = (i % cols) + 1
sub = (
breakdown_df[breakdown_df["Category"] == cat]
.sort_values("Mean score", ascending=True)
)
fig.add_trace(
go.Bar(
x=sub["Model"],
y=sub["Mean score"],
marker_color=[MODEL_COLORS.get(m, "#808080") for m in sub["Model"]],
showlegend=False,
),
row=r,
col=c,
)
fig.update_xaxes(showticklabels=False, title_text="", row=r, col=c)
fig.update_yaxes(range=y_range, title_text="", row=r, col=c) # 👈 apply range
fig.update_layout(
height=height,
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
margin=dict(t=60, l=10, r=10, b=10),
)
# Single y-axis label on the leftmost panel
fig.update_yaxes(title_text=metric_label, row=1, col=1)
return fig
def build_radar_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
) -> pd.DataFrame:
cfg = _BENCHMARKS[benchmark_name]
df = filter_base_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
)
if df.empty:
return pd.DataFrame()
# Choose axis column
if cfg.get("has_assay_type", False) and "assay_type" in df.columns:
axis_col = "assay_type"
axis_label = "Assay type"
else:
axis_col = "datasets"
axis_label = "Dataset"
radar_df = (
df.groupby([axis_col, "Model"], as_index=False)["Score"]
.mean()
.rename(columns={axis_col: "Axis", "Score": "Value"})
)
radar_df.attrs["axis_label"] = axis_label
return radar_df
def build_pairwise_scatter_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
model_a: str,
model_b: str,
) -> pd.DataFrame:
"""
Returns a per-track dataframe with columns:
Track, Model A, Model B, (optional) species, (optional) assay_type, datasets
Where each row corresponds to a specific track (datasets [+ assay_type]).
"""
cfg = _BENCHMARKS[benchmark_name]
# Filter using the same UI toggles, but ensure the chosen models are included
models_for_filter = list(set(selected_models + [model_a, model_b])) if selected_models else [model_a, model_b]
df = filter_base_df(
benchmark_name,
selected_species,
selected_assays,
models_for_filter,
selected_datasets,
)
if df.empty:
return pd.DataFrame()
# Define what “a specific track” means
track_cols = ["datasets"]
if cfg.get("has_assay_type", False) and "assay_type" in df.columns:
track_cols = ["assay_type", "datasets"]
# (Optional) keep species in hover if multiple are selected
keep_species = "species" in df.columns and (selected_species is None or len(selected_species) != 1)
id_cols = (["species"] if keep_species else []) + track_cols
# Pivot into two model columns
wide = (
df[df["Model"].isin([model_a, model_b])]
.pivot_table(index=id_cols, columns="Model", values="Score", aggfunc="mean")
.reset_index()
)
# Require both values to exist for a dot
if model_a not in wide.columns or model_b not in wide.columns:
return pd.DataFrame()
wide = wide.dropna(subset=[model_a, model_b])
# Nice “Track” label for display
if "assay_type" in wide.columns:
wide["Track"] = wide["assay_type"].astype(str) + " / " + wide["datasets"].astype(str)
else:
wide["Track"] = wide["datasets"].astype(str)
# Rename for plotting
wide = wide.rename(columns={model_a: "Model A", model_b: "Model B"})
return wide
def build_violin_df(
benchmark_name: str,
selected_species: List[str],
selected_assays: List[str],
selected_models: List[str],
selected_datasets: List[str],
) -> pd.DataFrame:
# Use the same base filtering, but keep all per-track rows
df = filter_base_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
)
# Keep only needed columns
keep = ["Model", "Score"]
for col in ["species", "assay_type", "datasets"]:
if col in df.columns:
keep.append(col)
return df[keep].copy()
def plot_radar(
radar_df: pd.DataFrame,
metric_label: str,
height: int = 600,
):
if radar_df.empty:
return None
axes = radar_df["Axis"].unique().tolist()
# Global radial range
r_min = radar_df["Value"].min()
r_max = radar_df["Value"].max()
pad = 0.05 * (r_max - r_min if r_max > r_min else 1.0)
r_range = [r_min - pad, r_max + pad]
fig = go.Figure()
for model in radar_df["Model"].unique():
sub = radar_df[radar_df["Model"] == model]
# Ensure consistent axis ordering
sub = sub.set_index("Axis").reindex(axes)
fig.add_trace(
go.Scatterpolar(
r=sub["Value"],
theta=axes,
fill="toself",
name=model,
line_color=MODEL_COLORS.get(model),
opacity=0.75,
)
)
fig.update_layout(
height=height,
polar=dict(
bgcolor="rgba(0,0,0,0)", # 👈 polar background
radialaxis=dict(
title=metric_label,
range=r_range,
tickformat=".2f",
showgrid=True,
gridcolor="rgba(0,0,0,0.15)", # subtle grid
),
angularaxis=dict(
showgrid=True,
gridcolor="rgba(0,0,0,0.15)",
),
),
paper_bgcolor="rgba(0,0,0,0)", # 👈 entire figure background
plot_bgcolor="rgba(0,0,0,0)", # 👈 plot area
showlegend=True,
legend_title_text="Model",
margin=dict(t=40, b=40, l=40, r=40),
)
return fig
# ---------------------------------------------------------------------
# UI helpers
# ---------------------------------------------------------------------
def sidebar_toggle(label: str, value: bool = False, key: str | None = None) -> bool:
"""
Wrapper to use st.sidebar.toggle when available, otherwise fall back to checkbox.
This makes the app compatible with older Streamlit versions on Hugging Face.
"""
toggle_fn = getattr(st.sidebar, "toggle", None)
if toggle_fn is not None:
return toggle_fn(label, value=value, key=key)
# Fallback for older Streamlit versions
return st.sidebar.checkbox(label, value=value, key=key)
# ---------------------------------------------------------------------
# Streamlit UI
# ---------------------------------------------------------------------
def main():
st.title("🧬 NTv3 Benchmark")
st.markdown(_INTRO)
st.markdown(f"_Last updated: **{_LAST_UPDATED}**_")
# --- Sidebar filters ---
st.sidebar.header("Filters")
# Benchmark
benchmark_name = st.sidebar.selectbox(
"Benchmark",
options=list(_BENCHMARKS.keys()),
index=0,
)
cfg = _BENCHMARKS[benchmark_name]
df_bench = cfg["df"]
# Species toggles, but only for species present in this benchmark
st.sidebar.subheader("Species")
# Toggle: Plants vs Animals
species_group = st.sidebar.radio(
"Group",
options=["Animals", "Plants"],
index=0,
horizontal=True,
key=f"species_group_{benchmark_name}",
)
available_species_all = sorted(df_bench["species"].unique())
allowed_species = set(SPECIES_GROUPS[species_group]).intersection(available_species_all)
available_species = sorted(allowed_species)
selected_species: List[str] = []
for sp in available_species:
if sidebar_toggle(sp, value=True, key=f"species_{benchmark_name}_{species_group}_{sp}"):
selected_species.append(sp)
# (Optional) If no species exist for that group in this benchmark
if not available_species:
st.sidebar.info(f"No {species_group.lower()} species available for this benchmark.")
# Assay toggles (Pearson only), based on filtered species
if cfg.get("has_assay_type", False):
st.sidebar.subheader("Assay types")
if selected_species:
df_for_assays = df_bench[df_bench["species"].isin(selected_species)]
else:
df_for_assays = df_bench
available_assays = (
sorted(df_for_assays["assay_type"].dropna().unique())
if "assay_type" in df_for_assays.columns
else []
)
selected_assays: List[str] = []
for assay in available_assays:
if sidebar_toggle(assay, value=True, key=f"assay_{benchmark_name}_{assay}"):
selected_assays.append(assay)
else:
selected_assays = []
# Bed track / dataset toggles (MCC only), based on species selection
selected_datasets: List[str] = []
if benchmark_name == "Genome Annotation":
st.sidebar.subheader("Genome annotations")
if selected_species:
df_for_tracks = df_bench[df_bench["species"].isin(selected_species)]
else:
df_for_tracks = df_bench
available_datasets = sorted(df_for_tracks["datasets"].unique())
for ds in available_datasets:
if sidebar_toggle(ds, value=True, key=f"dataset_{benchmark_name}_{ds}"):
selected_datasets.append(ds)
else:
selected_datasets = []
# Model toggles (we keep all models in MODEL_NAMES; filters + data will prune)
st.sidebar.subheader("Models")
selected_models: List[str] = []
for model in _ALL_MODELS:
if sidebar_toggle(model, value=True, key=f"model_{model}"):
selected_models.append(model)
# --- Main content ---
leaderboard_df = build_leaderboard(
benchmark_name, selected_species, selected_assays, selected_models, selected_datasets
)
bar_df = build_bar_df(
benchmark_name, selected_species, selected_assays, selected_models, selected_datasets
)
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("🏅 Leaderboard (per model)")
st.write("\n") # 👈 spacer to match plotly padding
st.write("\n")
st.write("\n")
if leaderboard_df.empty:
st.info("No data for the selected filters.")
else:
st.dataframe(leaderboard_df, use_container_width=True)
with col2:
st.subheader("📈 Mean score per model")
if bar_df.empty:
st.info("No data for the selected filters.")
else:
# Order models by performance (least -> most)
bar_df = bar_df.sort_values("Mean score", ascending=True)
model_order = bar_df["Model"].tolist()
fig = px.bar(
bar_df,
x="Model",
y="Mean score",
color="Model",
color_discrete_map=MODEL_COLORS,
category_orders={"Model": model_order}, # enforce ordering on x
)
fig.update_layout(
barmode="group",
height=500,
xaxis_title="",
yaxis_title="Mean score",
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
bargap=0.08,
)
# Hide x-axis model names (same style as the panels)
fig.update_xaxes(showticklabels=False)
st.plotly_chart(fig, use_container_width=True)
# --- Breakdown plot: assay_type (Functional Tracks) OR datasets (Genome Annotation) ---
breakdown_df = build_category_model_df(
benchmark_name, selected_species, selected_assays, selected_models, selected_datasets
)
st.subheader("🧪 Mean score by assay type / dataset (all models)")
if breakdown_df.empty:
st.info("No data for the selected filters.")
else:
fig_breakdown = plot_breakdown_facets_sorted_models(
breakdown_df,
metric_label=cfg["metric_label"],
height=300,
)
st.plotly_chart(fig_breakdown, use_container_width=True)
st.subheader("🕸️ Performance by assay type / dataset (radar)")
radar_df = build_radar_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
)
if radar_df.empty:
st.info("No data for the selected filters.")
else:
fig_radar = plot_radar(
radar_df,
metric_label=cfg["metric_label"],
)
st.plotly_chart(fig_radar, use_container_width=True)
st.subheader("⚖️ Model comparison")
left, right = st.columns([1, 1], gap="large")
with left:
st.markdown("#### Head-to-head (per track)")
model_picker_options = selected_models if selected_models else _ALL_MODELS
default_a = model_picker_options[0] if model_picker_options else _ALL_MODELS[0]
default_b = model_picker_options[1] if len(model_picker_options) > 1 else (
_ALL_MODELS[1] if len(_ALL_MODELS) > 1 else default_a
)
cA, cB = st.columns([1, 1])
with cA:
model_a = st.selectbox(
"Model A (y-axis)",
options=model_picker_options,
index=model_picker_options.index(default_a) if default_a in model_picker_options else 0,
key=f"pair_model_a_{benchmark_name}",
)
with cB:
b_options = [m for m in model_picker_options if m != model_a] or model_picker_options
model_b = st.selectbox(
"Model B (x-axis)",
options=b_options,
index=0,
key=f"pair_model_b_{benchmark_name}",
)
scatter_df = build_pairwise_scatter_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
model_a,
model_b,
)
if scatter_df.empty:
st.info("No overlapping tracks for the selected filters (or one model is missing values).")
else:
min_v = float(min(scatter_df["Model A"].min(), scatter_df["Model B"].min()))
max_v = float(max(scatter_df["Model A"].max(), scatter_df["Model B"].max()))
pad = 0.05 * (max_v - min_v if max_v > min_v else 1.0)
axis_range = [min_v - pad, max_v + pad]
tick_step = (axis_range[1] - axis_range[0]) / 5
hover_cols = ["Track"]
for c in ["species", "assay_type", "datasets"]:
if c in scatter_df.columns:
hover_cols.append(c)
# Model A on Y, Model B on X
fig_scatter = px.scatter(
scatter_df,
x="Model B",
y="Model A",
hover_name="Track",
hover_data=hover_cols,
)
# Red diagonal y=x
fig_scatter.add_shape(
type="line",
x0=axis_range[0], y0=axis_range[0],
x1=axis_range[1], y1=axis_range[1],
xref="x", yref="y",
line=dict(color="red", dash="dot", width=2),
)
# Square + identical scale/ticks (works even with use_container_width=True)
fig_scatter.update_layout(
height=550,
margin=dict(l=60, r=20, t=20, b=60),
xaxis=dict(
title=f"{model_b}{cfg['metric_label']}",
range=axis_range,
dtick=tick_step,
constrain="domain",
),
yaxis=dict(
title=f"{model_a}{cfg['metric_label']}",
range=axis_range,
dtick=tick_step,
scaleanchor="x", # lock y to x
scaleratio=1,
constrain="domain",
),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)
st.plotly_chart(fig_scatter, use_container_width=True)
with right:
st.markdown("#### All models (distribution across tracks)")
violin_df = build_violin_df(
benchmark_name,
selected_species,
selected_assays,
selected_models,
selected_datasets,
)
if violin_df.empty:
st.info("No data for the selected filters.")
else:
# Order models by median performance (least -> most)
model_order = (
violin_df
.groupby("Model")["Score"]
.median()
.sort_values(ascending=True)
.index
.tolist()
)
fig_violin = px.violin(
violin_df,
x="Model",
y="Score",
color="Model",
color_discrete_map=MODEL_COLORS,
box=True, # keep inner boxplot
points=False, # 👈 remove all dots
category_orders={"Model": model_order}, # 👈 enforce ordering
)
fig_violin.update_layout(
height=650,
xaxis_title="",
yaxis_title=cfg["metric_label"],
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
showlegend=False,
)
fig_violin.update_traces(
box_visible=True,
meanline_visible=False,
)
# Optional: hide model names if you prefer a cleaner look
# fig_violin.update_xaxes(showticklabels=False)
st.plotly_chart(fig_violin, use_container_width=True)
if __name__ == "__main__":
main()