"""Brain Alignment Benchmark - Research Grade.""" import streamlit as st import numpy as np import pandas as pd import plotly.graph_objects as go import plotly.express as px from session import init_session, log_analysis, carry_rois, download_csv_button, show_analysis_log from theme import inject_theme, section_header from utils import ( ALIGNMENT_METHODS, ROI_GROUPS, make_roi_indices, permutation_test, bootstrap_ci, fdr_correction, noise_ceiling, compute_rdm, ) from synthetic import generate_realistic_predictions, generate_correlated_features st.set_page_config(page_title="Brain Alignment", page_icon="🎯", layout="wide") init_session() inject_theme() show_analysis_log() st.title("🎯 Brain Alignment Benchmark") st.markdown("Score how well AI model representations align with predicted brain responses, with full statistical testing.") # --- Sidebar --- with st.sidebar: st.header("Configuration") stimulus = st.selectbox("Stimulus type", ["visual", "auditory", "language", "multimodal"], index=["visual", "auditory", "language", "multimodal"].index(st.session_state.get("stimulus_type", "visual"))) n_timepoints = st.slider("Timepoints", 30, 200, st.session_state.get("n_timepoints", 80)) seed = st.number_input("Seed", value=st.session_state.get("seed", 42), min_value=0) st.subheader("Models") model_configs = { "CLIP ViT-L/14": {"dim": 768, "alignment": st.slider("CLIP alignment", 0.0, 1.0, 0.6, 0.05)}, "DINOv2 ViT-S": {"dim": 384, "alignment": st.slider("DINOv2 alignment", 0.0, 1.0, 0.3, 0.05)}, "V-JEPA2 ViT-G": {"dim": 1024, "alignment": st.slider("V-JEPA2 alignment", 0.0, 1.0, 0.8, 0.05)}, } st.subheader("Methods & Statistics") methods = st.multiselect("Methods", ["RSA", "CKA", "Procrustes"], default=["RSA", "CKA"]) n_perm = st.slider("Permutations", 50, 1000, 200) n_boot = st.slider("Bootstrap samples", 100, 2000, 500) apply_fdr = st.checkbox("Apply FDR correction", value=True) if not methods: st.warning("Select at least one method.") st.stop() # --- Generate Data --- roi_indices, n_vertices = make_roi_indices() brain_pred = generate_realistic_predictions(n_timepoints, roi_indices, stimulus, seed=seed) model_features = {} for i, (name, cfg) in enumerate(model_configs.items()): model_features[name] = generate_correlated_features( brain_pred, cfg["alignment"], cfg["dim"], seed=seed + i + 1 ) # --- Run Benchmark --- with st.spinner("Computing alignment scores with statistical testing..."): results = [] null_distributions = {} for model_name, features in model_features.items(): for method_name in methods: score_fn = ALIGNMENT_METHODS[method_name] observed, p_val, null_dist = permutation_test(features, brain_pred, score_fn, n_perm, seed) point, ci_lo, ci_hi = bootstrap_ci(features, brain_pred, score_fn, n_boot, seed=seed) null_distributions[f"{model_name}_{method_name}"] = null_dist results.append({ "Model": model_name, "Method": method_name, "Score": observed, "CI Lower": ci_lo, "CI Upper": ci_hi, "p-value": p_val, }) df = pd.DataFrame(results) log_analysis(f"Brain alignment: {len(model_features)} models x {len(methods)} methods") # --- Noise Ceiling --- ceiling_scores = {} for method_name in methods: score_fn = ALIGNMENT_METHODS[method_name] ceil_mean, ceil_std = noise_ceiling(brain_pred, score_fn, seed=seed) ceiling_scores[method_name] = ceil_mean # --- Display: Alignment Scores with CIs --- st.subheader("Alignment Scores") col_chart, col_table = st.columns([2, 1]) with col_chart: fig = go.Figure() method_colors = {"RSA": "#00D2FF", "CKA": "#FF6B6B", "Procrustes": "#A29BFE"} x_positions = list(model_configs.keys()) for method_name in methods: method_df = df[df["Method"] == method_name] fig.add_trace(go.Bar( name=method_name, x=method_df["Model"], y=method_df["Score"], error_y=dict( type="data", symmetric=False, array=(method_df["CI Upper"] - method_df["Score"]).tolist(), arrayminus=(method_df["Score"] - method_df["CI Lower"]).tolist(), ), marker_color=method_colors.get(method_name, "#888"), )) # Noise ceiling line if method_name in ceiling_scores: fig.add_hline( y=ceiling_scores[method_name], line_dash="dash", line_color=method_colors.get(method_name, "#888"), opacity=0.4, annotation_text=f"{method_name} ceiling", annotation_position="top right", ) fig.update_layout( barmode="group", yaxis_title="Alignment Score", height=450, template="plotly_dark", legend=dict(orientation="h", yanchor="bottom", y=1.02), ) st.plotly_chart(fig, use_container_width=True) with col_table: st.subheader("Results") display_df = df.copy() for col in ["Score", "CI Lower", "CI Upper", "p-value"]: display_df[col] = display_df[col].map(lambda x: f"{x:.4f}") st.dataframe(display_df, use_container_width=True, hide_index=True) download_csv_button(df, "brain_alignment_results.csv") # --- Null Distribution --- with st.expander("Null Distributions (Permutation Tests)", expanded=False): st.markdown("The histogram shows the distribution of scores under the null hypothesis (no alignment). " "The red line marks the observed score. If it falls far to the right, alignment is significant.") cols = st.columns(min(len(null_distributions), 3)) for i, (key, null_dist) in enumerate(null_distributions.items()): model_name, method_name = key.rsplit("_", 1) row = df[(df["Model"] == model_name) & (df["Method"] == method_name)].iloc[0] with cols[i % len(cols)]: fig_null = go.Figure() fig_null.add_trace(go.Histogram(x=null_dist, nbinsx=30, marker_color="rgba(100,100,100,0.6)", name="Null")) fig_null.add_vline(x=row["Score"], line_color="red", line_width=2, annotation_text=f"Observed") fig_null.update_layout( title=f"{model_name} ({method_name})", xaxis_title="Score", yaxis_title="Count", height=250, template="plotly_dark", showlegend=False, margin=dict(t=40, b=30, l=30, r=10), ) st.plotly_chart(fig_null, use_container_width=True) st.caption(f"p = {row['p-value']:.4f}") # --- RDM Visualization --- with st.expander("Representational Dissimilarity Matrices", expanded=False): st.markdown("RDMs show pairwise dissimilarity between stimulus representations. " "Similar RDM structure between model and brain indicates representational alignment.") rdm_model_name = st.selectbox("Model for RDM", list(model_features.keys())) col_brain, col_model = st.columns(2) brain_rdm = compute_rdm(brain_pred) model_rdm = compute_rdm(model_features[rdm_model_name]) with col_brain: fig_rdm = go.Figure(go.Heatmap(z=brain_rdm, colorscale="Viridis", colorbar=dict(title="Dissimilarity"))) fig_rdm.update_layout(title="Brain RDM", height=350, template="plotly_dark", xaxis_title="Stimulus", yaxis_title="Stimulus") st.plotly_chart(fig_rdm, use_container_width=True) with col_model: fig_rdm2 = go.Figure(go.Heatmap(z=model_rdm, colorscale="Viridis", colorbar=dict(title="Dissimilarity"))) fig_rdm2.update_layout(title=f"{rdm_model_name} RDM", height=350, template="plotly_dark", xaxis_title="Stimulus", yaxis_title="Stimulus") st.plotly_chart(fig_rdm2, use_container_width=True) # --- Per-ROI Analysis with FDR --- st.divider() st.subheader("Per-ROI Alignment") roi_method = st.selectbox("Method for ROI analysis", methods, key="roi_method") score_fn = ALIGNMENT_METHODS[roi_method] roi_data = [] roi_p_values = [] top_model = df[df["Method"] == roi_method].sort_values("Score", ascending=False).iloc[0]["Model"] features = model_features[top_model] for group_name, rois in ROI_GROUPS.items(): for roi in rois: if roi in roi_indices: verts = roi_indices[roi] valid = verts[verts < brain_pred.shape[1]] if len(valid) >= 2: s = score_fn(features, brain_pred[:, valid]) _, p, _ = permutation_test(features, brain_pred[:, valid], score_fn, n_perm=50, seed=seed) roi_data.append({"ROI": roi, "Group": group_name, "Score": s, "p-value": p}) roi_p_values.append(p) if roi_data: roi_df = pd.DataFrame(roi_data) if apply_fdr and len(roi_p_values) > 1: corrected_p, significant = fdr_correction(roi_p_values) roi_df["FDR p-value"] = corrected_p roi_df["Significant"] = significant roi_df["Label"] = roi_df.apply(lambda r: f"{r['ROI']} *" if r["Significant"] else r["ROI"], axis=1) else: roi_df["Label"] = roi_df["ROI"] roi_df["Significant"] = roi_df["p-value"] < 0.05 group_colors = {"Visual": "#00D2FF", "Auditory": "#FF6B6B", "Language": "#A29BFE", "Executive": "#FFEAA7"} fig_roi = px.bar(roi_df, x="Label", y="Score", color="Group", color_discrete_map=group_colors) fig_roi.update_layout(height=400, template="plotly_dark", xaxis_tickangle=45) st.plotly_chart(fig_roi, use_container_width=True) st.caption(f"Model: {top_model} | * = significant after FDR correction (q < 0.05)" if apply_fdr else f"Model: {top_model}") # Carry ROIs button sig_rois = roi_df[roi_df["Significant"]]["ROI"].tolist() if "Significant" in roi_df.columns else [] if sig_rois: if st.button(f"Carry {len(sig_rois)} significant ROIs to other pages"): carry_rois(sig_rois, "Temporal Dynamics / Connectivity") st.success(f"Carried {len(sig_rois)} ROIs: {', '.join(sig_rois[:5])}{'...' if len(sig_rois) > 5 else ''}") # --- Methodology --- with st.expander("Methodology", expanded=False): st.markdown(""" **Representational Similarity Analysis (RSA)** compares the geometry of two representation spaces by computing pairwise dissimilarity matrices (RDMs) and correlating their upper triangles via Spearman rank correlation. Range: [-1, 1]. Values > 0.1 are typically meaningful. *Kriegeskorte et al., 2008, Frontiers in Systems Neuroscience.* **Centered Kernel Alignment (CKA)** measures similarity between representations using HSIC (Hilbert-Schmidt Independence Criterion) normalized by self-similarities. Invariant to orthogonal transformations and isotropic scaling. Range: [0, 1]. *Kornblith et al., 2019, ICML.* **Procrustes** finds the optimal rotation mapping one space onto another and measures residual distance. Score = 1 - normalized Procrustes distance. Range: [0, 1]. *Ding et al., 2021, NeurIPS.* **Noise ceiling** estimates the maximum achievable alignment score given the noise in the brain data, computed via split-half reliability. **FDR correction** (Benjamini-Hochberg) controls the false discovery rate when testing multiple ROIs simultaneously. """)