jesseab's picture
Code changes
09ace6e
import gradio as gr
import pandas as pd
import numpy as np
import json
from pathlib import Path
def load_css():
with open('css.css', 'r', encoding='utf-8') as f:
return f.read()
def load_and_prepare_data(json_path="model_evals.json"):
df = pd.read_json(json_path)
# Flatten downstream dictionaries
df["train_mae_age"] = df["downstream_train"].apply(lambda x: x.get("mae_age") if isinstance(x, dict) else None)
df["val_mae_age"] = df["downstream_val"].apply(lambda x: x.get("mae_age") if isinstance(x, dict) else None)
df["test_mae_age"] = df["downstream_test"].apply(lambda x: x.get("mae_age") if isinstance(x, dict) else None)
df["train_acc_sex"] = df["downstream_train"].apply(lambda x: x.get("acc_sex") if isinstance(x, dict) else None)
df["val_acc_sex"] = df["downstream_val"].apply(lambda x: x.get("acc_sex") if isinstance(x, dict) else None)
df["test_acc_sex"] = df["downstream_test"].apply(lambda x: x.get("acc_sex") if isinstance(x, dict) else None)
df["train_acc_sex_female"] = df["downstream_train"].apply(
lambda x: x.get("sex_per_class_acc", {}).get("female") if isinstance(x, dict) else None)
df["val_acc_sex_female"] = df["downstream_val"].apply(
lambda x: x.get("sex_per_class_acc", {}).get("female") if isinstance(x, dict) else None)
df["test_acc_sex_female"] = df["downstream_test"].apply(
lambda x: x.get("sex_per_class_acc", {}).get("female") if isinstance(x, dict) else None)
df["train_acc_sex_male"] = df["downstream_train"].apply(
lambda x: x.get("sex_per_class_acc", {}).get("male") if isinstance(x, dict) else None)
df["val_acc_sex_male"] = df["downstream_val"].apply(
lambda x: x.get("sex_per_class_acc", {}).get("male") if isinstance(x, dict) else None)
df["test_acc_sex_male"] = df["downstream_test"].apply(
lambda x: x.get("sex_per_class_acc", {}).get("male") if isinstance(x, dict) else None)
df["train_acc_diag"] = df["downstream_train"].apply(lambda x: x.get("acc_diag") if isinstance(x, dict) else None)
df["val_acc_diag"] = df["downstream_val"].apply(lambda x: x.get("acc_diag") if isinstance(x, dict) else None)
df["test_acc_diag"] = df["downstream_test"].apply(lambda x: x.get("acc_diag") if isinstance(x, dict) else None)
df["train_acc_diag_Alzheimers_disease"] = df["downstream_train"].apply(
lambda x: x.get("diag_per_class_acc", {}).get("Alzheimers_disease") if isinstance(x, dict) else None)
df["val_acc_diag_Alzheimers_disease"] = df["downstream_val"].apply(
lambda x: x.get("diag_per_class_acc", {}).get("Alzheimers_disease") if isinstance(x, dict) else None)
df["test_acc_diag_Alzheimers_disease"] = df["downstream_test"].apply(
lambda x: x.get("diag_per_class_acc", {}).get("Alzheimers_disease") if isinstance(x, dict) else None)
df["train_acc_diag_cognitively_normal"] = df["downstream_train"].apply(
lambda x: x.get("diag_per_class_acc", {}).get("cognitively_normal") if isinstance(x, dict) else None)
df["val_acc_diag_cognitively_normal"] = df["downstream_val"].apply(
lambda x: x.get("diag_per_class_acc", {}).get("cognitively_normal") if isinstance(x, dict) else None)
df["test_acc_diag_cognitively_normal"] = df["downstream_test"].apply(
lambda x: x.get("diag_per_class_acc", {}).get("cognitively_normal") if isinstance(x, dict) else None)
# Create a "Model" column as link data
def make_model_link(row):
nickname = row.get("model_nickname", "Unknown")
hf_path = row.get("model_hf_path", "#")
# We'll store it as "Nickname"||URL, to parse in JS
return f'"{nickname}"||{hf_path}'
df["Model"] = df.apply(make_model_link, axis=1)
# Compute ranks
df = compute_ranks(df)
# Replace None -> NaN for numeric fields
df = df.replace({None: np.nan})
return df
def compute_ranks(df):
# Reconstruction
recon_lower_is_better = ["test_l1", "test_perceptual"]
recon_higher_is_better = ["test_ssim", "test_psnr"]
# Downstream
downstream_lower_is_better = ["test_mae_age"]
downstream_higher_is_better = [
"test_acc_sex", "test_acc_sex_female", "test_acc_sex_male",
"test_acc_diag", "test_acc_diag_Alzheimers_disease", "test_acc_diag_cognitively_normal"
]
# IDP => skip recon
recon_mask = (df["model_type"] != "IDP")
def rank_col(col, ascending, mask=None):
if col not in df.columns:
return
if mask is not None:
sub = df.loc[mask, col].rank(method="dense", ascending=ascending)
df.loc[mask, col + "_rank"] = sub
else:
df[col + "_rank"] = df[col].rank(method="dense", ascending=ascending)
for c in recon_lower_is_better:
rank_col(c, True, recon_mask)
for c in recon_higher_is_better:
rank_col(c, False, recon_mask)
for c in downstream_lower_is_better:
rank_col(c, True)
for c in downstream_higher_is_better:
rank_col(c, False)
def avg_rank(row, col_list):
vals = []
for col_ in col_list:
rcol = col_ + "_rank"
if rcol in row and pd.notna(row[rcol]):
vals.append(row[rcol])
return np.mean(vals) if vals else np.nan
recon_cols = recon_lower_is_better + recon_higher_is_better
df["Recon rank"] = df.apply(lambda r: avg_rank(r, recon_cols), axis=1)
down_cols = downstream_lower_is_better + downstream_higher_is_better
df["Downstream rank"] = df.apply(lambda r: avg_rank(r, down_cols), axis=1)
def overall_rank(r):
if pd.isna(r["Recon rank"]) or pd.isna(r["Downstream rank"]):
return np.nan
return (r["Recon rank"] + r["Downstream rank"]) / 2
df["Overall rank"] = df.apply(overall_rank, axis=1)
return df
# The columns in a fixed order
DEFAULT_COLUMN_ORDER = [
"Model",
"model_type",
"Overall rank",
"Recon rank",
"train_l1", "val_l1", "test_l1",
"train_perceptual", "val_perceptual", "test_perceptual",
"train_ssim", "val_ssim", "test_ssim",
"train_psnr", "val_psnr", "test_psnr",
"Downstream rank",
"train_mae_age", "val_mae_age", "test_mae_age",
"train_acc_sex", "val_acc_sex", "test_acc_sex",
"train_acc_sex_female", "val_acc_sex_female", "test_acc_sex_female",
"train_acc_sex_male", "val_acc_sex_male", "test_acc_sex_male",
"train_acc_diag", "val_acc_diag", "test_acc_diag",
"train_acc_diag_Alzheimers_disease", "val_acc_diag_Alzheimers_disease", "test_acc_diag_Alzheimers_disease",
"train_acc_diag_cognitively_normal", "val_acc_diag_cognitively_normal", "test_acc_diag_cognitively_normal",
]
def build_app():
css_str = load_css()
df = load_and_prepare_data()
# Convert each row to dict with 4 sig digits for numeric
data_list = []
for _, row in df.iterrows():
row_dict = {}
for col in DEFAULT_COLUMN_ORDER:
val = row.get(col, np.nan)
if pd.isna(val):
row_dict[col] = None
elif isinstance(val, (int, float)):
val_str = f"{val:.4g}"
# If "9999", interpret as missing:
row_dict[col] = float(val_str) if val_str != "9999" else None
else:
row_dict[col] = str(val)
data_list.append(row_dict)
table_data_json = json.dumps(data_list)
col_order_json = json.dumps(DEFAULT_COLUMN_ORDER)
# HTML for scoreboard
scoreboard_html = """
<div id="leaderboard-container" style="overflow-x:auto;">
<table id="leaderboard" class="custom-board">
<thead>
<!-- Top row -->
<tr>
<!-- 1) Model pinned (sticky-col pinned-col) -->
<th class="sticky-col pinned-col sortable" data-col="Model" rowspan="3">
Model <span class="sort-arrow">▲</span>
</th>
<!-- 2) Model Type (not pinned) -->
<th data-col="model_type" class="sortable" rowspan="3">
Model Type <span class="sort-arrow">▲</span>
</th>
<!-- 3) Overall rank (not pinned) -->
<th data-col="Overall rank" class="sortable" rowspan="3">
Overall Rank <span class="sort-arrow">▲</span>
</th>
<!-- Next columns stay the same order -->
<th data-col="Recon rank" class="sortable" rowspan="3">
Recon Rank <span class="sort-arrow">▲</span>
</th>
<th colspan="12">Reconstruction Error</th>
<th data-col="Downstream rank" class="sortable" rowspan="3">
Downstream Rank <span class="sort-arrow">▲</span>
</th>
<th colspan="21">Downstream Models</th>
</tr>
<!-- 2nd row -->
<tr>
<th colspan="3">L1</th>
<th colspan="3">Perceptual</th>
<th colspan="3">SSIM</th>
<th colspan="3">PSNR</th>
<th colspan="3">Age (MAE)</th>
<th colspan="3">Sex (Overall)</th>
<th colspan="3">Sex (Female)</th>
<th colspan="3">Sex (Male)</th>
<th colspan="3">Diagnosis (Overall)</th>
<th colspan="3">Diagnosis (Alzheimers_disease)</th>
<th colspan="3">Diagnosis (cognitively_normal)</th>
</tr>
<!-- 3rd row -->
<tr>
<th data-col="train_l1" class="sortable">Train <span class="sort-arrow">▲</span></th>
<th data-col="val_l1" class="sortable">Val <span class="sort-arrow">▲</span></th>
<th data-col="test_l1" class="sortable">Test <span class="sort-arrow">▲</span></th>
<th data-col="train_perceptual" class="sortable">Train <span class="sort-arrow">▲</span></th>
<th data-col="val_perceptual" class="sortable">Val <span class="sort-arrow">▲</span></th>
<th data-col="test_perceptual" class="sortable">Test <span class="sort-arrow">▲</span></th>
<th data-col="train_ssim" class="sortable">Train <span class="sort-arrow">▲</span></th>
<th data-col="val_ssim" class="sortable">Val <span class="sort-arrow">▲</span></th>
<th data-col="test_ssim" class="sortable">Test <span class="sort-arrow">▲</span></th>
<th data-col="train_psnr" class="sortable">Train <span class="sort-arrow">▲</span></th>
<th data-col="val_psnr" class="sortable">Val <span class="sort-arrow">▲</span></th>
<th data-col="test_psnr" class="sortable">Test <span class="sort-arrow">▲</span></th>
<th data-col="train_mae_age" class="sortable">Train <span class="sort-arrow">▲</span></th>
<th data-col="val_mae_age" class="sortable">Val <span class="sort-arrow">▲</span></th>
<th data-col="test_mae_age" class="sortable">Test <span class="sort-arrow">▲</span></th>
<th data-col="train_acc_sex" class="sortable">Train <span class="sort-arrow">▲</span></th>
<th data-col="val_acc_sex" class="sortable">Val <span class="sort-arrow">▲</span></th>
<th data-col="test_acc_sex" class="sortable">Test <span class="sort-arrow">▲</span></th>
<th data-col="train_acc_sex_female" class="sortable">Train <span class="sort-arrow">▲</span></th>
<th data-col="val_acc_sex_female" class="sortable">Val <span class="sort-arrow">▲</span></th>
<th data-col="test_acc_sex_female" class="sortable">Test <span class="sort-arrow">▲</span></th>
<th data-col="train_acc_sex_male" class="sortable">Train <span class="sort-arrow">▲</span></th>
<th data-col="val_acc_sex_male" class="sortable">Val <span class="sort-arrow">▲</span></th>
<th data-col="test_acc_sex_male" class="sortable">Test <span class="sort-arrow">▲</span></th>
<th data-col="train_acc_diag" class="sortable">Train <span class="sort-arrow">▲</span></th>
<th data-col="val_acc_diag" class="sortable">Val <span class="sort-arrow">▲</span></th>
<th data-col="test_acc_diag" class="sortable">Test <span class="sort-arrow">▲</span></th>
<th data-col="train_acc_diag_Alzheimers_disease" class="sortable">Train <span class="sort-arrow">▲</span></th>
<th data-col="val_acc_diag_Alzheimers_disease" class="sortable">Val <span class="sort-arrow">▲</span></th>
<th data-col="test_acc_diag_Alzheimers_disease" class="sortable">Test <span class="sort-arrow">▲</span></th>
<th data-col="train_acc_diag_cognitively_normal" class="sortable">Train <span class="sort-arrow">▲</span></th>
<th data-col="val_acc_diag_cognitively_normal" class="sortable">Val <span class="sort-arrow">▲</span></th>
<th data-col="test_acc_diag_cognitively_normal" class="sortable">Test <span class="sort-arrow">▲</span></th>
</tr>
</thead>
<tbody id="leaderboard-body"></tbody>
</table>
</div>
<script>
window.tableData = """ + table_data_json + """;
window.defaultColumnOrder = """ + col_order_json + """;
// Default sort by Overall rank ascending
window.currentSort = { col: "Overall rank", dir: "asc" };
function initTable() {
// sort first
sortData();
// then render + setup
renderTableBody();
setupSortableHeaders();
updateHeaderArrows();
}
function setupSortableHeaders() {
const thEls = document.querySelectorAll('#leaderboard th.sortable');
thEls.forEach(th => {
th.addEventListener('click', () => {
const col = th.dataset.col;
if (window.currentSort.col === col) {
window.currentSort.dir = (window.currentSort.dir === 'asc') ? 'desc' : 'asc';
} else {
window.currentSort.col = col;
window.currentSort.dir = 'asc';
}
sortData();
renderTableBody();
updateHeaderArrows();
});
});
}
function sortData() {
const col = window.currentSort.col;
if (!col) return;
const asc = (window.currentSort.dir === 'asc');
window.tableData.sort((a, b) => {
const aVal = (a[col] === null) ? null : a[col];
const bVal = (b[col] === null) ? null : b[col];
// null => sorted last
if (aVal === null && bVal === null) return 0;
if (aVal === null) return 1;
if (bVal === null) return -1;
if (typeof aVal === 'number' && typeof bVal === 'number') {
return asc ? aVal - bVal : bVal - aVal;
} else {
const av = String(aVal);
const bv = String(bVal);
if (av === bv) return 0;
return asc ? (av < bv ? -1 : 1) : (av > bv ? -1 : 1);
}
});
}
function updateHeaderArrows() {
// reset all arrow icons
const allArrows = document.querySelectorAll('#leaderboard th.sortable .sort-arrow');
allArrows.forEach(ar => {
ar.style.color = '#aaa';
ar.textContent = '▲';
});
if (window.currentSort.col) {
const selector = '#leaderboard th.sortable[data-col="'+window.currentSort.col+'"] .sort-arrow';
const target = document.querySelector(selector);
if (target) {
target.style.color = 'orange';
if (window.currentSort.dir === 'desc') {
target.textContent = '▼';
} else {
target.textContent = '▲';
}
}
}
}
function renderTableBody() {
const tbody = document.getElementById('leaderboard-body');
tbody.innerHTML = '';
window.tableData.forEach(row => {
const tr = document.createElement('tr');
window.defaultColumnOrder.forEach(col => {
const val = row[col];
const td = document.createElement('td');
// if col is pinned => add pinned classes to td
if (col === "Model") {
td.classList.add("sticky-col", "pinned-col");
}
if (val === null) {
td.textContent = 'N/A';
} else if (col === 'Model') {
// stored as "nickname"||URL
const parts = val.split('||');
let nickname = parts[0] || '"Unknown"';
if (nickname.startsWith('"') && nickname.endsWith('"')) {
nickname = nickname.slice(1, -1);
}
const href = parts[1] || '#';
td.innerHTML = `<a href="${href}" target="_blank">${nickname}</a>`;
} else {
td.textContent = val;
}
tr.appendChild(td);
});
tbody.appendChild(tr);
});
}
setTimeout(initTable, 50);
</script>
"""
# Read docs.md
docs_path = Path("docs.md")
docs_content = ""
if docs_path.is_file():
docs_content = docs_path.read_text(encoding="utf-8")
with gr.Blocks(css=css_str) as demo:
gr.HTML(
"<div id='leaderboard-heading'>🧠↗️ Brain2vec Leaderboard</div>"
)
gr.HTML(scoreboard_html)
gr.Markdown(docs_content)
return demo
if __name__ == "__main__":
app = build_app()
app.launch()