Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import json | |
| from pathlib import Path | |
| def load_css(): | |
| with open('css.css', 'r', encoding='utf-8') as f: | |
| return f.read() | |
| def load_and_prepare_data(json_path="model_evals.json"): | |
| df = pd.read_json(json_path) | |
| # Flatten downstream dictionaries | |
| df["train_mae_age"] = df["downstream_train"].apply(lambda x: x.get("mae_age") if isinstance(x, dict) else None) | |
| df["val_mae_age"] = df["downstream_val"].apply(lambda x: x.get("mae_age") if isinstance(x, dict) else None) | |
| df["test_mae_age"] = df["downstream_test"].apply(lambda x: x.get("mae_age") if isinstance(x, dict) else None) | |
| df["train_acc_sex"] = df["downstream_train"].apply(lambda x: x.get("acc_sex") if isinstance(x, dict) else None) | |
| df["val_acc_sex"] = df["downstream_val"].apply(lambda x: x.get("acc_sex") if isinstance(x, dict) else None) | |
| df["test_acc_sex"] = df["downstream_test"].apply(lambda x: x.get("acc_sex") if isinstance(x, dict) else None) | |
| df["train_acc_sex_female"] = df["downstream_train"].apply( | |
| lambda x: x.get("sex_per_class_acc", {}).get("female") if isinstance(x, dict) else None) | |
| df["val_acc_sex_female"] = df["downstream_val"].apply( | |
| lambda x: x.get("sex_per_class_acc", {}).get("female") if isinstance(x, dict) else None) | |
| df["test_acc_sex_female"] = df["downstream_test"].apply( | |
| lambda x: x.get("sex_per_class_acc", {}).get("female") if isinstance(x, dict) else None) | |
| df["train_acc_sex_male"] = df["downstream_train"].apply( | |
| lambda x: x.get("sex_per_class_acc", {}).get("male") if isinstance(x, dict) else None) | |
| df["val_acc_sex_male"] = df["downstream_val"].apply( | |
| lambda x: x.get("sex_per_class_acc", {}).get("male") if isinstance(x, dict) else None) | |
| df["test_acc_sex_male"] = df["downstream_test"].apply( | |
| lambda x: x.get("sex_per_class_acc", {}).get("male") if isinstance(x, dict) else None) | |
| df["train_acc_diag"] = df["downstream_train"].apply(lambda x: x.get("acc_diag") if isinstance(x, dict) else None) | |
| df["val_acc_diag"] = df["downstream_val"].apply(lambda x: x.get("acc_diag") if isinstance(x, dict) else None) | |
| df["test_acc_diag"] = df["downstream_test"].apply(lambda x: x.get("acc_diag") if isinstance(x, dict) else None) | |
| df["train_acc_diag_Alzheimers_disease"] = df["downstream_train"].apply( | |
| lambda x: x.get("diag_per_class_acc", {}).get("Alzheimers_disease") if isinstance(x, dict) else None) | |
| df["val_acc_diag_Alzheimers_disease"] = df["downstream_val"].apply( | |
| lambda x: x.get("diag_per_class_acc", {}).get("Alzheimers_disease") if isinstance(x, dict) else None) | |
| df["test_acc_diag_Alzheimers_disease"] = df["downstream_test"].apply( | |
| lambda x: x.get("diag_per_class_acc", {}).get("Alzheimers_disease") if isinstance(x, dict) else None) | |
| df["train_acc_diag_cognitively_normal"] = df["downstream_train"].apply( | |
| lambda x: x.get("diag_per_class_acc", {}).get("cognitively_normal") if isinstance(x, dict) else None) | |
| df["val_acc_diag_cognitively_normal"] = df["downstream_val"].apply( | |
| lambda x: x.get("diag_per_class_acc", {}).get("cognitively_normal") if isinstance(x, dict) else None) | |
| df["test_acc_diag_cognitively_normal"] = df["downstream_test"].apply( | |
| lambda x: x.get("diag_per_class_acc", {}).get("cognitively_normal") if isinstance(x, dict) else None) | |
| # Create a "Model" column as link data | |
| def make_model_link(row): | |
| nickname = row.get("model_nickname", "Unknown") | |
| hf_path = row.get("model_hf_path", "#") | |
| # We'll store it as "Nickname"||URL, to parse in JS | |
| return f'"{nickname}"||{hf_path}' | |
| df["Model"] = df.apply(make_model_link, axis=1) | |
| # Compute ranks | |
| df = compute_ranks(df) | |
| # Replace None -> NaN for numeric fields | |
| df = df.replace({None: np.nan}) | |
| return df | |
| def compute_ranks(df): | |
| # Reconstruction | |
| recon_lower_is_better = ["test_l1", "test_perceptual"] | |
| recon_higher_is_better = ["test_ssim", "test_psnr"] | |
| # Downstream | |
| downstream_lower_is_better = ["test_mae_age"] | |
| downstream_higher_is_better = [ | |
| "test_acc_sex", "test_acc_sex_female", "test_acc_sex_male", | |
| "test_acc_diag", "test_acc_diag_Alzheimers_disease", "test_acc_diag_cognitively_normal" | |
| ] | |
| # IDP => skip recon | |
| recon_mask = (df["model_type"] != "IDP") | |
| def rank_col(col, ascending, mask=None): | |
| if col not in df.columns: | |
| return | |
| if mask is not None: | |
| sub = df.loc[mask, col].rank(method="dense", ascending=ascending) | |
| df.loc[mask, col + "_rank"] = sub | |
| else: | |
| df[col + "_rank"] = df[col].rank(method="dense", ascending=ascending) | |
| for c in recon_lower_is_better: | |
| rank_col(c, True, recon_mask) | |
| for c in recon_higher_is_better: | |
| rank_col(c, False, recon_mask) | |
| for c in downstream_lower_is_better: | |
| rank_col(c, True) | |
| for c in downstream_higher_is_better: | |
| rank_col(c, False) | |
| def avg_rank(row, col_list): | |
| vals = [] | |
| for col_ in col_list: | |
| rcol = col_ + "_rank" | |
| if rcol in row and pd.notna(row[rcol]): | |
| vals.append(row[rcol]) | |
| return np.mean(vals) if vals else np.nan | |
| recon_cols = recon_lower_is_better + recon_higher_is_better | |
| df["Recon rank"] = df.apply(lambda r: avg_rank(r, recon_cols), axis=1) | |
| down_cols = downstream_lower_is_better + downstream_higher_is_better | |
| df["Downstream rank"] = df.apply(lambda r: avg_rank(r, down_cols), axis=1) | |
| def overall_rank(r): | |
| if pd.isna(r["Recon rank"]) or pd.isna(r["Downstream rank"]): | |
| return np.nan | |
| return (r["Recon rank"] + r["Downstream rank"]) / 2 | |
| df["Overall rank"] = df.apply(overall_rank, axis=1) | |
| return df | |
| # The columns in a fixed order | |
| DEFAULT_COLUMN_ORDER = [ | |
| "Model", | |
| "model_type", | |
| "Overall rank", | |
| "Recon rank", | |
| "train_l1", "val_l1", "test_l1", | |
| "train_perceptual", "val_perceptual", "test_perceptual", | |
| "train_ssim", "val_ssim", "test_ssim", | |
| "train_psnr", "val_psnr", "test_psnr", | |
| "Downstream rank", | |
| "train_mae_age", "val_mae_age", "test_mae_age", | |
| "train_acc_sex", "val_acc_sex", "test_acc_sex", | |
| "train_acc_sex_female", "val_acc_sex_female", "test_acc_sex_female", | |
| "train_acc_sex_male", "val_acc_sex_male", "test_acc_sex_male", | |
| "train_acc_diag", "val_acc_diag", "test_acc_diag", | |
| "train_acc_diag_Alzheimers_disease", "val_acc_diag_Alzheimers_disease", "test_acc_diag_Alzheimers_disease", | |
| "train_acc_diag_cognitively_normal", "val_acc_diag_cognitively_normal", "test_acc_diag_cognitively_normal", | |
| ] | |
| def build_app(): | |
| css_str = load_css() | |
| df = load_and_prepare_data() | |
| # Convert each row to dict with 4 sig digits for numeric | |
| data_list = [] | |
| for _, row in df.iterrows(): | |
| row_dict = {} | |
| for col in DEFAULT_COLUMN_ORDER: | |
| val = row.get(col, np.nan) | |
| if pd.isna(val): | |
| row_dict[col] = None | |
| elif isinstance(val, (int, float)): | |
| val_str = f"{val:.4g}" | |
| # If "9999", interpret as missing: | |
| row_dict[col] = float(val_str) if val_str != "9999" else None | |
| else: | |
| row_dict[col] = str(val) | |
| data_list.append(row_dict) | |
| table_data_json = json.dumps(data_list) | |
| col_order_json = json.dumps(DEFAULT_COLUMN_ORDER) | |
| # HTML for scoreboard | |
| scoreboard_html = """ | |
| <div id="leaderboard-container" style="overflow-x:auto;"> | |
| <table id="leaderboard" class="custom-board"> | |
| <thead> | |
| <!-- Top row --> | |
| <tr> | |
| <!-- 1) Model pinned (sticky-col pinned-col) --> | |
| <th class="sticky-col pinned-col sortable" data-col="Model" rowspan="3"> | |
| Model <span class="sort-arrow">▲</span> | |
| </th> | |
| <!-- 2) Model Type (not pinned) --> | |
| <th data-col="model_type" class="sortable" rowspan="3"> | |
| Model Type <span class="sort-arrow">▲</span> | |
| </th> | |
| <!-- 3) Overall rank (not pinned) --> | |
| <th data-col="Overall rank" class="sortable" rowspan="3"> | |
| Overall Rank <span class="sort-arrow">▲</span> | |
| </th> | |
| <!-- Next columns stay the same order --> | |
| <th data-col="Recon rank" class="sortable" rowspan="3"> | |
| Recon Rank <span class="sort-arrow">▲</span> | |
| </th> | |
| <th colspan="12">Reconstruction Error</th> | |
| <th data-col="Downstream rank" class="sortable" rowspan="3"> | |
| Downstream Rank <span class="sort-arrow">▲</span> | |
| </th> | |
| <th colspan="21">Downstream Models</th> | |
| </tr> | |
| <!-- 2nd row --> | |
| <tr> | |
| <th colspan="3">L1</th> | |
| <th colspan="3">Perceptual</th> | |
| <th colspan="3">SSIM</th> | |
| <th colspan="3">PSNR</th> | |
| <th colspan="3">Age (MAE)</th> | |
| <th colspan="3">Sex (Overall)</th> | |
| <th colspan="3">Sex (Female)</th> | |
| <th colspan="3">Sex (Male)</th> | |
| <th colspan="3">Diagnosis (Overall)</th> | |
| <th colspan="3">Diagnosis (Alzheimers_disease)</th> | |
| <th colspan="3">Diagnosis (cognitively_normal)</th> | |
| </tr> | |
| <!-- 3rd row --> | |
| <tr> | |
| <th data-col="train_l1" class="sortable">Train <span class="sort-arrow">▲</span></th> | |
| <th data-col="val_l1" class="sortable">Val <span class="sort-arrow">▲</span></th> | |
| <th data-col="test_l1" class="sortable">Test <span class="sort-arrow">▲</span></th> | |
| <th data-col="train_perceptual" class="sortable">Train <span class="sort-arrow">▲</span></th> | |
| <th data-col="val_perceptual" class="sortable">Val <span class="sort-arrow">▲</span></th> | |
| <th data-col="test_perceptual" class="sortable">Test <span class="sort-arrow">▲</span></th> | |
| <th data-col="train_ssim" class="sortable">Train <span class="sort-arrow">▲</span></th> | |
| <th data-col="val_ssim" class="sortable">Val <span class="sort-arrow">▲</span></th> | |
| <th data-col="test_ssim" class="sortable">Test <span class="sort-arrow">▲</span></th> | |
| <th data-col="train_psnr" class="sortable">Train <span class="sort-arrow">▲</span></th> | |
| <th data-col="val_psnr" class="sortable">Val <span class="sort-arrow">▲</span></th> | |
| <th data-col="test_psnr" class="sortable">Test <span class="sort-arrow">▲</span></th> | |
| <th data-col="train_mae_age" class="sortable">Train <span class="sort-arrow">▲</span></th> | |
| <th data-col="val_mae_age" class="sortable">Val <span class="sort-arrow">▲</span></th> | |
| <th data-col="test_mae_age" class="sortable">Test <span class="sort-arrow">▲</span></th> | |
| <th data-col="train_acc_sex" class="sortable">Train <span class="sort-arrow">▲</span></th> | |
| <th data-col="val_acc_sex" class="sortable">Val <span class="sort-arrow">▲</span></th> | |
| <th data-col="test_acc_sex" class="sortable">Test <span class="sort-arrow">▲</span></th> | |
| <th data-col="train_acc_sex_female" class="sortable">Train <span class="sort-arrow">▲</span></th> | |
| <th data-col="val_acc_sex_female" class="sortable">Val <span class="sort-arrow">▲</span></th> | |
| <th data-col="test_acc_sex_female" class="sortable">Test <span class="sort-arrow">▲</span></th> | |
| <th data-col="train_acc_sex_male" class="sortable">Train <span class="sort-arrow">▲</span></th> | |
| <th data-col="val_acc_sex_male" class="sortable">Val <span class="sort-arrow">▲</span></th> | |
| <th data-col="test_acc_sex_male" class="sortable">Test <span class="sort-arrow">▲</span></th> | |
| <th data-col="train_acc_diag" class="sortable">Train <span class="sort-arrow">▲</span></th> | |
| <th data-col="val_acc_diag" class="sortable">Val <span class="sort-arrow">▲</span></th> | |
| <th data-col="test_acc_diag" class="sortable">Test <span class="sort-arrow">▲</span></th> | |
| <th data-col="train_acc_diag_Alzheimers_disease" class="sortable">Train <span class="sort-arrow">▲</span></th> | |
| <th data-col="val_acc_diag_Alzheimers_disease" class="sortable">Val <span class="sort-arrow">▲</span></th> | |
| <th data-col="test_acc_diag_Alzheimers_disease" class="sortable">Test <span class="sort-arrow">▲</span></th> | |
| <th data-col="train_acc_diag_cognitively_normal" class="sortable">Train <span class="sort-arrow">▲</span></th> | |
| <th data-col="val_acc_diag_cognitively_normal" class="sortable">Val <span class="sort-arrow">▲</span></th> | |
| <th data-col="test_acc_diag_cognitively_normal" class="sortable">Test <span class="sort-arrow">▲</span></th> | |
| </tr> | |
| </thead> | |
| <tbody id="leaderboard-body"></tbody> | |
| </table> | |
| </div> | |
| <script> | |
| window.tableData = """ + table_data_json + """; | |
| window.defaultColumnOrder = """ + col_order_json + """; | |
| // Default sort by Overall rank ascending | |
| window.currentSort = { col: "Overall rank", dir: "asc" }; | |
| function initTable() { | |
| // sort first | |
| sortData(); | |
| // then render + setup | |
| renderTableBody(); | |
| setupSortableHeaders(); | |
| updateHeaderArrows(); | |
| } | |
| function setupSortableHeaders() { | |
| const thEls = document.querySelectorAll('#leaderboard th.sortable'); | |
| thEls.forEach(th => { | |
| th.addEventListener('click', () => { | |
| const col = th.dataset.col; | |
| if (window.currentSort.col === col) { | |
| window.currentSort.dir = (window.currentSort.dir === 'asc') ? 'desc' : 'asc'; | |
| } else { | |
| window.currentSort.col = col; | |
| window.currentSort.dir = 'asc'; | |
| } | |
| sortData(); | |
| renderTableBody(); | |
| updateHeaderArrows(); | |
| }); | |
| }); | |
| } | |
| function sortData() { | |
| const col = window.currentSort.col; | |
| if (!col) return; | |
| const asc = (window.currentSort.dir === 'asc'); | |
| window.tableData.sort((a, b) => { | |
| const aVal = (a[col] === null) ? null : a[col]; | |
| const bVal = (b[col] === null) ? null : b[col]; | |
| // null => sorted last | |
| if (aVal === null && bVal === null) return 0; | |
| if (aVal === null) return 1; | |
| if (bVal === null) return -1; | |
| if (typeof aVal === 'number' && typeof bVal === 'number') { | |
| return asc ? aVal - bVal : bVal - aVal; | |
| } else { | |
| const av = String(aVal); | |
| const bv = String(bVal); | |
| if (av === bv) return 0; | |
| return asc ? (av < bv ? -1 : 1) : (av > bv ? -1 : 1); | |
| } | |
| }); | |
| } | |
| function updateHeaderArrows() { | |
| // reset all arrow icons | |
| const allArrows = document.querySelectorAll('#leaderboard th.sortable .sort-arrow'); | |
| allArrows.forEach(ar => { | |
| ar.style.color = '#aaa'; | |
| ar.textContent = '▲'; | |
| }); | |
| if (window.currentSort.col) { | |
| const selector = '#leaderboard th.sortable[data-col="'+window.currentSort.col+'"] .sort-arrow'; | |
| const target = document.querySelector(selector); | |
| if (target) { | |
| target.style.color = 'orange'; | |
| if (window.currentSort.dir === 'desc') { | |
| target.textContent = '▼'; | |
| } else { | |
| target.textContent = '▲'; | |
| } | |
| } | |
| } | |
| } | |
| function renderTableBody() { | |
| const tbody = document.getElementById('leaderboard-body'); | |
| tbody.innerHTML = ''; | |
| window.tableData.forEach(row => { | |
| const tr = document.createElement('tr'); | |
| window.defaultColumnOrder.forEach(col => { | |
| const val = row[col]; | |
| const td = document.createElement('td'); | |
| // if col is pinned => add pinned classes to td | |
| if (col === "Model") { | |
| td.classList.add("sticky-col", "pinned-col"); | |
| } | |
| if (val === null) { | |
| td.textContent = 'N/A'; | |
| } else if (col === 'Model') { | |
| // stored as "nickname"||URL | |
| const parts = val.split('||'); | |
| let nickname = parts[0] || '"Unknown"'; | |
| if (nickname.startsWith('"') && nickname.endsWith('"')) { | |
| nickname = nickname.slice(1, -1); | |
| } | |
| const href = parts[1] || '#'; | |
| td.innerHTML = `<a href="${href}" target="_blank">${nickname}</a>`; | |
| } else { | |
| td.textContent = val; | |
| } | |
| tr.appendChild(td); | |
| }); | |
| tbody.appendChild(tr); | |
| }); | |
| } | |
| setTimeout(initTable, 50); | |
| </script> | |
| """ | |
| # Read docs.md | |
| docs_path = Path("docs.md") | |
| docs_content = "" | |
| if docs_path.is_file(): | |
| docs_content = docs_path.read_text(encoding="utf-8") | |
| with gr.Blocks(css=css_str) as demo: | |
| gr.HTML( | |
| "<div id='leaderboard-heading'>🧠↗️ Brain2vec Leaderboard</div>" | |
| ) | |
| gr.HTML(scoreboard_html) | |
| gr.Markdown(docs_content) | |
| return demo | |
| if __name__ == "__main__": | |
| app = build_app() | |
| app.launch() |