siddhant-rajhans commited on
Commit
bce4bae
·
0 Parent(s):

Initial CortexLab Dashboard - interactive analysis toolkit

Browse files

Streamlit multipage app with 4 analysis pages:
- Brain Alignment Benchmark: compare AI models with RSA/CKA/Procrustes + permutation tests
- Cognitive Load Scorer: timeline, radar chart, dimension breakdown with content profiles
- Temporal Dynamics: peak latency, lag correlation, sustained/transient decomposition
- ROI Connectivity: correlation heatmap, network clustering, degree centrality, graph viz

Runs on synthetic data by default - no GPU or fMRI data required.

.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ .streamlit/secrets.toml
4
+ .env
5
+ venv/
6
+ .venv/
.streamlit/config.toml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor = "#6C5CE7"
3
+ backgroundColor = "#0E1117"
4
+ secondaryBackgroundColor = "#1A1A2E"
5
+ textColor = "#FAFAFA"
6
+ font = "sans serif"
7
+
8
+ [server]
9
+ headless = true
Home.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CortexLab Dashboard - Home Page."""
2
+
3
+ import streamlit as st
4
+
5
+ st.set_page_config(
6
+ page_title="CortexLab Dashboard",
7
+ page_icon="🧠",
8
+ layout="wide",
9
+ initial_sidebar_state="expanded",
10
+ )
11
+
12
+ st.title("CortexLab Dashboard")
13
+ st.markdown("**Interactive analysis toolkit for multimodal fMRI brain encoding**")
14
+
15
+ st.divider()
16
+
17
+ col1, col2 = st.columns(2)
18
+
19
+ with col1:
20
+ st.subheader("Analysis Tools")
21
+ st.page_link("pages/1_Brain_Alignment.py", label="Brain Alignment Benchmark", icon="🎯")
22
+ st.markdown("Score how brain-like any AI model's representations are using RSA, CKA, or Procrustes")
23
+
24
+ st.page_link("pages/2_Cognitive_Load.py", label="Cognitive Load Scorer", icon="📊")
25
+ st.markdown("Predict cognitive demand across visual, auditory, language, and executive dimensions")
26
+
27
+ with col2:
28
+ st.subheader("Advanced Analysis")
29
+ st.page_link("pages/3_Temporal_Dynamics.py", label="Temporal Dynamics", icon="⏱️")
30
+ st.markdown("Analyze peak response latency, lag correlations, and sustained vs transient components")
31
+
32
+ st.page_link("pages/4_Connectivity.py", label="ROI Connectivity", icon="🔗")
33
+ st.markdown("Compute functional connectivity matrices, cluster networks, and graph metrics")
34
+
35
+ st.divider()
36
+
37
+ st.subheader("About")
38
+ st.markdown(
39
+ """
40
+ CortexLab is an enhanced toolkit built on [Meta's TRIBE v2](https://github.com/facebookresearch/tribev2)
41
+ for predicting how the human brain responds to video, audio, and text.
42
+
43
+ This dashboard runs on **synthetic data** by default - no GPU or real fMRI data required.
44
+ All analysis tools mirror the CortexLab Python API.
45
+
46
+ [GitHub](https://github.com/siddhant-rajhans/cortexlab)
47
+   |  
48
+ [HuggingFace](https://huggingface.co/SID2000/cortexlab)
49
+ """
50
+ )
README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CortexLab Dashboard
2
+
3
+ Interactive analysis dashboard for [CortexLab](https://github.com/siddhant-rajhans/cortexlab) - multimodal fMRI brain encoding toolkit.
4
+
5
+ ## Pages
6
+
7
+ - **Brain Alignment Benchmark** - Compare AI model representations against brain responses (RSA, CKA, Procrustes)
8
+ - **Cognitive Load Scorer** - Visualize cognitive demand across visual, auditory, language, and executive dimensions
9
+ - **Temporal Dynamics** - Peak latency, lag correlations, sustained vs transient response decomposition
10
+ - **ROI Connectivity** - Correlation matrices, network clustering, degree centrality, graph visualization
11
+
12
+ ## Quick Start
13
+
14
+ ```bash
15
+ pip install -r requirements.txt
16
+ streamlit run Home.py
17
+ ```
18
+
19
+ Runs on **synthetic data** by default - no GPU or real fMRI data required.
20
+
21
+ ## Links
22
+
23
+ - [CortexLab Library](https://github.com/siddhant-rajhans/cortexlab)
24
+ - [CortexLab on HuggingFace](https://huggingface.co/SID2000/cortexlab)
25
+
26
+ ## License
27
+
28
+ CC BY-NC 4.0
pages/1_Brain_Alignment.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Brain Alignment Benchmark page."""
2
+
3
+ import streamlit as st
4
+ import numpy as np
5
+ import pandas as pd
6
+ import plotly.graph_objects as go
7
+ import plotly.express as px
8
+
9
+ from utils import (
10
+ ALIGNMENT_METHODS,
11
+ make_roi_indices,
12
+ generate_brain_predictions,
13
+ generate_model_features,
14
+ permutation_test,
15
+ )
16
+
17
+ st.set_page_config(page_title="Brain Alignment", page_icon="🎯", layout="wide")
18
+ st.title("🎯 Brain Alignment Benchmark")
19
+ st.markdown("Compare how well AI model representations align with predicted brain responses.")
20
+
21
+ # --- Sidebar Controls ---
22
+ with st.sidebar:
23
+ st.header("Configuration")
24
+ n_stimuli = st.slider("Number of stimuli", 10, 200, 50)
25
+ seed = st.number_input("Random seed", value=42, min_value=0)
26
+
27
+ st.subheader("Models to compare")
28
+ models_config = {
29
+ "CLIP ViT-L/14": st.checkbox("CLIP ViT-L/14", value=True),
30
+ "DINOv2 ViT-S": st.checkbox("DINOv2 ViT-S", value=True),
31
+ "V-JEPA2 ViT-G": st.checkbox("V-JEPA2 ViT-G", value=True),
32
+ "LLaMA 3.2-3B": st.checkbox("LLaMA 3.2-3B", value=False),
33
+ }
34
+ model_dims = {
35
+ "CLIP ViT-L/14": 768,
36
+ "DINOv2 ViT-S": 384,
37
+ "V-JEPA2 ViT-G": 1024,
38
+ "LLaMA 3.2-3B": 3072,
39
+ }
40
+ selected_models = [m for m, checked in models_config.items() if checked]
41
+
42
+ methods = st.multiselect("Methods", ["RSA", "CKA", "Procrustes"], default=["RSA", "CKA"])
43
+ run_stats = st.checkbox("Run permutation test", value=False)
44
+ n_perm = st.slider("Permutations", 50, 1000, 200) if run_stats else 0
45
+
46
+ if not selected_models or not methods:
47
+ st.warning("Select at least one model and one method.")
48
+ st.stop()
49
+
50
+ # --- Generate Data ---
51
+ roi_indices, n_vertices = make_roi_indices()
52
+ brain_pred = generate_brain_predictions(n_stimuli, n_vertices, seed)
53
+
54
+ model_features = {}
55
+ for i, name in enumerate(selected_models):
56
+ model_features[name] = generate_model_features(n_stimuli, model_dims[name], seed + i + 1)
57
+
58
+ # --- Run Benchmark ---
59
+ with st.spinner("Computing alignment scores..."):
60
+ results = []
61
+ for model_name, features in model_features.items():
62
+ for method_name in methods:
63
+ score_fn = ALIGNMENT_METHODS[method_name]
64
+ score = score_fn(features, brain_pred)
65
+ row = {"Model": model_name, "Method": method_name, "Score": score}
66
+
67
+ if run_stats:
68
+ _, p_val = permutation_test(features, brain_pred, score_fn, n_perm, seed)
69
+ row["p-value"] = p_val
70
+ row["Significant"] = "Yes" if p_val < 0.05 else "No"
71
+
72
+ results.append(row)
73
+
74
+ df = pd.DataFrame(results)
75
+
76
+ # --- Display Results ---
77
+ col1, col2 = st.columns([2, 1])
78
+
79
+ with col1:
80
+ st.subheader("Alignment Scores")
81
+ fig = px.bar(
82
+ df,
83
+ x="Model",
84
+ y="Score",
85
+ color="Method",
86
+ barmode="group",
87
+ color_discrete_sequence=px.colors.qualitative.Set2,
88
+ )
89
+ fig.update_layout(
90
+ yaxis_title="Alignment Score",
91
+ height=450,
92
+ template="plotly_dark",
93
+ )
94
+ st.plotly_chart(fig, use_container_width=True)
95
+
96
+ with col2:
97
+ st.subheader("Results Table")
98
+ display_df = df.copy()
99
+ display_df["Score"] = display_df["Score"].map(lambda x: f"{x:.4f}")
100
+ if "p-value" in display_df.columns:
101
+ display_df["p-value"] = display_df["p-value"].map(lambda x: f"{x:.4f}")
102
+ st.dataframe(display_df, use_container_width=True, hide_index=True)
103
+
104
+ # --- Per-ROI Analysis ---
105
+ st.divider()
106
+ st.subheader("Per-ROI Alignment (RSA)")
107
+
108
+ if "RSA" in methods and len(selected_models) >= 1:
109
+ from utils import ROI_GROUPS, rsa_score
110
+
111
+ roi_data = []
112
+ for model_name, features in model_features.items():
113
+ for group_name, rois in ROI_GROUPS.items():
114
+ group_scores = []
115
+ for roi in rois:
116
+ if roi in roi_indices:
117
+ verts = roi_indices[roi]
118
+ valid = verts[verts < brain_pred.shape[1]]
119
+ if len(valid) >= 2:
120
+ s = rsa_score(features, brain_pred[:, valid])
121
+ group_scores.append(s)
122
+ if group_scores:
123
+ roi_data.append({
124
+ "Model": model_name,
125
+ "Region": group_name,
126
+ "RSA Score": float(np.mean(group_scores)),
127
+ })
128
+
129
+ if roi_data:
130
+ roi_df = pd.DataFrame(roi_data)
131
+ fig2 = px.bar(
132
+ roi_df,
133
+ x="Region",
134
+ y="RSA Score",
135
+ color="Model",
136
+ barmode="group",
137
+ color_discrete_sequence=px.colors.qualitative.Pastel,
138
+ )
139
+ fig2.update_layout(height=400, template="plotly_dark")
140
+ st.plotly_chart(fig2, use_container_width=True)
pages/2_Cognitive_Load.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cognitive Load Scorer page."""
2
+
3
+ import streamlit as st
4
+ import numpy as np
5
+ import pandas as pd
6
+ import plotly.graph_objects as go
7
+ import plotly.express as px
8
+
9
+ from utils import (
10
+ make_roi_indices,
11
+ generate_brain_predictions,
12
+ score_cognitive_load,
13
+ COGNITIVE_DIMENSIONS,
14
+ )
15
+
16
+ st.set_page_config(page_title="Cognitive Load", page_icon="📊", layout="wide")
17
+ st.title("📊 Cognitive Load Scorer")
18
+ st.markdown("Predict cognitive demand of media content from brain activation patterns.")
19
+
20
+ # --- Sidebar ---
21
+ with st.sidebar:
22
+ st.header("Configuration")
23
+ n_timepoints = st.slider("Duration (TRs)", 20, 200, 60)
24
+ tr_seconds = st.slider("TR duration (seconds)", 0.5, 2.0, 1.0, 0.1)
25
+ seed = st.number_input("Random seed", value=42, min_value=0)
26
+
27
+ st.subheader("Simulate content type")
28
+ content_type = st.selectbox(
29
+ "Content profile",
30
+ ["Balanced", "Visual-heavy", "Audio-heavy", "Language-heavy", "Low engagement"],
31
+ )
32
+
33
+ # --- Generate Data ---
34
+ roi_indices, n_vertices = make_roi_indices()
35
+ predictions = generate_brain_predictions(n_timepoints, n_vertices, seed)
36
+
37
+ # Apply content profile by amplifying relevant ROIs
38
+ multipliers = {"Balanced": {}, "Visual-heavy": {}, "Audio-heavy": {}, "Language-heavy": {}, "Low engagement": {}}
39
+ if content_type == "Visual-heavy":
40
+ for roi in COGNITIVE_DIMENSIONS["Visual Complexity"]:
41
+ if roi in roi_indices:
42
+ predictions[:, roi_indices[roi]] *= 3.0
43
+ elif content_type == "Audio-heavy":
44
+ for roi in COGNITIVE_DIMENSIONS["Auditory Demand"]:
45
+ if roi in roi_indices:
46
+ predictions[:, roi_indices[roi]] *= 3.0
47
+ elif content_type == "Language-heavy":
48
+ for roi in COGNITIVE_DIMENSIONS["Language Processing"]:
49
+ if roi in roi_indices:
50
+ predictions[:, roi_indices[roi]] *= 3.0
51
+ elif content_type == "Low engagement":
52
+ predictions *= 0.2
53
+
54
+ # --- Score ---
55
+ averages, timeline = score_cognitive_load(predictions, roi_indices, tr_seconds)
56
+
57
+ # --- Display ---
58
+ col1, col2, col3, col4, col5 = st.columns(5)
59
+ dims = ["Overall", "Visual Complexity", "Auditory Demand", "Language Processing", "Executive Load"]
60
+ cols = [col1, col2, col3, col4, col5]
61
+ for col, dim in zip(cols, dims):
62
+ val = averages.get(dim, 0.0)
63
+ col.metric(dim, f"{val:.2f}", delta=None)
64
+
65
+ st.divider()
66
+
67
+ # --- Timeline ---
68
+ st.subheader("Cognitive Load Timeline")
69
+ timeline_df = pd.DataFrame(timeline)
70
+
71
+ fig = go.Figure()
72
+ colors = {"Visual Complexity": "#00D2FF", "Auditory Demand": "#FF6B6B", "Language Processing": "#A29BFE", "Executive Load": "#FFEAA7"}
73
+ for dim, color in colors.items():
74
+ fig.add_trace(go.Scatter(
75
+ x=timeline_df["time"],
76
+ y=timeline_df[dim],
77
+ name=dim,
78
+ line=dict(color=color, width=2),
79
+ mode="lines",
80
+ ))
81
+
82
+ fig.update_layout(
83
+ xaxis_title="Time (seconds)",
84
+ yaxis_title="Cognitive Load (normalized)",
85
+ yaxis_range=[0, 1.05],
86
+ height=400,
87
+ template="plotly_dark",
88
+ legend=dict(orientation="h", yanchor="bottom", y=1.02),
89
+ )
90
+ st.plotly_chart(fig, use_container_width=True)
91
+
92
+ # --- Dimension Breakdown ---
93
+ st.divider()
94
+ col1, col2 = st.columns(2)
95
+
96
+ with col1:
97
+ st.subheader("Dimension Breakdown")
98
+ dim_data = {k: v for k, v in averages.items() if k != "Overall"}
99
+ fig2 = go.Figure(go.Bar(
100
+ x=list(dim_data.values()),
101
+ y=list(dim_data.keys()),
102
+ orientation="h",
103
+ marker_color=list(colors.values()),
104
+ ))
105
+ fig2.update_layout(
106
+ xaxis_title="Score",
107
+ xaxis_range=[0, 1],
108
+ height=300,
109
+ template="plotly_dark",
110
+ )
111
+ st.plotly_chart(fig2, use_container_width=True)
112
+
113
+ with col2:
114
+ st.subheader("Radar Chart")
115
+ categories = list(dim_data.keys())
116
+ values = list(dim_data.values()) + [list(dim_data.values())[0]] # close the polygon
117
+
118
+ fig3 = go.Figure(go.Scatterpolar(
119
+ r=values,
120
+ theta=categories + [categories[0]],
121
+ fill="toself",
122
+ fillcolor="rgba(108, 92, 231, 0.3)",
123
+ line=dict(color="#6C5CE7"),
124
+ ))
125
+ fig3.update_layout(
126
+ polar=dict(radialaxis=dict(visible=True, range=[0, 1])),
127
+ height=350,
128
+ template="plotly_dark",
129
+ showlegend=False,
130
+ )
131
+ st.plotly_chart(fig3, use_container_width=True)
pages/3_Temporal_Dynamics.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Temporal Dynamics page."""
2
+
3
+ import streamlit as st
4
+ import numpy as np
5
+ import pandas as pd
6
+ import plotly.graph_objects as go
7
+ from plotly.subplots import make_subplots
8
+
9
+ from utils import (
10
+ make_roi_indices,
11
+ generate_brain_predictions,
12
+ generate_model_features,
13
+ peak_latency,
14
+ temporal_correlation,
15
+ decompose_response,
16
+ ROI_GROUPS,
17
+ ALL_ROIS,
18
+ )
19
+
20
+ st.set_page_config(page_title="Temporal Dynamics", page_icon="⏱️", layout="wide")
21
+ st.title("⏱️ Temporal Dynamics")
22
+ st.markdown("Analyze how brain responses evolve over time.")
23
+
24
+ # --- Sidebar ---
25
+ with st.sidebar:
26
+ st.header("Configuration")
27
+ n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
28
+ tr_seconds = st.slider("TR duration (seconds)", 0.5, 2.0, 1.0, 0.1)
29
+ seed = st.number_input("Random seed", value=42, min_value=0)
30
+
31
+ st.subheader("ROI Selection")
32
+ selected_group = st.selectbox("Region group", list(ROI_GROUPS.keys()))
33
+ available_rois = ROI_GROUPS[selected_group]
34
+ selected_rois = st.multiselect("ROIs to analyze", available_rois, default=available_rois[:3])
35
+
36
+ max_lag = st.slider("Max correlation lag (TRs)", 5, 30, 15)
37
+ cutoff = st.slider("Decomposition cutoff (seconds)", 1.0, 10.0, 4.0, 0.5)
38
+
39
+ if not selected_rois:
40
+ st.warning("Select at least one ROI.")
41
+ st.stop()
42
+
43
+ # --- Generate Data ---
44
+ roi_indices, n_vertices = make_roi_indices()
45
+ predictions = generate_brain_predictions(n_timepoints, n_vertices, seed)
46
+ features = generate_model_features(n_timepoints, 64, seed + 1)
47
+
48
+ # --- Peak Latency ---
49
+ st.subheader("Peak Response Latency")
50
+ latency_data = []
51
+ for roi in selected_rois:
52
+ lat = peak_latency(predictions, roi_indices, roi, tr_seconds)
53
+ latency_data.append({"ROI": roi, "Peak Latency (s)": lat})
54
+
55
+ lat_df = pd.DataFrame(latency_data)
56
+ col1, col2 = st.columns([2, 1])
57
+
58
+ with col1:
59
+ fig = go.Figure(go.Bar(
60
+ x=lat_df["ROI"],
61
+ y=lat_df["Peak Latency (s)"],
62
+ marker_color="#6C5CE7",
63
+ ))
64
+ fig.update_layout(
65
+ yaxis_title="Time to peak (seconds)",
66
+ height=350,
67
+ template="plotly_dark",
68
+ )
69
+ st.plotly_chart(fig, use_container_width=True)
70
+
71
+ with col2:
72
+ st.dataframe(lat_df, use_container_width=True, hide_index=True)
73
+
74
+ # --- Lag Correlation ---
75
+ st.divider()
76
+ st.subheader("Temporal Correlation (Brain vs Model Features)")
77
+
78
+ lags = np.arange(-max_lag, max_lag + 1) * tr_seconds
79
+ fig2 = go.Figure()
80
+ colors = ["#00D2FF", "#FF6B6B", "#A29BFE", "#FFEAA7", "#55EFC4", "#FD79A8"]
81
+
82
+ for i, roi in enumerate(selected_rois):
83
+ corr = temporal_correlation(predictions, features, roi_indices, roi, max_lag)
84
+ fig2.add_trace(go.Scatter(
85
+ x=lags,
86
+ y=corr,
87
+ name=roi,
88
+ line=dict(color=colors[i % len(colors)], width=2),
89
+ ))
90
+
91
+ fig2.add_vline(x=0, line_dash="dash", line_color="gray", opacity=0.5)
92
+ fig2.update_layout(
93
+ xaxis_title="Lag (seconds)",
94
+ yaxis_title="Pearson Correlation",
95
+ height=400,
96
+ template="plotly_dark",
97
+ legend=dict(orientation="h", yanchor="bottom", y=1.02),
98
+ )
99
+ st.plotly_chart(fig2, use_container_width=True)
100
+
101
+ # --- Sustained vs Transient ---
102
+ st.divider()
103
+ st.subheader("Sustained vs Transient Decomposition")
104
+
105
+ roi_for_decomp = st.selectbox("ROI for decomposition", selected_rois)
106
+ sustained, transient = decompose_response(predictions, roi_indices, roi_for_decomp, cutoff, tr_seconds)
107
+ time_axis = np.arange(len(sustained)) * tr_seconds
108
+
109
+ fig3 = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.08,
110
+ subplot_titles=("Sustained Component", "Transient Component"))
111
+
112
+ fig3.add_trace(go.Scatter(x=time_axis, y=sustained, name="Sustained",
113
+ line=dict(color="#6C5CE7", width=2)), row=1, col=1)
114
+ fig3.add_trace(go.Scatter(x=time_axis, y=transient, name="Transient",
115
+ line=dict(color="#FF6B6B", width=1.5)), row=2, col=1)
116
+
117
+ fig3.update_xaxes(title_text="Time (seconds)", row=2, col=1)
118
+ fig3.update_layout(height=500, template="plotly_dark", showlegend=False)
119
+ st.plotly_chart(fig3, use_container_width=True)
pages/4_Connectivity.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ROI Connectivity page."""
2
+
3
+ import streamlit as st
4
+ import numpy as np
5
+ import pandas as pd
6
+ import plotly.graph_objects as go
7
+ import plotly.express as px
8
+
9
+ from utils import (
10
+ make_roi_indices,
11
+ generate_brain_predictions,
12
+ compute_connectivity,
13
+ cluster_rois,
14
+ graph_metrics,
15
+ ROI_GROUPS,
16
+ )
17
+
18
+ st.set_page_config(page_title="ROI Connectivity", page_icon="🔗", layout="wide")
19
+ st.title("🔗 ROI Connectivity Analysis")
20
+ st.markdown("Functional connectivity between brain regions from predicted responses.")
21
+
22
+ # --- Sidebar ---
23
+ with st.sidebar:
24
+ st.header("Configuration")
25
+ n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
26
+ seed = st.number_input("Random seed", value=42, min_value=0)
27
+ n_clusters = st.slider("Number of clusters", 2, 8, 4)
28
+ threshold = st.slider("Edge threshold", 0.1, 0.8, 0.3, 0.05)
29
+
30
+ # --- Generate Data ---
31
+ roi_indices, n_vertices = make_roi_indices()
32
+ predictions = generate_brain_predictions(n_timepoints, n_vertices, seed)
33
+
34
+ # --- Correlation Matrix ---
35
+ corr_matrix, roi_names = compute_connectivity(predictions, roi_indices)
36
+
37
+ st.subheader("Correlation Matrix")
38
+ fig = go.Figure(go.Heatmap(
39
+ z=corr_matrix,
40
+ x=roi_names,
41
+ y=roi_names,
42
+ colorscale="RdBu_r",
43
+ zmid=0,
44
+ zmin=-1,
45
+ zmax=1,
46
+ colorbar=dict(title="Correlation"),
47
+ ))
48
+ fig.update_layout(
49
+ height=600,
50
+ width=700,
51
+ template="plotly_dark",
52
+ xaxis=dict(tickangle=45, tickfont=dict(size=8)),
53
+ yaxis=dict(tickfont=dict(size=8)),
54
+ )
55
+ st.plotly_chart(fig, use_container_width=True)
56
+
57
+ # --- Network Clusters ---
58
+ st.divider()
59
+ col1, col2 = st.columns(2)
60
+
61
+ with col1:
62
+ st.subheader("Functional Network Clusters")
63
+ clusters, labels = cluster_rois(corr_matrix, roi_names, n_clusters)
64
+
65
+ cluster_data = []
66
+ for cid, rois in sorted(clusters.items()):
67
+ for roi in rois:
68
+ cluster_data.append({"Cluster": f"Network {cid}", "ROI": roi})
69
+
70
+ cluster_df = pd.DataFrame(cluster_data)
71
+ fig2 = px.bar(
72
+ cluster_df.groupby("Cluster").size().reset_index(name="Count"),
73
+ x="Cluster",
74
+ y="Count",
75
+ color="Cluster",
76
+ color_discrete_sequence=px.colors.qualitative.Set2,
77
+ )
78
+ fig2.update_layout(height=350, template="plotly_dark", showlegend=False)
79
+ st.plotly_chart(fig2, use_container_width=True)
80
+
81
+ for cid, rois in sorted(clusters.items()):
82
+ st.markdown(f"**Network {cid}:** {', '.join(rois)}")
83
+
84
+ with col2:
85
+ st.subheader("Degree Centrality")
86
+ degrees = graph_metrics(corr_matrix, roi_names, threshold)
87
+
88
+ # Sort by degree
89
+ sorted_degrees = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
90
+ deg_df = pd.DataFrame(sorted_degrees, columns=["ROI", "Degree Centrality"])
91
+
92
+ fig3 = go.Figure(go.Bar(
93
+ x=deg_df["Degree Centrality"],
94
+ y=deg_df["ROI"],
95
+ orientation="h",
96
+ marker_color="#6C5CE7",
97
+ ))
98
+ fig3.update_layout(
99
+ xaxis_title="Degree Centrality",
100
+ xaxis_range=[0, 1],
101
+ height=600,
102
+ template="plotly_dark",
103
+ yaxis=dict(autorange="reversed", tickfont=dict(size=9)),
104
+ )
105
+ st.plotly_chart(fig3, use_container_width=True)
106
+
107
+ # --- Network Graph ---
108
+ st.divider()
109
+ st.subheader("Network Graph")
110
+
111
+ try:
112
+ import networkx as nx
113
+
114
+ G = nx.Graph()
115
+ for name in roi_names:
116
+ G.add_node(name)
117
+
118
+ for i in range(len(roi_names)):
119
+ for j in range(i + 1, len(roi_names)):
120
+ if abs(corr_matrix[i, j]) > threshold:
121
+ G.add_edge(roi_names[i], roi_names[j], weight=abs(corr_matrix[i, j]))
122
+
123
+ pos = nx.spring_layout(G, seed=seed, k=2.0)
124
+
125
+ # Cluster colors
126
+ color_map = px.colors.qualitative.Set2
127
+ node_colors = [color_map[(labels[i] - 1) % len(color_map)] for i in range(len(roi_names))]
128
+
129
+ edge_x, edge_y = [], []
130
+ for u, v in G.edges():
131
+ x0, y0 = pos[u]
132
+ x1, y1 = pos[v]
133
+ edge_x.extend([x0, x1, None])
134
+ edge_y.extend([y0, y1, None])
135
+
136
+ node_x = [pos[n][0] for n in roi_names]
137
+ node_y = [pos[n][1] for n in roi_names]
138
+
139
+ fig4 = go.Figure()
140
+ fig4.add_trace(go.Scatter(
141
+ x=edge_x, y=edge_y, mode="lines",
142
+ line=dict(width=0.5, color="rgba(150,150,150,0.3)"),
143
+ hoverinfo="none",
144
+ ))
145
+ fig4.add_trace(go.Scatter(
146
+ x=node_x, y=node_y, mode="markers+text",
147
+ marker=dict(size=12, color=node_colors, line=dict(width=1, color="white")),
148
+ text=roi_names,
149
+ textposition="top center",
150
+ textfont=dict(size=8),
151
+ hoverinfo="text",
152
+ ))
153
+ fig4.update_layout(
154
+ height=500,
155
+ template="plotly_dark",
156
+ showlegend=False,
157
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
158
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
159
+ )
160
+ st.plotly_chart(fig4, use_container_width=True)
161
+ except ImportError:
162
+ st.info("Install `networkx` for the network graph visualization: `pip install networkx`")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.30
2
+ plotly>=5.18
3
+ numpy>=2.0
4
+ scipy>=1.12
5
+ pandas>=2.0
6
+ networkx>=3.2
7
+ matplotlib>=3.8
8
+ seaborn>=0.13
utils.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared utilities for the CortexLab dashboard.
2
+
3
+ Provides synthetic data generation and analysis functions that mirror
4
+ CortexLab's API without requiring the full library or GPU.
5
+ """
6
+
7
+ import numpy as np
8
+ from scipy.stats import spearmanr
9
+ from scipy.cluster.hierarchy import linkage, fcluster
10
+
11
+
12
+ # --- ROI Definitions ---
13
+
14
+ ROI_GROUPS = {
15
+ "Executive": ["46", "9-46d", "8Av", "8Ad", "FEF", "p32pr", "a32pr"],
16
+ "Visual": ["V1", "V2", "V3", "V4", "MT", "MST", "FFC", "VVC"],
17
+ "Auditory": ["A1", "LBelt", "MBelt", "PBelt", "A4", "A5"],
18
+ "Language": ["44", "45", "IFJa", "IFJp", "TPOJ1", "TPOJ2", "STV", "PSL"],
19
+ }
20
+
21
+ ALL_ROIS = [roi for group in ROI_GROUPS.values() for roi in group]
22
+
23
+
24
+ def make_roi_indices(n_vertices_per_roi=20):
25
+ """Create ROI -> vertex index mapping."""
26
+ indices = {}
27
+ offset = 0
28
+ for roi in ALL_ROIS:
29
+ indices[roi] = np.arange(offset, offset + n_vertices_per_roi)
30
+ offset += n_vertices_per_roi
31
+ return indices, offset
32
+
33
+
34
+ # --- Brain Alignment ---
35
+
36
+ def compute_rdm(features):
37
+ norms = np.linalg.norm(features, axis=1, keepdims=True)
38
+ norms = np.where(norms > 0, norms, 1.0)
39
+ normalised = features / norms
40
+ return 1.0 - normalised @ normalised.T
41
+
42
+
43
+ def rsa_score(model_features, brain_features):
44
+ rdm_m = compute_rdm(model_features)
45
+ rdm_b = compute_rdm(brain_features)
46
+ idx = np.triu_indices(rdm_m.shape[0], k=1)
47
+ corr, _ = spearmanr(rdm_m[idx], rdm_b[idx])
48
+ return float(corr) if not np.isnan(corr) else 0.0
49
+
50
+
51
+ def cka_score(X, Y):
52
+ n = X.shape[0]
53
+ X = X - X.mean(axis=0)
54
+ Y = Y - Y.mean(axis=0)
55
+ XX = X @ X.T
56
+ YY = Y @ Y.T
57
+ hsic_xy = np.trace(XX @ YY) / (n - 1) ** 2
58
+ hsic_xx = np.trace(XX @ XX) / (n - 1) ** 2
59
+ hsic_yy = np.trace(YY @ YY) / (n - 1) ** 2
60
+ denom = np.sqrt(hsic_xx * hsic_yy)
61
+ return float(hsic_xy / denom) if denom > 1e-12 else 0.0
62
+
63
+
64
+ def procrustes_score(X, Y):
65
+ min_dim = min(X.shape[1], Y.shape[1])
66
+ X, Y = X[:, :min_dim], Y[:, :min_dim]
67
+ X = X - X.mean(axis=0)
68
+ Y = Y - Y.mean(axis=0)
69
+ nx, ny = np.linalg.norm(X), np.linalg.norm(Y)
70
+ if nx < 1e-12 or ny < 1e-12:
71
+ return 0.0
72
+ X, Y = X / nx, Y / ny
73
+ U, _, Vt = np.linalg.svd(Y.T @ X, full_matrices=False)
74
+ rotated = Y @ (U @ Vt)
75
+ return float(max(0.0, 1.0 - np.linalg.norm(X - rotated)))
76
+
77
+
78
+ ALIGNMENT_METHODS = {"RSA": rsa_score, "CKA": cka_score, "Procrustes": procrustes_score}
79
+
80
+
81
+ def permutation_test(model_feat, brain_pred, method_fn, n_perm=500, seed=42):
82
+ rng = np.random.default_rng(seed)
83
+ observed = method_fn(model_feat, brain_pred)
84
+ count = sum(
85
+ 1 for _ in range(n_perm)
86
+ if method_fn(model_feat[rng.permutation(len(model_feat))], brain_pred) >= observed
87
+ )
88
+ return observed, (count + 1) / (n_perm + 1)
89
+
90
+
91
+ # --- Cognitive Load ---
92
+
93
+ COGNITIVE_DIMENSIONS = {
94
+ "Executive Load": ["46", "9-46d", "8Av", "8Ad", "FEF", "p32pr", "a32pr"],
95
+ "Visual Complexity": ["V1", "V2", "V3", "V4", "MT", "MST", "FFC", "VVC"],
96
+ "Auditory Demand": ["A1", "LBelt", "MBelt", "PBelt", "A4", "A5"],
97
+ "Language Processing": ["44", "45", "IFJa", "IFJp", "TPOJ1", "TPOJ2", "STV", "PSL"],
98
+ }
99
+
100
+
101
+ def score_cognitive_load(predictions, roi_indices, tr_seconds=1.0):
102
+ baseline = max(float(np.median(np.abs(predictions))), 1e-8)
103
+ timeline = []
104
+ dim_scores = {d: [] for d in COGNITIVE_DIMENSIONS}
105
+
106
+ for t in range(predictions.shape[0]):
107
+ row = {}
108
+ for dim, rois in COGNITIVE_DIMENSIONS.items():
109
+ vals = []
110
+ for roi in rois:
111
+ if roi in roi_indices:
112
+ verts = roi_indices[roi]
113
+ valid = verts[verts < predictions.shape[1]]
114
+ if len(valid) > 0:
115
+ vals.append(np.abs(predictions[t, valid]).mean())
116
+ score = min(float(np.mean(vals)) / baseline, 1.0) if vals else 0.0
117
+ dim_scores[dim].append(score)
118
+ row[dim] = score
119
+ row["time"] = t * tr_seconds
120
+ timeline.append(row)
121
+
122
+ averages = {d: float(np.mean(v)) for d, v in dim_scores.items()}
123
+ averages["Overall"] = float(np.mean(list(averages.values())))
124
+ return averages, timeline
125
+
126
+
127
+ # --- Temporal Dynamics ---
128
+
129
+ def peak_latency(predictions, roi_indices, roi_name, tr_seconds=1.0):
130
+ verts = roi_indices.get(roi_name, np.array([]))
131
+ valid = verts[verts < predictions.shape[1]]
132
+ if len(valid) == 0:
133
+ return 0.0
134
+ tc = np.abs(predictions[:, valid]).mean(axis=1)
135
+ return float(np.argmax(tc) * tr_seconds)
136
+
137
+
138
+ def temporal_correlation(predictions, features, roi_indices, roi_name, max_lag=10):
139
+ verts = roi_indices.get(roi_name, np.array([]))
140
+ valid = verts[verts < predictions.shape[1]]
141
+ if len(valid) == 0:
142
+ return np.zeros(2 * max_lag + 1)
143
+ brain_tc = np.abs(predictions[:, valid]).mean(axis=1)
144
+ model_tc = features.mean(axis=1) if features.ndim > 1 else features
145
+ n = min(len(brain_tc), len(model_tc))
146
+ brain_tc, model_tc = brain_tc[:n], model_tc[:n]
147
+
148
+ corrs = []
149
+ for lag in range(-max_lag, max_lag + 1):
150
+ if lag >= 0:
151
+ b, m = brain_tc[lag:], model_tc[:n - lag]
152
+ else:
153
+ b, m = brain_tc[:n + lag], model_tc[-lag:]
154
+ if len(b) < 2:
155
+ corrs.append(0.0)
156
+ continue
157
+ bz, mz = b - b.mean(), m - m.mean()
158
+ denom = np.sqrt((bz ** 2).sum() * (mz ** 2).sum())
159
+ corrs.append(float((bz * mz).sum() / denom) if denom > 1e-12 else 0.0)
160
+ return np.array(corrs)
161
+
162
+
163
+ def decompose_response(predictions, roi_indices, roi_name, cutoff_seconds=4.0, tr_seconds=1.0):
164
+ verts = roi_indices.get(roi_name, np.array([]))
165
+ valid = verts[verts < predictions.shape[1]]
166
+ if len(valid) == 0:
167
+ return np.zeros(predictions.shape[0]), np.zeros(predictions.shape[0])
168
+ tc = np.abs(predictions[:, valid]).mean(axis=1)
169
+ window = max(1, int(cutoff_seconds / tr_seconds))
170
+ sustained = np.convolve(tc, np.ones(window) / window, mode="same")
171
+ return sustained, tc - sustained
172
+
173
+
174
+ # --- Connectivity ---
175
+
176
+ def compute_connectivity(predictions, roi_indices):
177
+ names = list(roi_indices.keys())
178
+ n = len(names)
179
+ T = predictions.shape[0]
180
+ timecourses = np.zeros((n, T))
181
+ for i, name in enumerate(names):
182
+ verts = roi_indices[name]
183
+ valid = verts[verts < predictions.shape[1]]
184
+ if len(valid) > 0:
185
+ timecourses[i] = predictions[:, valid].mean(axis=1)
186
+ corr = np.corrcoef(timecourses) if T >= 2 else np.eye(n)
187
+ return np.nan_to_num(corr, nan=0.0), names
188
+
189
+
190
+ def cluster_rois(corr_matrix, roi_names, n_clusters=4):
191
+ n = corr_matrix.shape[0]
192
+ n_clusters = min(n_clusters, n)
193
+ dist = 1.0 - np.abs(corr_matrix)
194
+ np.fill_diagonal(dist, 0.0)
195
+ condensed = [dist[i, j] for i in range(n) for j in range(i + 1, n)]
196
+ Z = linkage(condensed, method="average")
197
+ labels = fcluster(Z, t=n_clusters, criterion="maxclust")
198
+ clusters = {}
199
+ for name, cid in zip(roi_names, labels):
200
+ clusters.setdefault(int(cid), []).append(name)
201
+ return clusters, labels
202
+
203
+
204
+ def graph_metrics(corr_matrix, roi_names, threshold=0.3):
205
+ n = corr_matrix.shape[0]
206
+ adj = (np.abs(corr_matrix) > threshold).astype(float)
207
+ np.fill_diagonal(adj, 0.0)
208
+ degree = adj.sum(axis=1)
209
+ max_d = max(n - 1, 1)
210
+ return {name: float(degree[i] / max_d) for i, name in enumerate(roi_names)}
211
+
212
+
213
+ # --- Synthetic Data Generators ---
214
+
215
+ def generate_brain_predictions(n_timepoints=60, n_vertices=580, seed=42):
216
+ rng = np.random.default_rng(seed)
217
+ return rng.standard_normal((n_timepoints, n_vertices))
218
+
219
+
220
+ def generate_model_features(n_stimuli=60, feature_dim=512, seed=42):
221
+ rng = np.random.default_rng(seed)
222
+ return rng.standard_normal((n_stimuli, feature_dim))