Spaces:
Running
Running
siddhant-rajhans commited on
Commit ·
bce4bae
0
Parent(s):
Initial CortexLab Dashboard - interactive analysis toolkit
Browse filesStreamlit multipage app with 4 analysis pages:
- Brain Alignment Benchmark: compare AI models with RSA/CKA/Procrustes + permutation tests
- Cognitive Load Scorer: timeline, radar chart, dimension breakdown with content profiles
- Temporal Dynamics: peak latency, lag correlation, sustained/transient decomposition
- ROI Connectivity: correlation heatmap, network clustering, degree centrality, graph viz
Runs on synthetic data by default - no GPU or fMRI data required.
- .gitignore +6 -0
- .streamlit/config.toml +9 -0
- Home.py +50 -0
- README.md +28 -0
- pages/1_Brain_Alignment.py +140 -0
- pages/2_Cognitive_Load.py +131 -0
- pages/3_Temporal_Dynamics.py +119 -0
- pages/4_Connectivity.py +162 -0
- requirements.txt +8 -0
- utils.py +222 -0
.gitignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
.streamlit/secrets.toml
|
| 4 |
+
.env
|
| 5 |
+
venv/
|
| 6 |
+
.venv/
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[theme]
|
| 2 |
+
primaryColor = "#6C5CE7"
|
| 3 |
+
backgroundColor = "#0E1117"
|
| 4 |
+
secondaryBackgroundColor = "#1A1A2E"
|
| 5 |
+
textColor = "#FAFAFA"
|
| 6 |
+
font = "sans serif"
|
| 7 |
+
|
| 8 |
+
[server]
|
| 9 |
+
headless = true
|
Home.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CortexLab Dashboard - Home Page."""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
|
| 5 |
+
st.set_page_config(
|
| 6 |
+
page_title="CortexLab Dashboard",
|
| 7 |
+
page_icon="🧠",
|
| 8 |
+
layout="wide",
|
| 9 |
+
initial_sidebar_state="expanded",
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
st.title("CortexLab Dashboard")
|
| 13 |
+
st.markdown("**Interactive analysis toolkit for multimodal fMRI brain encoding**")
|
| 14 |
+
|
| 15 |
+
st.divider()
|
| 16 |
+
|
| 17 |
+
col1, col2 = st.columns(2)
|
| 18 |
+
|
| 19 |
+
with col1:
|
| 20 |
+
st.subheader("Analysis Tools")
|
| 21 |
+
st.page_link("pages/1_Brain_Alignment.py", label="Brain Alignment Benchmark", icon="🎯")
|
| 22 |
+
st.markdown("Score how brain-like any AI model's representations are using RSA, CKA, or Procrustes")
|
| 23 |
+
|
| 24 |
+
st.page_link("pages/2_Cognitive_Load.py", label="Cognitive Load Scorer", icon="📊")
|
| 25 |
+
st.markdown("Predict cognitive demand across visual, auditory, language, and executive dimensions")
|
| 26 |
+
|
| 27 |
+
with col2:
|
| 28 |
+
st.subheader("Advanced Analysis")
|
| 29 |
+
st.page_link("pages/3_Temporal_Dynamics.py", label="Temporal Dynamics", icon="⏱️")
|
| 30 |
+
st.markdown("Analyze peak response latency, lag correlations, and sustained vs transient components")
|
| 31 |
+
|
| 32 |
+
st.page_link("pages/4_Connectivity.py", label="ROI Connectivity", icon="🔗")
|
| 33 |
+
st.markdown("Compute functional connectivity matrices, cluster networks, and graph metrics")
|
| 34 |
+
|
| 35 |
+
st.divider()
|
| 36 |
+
|
| 37 |
+
st.subheader("About")
|
| 38 |
+
st.markdown(
|
| 39 |
+
"""
|
| 40 |
+
CortexLab is an enhanced toolkit built on [Meta's TRIBE v2](https://github.com/facebookresearch/tribev2)
|
| 41 |
+
for predicting how the human brain responds to video, audio, and text.
|
| 42 |
+
|
| 43 |
+
This dashboard runs on **synthetic data** by default - no GPU or real fMRI data required.
|
| 44 |
+
All analysis tools mirror the CortexLab Python API.
|
| 45 |
+
|
| 46 |
+
[GitHub](https://github.com/siddhant-rajhans/cortexlab)
|
| 47 |
+
|
|
| 48 |
+
[HuggingFace](https://huggingface.co/SID2000/cortexlab)
|
| 49 |
+
"""
|
| 50 |
+
)
|
README.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CortexLab Dashboard
|
| 2 |
+
|
| 3 |
+
Interactive analysis dashboard for [CortexLab](https://github.com/siddhant-rajhans/cortexlab) - multimodal fMRI brain encoding toolkit.
|
| 4 |
+
|
| 5 |
+
## Pages
|
| 6 |
+
|
| 7 |
+
- **Brain Alignment Benchmark** - Compare AI model representations against brain responses (RSA, CKA, Procrustes)
|
| 8 |
+
- **Cognitive Load Scorer** - Visualize cognitive demand across visual, auditory, language, and executive dimensions
|
| 9 |
+
- **Temporal Dynamics** - Peak latency, lag correlations, sustained vs transient response decomposition
|
| 10 |
+
- **ROI Connectivity** - Correlation matrices, network clustering, degree centrality, graph visualization
|
| 11 |
+
|
| 12 |
+
## Quick Start
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
pip install -r requirements.txt
|
| 16 |
+
streamlit run Home.py
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
Runs on **synthetic data** by default - no GPU or real fMRI data required.
|
| 20 |
+
|
| 21 |
+
## Links
|
| 22 |
+
|
| 23 |
+
- [CortexLab Library](https://github.com/siddhant-rajhans/cortexlab)
|
| 24 |
+
- [CortexLab on HuggingFace](https://huggingface.co/SID2000/cortexlab)
|
| 25 |
+
|
| 26 |
+
## License
|
| 27 |
+
|
| 28 |
+
CC BY-NC 4.0
|
pages/1_Brain_Alignment.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Brain Alignment Benchmark page."""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
|
| 9 |
+
from utils import (
|
| 10 |
+
ALIGNMENT_METHODS,
|
| 11 |
+
make_roi_indices,
|
| 12 |
+
generate_brain_predictions,
|
| 13 |
+
generate_model_features,
|
| 14 |
+
permutation_test,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
st.set_page_config(page_title="Brain Alignment", page_icon="🎯", layout="wide")
|
| 18 |
+
st.title("🎯 Brain Alignment Benchmark")
|
| 19 |
+
st.markdown("Compare how well AI model representations align with predicted brain responses.")
|
| 20 |
+
|
| 21 |
+
# --- Sidebar Controls ---
|
| 22 |
+
with st.sidebar:
|
| 23 |
+
st.header("Configuration")
|
| 24 |
+
n_stimuli = st.slider("Number of stimuli", 10, 200, 50)
|
| 25 |
+
seed = st.number_input("Random seed", value=42, min_value=0)
|
| 26 |
+
|
| 27 |
+
st.subheader("Models to compare")
|
| 28 |
+
models_config = {
|
| 29 |
+
"CLIP ViT-L/14": st.checkbox("CLIP ViT-L/14", value=True),
|
| 30 |
+
"DINOv2 ViT-S": st.checkbox("DINOv2 ViT-S", value=True),
|
| 31 |
+
"V-JEPA2 ViT-G": st.checkbox("V-JEPA2 ViT-G", value=True),
|
| 32 |
+
"LLaMA 3.2-3B": st.checkbox("LLaMA 3.2-3B", value=False),
|
| 33 |
+
}
|
| 34 |
+
model_dims = {
|
| 35 |
+
"CLIP ViT-L/14": 768,
|
| 36 |
+
"DINOv2 ViT-S": 384,
|
| 37 |
+
"V-JEPA2 ViT-G": 1024,
|
| 38 |
+
"LLaMA 3.2-3B": 3072,
|
| 39 |
+
}
|
| 40 |
+
selected_models = [m for m, checked in models_config.items() if checked]
|
| 41 |
+
|
| 42 |
+
methods = st.multiselect("Methods", ["RSA", "CKA", "Procrustes"], default=["RSA", "CKA"])
|
| 43 |
+
run_stats = st.checkbox("Run permutation test", value=False)
|
| 44 |
+
n_perm = st.slider("Permutations", 50, 1000, 200) if run_stats else 0
|
| 45 |
+
|
| 46 |
+
if not selected_models or not methods:
|
| 47 |
+
st.warning("Select at least one model and one method.")
|
| 48 |
+
st.stop()
|
| 49 |
+
|
| 50 |
+
# --- Generate Data ---
|
| 51 |
+
roi_indices, n_vertices = make_roi_indices()
|
| 52 |
+
brain_pred = generate_brain_predictions(n_stimuli, n_vertices, seed)
|
| 53 |
+
|
| 54 |
+
model_features = {}
|
| 55 |
+
for i, name in enumerate(selected_models):
|
| 56 |
+
model_features[name] = generate_model_features(n_stimuli, model_dims[name], seed + i + 1)
|
| 57 |
+
|
| 58 |
+
# --- Run Benchmark ---
|
| 59 |
+
with st.spinner("Computing alignment scores..."):
|
| 60 |
+
results = []
|
| 61 |
+
for model_name, features in model_features.items():
|
| 62 |
+
for method_name in methods:
|
| 63 |
+
score_fn = ALIGNMENT_METHODS[method_name]
|
| 64 |
+
score = score_fn(features, brain_pred)
|
| 65 |
+
row = {"Model": model_name, "Method": method_name, "Score": score}
|
| 66 |
+
|
| 67 |
+
if run_stats:
|
| 68 |
+
_, p_val = permutation_test(features, brain_pred, score_fn, n_perm, seed)
|
| 69 |
+
row["p-value"] = p_val
|
| 70 |
+
row["Significant"] = "Yes" if p_val < 0.05 else "No"
|
| 71 |
+
|
| 72 |
+
results.append(row)
|
| 73 |
+
|
| 74 |
+
df = pd.DataFrame(results)
|
| 75 |
+
|
| 76 |
+
# --- Display Results ---
|
| 77 |
+
col1, col2 = st.columns([2, 1])
|
| 78 |
+
|
| 79 |
+
with col1:
|
| 80 |
+
st.subheader("Alignment Scores")
|
| 81 |
+
fig = px.bar(
|
| 82 |
+
df,
|
| 83 |
+
x="Model",
|
| 84 |
+
y="Score",
|
| 85 |
+
color="Method",
|
| 86 |
+
barmode="group",
|
| 87 |
+
color_discrete_sequence=px.colors.qualitative.Set2,
|
| 88 |
+
)
|
| 89 |
+
fig.update_layout(
|
| 90 |
+
yaxis_title="Alignment Score",
|
| 91 |
+
height=450,
|
| 92 |
+
template="plotly_dark",
|
| 93 |
+
)
|
| 94 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 95 |
+
|
| 96 |
+
with col2:
|
| 97 |
+
st.subheader("Results Table")
|
| 98 |
+
display_df = df.copy()
|
| 99 |
+
display_df["Score"] = display_df["Score"].map(lambda x: f"{x:.4f}")
|
| 100 |
+
if "p-value" in display_df.columns:
|
| 101 |
+
display_df["p-value"] = display_df["p-value"].map(lambda x: f"{x:.4f}")
|
| 102 |
+
st.dataframe(display_df, use_container_width=True, hide_index=True)
|
| 103 |
+
|
| 104 |
+
# --- Per-ROI Analysis ---
|
| 105 |
+
st.divider()
|
| 106 |
+
st.subheader("Per-ROI Alignment (RSA)")
|
| 107 |
+
|
| 108 |
+
if "RSA" in methods and len(selected_models) >= 1:
|
| 109 |
+
from utils import ROI_GROUPS, rsa_score
|
| 110 |
+
|
| 111 |
+
roi_data = []
|
| 112 |
+
for model_name, features in model_features.items():
|
| 113 |
+
for group_name, rois in ROI_GROUPS.items():
|
| 114 |
+
group_scores = []
|
| 115 |
+
for roi in rois:
|
| 116 |
+
if roi in roi_indices:
|
| 117 |
+
verts = roi_indices[roi]
|
| 118 |
+
valid = verts[verts < brain_pred.shape[1]]
|
| 119 |
+
if len(valid) >= 2:
|
| 120 |
+
s = rsa_score(features, brain_pred[:, valid])
|
| 121 |
+
group_scores.append(s)
|
| 122 |
+
if group_scores:
|
| 123 |
+
roi_data.append({
|
| 124 |
+
"Model": model_name,
|
| 125 |
+
"Region": group_name,
|
| 126 |
+
"RSA Score": float(np.mean(group_scores)),
|
| 127 |
+
})
|
| 128 |
+
|
| 129 |
+
if roi_data:
|
| 130 |
+
roi_df = pd.DataFrame(roi_data)
|
| 131 |
+
fig2 = px.bar(
|
| 132 |
+
roi_df,
|
| 133 |
+
x="Region",
|
| 134 |
+
y="RSA Score",
|
| 135 |
+
color="Model",
|
| 136 |
+
barmode="group",
|
| 137 |
+
color_discrete_sequence=px.colors.qualitative.Pastel,
|
| 138 |
+
)
|
| 139 |
+
fig2.update_layout(height=400, template="plotly_dark")
|
| 140 |
+
st.plotly_chart(fig2, use_container_width=True)
|
pages/2_Cognitive_Load.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cognitive Load Scorer page."""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
|
| 9 |
+
from utils import (
|
| 10 |
+
make_roi_indices,
|
| 11 |
+
generate_brain_predictions,
|
| 12 |
+
score_cognitive_load,
|
| 13 |
+
COGNITIVE_DIMENSIONS,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
st.set_page_config(page_title="Cognitive Load", page_icon="📊", layout="wide")
|
| 17 |
+
st.title("📊 Cognitive Load Scorer")
|
| 18 |
+
st.markdown("Predict cognitive demand of media content from brain activation patterns.")
|
| 19 |
+
|
| 20 |
+
# --- Sidebar ---
|
| 21 |
+
with st.sidebar:
|
| 22 |
+
st.header("Configuration")
|
| 23 |
+
n_timepoints = st.slider("Duration (TRs)", 20, 200, 60)
|
| 24 |
+
tr_seconds = st.slider("TR duration (seconds)", 0.5, 2.0, 1.0, 0.1)
|
| 25 |
+
seed = st.number_input("Random seed", value=42, min_value=0)
|
| 26 |
+
|
| 27 |
+
st.subheader("Simulate content type")
|
| 28 |
+
content_type = st.selectbox(
|
| 29 |
+
"Content profile",
|
| 30 |
+
["Balanced", "Visual-heavy", "Audio-heavy", "Language-heavy", "Low engagement"],
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# --- Generate Data ---
|
| 34 |
+
roi_indices, n_vertices = make_roi_indices()
|
| 35 |
+
predictions = generate_brain_predictions(n_timepoints, n_vertices, seed)
|
| 36 |
+
|
| 37 |
+
# Apply content profile by amplifying relevant ROIs
|
| 38 |
+
multipliers = {"Balanced": {}, "Visual-heavy": {}, "Audio-heavy": {}, "Language-heavy": {}, "Low engagement": {}}
|
| 39 |
+
if content_type == "Visual-heavy":
|
| 40 |
+
for roi in COGNITIVE_DIMENSIONS["Visual Complexity"]:
|
| 41 |
+
if roi in roi_indices:
|
| 42 |
+
predictions[:, roi_indices[roi]] *= 3.0
|
| 43 |
+
elif content_type == "Audio-heavy":
|
| 44 |
+
for roi in COGNITIVE_DIMENSIONS["Auditory Demand"]:
|
| 45 |
+
if roi in roi_indices:
|
| 46 |
+
predictions[:, roi_indices[roi]] *= 3.0
|
| 47 |
+
elif content_type == "Language-heavy":
|
| 48 |
+
for roi in COGNITIVE_DIMENSIONS["Language Processing"]:
|
| 49 |
+
if roi in roi_indices:
|
| 50 |
+
predictions[:, roi_indices[roi]] *= 3.0
|
| 51 |
+
elif content_type == "Low engagement":
|
| 52 |
+
predictions *= 0.2
|
| 53 |
+
|
| 54 |
+
# --- Score ---
|
| 55 |
+
averages, timeline = score_cognitive_load(predictions, roi_indices, tr_seconds)
|
| 56 |
+
|
| 57 |
+
# --- Display ---
|
| 58 |
+
col1, col2, col3, col4, col5 = st.columns(5)
|
| 59 |
+
dims = ["Overall", "Visual Complexity", "Auditory Demand", "Language Processing", "Executive Load"]
|
| 60 |
+
cols = [col1, col2, col3, col4, col5]
|
| 61 |
+
for col, dim in zip(cols, dims):
|
| 62 |
+
val = averages.get(dim, 0.0)
|
| 63 |
+
col.metric(dim, f"{val:.2f}", delta=None)
|
| 64 |
+
|
| 65 |
+
st.divider()
|
| 66 |
+
|
| 67 |
+
# --- Timeline ---
|
| 68 |
+
st.subheader("Cognitive Load Timeline")
|
| 69 |
+
timeline_df = pd.DataFrame(timeline)
|
| 70 |
+
|
| 71 |
+
fig = go.Figure()
|
| 72 |
+
colors = {"Visual Complexity": "#00D2FF", "Auditory Demand": "#FF6B6B", "Language Processing": "#A29BFE", "Executive Load": "#FFEAA7"}
|
| 73 |
+
for dim, color in colors.items():
|
| 74 |
+
fig.add_trace(go.Scatter(
|
| 75 |
+
x=timeline_df["time"],
|
| 76 |
+
y=timeline_df[dim],
|
| 77 |
+
name=dim,
|
| 78 |
+
line=dict(color=color, width=2),
|
| 79 |
+
mode="lines",
|
| 80 |
+
))
|
| 81 |
+
|
| 82 |
+
fig.update_layout(
|
| 83 |
+
xaxis_title="Time (seconds)",
|
| 84 |
+
yaxis_title="Cognitive Load (normalized)",
|
| 85 |
+
yaxis_range=[0, 1.05],
|
| 86 |
+
height=400,
|
| 87 |
+
template="plotly_dark",
|
| 88 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02),
|
| 89 |
+
)
|
| 90 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 91 |
+
|
| 92 |
+
# --- Dimension Breakdown ---
|
| 93 |
+
st.divider()
|
| 94 |
+
col1, col2 = st.columns(2)
|
| 95 |
+
|
| 96 |
+
with col1:
|
| 97 |
+
st.subheader("Dimension Breakdown")
|
| 98 |
+
dim_data = {k: v for k, v in averages.items() if k != "Overall"}
|
| 99 |
+
fig2 = go.Figure(go.Bar(
|
| 100 |
+
x=list(dim_data.values()),
|
| 101 |
+
y=list(dim_data.keys()),
|
| 102 |
+
orientation="h",
|
| 103 |
+
marker_color=list(colors.values()),
|
| 104 |
+
))
|
| 105 |
+
fig2.update_layout(
|
| 106 |
+
xaxis_title="Score",
|
| 107 |
+
xaxis_range=[0, 1],
|
| 108 |
+
height=300,
|
| 109 |
+
template="plotly_dark",
|
| 110 |
+
)
|
| 111 |
+
st.plotly_chart(fig2, use_container_width=True)
|
| 112 |
+
|
| 113 |
+
with col2:
|
| 114 |
+
st.subheader("Radar Chart")
|
| 115 |
+
categories = list(dim_data.keys())
|
| 116 |
+
values = list(dim_data.values()) + [list(dim_data.values())[0]] # close the polygon
|
| 117 |
+
|
| 118 |
+
fig3 = go.Figure(go.Scatterpolar(
|
| 119 |
+
r=values,
|
| 120 |
+
theta=categories + [categories[0]],
|
| 121 |
+
fill="toself",
|
| 122 |
+
fillcolor="rgba(108, 92, 231, 0.3)",
|
| 123 |
+
line=dict(color="#6C5CE7"),
|
| 124 |
+
))
|
| 125 |
+
fig3.update_layout(
|
| 126 |
+
polar=dict(radialaxis=dict(visible=True, range=[0, 1])),
|
| 127 |
+
height=350,
|
| 128 |
+
template="plotly_dark",
|
| 129 |
+
showlegend=False,
|
| 130 |
+
)
|
| 131 |
+
st.plotly_chart(fig3, use_container_width=True)
|
pages/3_Temporal_Dynamics.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Temporal Dynamics page."""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
from plotly.subplots import make_subplots
|
| 8 |
+
|
| 9 |
+
from utils import (
|
| 10 |
+
make_roi_indices,
|
| 11 |
+
generate_brain_predictions,
|
| 12 |
+
generate_model_features,
|
| 13 |
+
peak_latency,
|
| 14 |
+
temporal_correlation,
|
| 15 |
+
decompose_response,
|
| 16 |
+
ROI_GROUPS,
|
| 17 |
+
ALL_ROIS,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
st.set_page_config(page_title="Temporal Dynamics", page_icon="⏱️", layout="wide")
|
| 21 |
+
st.title("⏱️ Temporal Dynamics")
|
| 22 |
+
st.markdown("Analyze how brain responses evolve over time.")
|
| 23 |
+
|
| 24 |
+
# --- Sidebar ---
|
| 25 |
+
with st.sidebar:
|
| 26 |
+
st.header("Configuration")
|
| 27 |
+
n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
|
| 28 |
+
tr_seconds = st.slider("TR duration (seconds)", 0.5, 2.0, 1.0, 0.1)
|
| 29 |
+
seed = st.number_input("Random seed", value=42, min_value=0)
|
| 30 |
+
|
| 31 |
+
st.subheader("ROI Selection")
|
| 32 |
+
selected_group = st.selectbox("Region group", list(ROI_GROUPS.keys()))
|
| 33 |
+
available_rois = ROI_GROUPS[selected_group]
|
| 34 |
+
selected_rois = st.multiselect("ROIs to analyze", available_rois, default=available_rois[:3])
|
| 35 |
+
|
| 36 |
+
max_lag = st.slider("Max correlation lag (TRs)", 5, 30, 15)
|
| 37 |
+
cutoff = st.slider("Decomposition cutoff (seconds)", 1.0, 10.0, 4.0, 0.5)
|
| 38 |
+
|
| 39 |
+
if not selected_rois:
|
| 40 |
+
st.warning("Select at least one ROI.")
|
| 41 |
+
st.stop()
|
| 42 |
+
|
| 43 |
+
# --- Generate Data ---
|
| 44 |
+
roi_indices, n_vertices = make_roi_indices()
|
| 45 |
+
predictions = generate_brain_predictions(n_timepoints, n_vertices, seed)
|
| 46 |
+
features = generate_model_features(n_timepoints, 64, seed + 1)
|
| 47 |
+
|
| 48 |
+
# --- Peak Latency ---
|
| 49 |
+
st.subheader("Peak Response Latency")
|
| 50 |
+
latency_data = []
|
| 51 |
+
for roi in selected_rois:
|
| 52 |
+
lat = peak_latency(predictions, roi_indices, roi, tr_seconds)
|
| 53 |
+
latency_data.append({"ROI": roi, "Peak Latency (s)": lat})
|
| 54 |
+
|
| 55 |
+
lat_df = pd.DataFrame(latency_data)
|
| 56 |
+
col1, col2 = st.columns([2, 1])
|
| 57 |
+
|
| 58 |
+
with col1:
|
| 59 |
+
fig = go.Figure(go.Bar(
|
| 60 |
+
x=lat_df["ROI"],
|
| 61 |
+
y=lat_df["Peak Latency (s)"],
|
| 62 |
+
marker_color="#6C5CE7",
|
| 63 |
+
))
|
| 64 |
+
fig.update_layout(
|
| 65 |
+
yaxis_title="Time to peak (seconds)",
|
| 66 |
+
height=350,
|
| 67 |
+
template="plotly_dark",
|
| 68 |
+
)
|
| 69 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 70 |
+
|
| 71 |
+
with col2:
|
| 72 |
+
st.dataframe(lat_df, use_container_width=True, hide_index=True)
|
| 73 |
+
|
| 74 |
+
# --- Lag Correlation ---
|
| 75 |
+
st.divider()
|
| 76 |
+
st.subheader("Temporal Correlation (Brain vs Model Features)")
|
| 77 |
+
|
| 78 |
+
lags = np.arange(-max_lag, max_lag + 1) * tr_seconds
|
| 79 |
+
fig2 = go.Figure()
|
| 80 |
+
colors = ["#00D2FF", "#FF6B6B", "#A29BFE", "#FFEAA7", "#55EFC4", "#FD79A8"]
|
| 81 |
+
|
| 82 |
+
for i, roi in enumerate(selected_rois):
|
| 83 |
+
corr = temporal_correlation(predictions, features, roi_indices, roi, max_lag)
|
| 84 |
+
fig2.add_trace(go.Scatter(
|
| 85 |
+
x=lags,
|
| 86 |
+
y=corr,
|
| 87 |
+
name=roi,
|
| 88 |
+
line=dict(color=colors[i % len(colors)], width=2),
|
| 89 |
+
))
|
| 90 |
+
|
| 91 |
+
fig2.add_vline(x=0, line_dash="dash", line_color="gray", opacity=0.5)
|
| 92 |
+
fig2.update_layout(
|
| 93 |
+
xaxis_title="Lag (seconds)",
|
| 94 |
+
yaxis_title="Pearson Correlation",
|
| 95 |
+
height=400,
|
| 96 |
+
template="plotly_dark",
|
| 97 |
+
legend=dict(orientation="h", yanchor="bottom", y=1.02),
|
| 98 |
+
)
|
| 99 |
+
st.plotly_chart(fig2, use_container_width=True)
|
| 100 |
+
|
| 101 |
+
# --- Sustained vs Transient ---
|
| 102 |
+
st.divider()
|
| 103 |
+
st.subheader("Sustained vs Transient Decomposition")
|
| 104 |
+
|
| 105 |
+
roi_for_decomp = st.selectbox("ROI for decomposition", selected_rois)
|
| 106 |
+
sustained, transient = decompose_response(predictions, roi_indices, roi_for_decomp, cutoff, tr_seconds)
|
| 107 |
+
time_axis = np.arange(len(sustained)) * tr_seconds
|
| 108 |
+
|
| 109 |
+
fig3 = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.08,
|
| 110 |
+
subplot_titles=("Sustained Component", "Transient Component"))
|
| 111 |
+
|
| 112 |
+
fig3.add_trace(go.Scatter(x=time_axis, y=sustained, name="Sustained",
|
| 113 |
+
line=dict(color="#6C5CE7", width=2)), row=1, col=1)
|
| 114 |
+
fig3.add_trace(go.Scatter(x=time_axis, y=transient, name="Transient",
|
| 115 |
+
line=dict(color="#FF6B6B", width=1.5)), row=2, col=1)
|
| 116 |
+
|
| 117 |
+
fig3.update_xaxes(title_text="Time (seconds)", row=2, col=1)
|
| 118 |
+
fig3.update_layout(height=500, template="plotly_dark", showlegend=False)
|
| 119 |
+
st.plotly_chart(fig3, use_container_width=True)
|
pages/4_Connectivity.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ROI Connectivity page."""
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
|
| 9 |
+
from utils import (
|
| 10 |
+
make_roi_indices,
|
| 11 |
+
generate_brain_predictions,
|
| 12 |
+
compute_connectivity,
|
| 13 |
+
cluster_rois,
|
| 14 |
+
graph_metrics,
|
| 15 |
+
ROI_GROUPS,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
st.set_page_config(page_title="ROI Connectivity", page_icon="🔗", layout="wide")
|
| 19 |
+
st.title("🔗 ROI Connectivity Analysis")
|
| 20 |
+
st.markdown("Functional connectivity between brain regions from predicted responses.")
|
| 21 |
+
|
| 22 |
+
# --- Sidebar ---
|
| 23 |
+
with st.sidebar:
|
| 24 |
+
st.header("Configuration")
|
| 25 |
+
n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
|
| 26 |
+
seed = st.number_input("Random seed", value=42, min_value=0)
|
| 27 |
+
n_clusters = st.slider("Number of clusters", 2, 8, 4)
|
| 28 |
+
threshold = st.slider("Edge threshold", 0.1, 0.8, 0.3, 0.05)
|
| 29 |
+
|
| 30 |
+
# --- Generate Data ---
|
| 31 |
+
roi_indices, n_vertices = make_roi_indices()
|
| 32 |
+
predictions = generate_brain_predictions(n_timepoints, n_vertices, seed)
|
| 33 |
+
|
| 34 |
+
# --- Correlation Matrix ---
|
| 35 |
+
corr_matrix, roi_names = compute_connectivity(predictions, roi_indices)
|
| 36 |
+
|
| 37 |
+
st.subheader("Correlation Matrix")
|
| 38 |
+
fig = go.Figure(go.Heatmap(
|
| 39 |
+
z=corr_matrix,
|
| 40 |
+
x=roi_names,
|
| 41 |
+
y=roi_names,
|
| 42 |
+
colorscale="RdBu_r",
|
| 43 |
+
zmid=0,
|
| 44 |
+
zmin=-1,
|
| 45 |
+
zmax=1,
|
| 46 |
+
colorbar=dict(title="Correlation"),
|
| 47 |
+
))
|
| 48 |
+
fig.update_layout(
|
| 49 |
+
height=600,
|
| 50 |
+
width=700,
|
| 51 |
+
template="plotly_dark",
|
| 52 |
+
xaxis=dict(tickangle=45, tickfont=dict(size=8)),
|
| 53 |
+
yaxis=dict(tickfont=dict(size=8)),
|
| 54 |
+
)
|
| 55 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 56 |
+
|
| 57 |
+
# --- Network Clusters ---
|
| 58 |
+
st.divider()
|
| 59 |
+
col1, col2 = st.columns(2)
|
| 60 |
+
|
| 61 |
+
with col1:
|
| 62 |
+
st.subheader("Functional Network Clusters")
|
| 63 |
+
clusters, labels = cluster_rois(corr_matrix, roi_names, n_clusters)
|
| 64 |
+
|
| 65 |
+
cluster_data = []
|
| 66 |
+
for cid, rois in sorted(clusters.items()):
|
| 67 |
+
for roi in rois:
|
| 68 |
+
cluster_data.append({"Cluster": f"Network {cid}", "ROI": roi})
|
| 69 |
+
|
| 70 |
+
cluster_df = pd.DataFrame(cluster_data)
|
| 71 |
+
fig2 = px.bar(
|
| 72 |
+
cluster_df.groupby("Cluster").size().reset_index(name="Count"),
|
| 73 |
+
x="Cluster",
|
| 74 |
+
y="Count",
|
| 75 |
+
color="Cluster",
|
| 76 |
+
color_discrete_sequence=px.colors.qualitative.Set2,
|
| 77 |
+
)
|
| 78 |
+
fig2.update_layout(height=350, template="plotly_dark", showlegend=False)
|
| 79 |
+
st.plotly_chart(fig2, use_container_width=True)
|
| 80 |
+
|
| 81 |
+
for cid, rois in sorted(clusters.items()):
|
| 82 |
+
st.markdown(f"**Network {cid}:** {', '.join(rois)}")
|
| 83 |
+
|
| 84 |
+
with col2:
|
| 85 |
+
st.subheader("Degree Centrality")
|
| 86 |
+
degrees = graph_metrics(corr_matrix, roi_names, threshold)
|
| 87 |
+
|
| 88 |
+
# Sort by degree
|
| 89 |
+
sorted_degrees = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
|
| 90 |
+
deg_df = pd.DataFrame(sorted_degrees, columns=["ROI", "Degree Centrality"])
|
| 91 |
+
|
| 92 |
+
fig3 = go.Figure(go.Bar(
|
| 93 |
+
x=deg_df["Degree Centrality"],
|
| 94 |
+
y=deg_df["ROI"],
|
| 95 |
+
orientation="h",
|
| 96 |
+
marker_color="#6C5CE7",
|
| 97 |
+
))
|
| 98 |
+
fig3.update_layout(
|
| 99 |
+
xaxis_title="Degree Centrality",
|
| 100 |
+
xaxis_range=[0, 1],
|
| 101 |
+
height=600,
|
| 102 |
+
template="plotly_dark",
|
| 103 |
+
yaxis=dict(autorange="reversed", tickfont=dict(size=9)),
|
| 104 |
+
)
|
| 105 |
+
st.plotly_chart(fig3, use_container_width=True)
|
| 106 |
+
|
| 107 |
+
# --- Network Graph ---
|
| 108 |
+
st.divider()
|
| 109 |
+
st.subheader("Network Graph")
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
import networkx as nx
|
| 113 |
+
|
| 114 |
+
G = nx.Graph()
|
| 115 |
+
for name in roi_names:
|
| 116 |
+
G.add_node(name)
|
| 117 |
+
|
| 118 |
+
for i in range(len(roi_names)):
|
| 119 |
+
for j in range(i + 1, len(roi_names)):
|
| 120 |
+
if abs(corr_matrix[i, j]) > threshold:
|
| 121 |
+
G.add_edge(roi_names[i], roi_names[j], weight=abs(corr_matrix[i, j]))
|
| 122 |
+
|
| 123 |
+
pos = nx.spring_layout(G, seed=seed, k=2.0)
|
| 124 |
+
|
| 125 |
+
# Cluster colors
|
| 126 |
+
color_map = px.colors.qualitative.Set2
|
| 127 |
+
node_colors = [color_map[(labels[i] - 1) % len(color_map)] for i in range(len(roi_names))]
|
| 128 |
+
|
| 129 |
+
edge_x, edge_y = [], []
|
| 130 |
+
for u, v in G.edges():
|
| 131 |
+
x0, y0 = pos[u]
|
| 132 |
+
x1, y1 = pos[v]
|
| 133 |
+
edge_x.extend([x0, x1, None])
|
| 134 |
+
edge_y.extend([y0, y1, None])
|
| 135 |
+
|
| 136 |
+
node_x = [pos[n][0] for n in roi_names]
|
| 137 |
+
node_y = [pos[n][1] for n in roi_names]
|
| 138 |
+
|
| 139 |
+
fig4 = go.Figure()
|
| 140 |
+
fig4.add_trace(go.Scatter(
|
| 141 |
+
x=edge_x, y=edge_y, mode="lines",
|
| 142 |
+
line=dict(width=0.5, color="rgba(150,150,150,0.3)"),
|
| 143 |
+
hoverinfo="none",
|
| 144 |
+
))
|
| 145 |
+
fig4.add_trace(go.Scatter(
|
| 146 |
+
x=node_x, y=node_y, mode="markers+text",
|
| 147 |
+
marker=dict(size=12, color=node_colors, line=dict(width=1, color="white")),
|
| 148 |
+
text=roi_names,
|
| 149 |
+
textposition="top center",
|
| 150 |
+
textfont=dict(size=8),
|
| 151 |
+
hoverinfo="text",
|
| 152 |
+
))
|
| 153 |
+
fig4.update_layout(
|
| 154 |
+
height=500,
|
| 155 |
+
template="plotly_dark",
|
| 156 |
+
showlegend=False,
|
| 157 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 158 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
| 159 |
+
)
|
| 160 |
+
st.plotly_chart(fig4, use_container_width=True)
|
| 161 |
+
except ImportError:
|
| 162 |
+
st.info("Install `networkx` for the network graph visualization: `pip install networkx`")
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.30
|
| 2 |
+
plotly>=5.18
|
| 3 |
+
numpy>=2.0
|
| 4 |
+
scipy>=1.12
|
| 5 |
+
pandas>=2.0
|
| 6 |
+
networkx>=3.2
|
| 7 |
+
matplotlib>=3.8
|
| 8 |
+
seaborn>=0.13
|
utils.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared utilities for the CortexLab dashboard.
|
| 2 |
+
|
| 3 |
+
Provides synthetic data generation and analysis functions that mirror
|
| 4 |
+
CortexLab's API without requiring the full library or GPU.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from scipy.stats import spearmanr
|
| 9 |
+
from scipy.cluster.hierarchy import linkage, fcluster
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# --- ROI Definitions ---
|
| 13 |
+
|
| 14 |
+
ROI_GROUPS = {
|
| 15 |
+
"Executive": ["46", "9-46d", "8Av", "8Ad", "FEF", "p32pr", "a32pr"],
|
| 16 |
+
"Visual": ["V1", "V2", "V3", "V4", "MT", "MST", "FFC", "VVC"],
|
| 17 |
+
"Auditory": ["A1", "LBelt", "MBelt", "PBelt", "A4", "A5"],
|
| 18 |
+
"Language": ["44", "45", "IFJa", "IFJp", "TPOJ1", "TPOJ2", "STV", "PSL"],
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
ALL_ROIS = [roi for group in ROI_GROUPS.values() for roi in group]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def make_roi_indices(n_vertices_per_roi=20):
|
| 25 |
+
"""Create ROI -> vertex index mapping."""
|
| 26 |
+
indices = {}
|
| 27 |
+
offset = 0
|
| 28 |
+
for roi in ALL_ROIS:
|
| 29 |
+
indices[roi] = np.arange(offset, offset + n_vertices_per_roi)
|
| 30 |
+
offset += n_vertices_per_roi
|
| 31 |
+
return indices, offset
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# --- Brain Alignment ---
|
| 35 |
+
|
| 36 |
+
def compute_rdm(features):
|
| 37 |
+
norms = np.linalg.norm(features, axis=1, keepdims=True)
|
| 38 |
+
norms = np.where(norms > 0, norms, 1.0)
|
| 39 |
+
normalised = features / norms
|
| 40 |
+
return 1.0 - normalised @ normalised.T
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def rsa_score(model_features, brain_features):
|
| 44 |
+
rdm_m = compute_rdm(model_features)
|
| 45 |
+
rdm_b = compute_rdm(brain_features)
|
| 46 |
+
idx = np.triu_indices(rdm_m.shape[0], k=1)
|
| 47 |
+
corr, _ = spearmanr(rdm_m[idx], rdm_b[idx])
|
| 48 |
+
return float(corr) if not np.isnan(corr) else 0.0
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def cka_score(X, Y):
|
| 52 |
+
n = X.shape[0]
|
| 53 |
+
X = X - X.mean(axis=0)
|
| 54 |
+
Y = Y - Y.mean(axis=0)
|
| 55 |
+
XX = X @ X.T
|
| 56 |
+
YY = Y @ Y.T
|
| 57 |
+
hsic_xy = np.trace(XX @ YY) / (n - 1) ** 2
|
| 58 |
+
hsic_xx = np.trace(XX @ XX) / (n - 1) ** 2
|
| 59 |
+
hsic_yy = np.trace(YY @ YY) / (n - 1) ** 2
|
| 60 |
+
denom = np.sqrt(hsic_xx * hsic_yy)
|
| 61 |
+
return float(hsic_xy / denom) if denom > 1e-12 else 0.0
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def procrustes_score(X, Y):
|
| 65 |
+
min_dim = min(X.shape[1], Y.shape[1])
|
| 66 |
+
X, Y = X[:, :min_dim], Y[:, :min_dim]
|
| 67 |
+
X = X - X.mean(axis=0)
|
| 68 |
+
Y = Y - Y.mean(axis=0)
|
| 69 |
+
nx, ny = np.linalg.norm(X), np.linalg.norm(Y)
|
| 70 |
+
if nx < 1e-12 or ny < 1e-12:
|
| 71 |
+
return 0.0
|
| 72 |
+
X, Y = X / nx, Y / ny
|
| 73 |
+
U, _, Vt = np.linalg.svd(Y.T @ X, full_matrices=False)
|
| 74 |
+
rotated = Y @ (U @ Vt)
|
| 75 |
+
return float(max(0.0, 1.0 - np.linalg.norm(X - rotated)))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
ALIGNMENT_METHODS = {"RSA": rsa_score, "CKA": cka_score, "Procrustes": procrustes_score}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def permutation_test(model_feat, brain_pred, method_fn, n_perm=500, seed=42):
|
| 82 |
+
rng = np.random.default_rng(seed)
|
| 83 |
+
observed = method_fn(model_feat, brain_pred)
|
| 84 |
+
count = sum(
|
| 85 |
+
1 for _ in range(n_perm)
|
| 86 |
+
if method_fn(model_feat[rng.permutation(len(model_feat))], brain_pred) >= observed
|
| 87 |
+
)
|
| 88 |
+
return observed, (count + 1) / (n_perm + 1)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# --- Cognitive Load ---
|
| 92 |
+
|
| 93 |
+
COGNITIVE_DIMENSIONS = {
|
| 94 |
+
"Executive Load": ["46", "9-46d", "8Av", "8Ad", "FEF", "p32pr", "a32pr"],
|
| 95 |
+
"Visual Complexity": ["V1", "V2", "V3", "V4", "MT", "MST", "FFC", "VVC"],
|
| 96 |
+
"Auditory Demand": ["A1", "LBelt", "MBelt", "PBelt", "A4", "A5"],
|
| 97 |
+
"Language Processing": ["44", "45", "IFJa", "IFJp", "TPOJ1", "TPOJ2", "STV", "PSL"],
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def score_cognitive_load(predictions, roi_indices, tr_seconds=1.0):
|
| 102 |
+
baseline = max(float(np.median(np.abs(predictions))), 1e-8)
|
| 103 |
+
timeline = []
|
| 104 |
+
dim_scores = {d: [] for d in COGNITIVE_DIMENSIONS}
|
| 105 |
+
|
| 106 |
+
for t in range(predictions.shape[0]):
|
| 107 |
+
row = {}
|
| 108 |
+
for dim, rois in COGNITIVE_DIMENSIONS.items():
|
| 109 |
+
vals = []
|
| 110 |
+
for roi in rois:
|
| 111 |
+
if roi in roi_indices:
|
| 112 |
+
verts = roi_indices[roi]
|
| 113 |
+
valid = verts[verts < predictions.shape[1]]
|
| 114 |
+
if len(valid) > 0:
|
| 115 |
+
vals.append(np.abs(predictions[t, valid]).mean())
|
| 116 |
+
score = min(float(np.mean(vals)) / baseline, 1.0) if vals else 0.0
|
| 117 |
+
dim_scores[dim].append(score)
|
| 118 |
+
row[dim] = score
|
| 119 |
+
row["time"] = t * tr_seconds
|
| 120 |
+
timeline.append(row)
|
| 121 |
+
|
| 122 |
+
averages = {d: float(np.mean(v)) for d, v in dim_scores.items()}
|
| 123 |
+
averages["Overall"] = float(np.mean(list(averages.values())))
|
| 124 |
+
return averages, timeline
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# --- Temporal Dynamics ---
|
| 128 |
+
|
| 129 |
+
def peak_latency(predictions, roi_indices, roi_name, tr_seconds=1.0):
|
| 130 |
+
verts = roi_indices.get(roi_name, np.array([]))
|
| 131 |
+
valid = verts[verts < predictions.shape[1]]
|
| 132 |
+
if len(valid) == 0:
|
| 133 |
+
return 0.0
|
| 134 |
+
tc = np.abs(predictions[:, valid]).mean(axis=1)
|
| 135 |
+
return float(np.argmax(tc) * tr_seconds)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def temporal_correlation(predictions, features, roi_indices, roi_name, max_lag=10):
|
| 139 |
+
verts = roi_indices.get(roi_name, np.array([]))
|
| 140 |
+
valid = verts[verts < predictions.shape[1]]
|
| 141 |
+
if len(valid) == 0:
|
| 142 |
+
return np.zeros(2 * max_lag + 1)
|
| 143 |
+
brain_tc = np.abs(predictions[:, valid]).mean(axis=1)
|
| 144 |
+
model_tc = features.mean(axis=1) if features.ndim > 1 else features
|
| 145 |
+
n = min(len(brain_tc), len(model_tc))
|
| 146 |
+
brain_tc, model_tc = brain_tc[:n], model_tc[:n]
|
| 147 |
+
|
| 148 |
+
corrs = []
|
| 149 |
+
for lag in range(-max_lag, max_lag + 1):
|
| 150 |
+
if lag >= 0:
|
| 151 |
+
b, m = brain_tc[lag:], model_tc[:n - lag]
|
| 152 |
+
else:
|
| 153 |
+
b, m = brain_tc[:n + lag], model_tc[-lag:]
|
| 154 |
+
if len(b) < 2:
|
| 155 |
+
corrs.append(0.0)
|
| 156 |
+
continue
|
| 157 |
+
bz, mz = b - b.mean(), m - m.mean()
|
| 158 |
+
denom = np.sqrt((bz ** 2).sum() * (mz ** 2).sum())
|
| 159 |
+
corrs.append(float((bz * mz).sum() / denom) if denom > 1e-12 else 0.0)
|
| 160 |
+
return np.array(corrs)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def decompose_response(predictions, roi_indices, roi_name, cutoff_seconds=4.0, tr_seconds=1.0):
|
| 164 |
+
verts = roi_indices.get(roi_name, np.array([]))
|
| 165 |
+
valid = verts[verts < predictions.shape[1]]
|
| 166 |
+
if len(valid) == 0:
|
| 167 |
+
return np.zeros(predictions.shape[0]), np.zeros(predictions.shape[0])
|
| 168 |
+
tc = np.abs(predictions[:, valid]).mean(axis=1)
|
| 169 |
+
window = max(1, int(cutoff_seconds / tr_seconds))
|
| 170 |
+
sustained = np.convolve(tc, np.ones(window) / window, mode="same")
|
| 171 |
+
return sustained, tc - sustained
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# --- Connectivity ---
|
| 175 |
+
|
| 176 |
+
def compute_connectivity(predictions, roi_indices):
|
| 177 |
+
names = list(roi_indices.keys())
|
| 178 |
+
n = len(names)
|
| 179 |
+
T = predictions.shape[0]
|
| 180 |
+
timecourses = np.zeros((n, T))
|
| 181 |
+
for i, name in enumerate(names):
|
| 182 |
+
verts = roi_indices[name]
|
| 183 |
+
valid = verts[verts < predictions.shape[1]]
|
| 184 |
+
if len(valid) > 0:
|
| 185 |
+
timecourses[i] = predictions[:, valid].mean(axis=1)
|
| 186 |
+
corr = np.corrcoef(timecourses) if T >= 2 else np.eye(n)
|
| 187 |
+
return np.nan_to_num(corr, nan=0.0), names
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def cluster_rois(corr_matrix, roi_names, n_clusters=4):
|
| 191 |
+
n = corr_matrix.shape[0]
|
| 192 |
+
n_clusters = min(n_clusters, n)
|
| 193 |
+
dist = 1.0 - np.abs(corr_matrix)
|
| 194 |
+
np.fill_diagonal(dist, 0.0)
|
| 195 |
+
condensed = [dist[i, j] for i in range(n) for j in range(i + 1, n)]
|
| 196 |
+
Z = linkage(condensed, method="average")
|
| 197 |
+
labels = fcluster(Z, t=n_clusters, criterion="maxclust")
|
| 198 |
+
clusters = {}
|
| 199 |
+
for name, cid in zip(roi_names, labels):
|
| 200 |
+
clusters.setdefault(int(cid), []).append(name)
|
| 201 |
+
return clusters, labels
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def graph_metrics(corr_matrix, roi_names, threshold=0.3):
|
| 205 |
+
n = corr_matrix.shape[0]
|
| 206 |
+
adj = (np.abs(corr_matrix) > threshold).astype(float)
|
| 207 |
+
np.fill_diagonal(adj, 0.0)
|
| 208 |
+
degree = adj.sum(axis=1)
|
| 209 |
+
max_d = max(n - 1, 1)
|
| 210 |
+
return {name: float(degree[i] / max_d) for i, name in enumerate(roi_names)}
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# --- Synthetic Data Generators ---
|
| 214 |
+
|
| 215 |
+
def generate_brain_predictions(n_timepoints=60, n_vertices=580, seed=42):
|
| 216 |
+
rng = np.random.default_rng(seed)
|
| 217 |
+
return rng.standard_normal((n_timepoints, n_vertices))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def generate_model_features(n_stimuli=60, feature_dim=512, seed=42):
|
| 221 |
+
rng = np.random.default_rng(seed)
|
| 222 |
+
return rng.standard_normal((n_stimuli, feature_dim))
|