cortexlab-dashboard / pages /1_Brain_Alignment.py
siddhant-rajhans
Complete dashboard redesign: futuristic glassmorphism UI
8643122
"""Brain Alignment Benchmark - Research Grade."""
import streamlit as st
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from session import init_session, log_analysis, carry_rois, download_csv_button, show_analysis_log
from theme import inject_theme, section_header
from utils import (
ALIGNMENT_METHODS, ROI_GROUPS, make_roi_indices,
permutation_test, bootstrap_ci, fdr_correction, noise_ceiling,
compute_rdm,
)
from synthetic import generate_realistic_predictions, generate_correlated_features
st.set_page_config(page_title="Brain Alignment", page_icon="🎯", layout="wide")
init_session()
inject_theme()
show_analysis_log()
st.title("🎯 Brain Alignment Benchmark")
st.markdown("Score how well AI model representations align with predicted brain responses, with full statistical testing.")
# --- Sidebar ---
with st.sidebar:
st.header("Configuration")
stimulus = st.selectbox("Stimulus type", ["visual", "auditory", "language", "multimodal"],
index=["visual", "auditory", "language", "multimodal"].index(st.session_state.get("stimulus_type", "visual")))
n_timepoints = st.slider("Timepoints", 30, 200, st.session_state.get("n_timepoints", 80))
seed = st.number_input("Seed", value=st.session_state.get("seed", 42), min_value=0)
st.subheader("Models")
model_configs = {
"CLIP ViT-L/14": {"dim": 768, "alignment": st.slider("CLIP alignment", 0.0, 1.0, 0.6, 0.05)},
"DINOv2 ViT-S": {"dim": 384, "alignment": st.slider("DINOv2 alignment", 0.0, 1.0, 0.3, 0.05)},
"V-JEPA2 ViT-G": {"dim": 1024, "alignment": st.slider("V-JEPA2 alignment", 0.0, 1.0, 0.8, 0.05)},
}
st.subheader("Methods & Statistics")
methods = st.multiselect("Methods", ["RSA", "CKA", "Procrustes"], default=["RSA", "CKA"])
n_perm = st.slider("Permutations", 50, 1000, 200)
n_boot = st.slider("Bootstrap samples", 100, 2000, 500)
apply_fdr = st.checkbox("Apply FDR correction", value=True)
if not methods:
st.warning("Select at least one method.")
st.stop()
# --- Generate Data ---
roi_indices, n_vertices = make_roi_indices()
brain_pred = generate_realistic_predictions(n_timepoints, roi_indices, stimulus, seed=seed)
model_features = {}
for i, (name, cfg) in enumerate(model_configs.items()):
model_features[name] = generate_correlated_features(
brain_pred, cfg["alignment"], cfg["dim"], seed=seed + i + 1
)
# --- Run Benchmark ---
with st.spinner("Computing alignment scores with statistical testing..."):
results = []
null_distributions = {}
for model_name, features in model_features.items():
for method_name in methods:
score_fn = ALIGNMENT_METHODS[method_name]
observed, p_val, null_dist = permutation_test(features, brain_pred, score_fn, n_perm, seed)
point, ci_lo, ci_hi = bootstrap_ci(features, brain_pred, score_fn, n_boot, seed=seed)
null_distributions[f"{model_name}_{method_name}"] = null_dist
results.append({
"Model": model_name,
"Method": method_name,
"Score": observed,
"CI Lower": ci_lo,
"CI Upper": ci_hi,
"p-value": p_val,
})
df = pd.DataFrame(results)
log_analysis(f"Brain alignment: {len(model_features)} models x {len(methods)} methods")
# --- Noise Ceiling ---
ceiling_scores = {}
for method_name in methods:
score_fn = ALIGNMENT_METHODS[method_name]
ceil_mean, ceil_std = noise_ceiling(brain_pred, score_fn, seed=seed)
ceiling_scores[method_name] = ceil_mean
# --- Display: Alignment Scores with CIs ---
st.subheader("Alignment Scores")
col_chart, col_table = st.columns([2, 1])
with col_chart:
fig = go.Figure()
method_colors = {"RSA": "#00D2FF", "CKA": "#FF6B6B", "Procrustes": "#A29BFE"}
x_positions = list(model_configs.keys())
for method_name in methods:
method_df = df[df["Method"] == method_name]
fig.add_trace(go.Bar(
name=method_name,
x=method_df["Model"],
y=method_df["Score"],
error_y=dict(
type="data",
symmetric=False,
array=(method_df["CI Upper"] - method_df["Score"]).tolist(),
arrayminus=(method_df["Score"] - method_df["CI Lower"]).tolist(),
),
marker_color=method_colors.get(method_name, "#888"),
))
# Noise ceiling line
if method_name in ceiling_scores:
fig.add_hline(
y=ceiling_scores[method_name],
line_dash="dash", line_color=method_colors.get(method_name, "#888"),
opacity=0.4,
annotation_text=f"{method_name} ceiling",
annotation_position="top right",
)
fig.update_layout(
barmode="group", yaxis_title="Alignment Score",
height=450, template="plotly_dark",
legend=dict(orientation="h", yanchor="bottom", y=1.02),
)
st.plotly_chart(fig, use_container_width=True)
with col_table:
st.subheader("Results")
display_df = df.copy()
for col in ["Score", "CI Lower", "CI Upper", "p-value"]:
display_df[col] = display_df[col].map(lambda x: f"{x:.4f}")
st.dataframe(display_df, use_container_width=True, hide_index=True)
download_csv_button(df, "brain_alignment_results.csv")
# --- Null Distribution ---
with st.expander("Null Distributions (Permutation Tests)", expanded=False):
st.markdown("The histogram shows the distribution of scores under the null hypothesis (no alignment). "
"The red line marks the observed score. If it falls far to the right, alignment is significant.")
cols = st.columns(min(len(null_distributions), 3))
for i, (key, null_dist) in enumerate(null_distributions.items()):
model_name, method_name = key.rsplit("_", 1)
row = df[(df["Model"] == model_name) & (df["Method"] == method_name)].iloc[0]
with cols[i % len(cols)]:
fig_null = go.Figure()
fig_null.add_trace(go.Histogram(x=null_dist, nbinsx=30, marker_color="rgba(100,100,100,0.6)", name="Null"))
fig_null.add_vline(x=row["Score"], line_color="red", line_width=2, annotation_text=f"Observed")
fig_null.update_layout(
title=f"{model_name} ({method_name})",
xaxis_title="Score", yaxis_title="Count",
height=250, template="plotly_dark", showlegend=False,
margin=dict(t=40, b=30, l=30, r=10),
)
st.plotly_chart(fig_null, use_container_width=True)
st.caption(f"p = {row['p-value']:.4f}")
# --- RDM Visualization ---
with st.expander("Representational Dissimilarity Matrices", expanded=False):
st.markdown("RDMs show pairwise dissimilarity between stimulus representations. "
"Similar RDM structure between model and brain indicates representational alignment.")
rdm_model_name = st.selectbox("Model for RDM", list(model_features.keys()))
col_brain, col_model = st.columns(2)
brain_rdm = compute_rdm(brain_pred)
model_rdm = compute_rdm(model_features[rdm_model_name])
with col_brain:
fig_rdm = go.Figure(go.Heatmap(z=brain_rdm, colorscale="Viridis", colorbar=dict(title="Dissimilarity")))
fig_rdm.update_layout(title="Brain RDM", height=350, template="plotly_dark", xaxis_title="Stimulus", yaxis_title="Stimulus")
st.plotly_chart(fig_rdm, use_container_width=True)
with col_model:
fig_rdm2 = go.Figure(go.Heatmap(z=model_rdm, colorscale="Viridis", colorbar=dict(title="Dissimilarity")))
fig_rdm2.update_layout(title=f"{rdm_model_name} RDM", height=350, template="plotly_dark", xaxis_title="Stimulus", yaxis_title="Stimulus")
st.plotly_chart(fig_rdm2, use_container_width=True)
# --- Per-ROI Analysis with FDR ---
st.divider()
st.subheader("Per-ROI Alignment")
roi_method = st.selectbox("Method for ROI analysis", methods, key="roi_method")
score_fn = ALIGNMENT_METHODS[roi_method]
roi_data = []
roi_p_values = []
top_model = df[df["Method"] == roi_method].sort_values("Score", ascending=False).iloc[0]["Model"]
features = model_features[top_model]
for group_name, rois in ROI_GROUPS.items():
for roi in rois:
if roi in roi_indices:
verts = roi_indices[roi]
valid = verts[verts < brain_pred.shape[1]]
if len(valid) >= 2:
s = score_fn(features, brain_pred[:, valid])
_, p, _ = permutation_test(features, brain_pred[:, valid], score_fn, n_perm=50, seed=seed)
roi_data.append({"ROI": roi, "Group": group_name, "Score": s, "p-value": p})
roi_p_values.append(p)
if roi_data:
roi_df = pd.DataFrame(roi_data)
if apply_fdr and len(roi_p_values) > 1:
corrected_p, significant = fdr_correction(roi_p_values)
roi_df["FDR p-value"] = corrected_p
roi_df["Significant"] = significant
roi_df["Label"] = roi_df.apply(lambda r: f"{r['ROI']} *" if r["Significant"] else r["ROI"], axis=1)
else:
roi_df["Label"] = roi_df["ROI"]
roi_df["Significant"] = roi_df["p-value"] < 0.05
group_colors = {"Visual": "#00D2FF", "Auditory": "#FF6B6B", "Language": "#A29BFE", "Executive": "#FFEAA7"}
fig_roi = px.bar(roi_df, x="Label", y="Score", color="Group",
color_discrete_map=group_colors)
fig_roi.update_layout(height=400, template="plotly_dark", xaxis_tickangle=45)
st.plotly_chart(fig_roi, use_container_width=True)
st.caption(f"Model: {top_model} | * = significant after FDR correction (q < 0.05)" if apply_fdr else f"Model: {top_model}")
# Carry ROIs button
sig_rois = roi_df[roi_df["Significant"]]["ROI"].tolist() if "Significant" in roi_df.columns else []
if sig_rois:
if st.button(f"Carry {len(sig_rois)} significant ROIs to other pages"):
carry_rois(sig_rois, "Temporal Dynamics / Connectivity")
st.success(f"Carried {len(sig_rois)} ROIs: {', '.join(sig_rois[:5])}{'...' if len(sig_rois) > 5 else ''}")
# --- Methodology ---
with st.expander("Methodology", expanded=False):
st.markdown("""
**Representational Similarity Analysis (RSA)** compares the geometry of two representation
spaces by computing pairwise dissimilarity matrices (RDMs) and correlating their upper triangles
via Spearman rank correlation. Range: [-1, 1]. Values > 0.1 are typically meaningful.
*Kriegeskorte et al., 2008, Frontiers in Systems Neuroscience.*
**Centered Kernel Alignment (CKA)** measures similarity between representations using
HSIC (Hilbert-Schmidt Independence Criterion) normalized by self-similarities. Invariant to
orthogonal transformations and isotropic scaling. Range: [0, 1].
*Kornblith et al., 2019, ICML.*
**Procrustes** finds the optimal rotation mapping one space onto another and measures
residual distance. Score = 1 - normalized Procrustes distance. Range: [0, 1].
*Ding et al., 2021, NeurIPS.*
**Noise ceiling** estimates the maximum achievable alignment score given the noise in the
brain data, computed via split-half reliability.
**FDR correction** (Benjamini-Hochberg) controls the false discovery rate when testing
multiple ROIs simultaneously.
""")