Spaces:
Running
Dashboard v2: research-grade rebuild
Browse filesBiologically realistic synthetic data:
- HRF convolution with canonical double-gamma response
- Modality-specific ROI activation weights (visual/auditory/language/multimodal)
- Spatial smoothing, AR(1) temporal noise, scanner drift
- Controllable alignment strength for meaningful benchmark scores
Shared session state:
- Cross-page ROI carry (alignment -> temporal dynamics -> connectivity)
- File upload (.npy) and download (CSV, JSON) support
- Analysis log tracking across pages
- Data source selector (synthetic vs uploaded)
Brain Alignment (rebuilt):
- Bootstrap CI error bars on all scores
- Null distribution histograms with observed score overlay
- RDM visualization (brain vs model side-by-side)
- FDR correction for per-ROI tests
- Noise ceiling estimation (split-half reliability)
- Methodology documentation with references
Cognitive Load (rebuilt):
- Confidence bands on timeline (bootstrap across vertices)
- Dimension correlation heatmap
- Per-ROI activation breakdown within each dimension
- Comparison mode (two stimulus types side-by-side)
Temporal Dynamics (rebuilt):
- Raw ROI timecourses showing HRF shape
- Peak latency sorted by processing hierarchy
- Lag correlation with 95% null significance band
- Optimal lag summary table
- Cross-ROI lag matrix
- 3-panel sustained/transient decomposition
Connectivity (rebuilt):
- Partial correlation option
- Correlation matrix with cluster boundary lines
- Dendrogram visualization
- Modularity Q score with interpretation
- Betweenness centrality alongside degree
- Edge weight distribution histogram
- Network graph with weighted edges and sized nodes
- Cluster-to-functional-group mapping
- Home.py +87 -27
- pages/1_Brain_Alignment.py +212 -102
- pages/2_Cognitive_Load.py +167 -85
- pages/3_Temporal_Dynamics.py +200 -62
- pages/4_Connectivity.py +248 -127
- session.py +134 -0
- synthetic.py +230 -0
- utils.py +106 -5
|
@@ -1,50 +1,110 @@
|
|
| 1 |
-
"""CortexLab Dashboard - Home Page."""
|
| 2 |
|
| 3 |
import streamlit as st
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
)
|
| 11 |
|
| 12 |
st.title("CortexLab Dashboard")
|
| 13 |
-
st.markdown("**
|
| 14 |
|
|
|
|
| 15 |
st.divider()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
col1, col2 = st.columns(2)
|
| 18 |
|
| 19 |
with col1:
|
| 20 |
st.subheader("Analysis Tools")
|
| 21 |
st.page_link("pages/1_Brain_Alignment.py", label="Brain Alignment Benchmark", icon="🎯")
|
| 22 |
-
st.
|
| 23 |
|
| 24 |
st.page_link("pages/2_Cognitive_Load.py", label="Cognitive Load Scorer", icon="📊")
|
| 25 |
-
st.
|
| 26 |
|
| 27 |
with col2:
|
| 28 |
st.subheader("Advanced Analysis")
|
| 29 |
st.page_link("pages/3_Temporal_Dynamics.py", label="Temporal Dynamics", icon="⏱️")
|
| 30 |
-
st.
|
| 31 |
|
| 32 |
st.page_link("pages/4_Connectivity.py", label="ROI Connectivity", icon="🔗")
|
| 33 |
-
st.
|
| 34 |
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
-
st.
|
| 38 |
-
st.
|
| 39 |
-
"""
|
| 40 |
-
CortexLab is an enhanced toolkit built on [Meta's TRIBE v2](https://github.com/facebookresearch/tribev2)
|
| 41 |
-
for predicting how the human brain responds to video, audio, and text.
|
| 42 |
-
|
| 43 |
-
This dashboard runs on **synthetic data** by default - no GPU or real fMRI data required.
|
| 44 |
-
All analysis tools mirror the CortexLab Python API.
|
| 45 |
-
|
| 46 |
-
[GitHub](https://github.com/siddhant-rajhans/cortexlab)
|
| 47 |
-
|
|
| 48 |
-
[HuggingFace](https://huggingface.co/SID2000/cortexlab)
|
| 49 |
-
"""
|
| 50 |
-
)
|
|
|
|
| 1 |
+
"""CortexLab Dashboard - Home Page with Data Management."""
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
+
import numpy as np
|
| 5 |
|
| 6 |
+
from session import init_session, data_summary_widget, show_analysis_log, upload_npy_widget
|
| 7 |
+
from utils import make_roi_indices
|
| 8 |
+
|
| 9 |
+
st.set_page_config(page_title="CortexLab Dashboard", page_icon="🧠", layout="wide", initial_sidebar_state="expanded")
|
| 10 |
+
init_session()
|
|
|
|
| 11 |
|
| 12 |
st.title("CortexLab Dashboard")
|
| 13 |
+
st.markdown("**Research-grade analysis toolkit for multimodal fMRI brain encoding**")
|
| 14 |
|
| 15 |
+
# --- Data Source ---
|
| 16 |
st.divider()
|
| 17 |
+
st.subheader("Data Configuration")
|
| 18 |
+
|
| 19 |
+
col_src, col_params = st.columns([1, 2])
|
| 20 |
+
|
| 21 |
+
with col_src:
|
| 22 |
+
source = st.radio("Data source", ["Synthetic (realistic)", "Upload your data"], index=0)
|
| 23 |
+
st.session_state["data_source"] = "synthetic" if "Synthetic" in source else "uploaded"
|
| 24 |
+
|
| 25 |
+
with col_params:
|
| 26 |
+
if st.session_state["data_source"] == "synthetic":
|
| 27 |
+
c1, c2, c3, c4 = st.columns(4)
|
| 28 |
+
st.session_state["stimulus_type"] = c1.selectbox("Stimulus type", ["visual", "auditory", "language", "multimodal"])
|
| 29 |
+
st.session_state["n_timepoints"] = c2.slider("Duration (TRs)", 30, 200, 80)
|
| 30 |
+
st.session_state["tr_seconds"] = c3.slider("TR (seconds)", 0.5, 2.0, 1.0, 0.1)
|
| 31 |
+
st.session_state["seed"] = c4.number_input("Seed", value=42, min_value=0)
|
| 32 |
+
|
| 33 |
+
# Generate on config change
|
| 34 |
+
roi_indices, n_vertices = make_roi_indices()
|
| 35 |
+
st.session_state["roi_indices"] = roi_indices
|
| 36 |
+
st.session_state["n_vertices"] = n_vertices
|
| 37 |
+
|
| 38 |
+
from synthetic import generate_realistic_predictions
|
| 39 |
+
predictions = generate_realistic_predictions(
|
| 40 |
+
st.session_state["n_timepoints"], roi_indices,
|
| 41 |
+
st.session_state["stimulus_type"], st.session_state["tr_seconds"],
|
| 42 |
+
seed=st.session_state["seed"],
|
| 43 |
+
)
|
| 44 |
+
st.session_state["brain_predictions"] = predictions
|
| 45 |
+
else:
|
| 46 |
+
uploaded = upload_npy_widget("Upload brain predictions (.npy, shape: timepoints x vertices)", "upload_predictions")
|
| 47 |
+
if uploaded is not None:
|
| 48 |
+
st.session_state["brain_predictions"] = uploaded
|
| 49 |
+
roi_indices, _ = make_roi_indices()
|
| 50 |
+
st.session_state["roi_indices"] = roi_indices
|
| 51 |
+
|
| 52 |
+
# --- Data Summary ---
|
| 53 |
+
roi_indices = st.session_state.get("roi_indices")
|
| 54 |
+
predictions = st.session_state.get("brain_predictions")
|
| 55 |
+
if predictions is not None and roi_indices is not None:
|
| 56 |
+
data_summary_widget(predictions, roi_indices)
|
| 57 |
+
|
| 58 |
+
# Show HRF-convolved signal preview
|
| 59 |
+
with st.expander("Data Preview", expanded=False):
|
| 60 |
+
import plotly.graph_objects as go
|
| 61 |
+
from utils import ROI_GROUPS
|
| 62 |
|
| 63 |
+
fig = go.Figure()
|
| 64 |
+
t = np.arange(predictions.shape[0]) * st.session_state.get("tr_seconds", 1.0)
|
| 65 |
+
colors = {"Visual": "#00D2FF", "Auditory": "#FF6B6B", "Language": "#A29BFE", "Executive": "#FFEAA7"}
|
| 66 |
+
for group, rois in ROI_GROUPS.items():
|
| 67 |
+
vals = []
|
| 68 |
+
for roi in rois:
|
| 69 |
+
if roi in roi_indices:
|
| 70 |
+
verts = roi_indices[roi]
|
| 71 |
+
valid = verts[verts < predictions.shape[1]]
|
| 72 |
+
if len(valid) > 0:
|
| 73 |
+
vals.append(np.abs(predictions[:, valid]).mean(axis=1))
|
| 74 |
+
if vals:
|
| 75 |
+
mean_tc = np.mean(vals, axis=0)
|
| 76 |
+
fig.add_trace(go.Scatter(x=t, y=mean_tc, name=group, line=dict(color=colors.get(group, "#888"), width=2)))
|
| 77 |
+
|
| 78 |
+
fig.update_layout(
|
| 79 |
+
xaxis_title="Time (seconds)", yaxis_title="Mean |activation|",
|
| 80 |
+
height=300, template="plotly_dark",
|
| 81 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02),
|
| 82 |
+
)
|
| 83 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 84 |
+
st.caption("Mean absolute activation per functional group. Note the hemodynamic response shape and modality-specific activation patterns.")
|
| 85 |
+
|
| 86 |
+
# --- Navigation ---
|
| 87 |
+
st.divider()
|
| 88 |
col1, col2 = st.columns(2)
|
| 89 |
|
| 90 |
with col1:
|
| 91 |
st.subheader("Analysis Tools")
|
| 92 |
st.page_link("pages/1_Brain_Alignment.py", label="Brain Alignment Benchmark", icon="🎯")
|
| 93 |
+
st.caption("RSA, CKA, Procrustes with permutation tests, bootstrap CIs, FDR correction, noise ceiling, and RDM visualization")
|
| 94 |
|
| 95 |
st.page_link("pages/2_Cognitive_Load.py", label="Cognitive Load Scorer", icon="📊")
|
| 96 |
+
st.caption("Timeline with confidence bands, dimension correlation, per-ROI breakdown, comparison mode")
|
| 97 |
|
| 98 |
with col2:
|
| 99 |
st.subheader("Advanced Analysis")
|
| 100 |
st.page_link("pages/3_Temporal_Dynamics.py", label="Temporal Dynamics", icon="⏱️")
|
| 101 |
+
st.caption("Raw timecourses, peak latency hierarchy, optimal lag analysis, cross-ROI lag matrix")
|
| 102 |
|
| 103 |
st.page_link("pages/4_Connectivity.py", label="ROI Connectivity", icon="🔗")
|
| 104 |
+
st.caption("Partial correlation, modularity, betweenness centrality, dendrogram, network graph")
|
| 105 |
|
| 106 |
+
# --- Analysis Log ---
|
| 107 |
+
show_analysis_log()
|
| 108 |
|
| 109 |
+
st.divider()
|
| 110 |
+
st.caption("[GitHub](https://github.com/siddhant-rajhans/cortexlab) | [HuggingFace](https://huggingface.co/SID2000/cortexlab) | [Dashboard Repo](https://github.com/siddhant-rajhans/cortexlab-dashboard)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Brain Alignment Benchmark
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import numpy as np
|
|
@@ -6,135 +6,245 @@ import pandas as pd
|
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import plotly.express as px
|
| 8 |
|
|
|
|
| 9 |
from utils import (
|
| 10 |
-
ALIGNMENT_METHODS,
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
generate_model_features,
|
| 14 |
-
permutation_test,
|
| 15 |
)
|
|
|
|
| 16 |
|
| 17 |
st.set_page_config(page_title="Brain Alignment", page_icon="🎯", layout="wide")
|
|
|
|
|
|
|
|
|
|
| 18 |
st.title("🎯 Brain Alignment Benchmark")
|
| 19 |
-
st.markdown("
|
| 20 |
|
| 21 |
-
# --- Sidebar
|
| 22 |
with st.sidebar:
|
| 23 |
st.header("Configuration")
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
st.
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
"
|
| 32 |
-
"
|
| 33 |
-
|
| 34 |
-
model_dims = {
|
| 35 |
-
"CLIP ViT-L/14": 768,
|
| 36 |
-
"DINOv2 ViT-S": 384,
|
| 37 |
-
"V-JEPA2 ViT-G": 1024,
|
| 38 |
-
"LLaMA 3.2-3B": 3072,
|
| 39 |
}
|
| 40 |
-
selected_models = [m for m, checked in models_config.items() if checked]
|
| 41 |
|
|
|
|
| 42 |
methods = st.multiselect("Methods", ["RSA", "CKA", "Procrustes"], default=["RSA", "CKA"])
|
| 43 |
-
|
| 44 |
-
|
|
|
|
| 45 |
|
| 46 |
-
if not
|
| 47 |
-
st.warning("Select at least one
|
| 48 |
st.stop()
|
| 49 |
|
| 50 |
# --- Generate Data ---
|
| 51 |
roi_indices, n_vertices = make_roi_indices()
|
| 52 |
-
brain_pred =
|
| 53 |
|
| 54 |
model_features = {}
|
| 55 |
-
for i, name in enumerate(
|
| 56 |
-
model_features[name] =
|
|
|
|
|
|
|
| 57 |
|
| 58 |
# --- Run Benchmark ---
|
| 59 |
-
with st.spinner("Computing alignment scores..."):
|
| 60 |
results = []
|
|
|
|
|
|
|
| 61 |
for model_name, features in model_features.items():
|
| 62 |
for method_name in methods:
|
| 63 |
score_fn = ALIGNMENT_METHODS[method_name]
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
fig.update_layout(
|
| 90 |
-
yaxis_title="Alignment Score",
|
| 91 |
-
height=450,
|
| 92 |
-
|
| 93 |
)
|
| 94 |
st.plotly_chart(fig, use_container_width=True)
|
| 95 |
|
| 96 |
-
with
|
| 97 |
-
st.subheader("Results
|
| 98 |
display_df = df.copy()
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
display_df["p-value"] = display_df["p-value"].map(lambda x: f"{x:.4f}")
|
| 102 |
st.dataframe(display_df, use_container_width=True, hide_index=True)
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
st.divider()
|
| 106 |
-
st.subheader("Per-ROI Alignment
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
if
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Brain Alignment Benchmark - Research Grade."""
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import numpy as np
|
|
|
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import plotly.express as px
|
| 8 |
|
| 9 |
+
from session import init_session, log_analysis, carry_rois, download_csv_button, show_analysis_log
|
| 10 |
from utils import (
|
| 11 |
+
ALIGNMENT_METHODS, ROI_GROUPS, make_roi_indices,
|
| 12 |
+
permutation_test, bootstrap_ci, fdr_correction, noise_ceiling,
|
| 13 |
+
compute_rdm,
|
|
|
|
|
|
|
| 14 |
)
|
| 15 |
+
from synthetic import generate_realistic_predictions, generate_correlated_features
|
| 16 |
|
| 17 |
st.set_page_config(page_title="Brain Alignment", page_icon="🎯", layout="wide")
|
| 18 |
+
init_session()
|
| 19 |
+
show_analysis_log()
|
| 20 |
+
|
| 21 |
st.title("🎯 Brain Alignment Benchmark")
|
| 22 |
+
st.markdown("Score how well AI model representations align with predicted brain responses, with full statistical testing.")
|
| 23 |
|
| 24 |
+
# --- Sidebar ---
|
| 25 |
with st.sidebar:
|
| 26 |
st.header("Configuration")
|
| 27 |
+
stimulus = st.selectbox("Stimulus type", ["visual", "auditory", "language", "multimodal"],
|
| 28 |
+
index=["visual", "auditory", "language", "multimodal"].index(st.session_state.get("stimulus_type", "visual")))
|
| 29 |
+
n_timepoints = st.slider("Timepoints", 30, 200, st.session_state.get("n_timepoints", 80))
|
| 30 |
+
seed = st.number_input("Seed", value=st.session_state.get("seed", 42), min_value=0)
|
| 31 |
+
|
| 32 |
+
st.subheader("Models")
|
| 33 |
+
model_configs = {
|
| 34 |
+
"CLIP ViT-L/14": {"dim": 768, "alignment": st.slider("CLIP alignment", 0.0, 1.0, 0.6, 0.05)},
|
| 35 |
+
"DINOv2 ViT-S": {"dim": 384, "alignment": st.slider("DINOv2 alignment", 0.0, 1.0, 0.3, 0.05)},
|
| 36 |
+
"V-JEPA2 ViT-G": {"dim": 1024, "alignment": st.slider("V-JEPA2 alignment", 0.0, 1.0, 0.8, 0.05)},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
}
|
|
|
|
| 38 |
|
| 39 |
+
st.subheader("Methods & Statistics")
|
| 40 |
methods = st.multiselect("Methods", ["RSA", "CKA", "Procrustes"], default=["RSA", "CKA"])
|
| 41 |
+
n_perm = st.slider("Permutations", 50, 1000, 200)
|
| 42 |
+
n_boot = st.slider("Bootstrap samples", 100, 2000, 500)
|
| 43 |
+
apply_fdr = st.checkbox("Apply FDR correction", value=True)
|
| 44 |
|
| 45 |
+
if not methods:
|
| 46 |
+
st.warning("Select at least one method.")
|
| 47 |
st.stop()
|
| 48 |
|
| 49 |
# --- Generate Data ---
|
| 50 |
roi_indices, n_vertices = make_roi_indices()
|
| 51 |
+
brain_pred = generate_realistic_predictions(n_timepoints, roi_indices, stimulus, seed=seed)
|
| 52 |
|
| 53 |
model_features = {}
|
| 54 |
+
for i, (name, cfg) in enumerate(model_configs.items()):
|
| 55 |
+
model_features[name] = generate_correlated_features(
|
| 56 |
+
brain_pred, cfg["alignment"], cfg["dim"], seed=seed + i + 1
|
| 57 |
+
)
|
| 58 |
|
| 59 |
# --- Run Benchmark ---
|
| 60 |
+
with st.spinner("Computing alignment scores with statistical testing..."):
|
| 61 |
results = []
|
| 62 |
+
null_distributions = {}
|
| 63 |
+
|
| 64 |
for model_name, features in model_features.items():
|
| 65 |
for method_name in methods:
|
| 66 |
score_fn = ALIGNMENT_METHODS[method_name]
|
| 67 |
+
observed, p_val, null_dist = permutation_test(features, brain_pred, score_fn, n_perm, seed)
|
| 68 |
+
point, ci_lo, ci_hi = bootstrap_ci(features, brain_pred, score_fn, n_boot, seed=seed)
|
| 69 |
+
null_distributions[f"{model_name}_{method_name}"] = null_dist
|
| 70 |
+
|
| 71 |
+
results.append({
|
| 72 |
+
"Model": model_name,
|
| 73 |
+
"Method": method_name,
|
| 74 |
+
"Score": observed,
|
| 75 |
+
"CI Lower": ci_lo,
|
| 76 |
+
"CI Upper": ci_hi,
|
| 77 |
+
"p-value": p_val,
|
| 78 |
+
})
|
| 79 |
+
|
| 80 |
+
df = pd.DataFrame(results)
|
| 81 |
+
log_analysis(f"Brain alignment: {len(model_features)} models x {len(methods)} methods")
|
| 82 |
+
|
| 83 |
+
# --- Noise Ceiling ---
|
| 84 |
+
ceiling_scores = {}
|
| 85 |
+
for method_name in methods:
|
| 86 |
+
score_fn = ALIGNMENT_METHODS[method_name]
|
| 87 |
+
ceil_mean, ceil_std = noise_ceiling(brain_pred, score_fn, seed=seed)
|
| 88 |
+
ceiling_scores[method_name] = ceil_mean
|
| 89 |
+
|
| 90 |
+
# --- Display: Alignment Scores with CIs ---
|
| 91 |
+
st.subheader("Alignment Scores")
|
| 92 |
+
|
| 93 |
+
col_chart, col_table = st.columns([2, 1])
|
| 94 |
+
|
| 95 |
+
with col_chart:
|
| 96 |
+
fig = go.Figure()
|
| 97 |
+
method_colors = {"RSA": "#00D2FF", "CKA": "#FF6B6B", "Procrustes": "#A29BFE"}
|
| 98 |
+
x_positions = list(model_configs.keys())
|
| 99 |
+
|
| 100 |
+
for method_name in methods:
|
| 101 |
+
method_df = df[df["Method"] == method_name]
|
| 102 |
+
fig.add_trace(go.Bar(
|
| 103 |
+
name=method_name,
|
| 104 |
+
x=method_df["Model"],
|
| 105 |
+
y=method_df["Score"],
|
| 106 |
+
error_y=dict(
|
| 107 |
+
type="data",
|
| 108 |
+
symmetric=False,
|
| 109 |
+
array=(method_df["CI Upper"] - method_df["Score"]).tolist(),
|
| 110 |
+
arrayminus=(method_df["Score"] - method_df["CI Lower"]).tolist(),
|
| 111 |
+
),
|
| 112 |
+
marker_color=method_colors.get(method_name, "#888"),
|
| 113 |
+
))
|
| 114 |
+
# Noise ceiling line
|
| 115 |
+
if method_name in ceiling_scores:
|
| 116 |
+
fig.add_hline(
|
| 117 |
+
y=ceiling_scores[method_name],
|
| 118 |
+
line_dash="dash", line_color=method_colors.get(method_name, "#888"),
|
| 119 |
+
opacity=0.4,
|
| 120 |
+
annotation_text=f"{method_name} ceiling",
|
| 121 |
+
annotation_position="top right",
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
fig.update_layout(
|
| 125 |
+
barmode="group", yaxis_title="Alignment Score",
|
| 126 |
+
height=450, template="plotly_dark",
|
| 127 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02),
|
| 128 |
)
|
| 129 |
st.plotly_chart(fig, use_container_width=True)
|
| 130 |
|
| 131 |
+
with col_table:
|
| 132 |
+
st.subheader("Results")
|
| 133 |
display_df = df.copy()
|
| 134 |
+
for col in ["Score", "CI Lower", "CI Upper", "p-value"]:
|
| 135 |
+
display_df[col] = display_df[col].map(lambda x: f"{x:.4f}")
|
|
|
|
| 136 |
st.dataframe(display_df, use_container_width=True, hide_index=True)
|
| 137 |
+
download_csv_button(df, "brain_alignment_results.csv")
|
| 138 |
+
|
| 139 |
+
# --- Null Distribution ---
|
| 140 |
+
with st.expander("Null Distributions (Permutation Tests)", expanded=False):
|
| 141 |
+
st.markdown("The histogram shows the distribution of scores under the null hypothesis (no alignment). "
|
| 142 |
+
"The red line marks the observed score. If it falls far to the right, alignment is significant.")
|
| 143 |
+
cols = st.columns(min(len(null_distributions), 3))
|
| 144 |
+
for i, (key, null_dist) in enumerate(null_distributions.items()):
|
| 145 |
+
model_name, method_name = key.rsplit("_", 1)
|
| 146 |
+
row = df[(df["Model"] == model_name) & (df["Method"] == method_name)].iloc[0]
|
| 147 |
+
with cols[i % len(cols)]:
|
| 148 |
+
fig_null = go.Figure()
|
| 149 |
+
fig_null.add_trace(go.Histogram(x=null_dist, nbinsx=30, marker_color="rgba(100,100,100,0.6)", name="Null"))
|
| 150 |
+
fig_null.add_vline(x=row["Score"], line_color="red", line_width=2, annotation_text=f"Observed")
|
| 151 |
+
fig_null.update_layout(
|
| 152 |
+
title=f"{model_name} ({method_name})",
|
| 153 |
+
xaxis_title="Score", yaxis_title="Count",
|
| 154 |
+
height=250, template="plotly_dark", showlegend=False,
|
| 155 |
+
margin=dict(t=40, b=30, l=30, r=10),
|
| 156 |
+
)
|
| 157 |
+
st.plotly_chart(fig_null, use_container_width=True)
|
| 158 |
+
st.caption(f"p = {row['p-value']:.4f}")
|
| 159 |
+
|
| 160 |
+
# --- RDM Visualization ---
|
| 161 |
+
with st.expander("Representational Dissimilarity Matrices", expanded=False):
|
| 162 |
+
st.markdown("RDMs show pairwise dissimilarity between stimulus representations. "
|
| 163 |
+
"Similar RDM structure between model and brain indicates representational alignment.")
|
| 164 |
+
rdm_model_name = st.selectbox("Model for RDM", list(model_features.keys()))
|
| 165 |
+
col_brain, col_model = st.columns(2)
|
| 166 |
+
|
| 167 |
+
brain_rdm = compute_rdm(brain_pred)
|
| 168 |
+
model_rdm = compute_rdm(model_features[rdm_model_name])
|
| 169 |
+
|
| 170 |
+
with col_brain:
|
| 171 |
+
fig_rdm = go.Figure(go.Heatmap(z=brain_rdm, colorscale="Viridis", colorbar=dict(title="Dissimilarity")))
|
| 172 |
+
fig_rdm.update_layout(title="Brain RDM", height=350, template="plotly_dark", xaxis_title="Stimulus", yaxis_title="Stimulus")
|
| 173 |
+
st.plotly_chart(fig_rdm, use_container_width=True)
|
| 174 |
+
|
| 175 |
+
with col_model:
|
| 176 |
+
fig_rdm2 = go.Figure(go.Heatmap(z=model_rdm, colorscale="Viridis", colorbar=dict(title="Dissimilarity")))
|
| 177 |
+
fig_rdm2.update_layout(title=f"{rdm_model_name} RDM", height=350, template="plotly_dark", xaxis_title="Stimulus", yaxis_title="Stimulus")
|
| 178 |
+
st.plotly_chart(fig_rdm2, use_container_width=True)
|
| 179 |
+
|
| 180 |
+
# --- Per-ROI Analysis with FDR ---
|
| 181 |
st.divider()
|
| 182 |
+
st.subheader("Per-ROI Alignment")
|
| 183 |
+
|
| 184 |
+
roi_method = st.selectbox("Method for ROI analysis", methods, key="roi_method")
|
| 185 |
+
score_fn = ALIGNMENT_METHODS[roi_method]
|
| 186 |
+
|
| 187 |
+
roi_data = []
|
| 188 |
+
roi_p_values = []
|
| 189 |
+
top_model = df[df["Method"] == roi_method].sort_values("Score", ascending=False).iloc[0]["Model"]
|
| 190 |
+
features = model_features[top_model]
|
| 191 |
+
|
| 192 |
+
for group_name, rois in ROI_GROUPS.items():
|
| 193 |
+
for roi in rois:
|
| 194 |
+
if roi in roi_indices:
|
| 195 |
+
verts = roi_indices[roi]
|
| 196 |
+
valid = verts[verts < brain_pred.shape[1]]
|
| 197 |
+
if len(valid) >= 2:
|
| 198 |
+
s = score_fn(features, brain_pred[:, valid])
|
| 199 |
+
_, p, _ = permutation_test(features, brain_pred[:, valid], score_fn, n_perm=50, seed=seed)
|
| 200 |
+
roi_data.append({"ROI": roi, "Group": group_name, "Score": s, "p-value": p})
|
| 201 |
+
roi_p_values.append(p)
|
| 202 |
+
|
| 203 |
+
if roi_data:
|
| 204 |
+
roi_df = pd.DataFrame(roi_data)
|
| 205 |
+
if apply_fdr and len(roi_p_values) > 1:
|
| 206 |
+
corrected_p, significant = fdr_correction(roi_p_values)
|
| 207 |
+
roi_df["FDR p-value"] = corrected_p
|
| 208 |
+
roi_df["Significant"] = significant
|
| 209 |
+
roi_df["Label"] = roi_df.apply(lambda r: f"{r['ROI']} *" if r["Significant"] else r["ROI"], axis=1)
|
| 210 |
+
else:
|
| 211 |
+
roi_df["Label"] = roi_df["ROI"]
|
| 212 |
+
roi_df["Significant"] = roi_df["p-value"] < 0.05
|
| 213 |
+
|
| 214 |
+
group_colors = {"Visual": "#00D2FF", "Auditory": "#FF6B6B", "Language": "#A29BFE", "Executive": "#FFEAA7"}
|
| 215 |
+
fig_roi = px.bar(roi_df, x="Label", y="Score", color="Group",
|
| 216 |
+
color_discrete_map=group_colors)
|
| 217 |
+
fig_roi.update_layout(height=400, template="plotly_dark", xaxis_tickangle=45)
|
| 218 |
+
st.plotly_chart(fig_roi, use_container_width=True)
|
| 219 |
+
st.caption(f"Model: {top_model} | * = significant after FDR correction (q < 0.05)" if apply_fdr else f"Model: {top_model}")
|
| 220 |
+
|
| 221 |
+
# Carry ROIs button
|
| 222 |
+
sig_rois = roi_df[roi_df["Significant"]]["ROI"].tolist() if "Significant" in roi_df.columns else []
|
| 223 |
+
if sig_rois:
|
| 224 |
+
if st.button(f"Carry {len(sig_rois)} significant ROIs to other pages"):
|
| 225 |
+
carry_rois(sig_rois, "Temporal Dynamics / Connectivity")
|
| 226 |
+
st.success(f"Carried {len(sig_rois)} ROIs: {', '.join(sig_rois[:5])}{'...' if len(sig_rois) > 5 else ''}")
|
| 227 |
+
|
| 228 |
+
# --- Methodology ---
|
| 229 |
+
with st.expander("Methodology", expanded=False):
|
| 230 |
+
st.markdown("""
|
| 231 |
+
**Representational Similarity Analysis (RSA)** compares the geometry of two representation
|
| 232 |
+
spaces by computing pairwise dissimilarity matrices (RDMs) and correlating their upper triangles
|
| 233 |
+
via Spearman rank correlation. Range: [-1, 1]. Values > 0.1 are typically meaningful.
|
| 234 |
+
*Kriegeskorte et al., 2008, Frontiers in Systems Neuroscience.*
|
| 235 |
+
|
| 236 |
+
**Centered Kernel Alignment (CKA)** measures similarity between representations using
|
| 237 |
+
HSIC (Hilbert-Schmidt Independence Criterion) normalized by self-similarities. Invariant to
|
| 238 |
+
orthogonal transformations and isotropic scaling. Range: [0, 1].
|
| 239 |
+
*Kornblith et al., 2019, ICML.*
|
| 240 |
+
|
| 241 |
+
**Procrustes** finds the optimal rotation mapping one space onto another and measures
|
| 242 |
+
residual distance. Score = 1 - normalized Procrustes distance. Range: [0, 1].
|
| 243 |
+
*Ding et al., 2021, NeurIPS.*
|
| 244 |
+
|
| 245 |
+
**Noise ceiling** estimates the maximum achievable alignment score given the noise in the
|
| 246 |
+
brain data, computed via split-half reliability.
|
| 247 |
+
|
| 248 |
+
**FDR correction** (Benjamini-Hochberg) controls the false discovery rate when testing
|
| 249 |
+
multiple ROIs simultaneously.
|
| 250 |
+
""")
|
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Cognitive Load Scorer
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import numpy as np
|
|
@@ -6,126 +6,208 @@ import pandas as pd
|
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import plotly.express as px
|
| 8 |
|
| 9 |
-
from
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
score_cognitive_load,
|
| 13 |
-
COGNITIVE_DIMENSIONS,
|
| 14 |
-
)
|
| 15 |
|
| 16 |
st.set_page_config(page_title="Cognitive Load", page_icon="📊", layout="wide")
|
|
|
|
|
|
|
|
|
|
| 17 |
st.title("📊 Cognitive Load Scorer")
|
| 18 |
-
st.markdown("Predict cognitive demand
|
| 19 |
|
| 20 |
# --- Sidebar ---
|
| 21 |
with st.sidebar:
|
| 22 |
st.header("Configuration")
|
| 23 |
-
n_timepoints = st.slider("Duration (TRs)",
|
| 24 |
-
tr_seconds = st.slider("TR
|
| 25 |
-
seed = st.number_input("
|
| 26 |
-
|
| 27 |
-
st.subheader("Simulate content type")
|
| 28 |
-
content_type = st.selectbox(
|
| 29 |
-
"Content profile",
|
| 30 |
-
["Balanced", "Visual-heavy", "Audio-heavy", "Language-heavy", "Low engagement"],
|
| 31 |
-
)
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
predictions = generate_brain_predictions(n_timepoints, n_vertices, seed)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
if
|
| 40 |
-
|
| 41 |
-
if roi in roi_indices:
|
| 42 |
-
predictions[:, roi_indices[roi]] *= 3.0
|
| 43 |
-
elif content_type == "Audio-heavy":
|
| 44 |
-
for roi in COGNITIVE_DIMENSIONS["Auditory Demand"]:
|
| 45 |
-
if roi in roi_indices:
|
| 46 |
-
predictions[:, roi_indices[roi]] *= 3.0
|
| 47 |
-
elif content_type == "Language-heavy":
|
| 48 |
-
for roi in COGNITIVE_DIMENSIONS["Language Processing"]:
|
| 49 |
-
if roi in roi_indices:
|
| 50 |
-
predictions[:, roi_indices[roi]] *= 3.0
|
| 51 |
-
elif content_type == "Low engagement":
|
| 52 |
-
predictions *= 0.2
|
| 53 |
|
| 54 |
-
# ---
|
|
|
|
|
|
|
| 55 |
averages, timeline = score_cognitive_load(predictions, roi_indices, tr_seconds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
# ---
|
| 58 |
-
col1, col2, col3, col4, col5 = st.columns(5)
|
| 59 |
dims = ["Overall", "Visual Complexity", "Auditory Demand", "Language Processing", "Executive Load"]
|
| 60 |
-
cols =
|
| 61 |
for col, dim in zip(cols, dims):
|
| 62 |
val = averages.get(dim, 0.0)
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
st.divider()
|
| 66 |
|
| 67 |
-
# --- Timeline ---
|
| 68 |
st.subheader("Cognitive Load Timeline")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
timeline_df = pd.DataFrame(timeline)
|
|
|
|
| 70 |
|
| 71 |
fig = go.Figure()
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
fig.update_layout(
|
| 83 |
-
xaxis_title="Time (seconds)",
|
| 84 |
-
|
| 85 |
-
yaxis_range=[0, 1.05],
|
| 86 |
-
height=400,
|
| 87 |
-
template="plotly_dark",
|
| 88 |
legend=dict(orientation="h", yanchor="bottom", y=1.02),
|
| 89 |
)
|
| 90 |
st.plotly_chart(fig, use_container_width=True)
|
| 91 |
|
| 92 |
-
# --- Dimension
|
| 93 |
st.divider()
|
| 94 |
col1, col2 = st.columns(2)
|
| 95 |
|
| 96 |
with col1:
|
| 97 |
-
st.subheader("Dimension
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
))
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
xaxis_range=[0, 1],
|
| 108 |
-
height=300,
|
| 109 |
-
template="plotly_dark",
|
| 110 |
-
)
|
| 111 |
-
st.plotly_chart(fig2, use_container_width=True)
|
| 112 |
|
| 113 |
with col2:
|
| 114 |
-
st.subheader("
|
|
|
|
| 115 |
categories = list(dim_data.keys())
|
| 116 |
-
values = list(dim_data.values()) + [list(dim_data.values())[0]]
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
theta=categories + [categories[0]],
|
| 121 |
-
fill="toself",
|
| 122 |
-
|
| 123 |
-
line=dict(color="#6C5CE7"),
|
| 124 |
))
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
polar=dict(radialaxis=dict(visible=True, range=[0, 1])),
|
| 127 |
-
height=350,
|
| 128 |
-
template="plotly_dark",
|
| 129 |
-
showlegend=False,
|
| 130 |
)
|
| 131 |
-
st.plotly_chart(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cognitive Load Scorer - Research Grade."""
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import numpy as np
|
|
|
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import plotly.express as px
|
| 8 |
|
| 9 |
+
from session import init_session, log_analysis, download_csv_button, show_analysis_log
|
| 10 |
+
from utils import make_roi_indices, score_cognitive_load, COGNITIVE_DIMENSIONS, ROI_GROUPS
|
| 11 |
+
from synthetic import generate_realistic_predictions
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
st.set_page_config(page_title="Cognitive Load", page_icon="📊", layout="wide")
|
| 14 |
+
init_session()
|
| 15 |
+
show_analysis_log()
|
| 16 |
+
|
| 17 |
st.title("📊 Cognitive Load Scorer")
|
| 18 |
+
st.markdown("Predict cognitive demand from brain activation patterns across four neurocognitive dimensions.")
|
| 19 |
|
| 20 |
# --- Sidebar ---
|
| 21 |
with st.sidebar:
|
| 22 |
st.header("Configuration")
|
| 23 |
+
n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
|
| 24 |
+
tr_seconds = st.slider("TR (seconds)", 0.5, 2.0, 1.0, 0.1)
|
| 25 |
+
seed = st.number_input("Seed", value=42, min_value=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
st.subheader("Stimulus")
|
| 28 |
+
stim_type = st.selectbox("Primary stimulus", ["visual", "auditory", "language", "multimodal"])
|
|
|
|
| 29 |
|
| 30 |
+
st.subheader("Comparison Mode")
|
| 31 |
+
compare = st.checkbox("Compare two stimulus types", value=False)
|
| 32 |
+
if compare:
|
| 33 |
+
stim_type_2 = st.selectbox("Second stimulus", [s for s in ["visual", "auditory", "language", "multimodal"] if s != stim_type])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
+
# --- Generate Data ---
|
| 36 |
+
roi_indices, n_vertices = make_roi_indices()
|
| 37 |
+
predictions = generate_realistic_predictions(n_timepoints, roi_indices, stim_type, tr_seconds, seed=seed)
|
| 38 |
averages, timeline = score_cognitive_load(predictions, roi_indices, tr_seconds)
|
| 39 |
+
log_analysis(f"Cognitive load: {stim_type}, {n_timepoints} TRs")
|
| 40 |
+
|
| 41 |
+
if compare:
|
| 42 |
+
predictions_2 = generate_realistic_predictions(n_timepoints, roi_indices, stim_type_2, tr_seconds, seed=seed + 100)
|
| 43 |
+
averages_2, timeline_2 = score_cognitive_load(predictions_2, roi_indices, tr_seconds)
|
| 44 |
|
| 45 |
+
# --- Metric Cards ---
|
|
|
|
| 46 |
dims = ["Overall", "Visual Complexity", "Auditory Demand", "Language Processing", "Executive Load"]
|
| 47 |
+
cols = st.columns(5)
|
| 48 |
for col, dim in zip(cols, dims):
|
| 49 |
val = averages.get(dim, 0.0)
|
| 50 |
+
if compare:
|
| 51 |
+
val_2 = averages_2.get(dim, 0.0)
|
| 52 |
+
delta = val - val_2
|
| 53 |
+
col.metric(dim, f"{val:.2f}", delta=f"{delta:+.2f} vs {stim_type_2}", delta_color="normal")
|
| 54 |
+
else:
|
| 55 |
+
col.metric(dim, f"{val:.2f}")
|
| 56 |
|
| 57 |
st.divider()
|
| 58 |
|
| 59 |
+
# --- Timeline with Confidence Bands ---
|
| 60 |
st.subheader("Cognitive Load Timeline")
|
| 61 |
+
|
| 62 |
+
if compare:
|
| 63 |
+
st.markdown(f"**Solid lines**: {stim_type} | **Dashed lines**: {stim_type_2}")
|
| 64 |
+
|
| 65 |
timeline_df = pd.DataFrame(timeline)
|
| 66 |
+
dim_colors = {"Visual Complexity": "#00D2FF", "Auditory Demand": "#FF6B6B", "Language Processing": "#A29BFE", "Executive Load": "#FFEAA7"}
|
| 67 |
|
| 68 |
fig = go.Figure()
|
| 69 |
+
for dim, color in dim_colors.items():
|
| 70 |
+
y = timeline_df[dim].values
|
| 71 |
+
|
| 72 |
+
# Bootstrap confidence band (resample vertices within dimension ROIs)
|
| 73 |
+
rng = np.random.default_rng(seed)
|
| 74 |
+
dim_rois = COGNITIVE_DIMENSIONS.get(dim, [])
|
| 75 |
+
dim_vertices = []
|
| 76 |
+
for roi in dim_rois:
|
| 77 |
+
if roi in roi_indices:
|
| 78 |
+
valid = roi_indices[roi]
|
| 79 |
+
dim_vertices.extend(valid[valid < predictions.shape[1]])
|
| 80 |
+
|
| 81 |
+
if dim_vertices:
|
| 82 |
+
boot_scores = []
|
| 83 |
+
for _ in range(50):
|
| 84 |
+
sample_verts = rng.choice(dim_vertices, size=max(1, len(dim_vertices) // 2), replace=True)
|
| 85 |
+
boot_tc = np.abs(predictions[:, sample_verts]).mean(axis=1)
|
| 86 |
+
baseline = max(np.median(np.abs(predictions)), 1e-8)
|
| 87 |
+
boot_scores.append(np.clip(boot_tc / baseline, 0, 1))
|
| 88 |
+
boot_arr = np.array(boot_scores)
|
| 89 |
+
ci_lo = np.percentile(boot_arr, 2.5, axis=0)
|
| 90 |
+
ci_hi = np.percentile(boot_arr, 97.5, axis=0)
|
| 91 |
+
t_axis = timeline_df["time"].values
|
| 92 |
+
|
| 93 |
+
fig.add_trace(go.Scatter(x=t_axis, y=ci_hi, mode="lines", line=dict(width=0), showlegend=False))
|
| 94 |
+
fig.add_trace(go.Scatter(x=t_axis, y=ci_lo, mode="lines", line=dict(width=0),
|
| 95 |
+
fill="tonexty", fillcolor=color.replace(")", ",0.15)").replace("rgb", "rgba").replace("#", "rgba(") if color.startswith("#") else color,
|
| 96 |
+
showlegend=False))
|
| 97 |
+
|
| 98 |
+
fig.add_trace(go.Scatter(x=timeline_df["time"], y=y, name=dim, line=dict(color=color, width=2)))
|
| 99 |
+
|
| 100 |
+
if compare:
|
| 101 |
+
timeline_df_2 = pd.DataFrame(timeline_2)
|
| 102 |
+
fig.add_trace(go.Scatter(x=timeline_df_2["time"], y=timeline_df_2[dim].values,
|
| 103 |
+
name=f"{dim} ({stim_type_2})", line=dict(color=color, width=1.5, dash="dash")))
|
| 104 |
|
| 105 |
fig.update_layout(
|
| 106 |
+
xaxis_title="Time (seconds)", yaxis_title="Cognitive Load (normalized)",
|
| 107 |
+
yaxis_range=[0, 1.05], height=450, template="plotly_dark",
|
|
|
|
|
|
|
|
|
|
| 108 |
legend=dict(orientation="h", yanchor="bottom", y=1.02),
|
| 109 |
)
|
| 110 |
st.plotly_chart(fig, use_container_width=True)
|
| 111 |
|
| 112 |
+
# --- Dimension Correlation + Radar ---
|
| 113 |
st.divider()
|
| 114 |
col1, col2 = st.columns(2)
|
| 115 |
|
| 116 |
with col1:
|
| 117 |
+
st.subheader("Dimension Correlation")
|
| 118 |
+
st.markdown("How do the four cognitive dimensions co-vary over time?")
|
| 119 |
+
dim_timeseries = {}
|
| 120 |
+
for dim in dim_colors:
|
| 121 |
+
dim_timeseries[dim] = timeline_df[dim].values
|
| 122 |
+
dim_arr = np.array(list(dim_timeseries.values()))
|
| 123 |
+
corr = np.corrcoef(dim_arr)
|
| 124 |
+
dim_names = list(dim_colors.keys())
|
| 125 |
+
|
| 126 |
+
fig_corr = go.Figure(go.Heatmap(
|
| 127 |
+
z=corr, x=dim_names, y=dim_names,
|
| 128 |
+
colorscale="RdBu_r", zmid=0, zmin=-1, zmax=1,
|
| 129 |
+
colorbar=dict(title="r"),
|
| 130 |
+
text=np.round(corr, 2), texttemplate="%{text}",
|
| 131 |
))
|
| 132 |
+
fig_corr.update_layout(height=350, template="plotly_dark", xaxis_tickangle=30)
|
| 133 |
+
st.plotly_chart(fig_corr, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
with col2:
|
| 136 |
+
st.subheader("Dimension Profile")
|
| 137 |
+
dim_data = {k: v for k, v in averages.items() if k != "Overall"}
|
| 138 |
categories = list(dim_data.keys())
|
| 139 |
+
values = list(dim_data.values()) + [list(dim_data.values())[0]]
|
| 140 |
+
|
| 141 |
+
fig_radar = go.Figure()
|
| 142 |
+
fig_radar.add_trace(go.Scatterpolar(
|
| 143 |
+
r=values, theta=categories + [categories[0]],
|
| 144 |
+
fill="toself", fillcolor="rgba(108, 92, 231, 0.3)",
|
| 145 |
+
line=dict(color="#6C5CE7"), name=stim_type,
|
|
|
|
| 146 |
))
|
| 147 |
+
if compare:
|
| 148 |
+
dim_data_2 = {k: v for k, v in averages_2.items() if k != "Overall"}
|
| 149 |
+
values_2 = list(dim_data_2.values()) + [list(dim_data_2.values())[0]]
|
| 150 |
+
fig_radar.add_trace(go.Scatterpolar(
|
| 151 |
+
r=values_2, theta=categories + [categories[0]],
|
| 152 |
+
fill="toself", fillcolor="rgba(255,107,107,0.2)",
|
| 153 |
+
line=dict(color="#FF6B6B", dash="dash"), name=stim_type_2,
|
| 154 |
+
))
|
| 155 |
+
|
| 156 |
+
fig_radar.update_layout(
|
| 157 |
polar=dict(radialaxis=dict(visible=True, range=[0, 1])),
|
| 158 |
+
height=350, template="plotly_dark",
|
|
|
|
|
|
|
| 159 |
)
|
| 160 |
+
st.plotly_chart(fig_radar, use_container_width=True)
|
| 161 |
+
|
| 162 |
+
# --- Per-ROI Activation Breakdown ---
|
| 163 |
+
st.divider()
|
| 164 |
+
st.subheader("Per-ROI Activation Within Each Dimension")
|
| 165 |
+
selected_dim = st.selectbox("Dimension", list(dim_colors.keys()))
|
| 166 |
+
dim_rois = COGNITIVE_DIMENSIONS.get(selected_dim, [])
|
| 167 |
+
|
| 168 |
+
roi_activations = []
|
| 169 |
+
for roi in dim_rois:
|
| 170 |
+
if roi in roi_indices:
|
| 171 |
+
verts = roi_indices[roi]
|
| 172 |
+
valid = verts[verts < predictions.shape[1]]
|
| 173 |
+
if len(valid) > 0:
|
| 174 |
+
act = float(np.abs(predictions[:, valid]).mean())
|
| 175 |
+
roi_activations.append({"ROI": roi, "Mean Activation": act})
|
| 176 |
+
|
| 177 |
+
if roi_activations:
|
| 178 |
+
roi_act_df = pd.DataFrame(roi_activations).sort_values("Mean Activation", ascending=False)
|
| 179 |
+
fig_roi = go.Figure(go.Bar(
|
| 180 |
+
x=roi_act_df["Mean Activation"], y=roi_act_df["ROI"],
|
| 181 |
+
orientation="h", marker_color=dim_colors[selected_dim],
|
| 182 |
+
))
|
| 183 |
+
fig_roi.update_layout(height=max(250, len(roi_activations) * 25), template="plotly_dark",
|
| 184 |
+
yaxis=dict(autorange="reversed"), xaxis_title="Mean |activation|")
|
| 185 |
+
st.plotly_chart(fig_roi, use_container_width=True)
|
| 186 |
+
download_csv_button(roi_act_df, f"cognitive_load_{selected_dim}_rois.csv")
|
| 187 |
+
|
| 188 |
+
# --- Methodology ---
|
| 189 |
+
with st.expander("Methodology", expanded=False):
|
| 190 |
+
st.markdown("""
|
| 191 |
+
**Cognitive Load Scoring** maps predicted fMRI activations onto four neurocognitive dimensions
|
| 192 |
+
using HCP MMP1.0 ROI groupings:
|
| 193 |
+
|
| 194 |
+
- **Visual Complexity**: V1-V4, MT, MST, FFC, VVC (ventral & dorsal visual streams)
|
| 195 |
+
- **Auditory Demand**: A1, belt areas, STS (auditory cortex + association areas)
|
| 196 |
+
- **Language Processing**: Areas 44/45 (Broca's), TPOJ (Wernicke's), STV, PSL (perisylvian language network)
|
| 197 |
+
- **Executive Load**: dlPFC (area 46), ACC, FEF (frontoparietal control network)
|
| 198 |
+
|
| 199 |
+
Each dimension score is the mean absolute activation across its ROIs, normalized by the
|
| 200 |
+
median activation across all vertices (baseline). Scores are clipped to [0, 1].
|
| 201 |
+
|
| 202 |
+
**Confidence bands** on the timeline are computed via bootstrap resampling of vertices
|
| 203 |
+
within each dimension's ROI group (50 resamples, 95% CI).
|
| 204 |
+
|
| 205 |
+
**Limitations**: The ROI-to-dimension mapping is based on established functional neuroanatomy
|
| 206 |
+
but is not exhaustive. Cognitive load is a multidimensional construct that cannot be fully
|
| 207 |
+
captured by fMRI activation alone. These scores should be interpreted as relative measures,
|
| 208 |
+
not absolute cognitive load values.
|
| 209 |
+
|
| 210 |
+
**References**:
|
| 211 |
+
- Glasser et al., 2016, *Nature* (HCP MMP1.0 parcellation)
|
| 212 |
+
- Sweller, 1988, *Cognitive Science* (Cognitive Load Theory)
|
| 213 |
+
""")
|
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""Temporal Dynamics
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import numpy as np
|
|
@@ -6,32 +6,39 @@ import pandas as pd
|
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
from plotly.subplots import make_subplots
|
| 8 |
|
| 9 |
-
from
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
generate_model_features,
|
| 13 |
-
peak_latency,
|
| 14 |
-
temporal_correlation,
|
| 15 |
-
decompose_response,
|
| 16 |
-
ROI_GROUPS,
|
| 17 |
-
ALL_ROIS,
|
| 18 |
-
)
|
| 19 |
|
| 20 |
st.set_page_config(page_title="Temporal Dynamics", page_icon="⏱️", layout="wide")
|
|
|
|
|
|
|
|
|
|
| 21 |
st.title("⏱️ Temporal Dynamics")
|
| 22 |
-
st.markdown("Analyze how brain responses evolve over time.")
|
| 23 |
|
| 24 |
# --- Sidebar ---
|
| 25 |
with st.sidebar:
|
| 26 |
st.header("Configuration")
|
|
|
|
|
|
|
| 27 |
n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
|
| 28 |
-
tr_seconds = st.slider("TR
|
| 29 |
-
seed = st.number_input("
|
| 30 |
|
| 31 |
st.subheader("ROI Selection")
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
max_lag = st.slider("Max correlation lag (TRs)", 5, 30, 15)
|
| 37 |
cutoff = st.slider("Decomposition cutoff (seconds)", 1.0, 10.0, 4.0, 0.5)
|
|
@@ -42,78 +49,209 @@ if not selected_rois:
|
|
| 42 |
|
| 43 |
# --- Generate Data ---
|
| 44 |
roi_indices, n_vertices = make_roi_indices()
|
| 45 |
-
predictions =
|
| 46 |
-
features =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
# --- Peak Latency ---
|
| 49 |
-
st.subheader("Peak Response Latency")
|
| 50 |
latency_data = []
|
| 51 |
for roi in selected_rois:
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
lat_df = pd.DataFrame(latency_data)
|
| 56 |
-
|
| 57 |
|
|
|
|
| 58 |
with col1:
|
| 59 |
-
|
| 60 |
-
x=lat_df["ROI"],
|
| 61 |
-
|
| 62 |
-
marker_color="#
|
| 63 |
))
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
template="plotly_dark",
|
| 68 |
)
|
| 69 |
-
st.plotly_chart(
|
| 70 |
|
| 71 |
with col2:
|
| 72 |
-
st.dataframe(lat_df, use_container_width=True, hide_index=True)
|
|
|
|
| 73 |
|
| 74 |
-
# --- Lag Correlation ---
|
| 75 |
st.divider()
|
| 76 |
st.subheader("Temporal Correlation (Brain vs Model Features)")
|
|
|
|
|
|
|
| 77 |
|
| 78 |
lags = np.arange(-max_lag, max_lag + 1) * tr_seconds
|
| 79 |
-
|
| 80 |
-
colors = ["#00D2FF", "#FF6B6B", "#A29BFE", "#FFEAA7", "#55EFC4", "#FD79A8"]
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
for i, roi in enumerate(selected_rois):
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
xaxis_title="Lag (seconds)",
|
| 94 |
-
|
| 95 |
-
height=400,
|
| 96 |
-
template="plotly_dark",
|
| 97 |
legend=dict(orientation="h", yanchor="bottom", y=1.02),
|
| 98 |
)
|
| 99 |
-
st.plotly_chart(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
# --- Sustained vs Transient ---
|
| 102 |
st.divider()
|
| 103 |
st.subheader("Sustained vs Transient Decomposition")
|
|
|
|
| 104 |
|
| 105 |
roi_for_decomp = st.selectbox("ROI for decomposition", selected_rois)
|
| 106 |
sustained, transient = decompose_response(predictions, roi_indices, roi_for_decomp, cutoff, tr_seconds)
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
|
|
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
fig3.add_trace(go.Scatter(x=time_axis, y=transient, name="Transient",
|
| 115 |
-
line=dict(color="#FF6B6B", width=1.5)), row=2, col=1)
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
| 1 |
+
"""Temporal Dynamics - Research Grade."""
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import numpy as np
|
|
|
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
from plotly.subplots import make_subplots
|
| 8 |
|
| 9 |
+
from session import init_session, log_analysis, get_carried_rois, download_csv_button, show_analysis_log
|
| 10 |
+
from utils import make_roi_indices, peak_latency, temporal_correlation, decompose_response, ROI_GROUPS, ALL_ROIS
|
| 11 |
+
from synthetic import generate_realistic_predictions, generate_correlated_features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
st.set_page_config(page_title="Temporal Dynamics", page_icon="⏱️", layout="wide")
|
| 14 |
+
init_session()
|
| 15 |
+
show_analysis_log()
|
| 16 |
+
|
| 17 |
st.title("⏱️ Temporal Dynamics")
|
| 18 |
+
st.markdown("Analyze how brain responses evolve over time, including processing hierarchy and temporal coupling with model features.")
|
| 19 |
|
| 20 |
# --- Sidebar ---
|
| 21 |
with st.sidebar:
|
| 22 |
st.header("Configuration")
|
| 23 |
+
stim_type = st.selectbox("Stimulus type", ["visual", "auditory", "language", "multimodal"],
|
| 24 |
+
index=["visual", "auditory", "language", "multimodal"].index(st.session_state.get("stimulus_type", "visual")))
|
| 25 |
n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
|
| 26 |
+
tr_seconds = st.slider("TR (seconds)", 0.5, 2.0, 1.0, 0.1)
|
| 27 |
+
seed = st.number_input("Seed", value=42, min_value=0)
|
| 28 |
|
| 29 |
st.subheader("ROI Selection")
|
| 30 |
+
carried = get_carried_rois()
|
| 31 |
+
use_carried = False
|
| 32 |
+
if carried:
|
| 33 |
+
use_carried = st.checkbox(f"Use {len(carried)} ROIs from Brain Alignment", value=True)
|
| 34 |
+
|
| 35 |
+
if use_carried and carried:
|
| 36 |
+
selected_rois = carried
|
| 37 |
+
st.caption(f"Using: {', '.join(selected_rois[:5])}{'...' if len(selected_rois) > 5 else ''}")
|
| 38 |
+
else:
|
| 39 |
+
selected_group = st.selectbox("Region group", list(ROI_GROUPS.keys()))
|
| 40 |
+
available_rois = ROI_GROUPS[selected_group]
|
| 41 |
+
selected_rois = st.multiselect("ROIs to analyze", available_rois, default=available_rois[:4])
|
| 42 |
|
| 43 |
max_lag = st.slider("Max correlation lag (TRs)", 5, 30, 15)
|
| 44 |
cutoff = st.slider("Decomposition cutoff (seconds)", 1.0, 10.0, 4.0, 0.5)
|
|
|
|
| 49 |
|
| 50 |
# --- Generate Data ---
|
| 51 |
roi_indices, n_vertices = make_roi_indices()
|
| 52 |
+
predictions = generate_realistic_predictions(n_timepoints, roi_indices, stim_type, tr_seconds, seed=seed)
|
| 53 |
+
features = generate_correlated_features(predictions, alignment_strength=0.5, feature_dim=64, seed=seed + 1)
|
| 54 |
+
log_analysis(f"Temporal dynamics: {stim_type}, {len(selected_rois)} ROIs")
|
| 55 |
+
|
| 56 |
+
time_axis = np.arange(n_timepoints) * tr_seconds
|
| 57 |
+
colors = ["#00D2FF", "#FF6B6B", "#A29BFE", "#FFEAA7", "#55EFC4", "#FD79A8", "#74B9FF", "#E17055"]
|
| 58 |
+
|
| 59 |
+
# --- Raw ROI Timecourses ---
|
| 60 |
+
st.subheader("Raw ROI Timecourses")
|
| 61 |
+
st.markdown("Mean absolute activation over time for each selected ROI. Note the hemodynamic response shape after stimulus events.")
|
| 62 |
+
|
| 63 |
+
fig_raw = go.Figure()
|
| 64 |
+
for i, roi in enumerate(selected_rois):
|
| 65 |
+
if roi in roi_indices:
|
| 66 |
+
verts = roi_indices[roi]
|
| 67 |
+
valid = verts[verts < predictions.shape[1]]
|
| 68 |
+
if len(valid) > 0:
|
| 69 |
+
tc = np.abs(predictions[:, valid]).mean(axis=1)
|
| 70 |
+
fig_raw.add_trace(go.Scatter(
|
| 71 |
+
x=time_axis, y=tc, name=roi,
|
| 72 |
+
line=dict(color=colors[i % len(colors)], width=2),
|
| 73 |
+
))
|
| 74 |
+
|
| 75 |
+
fig_raw.update_layout(
|
| 76 |
+
xaxis_title="Time (seconds)", yaxis_title="Mean |activation|",
|
| 77 |
+
height=400, template="plotly_dark",
|
| 78 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02),
|
| 79 |
+
)
|
| 80 |
+
st.plotly_chart(fig_raw, use_container_width=True)
|
| 81 |
+
|
| 82 |
+
# --- Peak Latency (sorted = processing hierarchy) ---
|
| 83 |
+
st.divider()
|
| 84 |
+
st.subheader("Peak Response Latency (Processing Hierarchy)")
|
| 85 |
+
st.markdown("ROIs sorted by peak latency reveal the cortical processing hierarchy: early sensory areas respond first, association cortex later.")
|
| 86 |
|
|
|
|
|
|
|
| 87 |
latency_data = []
|
| 88 |
for roi in selected_rois:
|
| 89 |
+
if roi in roi_indices:
|
| 90 |
+
lat = peak_latency(predictions, roi_indices, roi, tr_seconds)
|
| 91 |
+
# Determine functional group
|
| 92 |
+
group = "Other"
|
| 93 |
+
for g, rois in ROI_GROUPS.items():
|
| 94 |
+
if roi in rois:
|
| 95 |
+
group = g
|
| 96 |
+
break
|
| 97 |
+
latency_data.append({"ROI": roi, "Peak Latency (s)": lat, "Group": group})
|
| 98 |
|
| 99 |
+
lat_df = pd.DataFrame(latency_data).sort_values("Peak Latency (s)")
|
| 100 |
+
group_colors = {"Visual": "#00D2FF", "Auditory": "#FF6B6B", "Language": "#A29BFE", "Executive": "#FFEAA7", "Other": "#888"}
|
| 101 |
|
| 102 |
+
col1, col2 = st.columns([2, 1])
|
| 103 |
with col1:
|
| 104 |
+
fig_lat = go.Figure(go.Bar(
|
| 105 |
+
x=lat_df["Peak Latency (s)"], y=lat_df["ROI"],
|
| 106 |
+
orientation="h",
|
| 107 |
+
marker_color=[group_colors.get(g, "#888") for g in lat_df["Group"]],
|
| 108 |
))
|
| 109 |
+
fig_lat.update_layout(
|
| 110 |
+
xaxis_title="Time to peak (seconds)", height=max(250, len(selected_rois) * 30),
|
| 111 |
+
template="plotly_dark", yaxis=dict(autorange="reversed"),
|
|
|
|
| 112 |
)
|
| 113 |
+
st.plotly_chart(fig_lat, use_container_width=True)
|
| 114 |
|
| 115 |
with col2:
|
| 116 |
+
st.dataframe(lat_df[["ROI", "Peak Latency (s)", "Group"]], use_container_width=True, hide_index=True)
|
| 117 |
+
download_csv_button(lat_df, "peak_latencies.csv")
|
| 118 |
|
| 119 |
+
# --- Lag Correlation with Significance ---
|
| 120 |
st.divider()
|
| 121 |
st.subheader("Temporal Correlation (Brain vs Model Features)")
|
| 122 |
+
st.markdown("Pearson correlation at different time lags. The peak indicates optimal temporal alignment. "
|
| 123 |
+
"Gray band shows 95% null range from shuffled data.")
|
| 124 |
|
| 125 |
lags = np.arange(-max_lag, max_lag + 1) * tr_seconds
|
| 126 |
+
fig_corr = go.Figure()
|
|
|
|
| 127 |
|
| 128 |
+
# Null band (shuffle features, compute correlation envelope)
|
| 129 |
+
rng = np.random.default_rng(seed)
|
| 130 |
+
null_corrs = []
|
| 131 |
+
for _ in range(50):
|
| 132 |
+
shuffled = features[rng.permutation(len(features))]
|
| 133 |
+
for roi in selected_rois[:1]: # Use first ROI for null band
|
| 134 |
+
if roi in roi_indices:
|
| 135 |
+
nc = temporal_correlation(predictions, shuffled, roi_indices, roi, max_lag)
|
| 136 |
+
null_corrs.append(nc)
|
| 137 |
+
if null_corrs:
|
| 138 |
+
null_arr = np.array(null_corrs)
|
| 139 |
+
null_hi = np.percentile(null_arr, 97.5, axis=0)
|
| 140 |
+
null_lo = np.percentile(null_arr, 2.5, axis=0)
|
| 141 |
+
fig_corr.add_trace(go.Scatter(x=lags, y=null_hi, mode="lines", line=dict(width=0), showlegend=False))
|
| 142 |
+
fig_corr.add_trace(go.Scatter(x=lags, y=null_lo, mode="lines", line=dict(width=0),
|
| 143 |
+
fill="tonexty", fillcolor="rgba(150,150,150,0.2)",
|
| 144 |
+
name="95% null range"))
|
| 145 |
+
|
| 146 |
+
# Actual correlations
|
| 147 |
+
optimal_lags = []
|
| 148 |
for i, roi in enumerate(selected_rois):
|
| 149 |
+
if roi in roi_indices:
|
| 150 |
+
corr = temporal_correlation(predictions, features, roi_indices, roi, max_lag)
|
| 151 |
+
fig_corr.add_trace(go.Scatter(
|
| 152 |
+
x=lags, y=corr, name=roi,
|
| 153 |
+
line=dict(color=colors[i % len(colors)], width=2),
|
| 154 |
+
))
|
| 155 |
+
opt_idx = np.argmax(np.abs(corr))
|
| 156 |
+
optimal_lags.append({"ROI": roi, "Optimal Lag (s)": lags[opt_idx], "Max |r|": float(np.abs(corr[opt_idx]))})
|
| 157 |
|
| 158 |
+
fig_corr.add_vline(x=0, line_dash="dash", line_color="gray", opacity=0.5)
|
| 159 |
+
fig_corr.update_layout(
|
| 160 |
+
xaxis_title="Lag (seconds)", yaxis_title="Pearson Correlation",
|
| 161 |
+
height=400, template="plotly_dark",
|
|
|
|
|
|
|
| 162 |
legend=dict(orientation="h", yanchor="bottom", y=1.02),
|
| 163 |
)
|
| 164 |
+
st.plotly_chart(fig_corr, use_container_width=True)
|
| 165 |
+
|
| 166 |
+
# --- Optimal Lag Summary ---
|
| 167 |
+
if optimal_lags:
|
| 168 |
+
st.subheader("Optimal Lag Summary")
|
| 169 |
+
opt_df = pd.DataFrame(optimal_lags).sort_values("Max |r|", ascending=False)
|
| 170 |
+
st.dataframe(opt_df, use_container_width=True, hide_index=True)
|
| 171 |
+
download_csv_button(opt_df, "optimal_lags.csv")
|
| 172 |
+
|
| 173 |
+
# --- Cross-ROI Lag Matrix ---
|
| 174 |
+
if len(selected_rois) >= 2:
|
| 175 |
+
st.divider()
|
| 176 |
+
st.subheader("Cross-ROI Lag Matrix")
|
| 177 |
+
st.markdown("Optimal lag between each pair of ROIs. Positive values mean the row ROI leads the column ROI.")
|
| 178 |
+
|
| 179 |
+
n_rois = len(selected_rois)
|
| 180 |
+
lag_matrix = np.zeros((n_rois, n_rois))
|
| 181 |
+
for i, roi_a in enumerate(selected_rois):
|
| 182 |
+
if roi_a not in roi_indices:
|
| 183 |
+
continue
|
| 184 |
+
verts_a = roi_indices[roi_a]
|
| 185 |
+
valid_a = verts_a[verts_a < predictions.shape[1]]
|
| 186 |
+
if len(valid_a) == 0:
|
| 187 |
+
continue
|
| 188 |
+
tc_a = np.abs(predictions[:, valid_a]).mean(axis=1)
|
| 189 |
+
for j, roi_b in enumerate(selected_rois):
|
| 190 |
+
if i == j or roi_b not in roi_indices:
|
| 191 |
+
continue
|
| 192 |
+
verts_b = roi_indices[roi_b]
|
| 193 |
+
valid_b = verts_b[verts_b < predictions.shape[1]]
|
| 194 |
+
if len(valid_b) == 0:
|
| 195 |
+
continue
|
| 196 |
+
tc_b = np.abs(predictions[:, valid_b]).mean(axis=1)
|
| 197 |
+
# Cross-correlation to find optimal lag
|
| 198 |
+
corrs_ab = temporal_correlation(predictions, tc_b, roi_indices, roi_a, max_lag)
|
| 199 |
+
opt_idx = np.argmax(np.abs(corrs_ab))
|
| 200 |
+
lag_matrix[i, j] = lags[opt_idx]
|
| 201 |
+
|
| 202 |
+
fig_lagmat = go.Figure(go.Heatmap(
|
| 203 |
+
z=lag_matrix, x=selected_rois, y=selected_rois,
|
| 204 |
+
colorscale="RdBu_r", zmid=0,
|
| 205 |
+
colorbar=dict(title="Lag (s)"),
|
| 206 |
+
text=np.round(lag_matrix, 1), texttemplate="%{text}",
|
| 207 |
+
))
|
| 208 |
+
fig_lagmat.update_layout(height=400, template="plotly_dark")
|
| 209 |
+
st.plotly_chart(fig_lagmat, use_container_width=True)
|
| 210 |
|
| 211 |
# --- Sustained vs Transient ---
|
| 212 |
st.divider()
|
| 213 |
st.subheader("Sustained vs Transient Decomposition")
|
| 214 |
+
st.markdown("Moving-average filter separates slow sustained responses from fast transient spikes.")
|
| 215 |
|
| 216 |
roi_for_decomp = st.selectbox("ROI for decomposition", selected_rois)
|
| 217 |
sustained, transient = decompose_response(predictions, roi_indices, roi_for_decomp, cutoff, tr_seconds)
|
| 218 |
+
original = sustained + transient
|
| 219 |
+
|
| 220 |
+
fig_decomp = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.06,
|
| 221 |
+
subplot_titles=("Original Signal", "Sustained Component", "Transient Component"))
|
| 222 |
+
|
| 223 |
+
fig_decomp.add_trace(go.Scatter(x=time_axis, y=original, line=dict(color="#888", width=1.5)), row=1, col=1)
|
| 224 |
+
fig_decomp.add_trace(go.Scatter(x=time_axis, y=sustained, line=dict(color="#6C5CE7", width=2)), row=2, col=1)
|
| 225 |
+
fig_decomp.add_trace(go.Scatter(x=time_axis, y=transient, line=dict(color="#FF6B6B", width=1.5)), row=3, col=1)
|
| 226 |
+
|
| 227 |
+
fig_decomp.update_xaxes(title_text="Time (seconds)", row=3, col=1)
|
| 228 |
+
fig_decomp.update_layout(height=550, template="plotly_dark", showlegend=False)
|
| 229 |
+
st.plotly_chart(fig_decomp, use_container_width=True)
|
| 230 |
+
|
| 231 |
+
# --- Methodology ---
|
| 232 |
+
with st.expander("Methodology", expanded=False):
|
| 233 |
+
st.markdown("""
|
| 234 |
+
**Peak Latency** is the time at which mean absolute activation reaches its maximum
|
| 235 |
+
within an ROI. In real fMRI, early sensory cortex (V1, A1) peaks at ~5-6s post-stimulus
|
| 236 |
+
due to the hemodynamic response, while association cortex (dlPFC, angular gyrus) peaks
|
| 237 |
+
~1-3s later reflecting higher-order processing.
|
| 238 |
+
|
| 239 |
+
**Temporal Correlation** computes Pearson correlation between the ROI timecourse and model
|
| 240 |
+
feature timecourse at each lag in ``[-max_lag, +max_lag]`` TRs. The lag at maximum absolute
|
| 241 |
+
correlation reveals the temporal offset at which model and brain are best aligned.
|
| 242 |
+
|
| 243 |
+
**Null significance band** is estimated by shuffling the model features 50 times and
|
| 244 |
+
computing the lag correlation each time. The 95% envelope of these null correlations
|
| 245 |
+
provides a significance threshold.
|
| 246 |
|
| 247 |
+
**Sustained vs Transient Decomposition** uses a moving-average filter with the specified
|
| 248 |
+
cutoff period. The sustained component captures slow, maintained responses (e.g., block
|
| 249 |
+
design activations), while the transient component captures fast, event-related responses.
|
| 250 |
|
| 251 |
+
**Cross-ROI Lag Matrix** shows the optimal temporal offset between every pair of ROIs,
|
| 252 |
+
revealing directional information flow (positive lag = row ROI leads column ROI).
|
|
|
|
|
|
|
| 253 |
|
| 254 |
+
**References**:
|
| 255 |
+
- Boynton et al., 1996, *J Neuroscience* (hemodynamic response function)
|
| 256 |
+
- Friston et al., 1998, *NeuroImage* (temporal basis functions in fMRI)
|
| 257 |
+
""")
|
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""ROI Connectivity
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import numpy as np
|
|
@@ -6,157 +6,278 @@ import pandas as pd
|
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import plotly.express as px
|
| 8 |
|
|
|
|
| 9 |
from utils import (
|
| 10 |
-
make_roi_indices,
|
| 11 |
-
|
| 12 |
-
compute_connectivity,
|
| 13 |
-
cluster_rois,
|
| 14 |
-
graph_metrics,
|
| 15 |
ROI_GROUPS,
|
| 16 |
)
|
|
|
|
| 17 |
|
| 18 |
st.set_page_config(page_title="ROI Connectivity", page_icon="🔗", layout="wide")
|
|
|
|
|
|
|
|
|
|
| 19 |
st.title("🔗 ROI Connectivity Analysis")
|
| 20 |
-
st.markdown("Functional connectivity between brain regions
|
| 21 |
|
| 22 |
# --- Sidebar ---
|
| 23 |
with st.sidebar:
|
| 24 |
st.header("Configuration")
|
|
|
|
| 25 |
n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
| 27 |
n_clusters = st.slider("Number of clusters", 2, 8, 4)
|
| 28 |
threshold = st.slider("Edge threshold", 0.1, 0.8, 0.3, 0.05)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# --- Generate Data ---
|
| 31 |
roi_indices, n_vertices = make_roi_indices()
|
| 32 |
-
predictions =
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
))
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
xaxis=dict(tickangle=45, tickfont=dict(size=8)),
|
| 53 |
yaxis=dict(tickfont=dict(size=8)),
|
| 54 |
)
|
| 55 |
-
st.plotly_chart(
|
|
|
|
| 56 |
|
| 57 |
-
# ---
|
| 58 |
st.divider()
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
for roi in rois:
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
x="Cluster",
|
| 74 |
-
y="Count",
|
| 75 |
-
color="Cluster",
|
| 76 |
-
color_discrete_sequence=px.colors.qualitative.Set2,
|
| 77 |
-
)
|
| 78 |
-
fig2.update_layout(height=350, template="plotly_dark", showlegend=False)
|
| 79 |
-
st.plotly_chart(fig2, use_container_width=True)
|
| 80 |
-
|
| 81 |
-
for cid, rois in sorted(clusters.items()):
|
| 82 |
-
st.markdown(f"**Network {cid}:** {', '.join(rois)}")
|
| 83 |
-
|
| 84 |
-
with col2:
|
| 85 |
-
st.subheader("Degree Centrality")
|
| 86 |
-
degrees = graph_metrics(corr_matrix, roi_names, threshold)
|
| 87 |
-
|
| 88 |
-
# Sort by degree
|
| 89 |
-
sorted_degrees = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
| 90 |
-
deg_df = pd.DataFrame(sorted_degrees, columns=["ROI", "Degree Centrality"])
|
| 91 |
-
|
| 92 |
-
fig3 = go.Figure(go.Bar(
|
| 93 |
-
x=deg_df["Degree Centrality"],
|
| 94 |
-
y=deg_df["ROI"],
|
| 95 |
-
orientation="h",
|
| 96 |
-
marker_color="#6C5CE7",
|
| 97 |
-
))
|
| 98 |
-
fig3.update_layout(
|
| 99 |
-
xaxis_title="Degree Centrality",
|
| 100 |
-
xaxis_range=[0, 1],
|
| 101 |
-
height=600,
|
| 102 |
-
template="plotly_dark",
|
| 103 |
-
yaxis=dict(autorange="reversed", tickfont=dict(size=9)),
|
| 104 |
-
)
|
| 105 |
-
st.plotly_chart(fig3, use_container_width=True)
|
| 106 |
|
| 107 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
st.divider()
|
| 109 |
-
st.
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
if abs(corr_matrix[i, j]) > threshold:
|
| 121 |
-
G.add_edge(roi_names[i], roi_names[j], weight=abs(corr_matrix[i, j]))
|
| 122 |
-
|
| 123 |
-
pos = nx.spring_layout(G, seed=seed, k=2.0)
|
| 124 |
-
|
| 125 |
-
# Cluster colors
|
| 126 |
-
color_map = px.colors.qualitative.Set2
|
| 127 |
-
node_colors = [color_map[(labels[i] - 1) % len(color_map)] for i in range(len(roi_names))]
|
| 128 |
-
|
| 129 |
-
edge_x, edge_y = [], []
|
| 130 |
-
for u, v in G.edges():
|
| 131 |
-
x0, y0 = pos[u]
|
| 132 |
-
x1, y1 = pos[v]
|
| 133 |
-
edge_x.extend([x0, x1, None])
|
| 134 |
-
edge_y.extend([y0, y1, None])
|
| 135 |
-
|
| 136 |
-
node_x = [pos[n][0] for n in roi_names]
|
| 137 |
-
node_y = [pos[n][1] for n in roi_names]
|
| 138 |
-
|
| 139 |
-
fig4 = go.Figure()
|
| 140 |
-
fig4.add_trace(go.Scatter(
|
| 141 |
-
x=edge_x, y=edge_y, mode="lines",
|
| 142 |
-
line=dict(width=0.5, color="rgba(150,150,150,0.3)"),
|
| 143 |
-
hoverinfo="none",
|
| 144 |
-
))
|
| 145 |
-
fig4.add_trace(go.Scatter(
|
| 146 |
-
x=node_x, y=node_y, mode="markers+text",
|
| 147 |
-
marker=dict(size=12, color=node_colors, line=dict(width=1, color="white")),
|
| 148 |
-
text=roi_names,
|
| 149 |
-
textposition="top center",
|
| 150 |
-
textfont=dict(size=8),
|
| 151 |
-
hoverinfo="text",
|
| 152 |
-
))
|
| 153 |
-
fig4.update_layout(
|
| 154 |
-
height=500,
|
| 155 |
-
template="plotly_dark",
|
| 156 |
-
showlegend=False,
|
| 157 |
-
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 158 |
-
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 159 |
)
|
| 160 |
-
st.plotly_chart(
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ROI Connectivity - Research Grade."""
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
import numpy as np
|
|
|
|
| 6 |
import plotly.graph_objects as go
|
| 7 |
import plotly.express as px
|
| 8 |
|
| 9 |
+
from session import init_session, log_analysis, get_carried_rois, download_csv_button, show_analysis_log
|
| 10 |
from utils import (
|
| 11 |
+
make_roi_indices, compute_connectivity, cluster_rois, graph_metrics,
|
| 12 |
+
partial_correlation, betweenness_centrality, modularity_score,
|
|
|
|
|
|
|
|
|
|
| 13 |
ROI_GROUPS,
|
| 14 |
)
|
| 15 |
+
from synthetic import generate_realistic_predictions
|
| 16 |
|
| 17 |
st.set_page_config(page_title="ROI Connectivity", page_icon="🔗", layout="wide")
|
| 18 |
+
init_session()
|
| 19 |
+
show_analysis_log()
|
| 20 |
+
|
| 21 |
st.title("🔗 ROI Connectivity Analysis")
|
| 22 |
+
st.markdown("Functional connectivity between brain regions: correlation structure, network organization, and graph topology.")
|
| 23 |
|
| 24 |
# --- Sidebar ---
|
| 25 |
with st.sidebar:
|
| 26 |
st.header("Configuration")
|
| 27 |
+
stim_type = st.selectbox("Stimulus type", ["visual", "auditory", "language", "multimodal"])
|
| 28 |
n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
|
| 29 |
+
tr_seconds = st.slider("TR (seconds)", 0.5, 2.0, 1.0, 0.1)
|
| 30 |
+
seed = st.number_input("Seed", value=42, min_value=0)
|
| 31 |
+
|
| 32 |
+
st.subheader("Analysis Parameters")
|
| 33 |
n_clusters = st.slider("Number of clusters", 2, 8, 4)
|
| 34 |
threshold = st.slider("Edge threshold", 0.1, 0.8, 0.3, 0.05)
|
| 35 |
+
use_partial = st.checkbox("Use partial correlation", value=False,
|
| 36 |
+
help="Control for shared mean signal across all ROIs")
|
| 37 |
+
|
| 38 |
+
carried = get_carried_rois()
|
| 39 |
+
use_carried = False
|
| 40 |
+
if carried:
|
| 41 |
+
use_carried = st.checkbox(f"Filter to {len(carried)} carried ROIs", value=False)
|
| 42 |
|
| 43 |
# --- Generate Data ---
|
| 44 |
roi_indices, n_vertices = make_roi_indices()
|
| 45 |
+
predictions = generate_realistic_predictions(n_timepoints, roi_indices, stim_type, tr_seconds, seed=seed)
|
| 46 |
+
log_analysis(f"Connectivity: {stim_type}, partial={use_partial}")
|
| 47 |
+
|
| 48 |
+
# Filter ROIs if carrying from alignment
|
| 49 |
+
active_indices = roi_indices
|
| 50 |
+
if use_carried and carried:
|
| 51 |
+
active_indices = {k: v for k, v in roi_indices.items() if k in carried}
|
| 52 |
+
|
| 53 |
+
# --- Compute Connectivity ---
|
| 54 |
+
if use_partial:
|
| 55 |
+
corr_matrix, roi_names = partial_correlation(predictions, active_indices)
|
| 56 |
+
corr_label = "Partial Correlation"
|
| 57 |
+
else:
|
| 58 |
+
corr_matrix, roi_names = compute_connectivity(predictions, active_indices)
|
| 59 |
+
corr_label = "Pearson Correlation"
|
| 60 |
+
|
| 61 |
+
n_rois = len(roi_names)
|
| 62 |
+
|
| 63 |
+
# --- Correlation Matrix with Cluster Boundaries ---
|
| 64 |
+
st.subheader(f"{corr_label} Matrix")
|
| 65 |
+
|
| 66 |
+
clusters, labels = cluster_rois(corr_matrix, roi_names, n_clusters)
|
| 67 |
+
|
| 68 |
+
# Sort ROIs by cluster for block-diagonal structure
|
| 69 |
+
sorted_idx = np.argsort(labels)
|
| 70 |
+
sorted_corr = corr_matrix[np.ix_(sorted_idx, sorted_idx)]
|
| 71 |
+
sorted_names = [roi_names[i] for i in sorted_idx]
|
| 72 |
+
sorted_labels = labels[sorted_idx]
|
| 73 |
+
|
| 74 |
+
fig_corr = go.Figure(go.Heatmap(
|
| 75 |
+
z=sorted_corr, x=sorted_names, y=sorted_names,
|
| 76 |
+
colorscale="RdBu_r", zmid=0, zmin=-1, zmax=1,
|
| 77 |
+
colorbar=dict(title="r"),
|
| 78 |
))
|
| 79 |
+
|
| 80 |
+
# Add cluster boundary lines
|
| 81 |
+
boundaries = []
|
| 82 |
+
for i in range(1, len(sorted_labels)):
|
| 83 |
+
if sorted_labels[i] != sorted_labels[i - 1]:
|
| 84 |
+
boundaries.append(i - 0.5)
|
| 85 |
+
|
| 86 |
+
for b in boundaries:
|
| 87 |
+
fig_corr.add_shape(type="line", x0=b, x1=b, y0=-0.5, y1=n_rois - 0.5,
|
| 88 |
+
line=dict(color="white", width=1.5, dash="dot"))
|
| 89 |
+
fig_corr.add_shape(type="line", x0=-0.5, x1=n_rois - 0.5, y0=b, y1=b,
|
| 90 |
+
line=dict(color="white", width=1.5, dash="dot"))
|
| 91 |
+
|
| 92 |
+
fig_corr.update_layout(
|
| 93 |
+
height=550, template="plotly_dark",
|
| 94 |
xaxis=dict(tickangle=45, tickfont=dict(size=8)),
|
| 95 |
yaxis=dict(tickfont=dict(size=8)),
|
| 96 |
)
|
| 97 |
+
st.plotly_chart(fig_corr, use_container_width=True)
|
| 98 |
+
st.caption(f"White dotted lines indicate cluster boundaries ({n_clusters} clusters)")
|
| 99 |
|
| 100 |
+
# --- Dendrogram ---
|
| 101 |
st.divider()
|
| 102 |
+
col_dendro, col_clusters = st.columns([1, 1])
|
| 103 |
+
|
| 104 |
+
with col_dendro:
|
| 105 |
+
st.subheader("Hierarchical Clustering Dendrogram")
|
| 106 |
+
from scipy.cluster.hierarchy import linkage, dendrogram
|
| 107 |
+
import matplotlib.pyplot as plt
|
| 108 |
+
import matplotlib
|
| 109 |
+
matplotlib.use("Agg")
|
| 110 |
|
| 111 |
+
dist = 1.0 - np.abs(corr_matrix)
|
| 112 |
+
np.fill_diagonal(dist, 0.0)
|
| 113 |
+
condensed = [dist[i, j] for i in range(n_rois) for j in range(i + 1, n_rois)]
|
| 114 |
+
Z = linkage(condensed, method="average")
|
| 115 |
|
| 116 |
+
fig_dendro, ax = plt.subplots(figsize=(8, 4))
|
| 117 |
+
ax.set_facecolor("#0E1117")
|
| 118 |
+
fig_dendro.patch.set_facecolor("#0E1117")
|
| 119 |
+
dendrogram(Z, labels=roi_names, leaf_rotation=90, leaf_font_size=7, ax=ax,
|
| 120 |
+
color_threshold=Z[-n_clusters + 1, 2] if n_clusters < n_rois else 0)
|
| 121 |
+
ax.tick_params(colors="white")
|
| 122 |
+
ax.set_ylabel("Distance (1 - |r|)", color="white")
|
| 123 |
+
for spine in ax.spines.values():
|
| 124 |
+
spine.set_color("white")
|
| 125 |
+
st.pyplot(fig_dendro)
|
| 126 |
+
plt.close()
|
| 127 |
+
|
| 128 |
+
with col_clusters:
|
| 129 |
+
st.subheader("Network Clusters")
|
| 130 |
+
mod_q = modularity_score(corr_matrix, labels)
|
| 131 |
+
st.metric("Modularity (Q)", f"{mod_q:.3f}",
|
| 132 |
+
help="Newman's modularity. Higher = stronger community structure. Q > 0.3 is typically considered meaningful.")
|
| 133 |
+
|
| 134 |
+
for cid in sorted(clusters.keys()):
|
| 135 |
+
rois = clusters[cid]
|
| 136 |
+
# Identify dominant functional group
|
| 137 |
+
group_counts = {}
|
| 138 |
for roi in rois:
|
| 139 |
+
for g, g_rois in ROI_GROUPS.items():
|
| 140 |
+
if roi in g_rois:
|
| 141 |
+
group_counts[g] = group_counts.get(g, 0) + 1
|
| 142 |
+
dominant = max(group_counts, key=group_counts.get) if group_counts else "Mixed"
|
| 143 |
+
st.markdown(f"**Network {cid}** ({dominant}): {', '.join(rois)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
# --- Centrality Comparison ---
|
| 146 |
+
st.divider()
|
| 147 |
+
st.subheader("Centrality Analysis")
|
| 148 |
+
|
| 149 |
+
col_deg, col_btw = st.columns(2)
|
| 150 |
+
|
| 151 |
+
degrees = graph_metrics(corr_matrix, roi_names, threshold)
|
| 152 |
+
btw = betweenness_centrality(corr_matrix, roi_names, threshold)
|
| 153 |
+
|
| 154 |
+
with col_deg:
|
| 155 |
+
st.markdown("**Degree Centrality** - fraction of ROIs connected to each node")
|
| 156 |
+
deg_df = pd.DataFrame(sorted(degrees.items(), key=lambda x: x[1], reverse=True), columns=["ROI", "Degree"])
|
| 157 |
+
fig_deg = go.Figure(go.Bar(x=deg_df["Degree"], y=deg_df["ROI"], orientation="h", marker_color="#6C5CE7"))
|
| 158 |
+
fig_deg.update_layout(xaxis_range=[0, 1], height=max(300, n_rois * 20), template="plotly_dark",
|
| 159 |
+
yaxis=dict(autorange="reversed", tickfont=dict(size=9)))
|
| 160 |
+
st.plotly_chart(fig_deg, use_container_width=True)
|
| 161 |
+
|
| 162 |
+
with col_btw:
|
| 163 |
+
st.markdown("**Betweenness Centrality** - how often a node lies on shortest paths between others")
|
| 164 |
+
btw_df = pd.DataFrame(sorted(btw.items(), key=lambda x: x[1], reverse=True), columns=["ROI", "Betweenness"])
|
| 165 |
+
fig_btw = go.Figure(go.Bar(x=btw_df["Betweenness"], y=btw_df["ROI"], orientation="h", marker_color="#FF6B6B"))
|
| 166 |
+
fig_btw.update_layout(height=max(300, n_rois * 20), template="plotly_dark",
|
| 167 |
+
yaxis=dict(autorange="reversed", tickfont=dict(size=9)))
|
| 168 |
+
st.plotly_chart(fig_btw, use_container_width=True)
|
| 169 |
+
|
| 170 |
+
# Combined table
|
| 171 |
+
centrality_df = pd.merge(deg_df, btw_df, on="ROI")
|
| 172 |
+
download_csv_button(centrality_df, "centrality_metrics.csv")
|
| 173 |
+
|
| 174 |
+
# --- Edge Weight Distribution ---
|
| 175 |
st.divider()
|
| 176 |
+
col_dist, col_graph = st.columns([1, 2])
|
| 177 |
+
|
| 178 |
+
with col_dist:
|
| 179 |
+
st.subheader("Edge Weight Distribution")
|
| 180 |
+
upper_tri = corr_matrix[np.triu_indices(n_rois, k=1)]
|
| 181 |
+
fig_hist = go.Figure(go.Histogram(x=upper_tri, nbinsx=40, marker_color="rgba(108,92,231,0.7)"))
|
| 182 |
+
fig_hist.add_vline(x=threshold, line_color="red", line_dash="dash", annotation_text="Threshold")
|
| 183 |
+
fig_hist.add_vline(x=-threshold, line_color="red", line_dash="dash")
|
| 184 |
+
fig_hist.update_layout(
|
| 185 |
+
xaxis_title="Correlation", yaxis_title="Count",
|
| 186 |
+
height=350, template="plotly_dark",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
)
|
| 188 |
+
st.plotly_chart(fig_hist, use_container_width=True)
|
| 189 |
+
n_edges = np.sum(np.abs(upper_tri) > threshold)
|
| 190 |
+
max_edges = n_rois * (n_rois - 1) // 2
|
| 191 |
+
st.caption(f"{n_edges}/{max_edges} edges above threshold ({100 * n_edges / max(max_edges, 1):.1f}% density)")
|
| 192 |
+
|
| 193 |
+
# --- Network Graph ---
|
| 194 |
+
with col_graph:
|
| 195 |
+
st.subheader("Network Graph")
|
| 196 |
+
try:
|
| 197 |
+
import networkx as nx
|
| 198 |
+
|
| 199 |
+
G = nx.Graph()
|
| 200 |
+
for name in roi_names:
|
| 201 |
+
G.add_node(name)
|
| 202 |
+
for i in range(n_rois):
|
| 203 |
+
for j in range(i + 1, n_rois):
|
| 204 |
+
w = abs(corr_matrix[i, j])
|
| 205 |
+
if w > threshold:
|
| 206 |
+
G.add_edge(roi_names[i], roi_names[j], weight=w)
|
| 207 |
+
|
| 208 |
+
pos = nx.spring_layout(G, seed=seed, k=2.5)
|
| 209 |
+
color_map = px.colors.qualitative.Set2
|
| 210 |
+
node_colors = [color_map[(labels[i] - 1) % len(color_map)] for i in range(n_rois)]
|
| 211 |
+
|
| 212 |
+
# Edges with width proportional to weight
|
| 213 |
+
for u, v, data in G.edges(data=True):
|
| 214 |
+
x0, y0 = pos[u]
|
| 215 |
+
x1, y1 = pos[v]
|
| 216 |
+
fig_graph = go.Figure() if not hasattr(st, '_graph_fig') else st._graph_fig
|
| 217 |
+
|
| 218 |
+
fig_net = go.Figure()
|
| 219 |
+
for u, v, d in G.edges(data=True):
|
| 220 |
+
x0, y0 = pos[u]
|
| 221 |
+
x1, y1 = pos[v]
|
| 222 |
+
w = d.get("weight", 0.3)
|
| 223 |
+
fig_net.add_trace(go.Scatter(
|
| 224 |
+
x=[x0, x1, None], y=[y0, y1, None], mode="lines",
|
| 225 |
+
line=dict(width=w * 3, color=f"rgba(150,150,150,{min(w, 0.8)})"),
|
| 226 |
+
hoverinfo="none", showlegend=False,
|
| 227 |
+
))
|
| 228 |
+
|
| 229 |
+
# Node sizes by degree
|
| 230 |
+
max_deg = max(degrees.values()) if degrees else 1
|
| 231 |
+
node_sizes = [8 + 20 * degrees.get(name, 0) / max(max_deg, 0.01) for name in roi_names]
|
| 232 |
+
node_x = [pos[n][0] for n in roi_names]
|
| 233 |
+
node_y = [pos[n][1] for n in roi_names]
|
| 234 |
+
|
| 235 |
+
fig_net.add_trace(go.Scatter(
|
| 236 |
+
x=node_x, y=node_y, mode="markers+text",
|
| 237 |
+
marker=dict(size=node_sizes, color=node_colors, line=dict(width=1, color="white")),
|
| 238 |
+
text=roi_names, textposition="top center", textfont=dict(size=7, color="white"),
|
| 239 |
+
hovertext=[f"{name}<br>Degree: {degrees.get(name, 0):.2f}<br>Betweenness: {btw.get(name, 0):.3f}" for name in roi_names],
|
| 240 |
+
hoverinfo="text", showlegend=False,
|
| 241 |
+
))
|
| 242 |
+
|
| 243 |
+
fig_net.update_layout(
|
| 244 |
+
height=450, template="plotly_dark",
|
| 245 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 246 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 247 |
+
)
|
| 248 |
+
st.plotly_chart(fig_net, use_container_width=True)
|
| 249 |
+
except ImportError:
|
| 250 |
+
st.info("Install `networkx` for graph visualization: `pip install networkx`")
|
| 251 |
+
|
| 252 |
+
# --- Methodology ---
|
| 253 |
+
with st.expander("Methodology", expanded=False):
|
| 254 |
+
st.markdown("""
|
| 255 |
+
**Functional Connectivity** is computed as pairwise Pearson correlation between ROI
|
| 256 |
+
timecourses (mean activation across vertices within each ROI).
|
| 257 |
+
|
| 258 |
+
**Partial Correlation** controls for the shared mean signal by computing the precision
|
| 259 |
+
matrix (inverse covariance) and normalizing. This removes indirect correlations mediated
|
| 260 |
+
by a common driver.
|
| 261 |
+
|
| 262 |
+
**Hierarchical Clustering** uses agglomerative clustering with average linkage on a
|
| 263 |
+
distance matrix defined as ``1 - |correlation|``. The dendrogram shows the hierarchical
|
| 264 |
+
merging of ROIs into networks.
|
| 265 |
+
|
| 266 |
+
**Modularity (Q)** quantifies how strongly the network divides into communities compared
|
| 267 |
+
to a random network with the same degree distribution. Q > 0.3 typically indicates
|
| 268 |
+
meaningful community structure. (Newman, 2006, *PNAS*)
|
| 269 |
+
|
| 270 |
+
**Degree Centrality** is the fraction of other nodes each node is connected to (above
|
| 271 |
+
the correlation threshold). High degree = hub region.
|
| 272 |
+
|
| 273 |
+
**Betweenness Centrality** counts how often a node lies on the shortest path between
|
| 274 |
+
other node pairs. High betweenness = bridge between communities.
|
| 275 |
+
|
| 276 |
+
**Edge Weight Distribution** shows the histogram of all pairwise correlations. The
|
| 277 |
+
threshold (red line) determines which connections are retained for graph analysis.
|
| 278 |
+
|
| 279 |
+
**References**:
|
| 280 |
+
- Rubinov & Sporns, 2010, *NeuroImage* (graph metrics for brain networks)
|
| 281 |
+
- Newman, 2006, *PNAS* (modularity in networks)
|
| 282 |
+
- Smith et al., 2011, *NeuroImage* (partial correlation for fMRI connectivity)
|
| 283 |
+
""")
|
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared session state management and data I/O utilities.
|
| 2 |
+
|
| 3 |
+
Manages cross-page state (selected ROIs, predictions, analysis log)
|
| 4 |
+
and provides upload/download widgets.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import io
|
| 8 |
+
import json
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import streamlit as st
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def init_session():
|
| 17 |
+
"""Initialize session state with defaults. Safe to call multiple times."""
|
| 18 |
+
defaults = {
|
| 19 |
+
"brain_predictions": None,
|
| 20 |
+
"model_features": {},
|
| 21 |
+
"roi_indices": None,
|
| 22 |
+
"n_vertices": 0,
|
| 23 |
+
"selected_rois": [],
|
| 24 |
+
"data_source": "synthetic",
|
| 25 |
+
"stimulus_type": "visual",
|
| 26 |
+
"tr_seconds": 1.0,
|
| 27 |
+
"n_timepoints": 80,
|
| 28 |
+
"seed": 42,
|
| 29 |
+
"analysis_log": [],
|
| 30 |
+
"carry_rois": [], # ROIs carried from another page
|
| 31 |
+
}
|
| 32 |
+
for key, value in defaults.items():
|
| 33 |
+
if key not in st.session_state:
|
| 34 |
+
st.session_state[key] = value
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def log_analysis(description):
|
| 38 |
+
"""Append an entry to the analysis log."""
|
| 39 |
+
timestamp = datetime.now().strftime("%H:%M:%S")
|
| 40 |
+
entry = f"[{timestamp}] {description}"
|
| 41 |
+
if "analysis_log" not in st.session_state:
|
| 42 |
+
st.session_state["analysis_log"] = []
|
| 43 |
+
st.session_state["analysis_log"].append(entry)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def carry_rois(rois, target_page=""):
|
| 47 |
+
"""Store selected ROIs for cross-page workflow."""
|
| 48 |
+
st.session_state["carry_rois"] = list(rois)
|
| 49 |
+
log_analysis(f"Carried {len(rois)} ROIs to {target_page}")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_carried_rois():
|
| 53 |
+
"""Retrieve ROIs carried from another page."""
|
| 54 |
+
return st.session_state.get("carry_rois", [])
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_or_generate_data(roi_indices):
|
| 58 |
+
"""Get brain predictions from session or generate new synthetic data."""
|
| 59 |
+
from synthetic import generate_realistic_predictions
|
| 60 |
+
|
| 61 |
+
params_key = (
|
| 62 |
+
st.session_state.get("n_timepoints", 80),
|
| 63 |
+
st.session_state.get("stimulus_type", "visual"),
|
| 64 |
+
st.session_state.get("seed", 42),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Check if we need to regenerate
|
| 68 |
+
if (
|
| 69 |
+
st.session_state.get("brain_predictions") is None
|
| 70 |
+
or st.session_state.get("_data_params") != params_key
|
| 71 |
+
or st.session_state.get("data_source") == "synthetic"
|
| 72 |
+
):
|
| 73 |
+
if st.session_state.get("data_source") == "uploaded" and st.session_state.get("brain_predictions") is not None:
|
| 74 |
+
return st.session_state["brain_predictions"]
|
| 75 |
+
|
| 76 |
+
predictions = generate_realistic_predictions(
|
| 77 |
+
n_timepoints=st.session_state["n_timepoints"],
|
| 78 |
+
roi_indices=roi_indices,
|
| 79 |
+
stimulus_type=st.session_state["stimulus_type"],
|
| 80 |
+
tr_seconds=st.session_state["tr_seconds"],
|
| 81 |
+
seed=st.session_state["seed"],
|
| 82 |
+
)
|
| 83 |
+
st.session_state["brain_predictions"] = predictions
|
| 84 |
+
st.session_state["_data_params"] = params_key
|
| 85 |
+
|
| 86 |
+
return st.session_state["brain_predictions"]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def upload_npy_widget(label, key):
|
| 90 |
+
"""File uploader for .npy arrays with validation."""
|
| 91 |
+
uploaded = st.file_uploader(label, type=["npy"], key=key)
|
| 92 |
+
if uploaded is not None:
|
| 93 |
+
try:
|
| 94 |
+
data = np.load(io.BytesIO(uploaded.read()))
|
| 95 |
+
st.success(f"Loaded: shape {data.shape}, dtype {data.dtype}")
|
| 96 |
+
return data
|
| 97 |
+
except Exception as e:
|
| 98 |
+
st.error(f"Failed to load file: {e}")
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def download_csv_button(df, filename, label="Download CSV"):
|
| 103 |
+
"""Download button for a pandas DataFrame as CSV."""
|
| 104 |
+
csv = df.to_csv(index=False)
|
| 105 |
+
st.download_button(label, csv, filename, "text/csv")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def download_json_button(data, filename, label="Download JSON"):
|
| 109 |
+
"""Download button for a dict as JSON."""
|
| 110 |
+
json_str = json.dumps(data, indent=2, default=str)
|
| 111 |
+
st.download_button(label, json_str, filename, "application/json")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def show_analysis_log():
|
| 115 |
+
"""Display the analysis log in the sidebar."""
|
| 116 |
+
log = st.session_state.get("analysis_log", [])
|
| 117 |
+
if log:
|
| 118 |
+
with st.sidebar:
|
| 119 |
+
with st.expander("Analysis Log", expanded=False):
|
| 120 |
+
for entry in reversed(log[-20:]):
|
| 121 |
+
st.caption(entry)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def data_summary_widget(predictions, roi_indices):
|
| 125 |
+
"""Show a summary of the current data."""
|
| 126 |
+
if predictions is None:
|
| 127 |
+
st.info("No data loaded. Generate synthetic data or upload your own.")
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 131 |
+
col1.metric("Timepoints", predictions.shape[0])
|
| 132 |
+
col2.metric("Vertices", predictions.shape[1])
|
| 133 |
+
col3.metric("ROIs", len(roi_indices))
|
| 134 |
+
col4.metric("Source", st.session_state.get("data_source", "synthetic").title())
|
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Biologically realistic synthetic fMRI data generation.
|
| 2 |
+
|
| 3 |
+
Generates data with hemodynamic response convolution, modality-specific
|
| 4 |
+
activation patterns, spatial autocorrelation, temporal noise structure,
|
| 5 |
+
and scanner drift - mimicking real fMRI recordings.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
# --- Hemodynamic Response Function ---
|
| 11 |
+
|
| 12 |
+
def generate_hrf(tr_seconds=1.0, duration=30.0):
|
| 13 |
+
"""Canonical double-gamma hemodynamic response function.
|
| 14 |
+
|
| 15 |
+
Models the BOLD signal: a positive peak at ~5-6s followed by a
|
| 16 |
+
smaller negative undershoot at ~15s.
|
| 17 |
+
"""
|
| 18 |
+
t = np.arange(0, duration, tr_seconds)
|
| 19 |
+
# Double gamma parameters (SPM canonical)
|
| 20 |
+
a1, b1 = 6.0, 1.0 # positive peak
|
| 21 |
+
a2, b2 = 16.0, 1.0 # undershoot
|
| 22 |
+
c = 1.0 / 6.0 # undershoot ratio
|
| 23 |
+
|
| 24 |
+
from scipy.stats import gamma as gamma_dist
|
| 25 |
+
h = gamma_dist.pdf(t, a1, scale=b1) - c * gamma_dist.pdf(t, a2, scale=b2)
|
| 26 |
+
h = h / np.max(np.abs(h)) # normalize to [-1, 1]
|
| 27 |
+
return h
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def generate_stimulus_events(n_timepoints, tr_seconds=1.0, n_events=5, seed=42):
|
| 31 |
+
"""Generate random stimulus onset times as a binary event train.
|
| 32 |
+
|
| 33 |
+
Returns a (n_timepoints,) array with 1s at stimulus onsets.
|
| 34 |
+
Events are spaced at least 8 seconds apart.
|
| 35 |
+
"""
|
| 36 |
+
rng = np.random.default_rng(seed)
|
| 37 |
+
total_seconds = n_timepoints * tr_seconds
|
| 38 |
+
min_gap = 8.0 # minimum inter-stimulus interval
|
| 39 |
+
|
| 40 |
+
events = np.zeros(n_timepoints)
|
| 41 |
+
onsets = []
|
| 42 |
+
attempts = 0
|
| 43 |
+
while len(onsets) < n_events and attempts < 1000:
|
| 44 |
+
t = rng.uniform(2.0, total_seconds - 10.0)
|
| 45 |
+
if all(abs(t - o) > min_gap for o in onsets):
|
| 46 |
+
onsets.append(t)
|
| 47 |
+
attempts += 1
|
| 48 |
+
|
| 49 |
+
for onset in onsets:
|
| 50 |
+
idx = int(onset / tr_seconds)
|
| 51 |
+
if 0 <= idx < n_timepoints:
|
| 52 |
+
events[idx] = 1.0
|
| 53 |
+
|
| 54 |
+
return events
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# --- Modality-Specific Activation Weights ---
|
| 58 |
+
|
| 59 |
+
# Weight for each ROI given a stimulus modality (0 = no response, 1 = maximum)
|
| 60 |
+
MODALITY_WEIGHTS = {
|
| 61 |
+
"visual": {
|
| 62 |
+
# Strong visual cortex activation
|
| 63 |
+
"V1": 1.0, "V2": 0.95, "V3": 0.85, "V4": 0.8,
|
| 64 |
+
"MT": 0.75, "MST": 0.7, "FFC": 0.65, "VVC": 0.6,
|
| 65 |
+
# Weak cross-modal
|
| 66 |
+
"A1": 0.05, "LBelt": 0.04, "MBelt": 0.03, "PBelt": 0.03, "A4": 0.02, "A5": 0.02,
|
| 67 |
+
# Minimal language
|
| 68 |
+
"44": 0.08, "45": 0.07, "IFJa": 0.06, "IFJp": 0.05,
|
| 69 |
+
"TPOJ1": 0.1, "TPOJ2": 0.08, "STV": 0.07, "PSL": 0.06,
|
| 70 |
+
# Moderate executive (attention)
|
| 71 |
+
"46": 0.3, "9-46d": 0.25, "8Av": 0.35, "8Ad": 0.3,
|
| 72 |
+
"FEF": 0.4, "p32pr": 0.15, "a32pr": 0.12,
|
| 73 |
+
},
|
| 74 |
+
"auditory": {
|
| 75 |
+
"V1": 0.03, "V2": 0.03, "V3": 0.02, "V4": 0.02,
|
| 76 |
+
"MT": 0.02, "MST": 0.01, "FFC": 0.01, "VVC": 0.01,
|
| 77 |
+
"A1": 1.0, "LBelt": 0.95, "MBelt": 0.9, "PBelt": 0.85, "A4": 0.75, "A5": 0.7,
|
| 78 |
+
"44": 0.15, "45": 0.12, "IFJa": 0.1, "IFJp": 0.08,
|
| 79 |
+
"TPOJ1": 0.25, "TPOJ2": 0.2, "STV": 0.3, "PSL": 0.2,
|
| 80 |
+
"46": 0.2, "9-46d": 0.15, "8Av": 0.12, "8Ad": 0.1,
|
| 81 |
+
"FEF": 0.08, "p32pr": 0.1, "a32pr": 0.08,
|
| 82 |
+
},
|
| 83 |
+
"language": {
|
| 84 |
+
"V1": 0.05, "V2": 0.04, "V3": 0.03, "V4": 0.03,
|
| 85 |
+
"MT": 0.02, "MST": 0.02, "FFC": 0.1, "VVC": 0.08,
|
| 86 |
+
"A1": 0.3, "LBelt": 0.25, "MBelt": 0.2, "PBelt": 0.15, "A4": 0.2, "A5": 0.15,
|
| 87 |
+
"44": 1.0, "45": 0.95, "IFJa": 0.85, "IFJp": 0.8,
|
| 88 |
+
"TPOJ1": 0.9, "TPOJ2": 0.85, "STV": 0.75, "PSL": 0.7,
|
| 89 |
+
"46": 0.5, "9-46d": 0.45, "8Av": 0.3, "8Ad": 0.25,
|
| 90 |
+
"FEF": 0.15, "p32pr": 0.35, "a32pr": 0.3,
|
| 91 |
+
},
|
| 92 |
+
"multimodal": {
|
| 93 |
+
"V1": 0.7, "V2": 0.65, "V3": 0.55, "V4": 0.5,
|
| 94 |
+
"MT": 0.5, "MST": 0.45, "FFC": 0.4, "VVC": 0.35,
|
| 95 |
+
"A1": 0.7, "LBelt": 0.65, "MBelt": 0.55, "PBelt": 0.5, "A4": 0.45, "A5": 0.4,
|
| 96 |
+
"44": 0.65, "45": 0.6, "IFJa": 0.5, "IFJp": 0.45,
|
| 97 |
+
"TPOJ1": 0.6, "TPOJ2": 0.55, "STV": 0.5, "PSL": 0.45,
|
| 98 |
+
"46": 0.4, "9-46d": 0.35, "8Av": 0.3, "8Ad": 0.25,
|
| 99 |
+
"FEF": 0.3, "p32pr": 0.25, "a32pr": 0.2,
|
| 100 |
+
},
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def generate_realistic_predictions(
|
| 105 |
+
n_timepoints,
|
| 106 |
+
roi_indices,
|
| 107 |
+
stimulus_type="visual",
|
| 108 |
+
tr_seconds=1.0,
|
| 109 |
+
n_events=5,
|
| 110 |
+
snr=2.0,
|
| 111 |
+
seed=42,
|
| 112 |
+
):
|
| 113 |
+
"""Generate biologically realistic fMRI-like predictions.
|
| 114 |
+
|
| 115 |
+
Parameters
|
| 116 |
+
----------
|
| 117 |
+
n_timepoints : int
|
| 118 |
+
Number of TRs.
|
| 119 |
+
roi_indices : dict[str, np.ndarray]
|
| 120 |
+
ROI name -> vertex indices mapping.
|
| 121 |
+
stimulus_type : str
|
| 122 |
+
One of "visual", "auditory", "language", "multimodal".
|
| 123 |
+
tr_seconds : float
|
| 124 |
+
Repetition time in seconds.
|
| 125 |
+
n_events : int
|
| 126 |
+
Number of stimulus events.
|
| 127 |
+
snr : float
|
| 128 |
+
Signal-to-noise ratio (higher = cleaner signal).
|
| 129 |
+
seed : int
|
| 130 |
+
Random seed.
|
| 131 |
+
"""
|
| 132 |
+
rng = np.random.default_rng(seed)
|
| 133 |
+
n_vertices = max(max(v) for v in roi_indices.values()) + 1
|
| 134 |
+
predictions = np.zeros((n_timepoints, n_vertices))
|
| 135 |
+
|
| 136 |
+
# 1. Generate stimulus-evoked signal
|
| 137 |
+
events = generate_stimulus_events(n_timepoints, tr_seconds, n_events, seed)
|
| 138 |
+
hrf = generate_hrf(tr_seconds)
|
| 139 |
+
|
| 140 |
+
# Convolve events with HRF
|
| 141 |
+
bold_signal = np.convolve(events, hrf)[:n_timepoints]
|
| 142 |
+
|
| 143 |
+
# 2. Apply modality-specific weights per ROI
|
| 144 |
+
weights = MODALITY_WEIGHTS.get(stimulus_type, MODALITY_WEIGHTS["multimodal"])
|
| 145 |
+
for roi_name, vertices in roi_indices.items():
|
| 146 |
+
w = weights.get(roi_name, 0.1)
|
| 147 |
+
# Add per-ROI latency jitter (higher-order areas respond later)
|
| 148 |
+
latency_shift = 0
|
| 149 |
+
if roi_name in ["44", "45", "IFJa", "IFJp", "46", "9-46d"]:
|
| 150 |
+
latency_shift = int(2.0 / tr_seconds) # ~2s later for association cortex
|
| 151 |
+
elif roi_name in ["TPOJ1", "TPOJ2", "STV", "PSL"]:
|
| 152 |
+
latency_shift = int(1.5 / tr_seconds)
|
| 153 |
+
|
| 154 |
+
shifted = np.roll(bold_signal, latency_shift) * w
|
| 155 |
+
# Add per-vertex variation within ROI
|
| 156 |
+
for v in vertices:
|
| 157 |
+
if v < n_vertices:
|
| 158 |
+
vertex_scale = 0.8 + 0.4 * rng.random()
|
| 159 |
+
predictions[:, v] = shifted * vertex_scale
|
| 160 |
+
|
| 161 |
+
# 3. Add temporal autocorrelation (AR(1) noise)
|
| 162 |
+
ar_coeff = 0.5
|
| 163 |
+
noise = rng.standard_normal(predictions.shape)
|
| 164 |
+
for t in range(1, n_timepoints):
|
| 165 |
+
noise[t] += ar_coeff * noise[t - 1]
|
| 166 |
+
|
| 167 |
+
# 4. Add scanner drift (low-frequency sinusoidal)
|
| 168 |
+
t_axis = np.arange(n_timepoints) * tr_seconds
|
| 169 |
+
drift = 0.1 * np.sin(2 * np.pi * t_axis / (n_timepoints * tr_seconds * 0.8))
|
| 170 |
+
drift = drift[:, np.newaxis]
|
| 171 |
+
|
| 172 |
+
# 5. Combine signal + noise + drift
|
| 173 |
+
signal_power = np.std(predictions[predictions != 0]) if np.any(predictions != 0) else 1.0
|
| 174 |
+
noise_power = signal_power / max(snr, 0.1)
|
| 175 |
+
predictions = predictions + noise * noise_power + drift
|
| 176 |
+
|
| 177 |
+
# 6. Spatial smoothing (average with neighbors within same ROI)
|
| 178 |
+
for roi_name, vertices in roi_indices.items():
|
| 179 |
+
valid = vertices[vertices < n_vertices]
|
| 180 |
+
if len(valid) > 1:
|
| 181 |
+
roi_data = predictions[:, valid].copy()
|
| 182 |
+
kernel = np.ones(min(3, len(valid))) / min(3, len(valid))
|
| 183 |
+
for t in range(n_timepoints):
|
| 184 |
+
predictions[t, valid] = np.convolve(roi_data[t], kernel, mode="same")
|
| 185 |
+
|
| 186 |
+
return predictions
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def generate_correlated_features(
|
| 190 |
+
brain_predictions,
|
| 191 |
+
alignment_strength=0.5,
|
| 192 |
+
feature_dim=512,
|
| 193 |
+
seed=42,
|
| 194 |
+
):
|
| 195 |
+
"""Generate model features with controllable correlation to brain data.
|
| 196 |
+
|
| 197 |
+
Parameters
|
| 198 |
+
----------
|
| 199 |
+
brain_predictions : np.ndarray
|
| 200 |
+
Brain data of shape (n_stimuli, n_vertices).
|
| 201 |
+
alignment_strength : float
|
| 202 |
+
0.0 = random features, 1.0 = perfectly correlated with brain.
|
| 203 |
+
feature_dim : int
|
| 204 |
+
Output feature dimensionality.
|
| 205 |
+
seed : int
|
| 206 |
+
Random seed.
|
| 207 |
+
|
| 208 |
+
Returns
|
| 209 |
+
-------
|
| 210 |
+
np.ndarray
|
| 211 |
+
Features of shape (n_stimuli, feature_dim).
|
| 212 |
+
"""
|
| 213 |
+
rng = np.random.default_rng(seed)
|
| 214 |
+
n_stimuli = brain_predictions.shape[0]
|
| 215 |
+
|
| 216 |
+
# Project brain data to feature_dim via random projection
|
| 217 |
+
n_vertices = brain_predictions.shape[1]
|
| 218 |
+
projection = rng.standard_normal((n_vertices, feature_dim)) / np.sqrt(n_vertices)
|
| 219 |
+
brain_projected = brain_predictions @ projection
|
| 220 |
+
|
| 221 |
+
# Generate random features
|
| 222 |
+
random_features = rng.standard_normal((n_stimuli, feature_dim))
|
| 223 |
+
|
| 224 |
+
# Mix: strength controls brain-alignment vs randomness
|
| 225 |
+
strength = np.clip(alignment_strength, 0.0, 1.0)
|
| 226 |
+
features = strength * brain_projected + (1 - strength) * random_features
|
| 227 |
+
|
| 228 |
+
# Standardize
|
| 229 |
+
features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)
|
| 230 |
+
return features
|
|
@@ -79,13 +79,114 @@ ALIGNMENT_METHODS = {"RSA": rsa_score, "CKA": cka_score, "Procrustes": procruste
|
|
| 79 |
|
| 80 |
|
| 81 |
def permutation_test(model_feat, brain_pred, method_fn, n_perm=500, seed=42):
|
|
|
|
| 82 |
rng = np.random.default_rng(seed)
|
| 83 |
observed = method_fn(model_feat, brain_pred)
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
# --- Cognitive Load ---
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
def permutation_test(model_feat, brain_pred, method_fn, n_perm=500, seed=42):
|
| 82 |
+
"""Returns (observed_score, p_value, null_distribution)."""
|
| 83 |
rng = np.random.default_rng(seed)
|
| 84 |
observed = method_fn(model_feat, brain_pred)
|
| 85 |
+
null_dist = []
|
| 86 |
+
for _ in range(n_perm):
|
| 87 |
+
perm_score = method_fn(model_feat[rng.permutation(len(model_feat))], brain_pred)
|
| 88 |
+
null_dist.append(perm_score)
|
| 89 |
+
null_dist = np.array(null_dist)
|
| 90 |
+
count = np.sum(null_dist >= observed)
|
| 91 |
+
p_value = (count + 1) / (n_perm + 1)
|
| 92 |
+
return observed, p_value, null_dist
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def bootstrap_ci(model_feat, brain_pred, method_fn, n_boot=500, confidence=0.95, seed=42):
|
| 96 |
+
"""Returns (point_estimate, ci_lower, ci_upper)."""
|
| 97 |
+
rng = np.random.default_rng(seed)
|
| 98 |
+
n = model_feat.shape[0]
|
| 99 |
+
point = method_fn(model_feat, brain_pred)
|
| 100 |
+
scores = []
|
| 101 |
+
for _ in range(n_boot):
|
| 102 |
+
idx = rng.choice(n, size=n, replace=True)
|
| 103 |
+
scores.append(method_fn(model_feat[idx], brain_pred[idx]))
|
| 104 |
+
scores = np.array(scores)
|
| 105 |
+
alpha = 1 - confidence
|
| 106 |
+
return point, float(np.percentile(scores, 100 * alpha / 2)), float(np.percentile(scores, 100 * (1 - alpha / 2)))
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def fdr_correction(p_values, alpha=0.05):
|
| 110 |
+
"""Benjamini-Hochberg FDR correction. Returns corrected p-values and significance mask."""
|
| 111 |
+
p = np.array(p_values)
|
| 112 |
+
n = len(p)
|
| 113 |
+
sorted_idx = np.argsort(p)
|
| 114 |
+
sorted_p = p[sorted_idx]
|
| 115 |
+
corrected = np.empty(n)
|
| 116 |
+
corrected[sorted_idx[-1]] = sorted_p[-1]
|
| 117 |
+
for i in range(n - 2, -1, -1):
|
| 118 |
+
corrected[sorted_idx[i]] = min(corrected[sorted_idx[i + 1]], sorted_p[i] * n / (i + 1))
|
| 119 |
+
return corrected, corrected < alpha
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def noise_ceiling(brain_pred, method_fn, n_splits=20, seed=42):
|
| 123 |
+
"""Estimate noise ceiling via split-half reliability."""
|
| 124 |
+
rng = np.random.default_rng(seed)
|
| 125 |
+
n = brain_pred.shape[0]
|
| 126 |
+
scores = []
|
| 127 |
+
for _ in range(n_splits):
|
| 128 |
+
idx = rng.permutation(n)
|
| 129 |
+
half = n // 2
|
| 130 |
+
s = method_fn(brain_pred[idx[:half]], brain_pred[idx[half:half * 2]])
|
| 131 |
+
scores.append(s)
|
| 132 |
+
return float(np.mean(scores)), float(np.std(scores))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def partial_correlation(predictions, roi_indices):
|
| 136 |
+
"""Compute partial correlation matrix (correlation controlling for mean signal)."""
|
| 137 |
+
names = list(roi_indices.keys())
|
| 138 |
+
n = len(names)
|
| 139 |
+
T = predictions.shape[0]
|
| 140 |
+
timecourses = np.zeros((n, T))
|
| 141 |
+
for i, name in enumerate(names):
|
| 142 |
+
verts = roi_indices[name]
|
| 143 |
+
valid = verts[verts < predictions.shape[1]]
|
| 144 |
+
if len(valid) > 0:
|
| 145 |
+
timecourses[i] = predictions[:, valid].mean(axis=1)
|
| 146 |
+
|
| 147 |
+
# Partial correlation via precision matrix
|
| 148 |
+
cov = np.cov(timecourses)
|
| 149 |
+
try:
|
| 150 |
+
prec = np.linalg.inv(cov + 1e-6 * np.eye(n))
|
| 151 |
+
d = np.sqrt(np.diag(prec))
|
| 152 |
+
d[d == 0] = 1
|
| 153 |
+
partial = -prec / np.outer(d, d)
|
| 154 |
+
np.fill_diagonal(partial, 1.0)
|
| 155 |
+
except np.linalg.LinAlgError:
|
| 156 |
+
partial = np.eye(n)
|
| 157 |
+
return np.nan_to_num(partial, nan=0.0), names
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def betweenness_centrality(corr_matrix, roi_names, threshold=0.3):
|
| 161 |
+
"""Compute betweenness centrality from thresholded connectivity."""
|
| 162 |
+
import networkx as nx
|
| 163 |
+
n = corr_matrix.shape[0]
|
| 164 |
+
G = nx.Graph()
|
| 165 |
+
for i, name in enumerate(roi_names):
|
| 166 |
+
G.add_node(name)
|
| 167 |
+
for i in range(n):
|
| 168 |
+
for j in range(i + 1, n):
|
| 169 |
+
if abs(corr_matrix[i, j]) > threshold:
|
| 170 |
+
G.add_edge(roi_names[i], roi_names[j], weight=abs(corr_matrix[i, j]))
|
| 171 |
+
bc = nx.betweenness_centrality(G)
|
| 172 |
+
return {name: bc.get(name, 0.0) for name in roi_names}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def modularity_score(corr_matrix, labels):
|
| 176 |
+
"""Compute Newman's modularity Q for a given partition."""
|
| 177 |
+
n = corr_matrix.shape[0]
|
| 178 |
+
adj = np.abs(corr_matrix).copy()
|
| 179 |
+
np.fill_diagonal(adj, 0)
|
| 180 |
+
m = adj.sum() / 2
|
| 181 |
+
if m == 0:
|
| 182 |
+
return 0.0
|
| 183 |
+
Q = 0.0
|
| 184 |
+
k = adj.sum(axis=1)
|
| 185 |
+
for i in range(n):
|
| 186 |
+
for j in range(n):
|
| 187 |
+
if labels[i] == labels[j]:
|
| 188 |
+
Q += adj[i, j] - k[i] * k[j] / (2 * m)
|
| 189 |
+
return float(Q / (2 * m))
|
| 190 |
|
| 191 |
|
| 192 |
# --- Cognitive Load ---
|