File size: 11,365 Bytes
9b23ae9
bce4bae
 
 
 
 
 
 
9b23ae9
8643122
bce4bae
9b23ae9
 
 
bce4bae
9b23ae9
bce4bae
 
9b23ae9
8643122
9b23ae9
 
bce4bae
9b23ae9
bce4bae
9b23ae9
bce4bae
 
9b23ae9
 
 
 
 
 
 
 
 
 
bce4bae
 
9b23ae9
bce4bae
9b23ae9
 
 
bce4bae
9b23ae9
 
bce4bae
 
 
 
9b23ae9
bce4bae
 
9b23ae9
 
 
 
bce4bae
 
9b23ae9
bce4bae
9b23ae9
 
bce4bae
 
 
9b23ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bce4bae
9b23ae9
 
 
bce4bae
 
 
9b23ae9
 
bce4bae
9b23ae9
 
bce4bae
9b23ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bce4bae
9b23ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""Brain Alignment Benchmark - Research Grade."""

import streamlit as st
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

from session import init_session, log_analysis, carry_rois, download_csv_button, show_analysis_log
from theme import inject_theme, section_header
from utils import (
    ALIGNMENT_METHODS, ROI_GROUPS, make_roi_indices,
    permutation_test, bootstrap_ci, fdr_correction, noise_ceiling,
    compute_rdm,
)
from synthetic import generate_realistic_predictions, generate_correlated_features

st.set_page_config(page_title="Brain Alignment", page_icon="🎯", layout="wide")
init_session()
inject_theme()
show_analysis_log()

st.title("🎯 Brain Alignment Benchmark")
st.markdown("Score how well AI model representations align with predicted brain responses, with full statistical testing.")

# --- Sidebar ---
with st.sidebar:
    st.header("Configuration")
    stimulus = st.selectbox("Stimulus type", ["visual", "auditory", "language", "multimodal"],
                            index=["visual", "auditory", "language", "multimodal"].index(st.session_state.get("stimulus_type", "visual")))
    n_timepoints = st.slider("Timepoints", 30, 200, st.session_state.get("n_timepoints", 80))
    seed = st.number_input("Seed", value=st.session_state.get("seed", 42), min_value=0)

    st.subheader("Models")
    model_configs = {
        "CLIP ViT-L/14": {"dim": 768, "alignment": st.slider("CLIP alignment", 0.0, 1.0, 0.6, 0.05)},
        "DINOv2 ViT-S": {"dim": 384, "alignment": st.slider("DINOv2 alignment", 0.0, 1.0, 0.3, 0.05)},
        "V-JEPA2 ViT-G": {"dim": 1024, "alignment": st.slider("V-JEPA2 alignment", 0.0, 1.0, 0.8, 0.05)},
    }

    st.subheader("Methods & Statistics")
    methods = st.multiselect("Methods", ["RSA", "CKA", "Procrustes"], default=["RSA", "CKA"])
    n_perm = st.slider("Permutations", 50, 1000, 200)
    n_boot = st.slider("Bootstrap samples", 100, 2000, 500)
    apply_fdr = st.checkbox("Apply FDR correction", value=True)

if not methods:
    st.warning("Select at least one method.")
    st.stop()

# --- Generate Data ---
roi_indices, n_vertices = make_roi_indices()
brain_pred = generate_realistic_predictions(n_timepoints, roi_indices, stimulus, seed=seed)

model_features = {}
for i, (name, cfg) in enumerate(model_configs.items()):
    model_features[name] = generate_correlated_features(
        brain_pred, cfg["alignment"], cfg["dim"], seed=seed + i + 1
    )

# --- Run Benchmark ---
with st.spinner("Computing alignment scores with statistical testing..."):
    results = []
    null_distributions = {}

    for model_name, features in model_features.items():
        for method_name in methods:
            score_fn = ALIGNMENT_METHODS[method_name]
            observed, p_val, null_dist = permutation_test(features, brain_pred, score_fn, n_perm, seed)
            point, ci_lo, ci_hi = bootstrap_ci(features, brain_pred, score_fn, n_boot, seed=seed)
            null_distributions[f"{model_name}_{method_name}"] = null_dist

            results.append({
                "Model": model_name,
                "Method": method_name,
                "Score": observed,
                "CI Lower": ci_lo,
                "CI Upper": ci_hi,
                "p-value": p_val,
            })

    df = pd.DataFrame(results)
    log_analysis(f"Brain alignment: {len(model_features)} models x {len(methods)} methods")

# --- Noise Ceiling ---
ceiling_scores = {}
for method_name in methods:
    score_fn = ALIGNMENT_METHODS[method_name]
    ceil_mean, ceil_std = noise_ceiling(brain_pred, score_fn, seed=seed)
    ceiling_scores[method_name] = ceil_mean

# --- Display: Alignment Scores with CIs ---
st.subheader("Alignment Scores")

col_chart, col_table = st.columns([2, 1])

with col_chart:
    fig = go.Figure()
    method_colors = {"RSA": "#00D2FF", "CKA": "#FF6B6B", "Procrustes": "#A29BFE"}
    x_positions = list(model_configs.keys())

    for method_name in methods:
        method_df = df[df["Method"] == method_name]
        fig.add_trace(go.Bar(
            name=method_name,
            x=method_df["Model"],
            y=method_df["Score"],
            error_y=dict(
                type="data",
                symmetric=False,
                array=(method_df["CI Upper"] - method_df["Score"]).tolist(),
                arrayminus=(method_df["Score"] - method_df["CI Lower"]).tolist(),
            ),
            marker_color=method_colors.get(method_name, "#888"),
        ))
        # Noise ceiling line
        if method_name in ceiling_scores:
            fig.add_hline(
                y=ceiling_scores[method_name],
                line_dash="dash", line_color=method_colors.get(method_name, "#888"),
                opacity=0.4,
                annotation_text=f"{method_name} ceiling",
                annotation_position="top right",
            )

    fig.update_layout(
        barmode="group", yaxis_title="Alignment Score",
        height=450, template="plotly_dark",
        legend=dict(orientation="h", yanchor="bottom", y=1.02),
    )
    st.plotly_chart(fig, use_container_width=True)

with col_table:
    st.subheader("Results")
    display_df = df.copy()
    for col in ["Score", "CI Lower", "CI Upper", "p-value"]:
        display_df[col] = display_df[col].map(lambda x: f"{x:.4f}")
    st.dataframe(display_df, use_container_width=True, hide_index=True)
    download_csv_button(df, "brain_alignment_results.csv")

# --- Null Distribution ---
with st.expander("Null Distributions (Permutation Tests)", expanded=False):
    st.markdown("The histogram shows the distribution of scores under the null hypothesis (no alignment). "
                "The red line marks the observed score. If it falls far to the right, alignment is significant.")
    cols = st.columns(min(len(null_distributions), 3))
    for i, (key, null_dist) in enumerate(null_distributions.items()):
        model_name, method_name = key.rsplit("_", 1)
        row = df[(df["Model"] == model_name) & (df["Method"] == method_name)].iloc[0]
        with cols[i % len(cols)]:
            fig_null = go.Figure()
            fig_null.add_trace(go.Histogram(x=null_dist, nbinsx=30, marker_color="rgba(100,100,100,0.6)", name="Null"))
            fig_null.add_vline(x=row["Score"], line_color="red", line_width=2, annotation_text=f"Observed")
            fig_null.update_layout(
                title=f"{model_name} ({method_name})",
                xaxis_title="Score", yaxis_title="Count",
                height=250, template="plotly_dark", showlegend=False,
                margin=dict(t=40, b=30, l=30, r=10),
            )
            st.plotly_chart(fig_null, use_container_width=True)
            st.caption(f"p = {row['p-value']:.4f}")

# --- RDM Visualization ---
with st.expander("Representational Dissimilarity Matrices", expanded=False):
    st.markdown("RDMs show pairwise dissimilarity between stimulus representations. "
                "Similar RDM structure between model and brain indicates representational alignment.")
    rdm_model_name = st.selectbox("Model for RDM", list(model_features.keys()))
    col_brain, col_model = st.columns(2)

    brain_rdm = compute_rdm(brain_pred)
    model_rdm = compute_rdm(model_features[rdm_model_name])

    with col_brain:
        fig_rdm = go.Figure(go.Heatmap(z=brain_rdm, colorscale="Viridis", colorbar=dict(title="Dissimilarity")))
        fig_rdm.update_layout(title="Brain RDM", height=350, template="plotly_dark", xaxis_title="Stimulus", yaxis_title="Stimulus")
        st.plotly_chart(fig_rdm, use_container_width=True)

    with col_model:
        fig_rdm2 = go.Figure(go.Heatmap(z=model_rdm, colorscale="Viridis", colorbar=dict(title="Dissimilarity")))
        fig_rdm2.update_layout(title=f"{rdm_model_name} RDM", height=350, template="plotly_dark", xaxis_title="Stimulus", yaxis_title="Stimulus")
        st.plotly_chart(fig_rdm2, use_container_width=True)

# --- Per-ROI Analysis with FDR ---
st.divider()
st.subheader("Per-ROI Alignment")

roi_method = st.selectbox("Method for ROI analysis", methods, key="roi_method")
score_fn = ALIGNMENT_METHODS[roi_method]

roi_data = []
roi_p_values = []
top_model = df[df["Method"] == roi_method].sort_values("Score", ascending=False).iloc[0]["Model"]
features = model_features[top_model]

for group_name, rois in ROI_GROUPS.items():
    for roi in rois:
        if roi in roi_indices:
            verts = roi_indices[roi]
            valid = verts[verts < brain_pred.shape[1]]
            if len(valid) >= 2:
                s = score_fn(features, brain_pred[:, valid])
                _, p, _ = permutation_test(features, brain_pred[:, valid], score_fn, n_perm=50, seed=seed)
                roi_data.append({"ROI": roi, "Group": group_name, "Score": s, "p-value": p})
                roi_p_values.append(p)

if roi_data:
    roi_df = pd.DataFrame(roi_data)
    if apply_fdr and len(roi_p_values) > 1:
        corrected_p, significant = fdr_correction(roi_p_values)
        roi_df["FDR p-value"] = corrected_p
        roi_df["Significant"] = significant
        roi_df["Label"] = roi_df.apply(lambda r: f"{r['ROI']} *" if r["Significant"] else r["ROI"], axis=1)
    else:
        roi_df["Label"] = roi_df["ROI"]
        roi_df["Significant"] = roi_df["p-value"] < 0.05

    group_colors = {"Visual": "#00D2FF", "Auditory": "#FF6B6B", "Language": "#A29BFE", "Executive": "#FFEAA7"}
    fig_roi = px.bar(roi_df, x="Label", y="Score", color="Group",
                     color_discrete_map=group_colors)
    fig_roi.update_layout(height=400, template="plotly_dark", xaxis_tickangle=45)
    st.plotly_chart(fig_roi, use_container_width=True)
    st.caption(f"Model: {top_model} | * = significant after FDR correction (q < 0.05)" if apply_fdr else f"Model: {top_model}")

    # Carry ROIs button
    sig_rois = roi_df[roi_df["Significant"]]["ROI"].tolist() if "Significant" in roi_df.columns else []
    if sig_rois:
        if st.button(f"Carry {len(sig_rois)} significant ROIs to other pages"):
            carry_rois(sig_rois, "Temporal Dynamics / Connectivity")
            st.success(f"Carried {len(sig_rois)} ROIs: {', '.join(sig_rois[:5])}{'...' if len(sig_rois) > 5 else ''}")

# --- Methodology ---
with st.expander("Methodology", expanded=False):
    st.markdown("""
**Representational Similarity Analysis (RSA)** compares the geometry of two representation
spaces by computing pairwise dissimilarity matrices (RDMs) and correlating their upper triangles
via Spearman rank correlation. Range: [-1, 1]. Values > 0.1 are typically meaningful.
*Kriegeskorte et al., 2008, Frontiers in Systems Neuroscience.*

**Centered Kernel Alignment (CKA)** measures similarity between representations using
HSIC (Hilbert-Schmidt Independence Criterion) normalized by self-similarities. Invariant to
orthogonal transformations and isotropic scaling. Range: [0, 1].
*Kornblith et al., 2019, ICML.*

**Procrustes** finds the optimal rotation mapping one space onto another and measures
residual distance. Score = 1 - normalized Procrustes distance. Range: [0, 1].
*Ding et al., 2021, NeurIPS.*

**Noise ceiling** estimates the maximum achievable alignment score given the noise in the
brain data, computed via split-half reliability.

**FDR correction** (Benjamini-Hochberg) controls the false discovery rate when testing
multiple ROIs simultaneously.
""")