siddhant-rajhans commited on
Commit
9b23ae9
·
1 Parent(s): bce4bae

Dashboard v2: research-grade rebuild

Browse files

Biologically realistic synthetic data:
- HRF convolution with canonical double-gamma response
- Modality-specific ROI activation weights (visual/auditory/language/multimodal)
- Spatial smoothing, AR(1) temporal noise, scanner drift
- Controllable alignment strength for meaningful benchmark scores

Shared session state:
- Cross-page ROI carry (alignment -> temporal dynamics -> connectivity)
- File upload (.npy) and download (CSV, JSON) support
- Analysis log tracking across pages
- Data source selector (synthetic vs uploaded)

Brain Alignment (rebuilt):
- Bootstrap CI error bars on all scores
- Null distribution histograms with observed score overlay
- RDM visualization (brain vs model side-by-side)
- FDR correction for per-ROI tests
- Noise ceiling estimation (split-half reliability)
- Methodology documentation with references

Cognitive Load (rebuilt):
- Confidence bands on timeline (bootstrap across vertices)
- Dimension correlation heatmap
- Per-ROI activation breakdown within each dimension
- Comparison mode (two stimulus types side-by-side)

Temporal Dynamics (rebuilt):
- Raw ROI timecourses showing HRF shape
- Peak latency sorted by processing hierarchy
- Lag correlation with 95% null significance band
- Optimal lag summary table
- Cross-ROI lag matrix
- 3-panel sustained/transient decomposition

Connectivity (rebuilt):
- Partial correlation option
- Correlation matrix with cluster boundary lines
- Dendrogram visualization
- Modularity Q score with interpretation
- Betweenness centrality alongside degree
- Edge weight distribution histogram
- Network graph with weighted edges and sized nodes
- Cluster-to-functional-group mapping

Files changed (8) hide show
  1. Home.py +87 -27
  2. pages/1_Brain_Alignment.py +212 -102
  3. pages/2_Cognitive_Load.py +167 -85
  4. pages/3_Temporal_Dynamics.py +200 -62
  5. pages/4_Connectivity.py +248 -127
  6. session.py +134 -0
  7. synthetic.py +230 -0
  8. utils.py +106 -5
Home.py CHANGED
@@ -1,50 +1,110 @@
1
- """CortexLab Dashboard - Home Page."""
2
 
3
  import streamlit as st
 
4
 
5
- st.set_page_config(
6
- page_title="CortexLab Dashboard",
7
- page_icon="🧠",
8
- layout="wide",
9
- initial_sidebar_state="expanded",
10
- )
11
 
12
  st.title("CortexLab Dashboard")
13
- st.markdown("**Interactive analysis toolkit for multimodal fMRI brain encoding**")
14
 
 
15
  st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  col1, col2 = st.columns(2)
18
 
19
  with col1:
20
  st.subheader("Analysis Tools")
21
  st.page_link("pages/1_Brain_Alignment.py", label="Brain Alignment Benchmark", icon="🎯")
22
- st.markdown("Score how brain-like any AI model's representations are using RSA, CKA, or Procrustes")
23
 
24
  st.page_link("pages/2_Cognitive_Load.py", label="Cognitive Load Scorer", icon="📊")
25
- st.markdown("Predict cognitive demand across visual, auditory, language, and executive dimensions")
26
 
27
  with col2:
28
  st.subheader("Advanced Analysis")
29
  st.page_link("pages/3_Temporal_Dynamics.py", label="Temporal Dynamics", icon="⏱️")
30
- st.markdown("Analyze peak response latency, lag correlations, and sustained vs transient components")
31
 
32
  st.page_link("pages/4_Connectivity.py", label="ROI Connectivity", icon="🔗")
33
- st.markdown("Compute functional connectivity matrices, cluster networks, and graph metrics")
34
 
35
- st.divider()
 
36
 
37
- st.subheader("About")
38
- st.markdown(
39
- """
40
- CortexLab is an enhanced toolkit built on [Meta's TRIBE v2](https://github.com/facebookresearch/tribev2)
41
- for predicting how the human brain responds to video, audio, and text.
42
-
43
- This dashboard runs on **synthetic data** by default - no GPU or real fMRI data required.
44
- All analysis tools mirror the CortexLab Python API.
45
-
46
- [GitHub](https://github.com/siddhant-rajhans/cortexlab)
47
-   |  
48
- [HuggingFace](https://huggingface.co/SID2000/cortexlab)
49
- """
50
- )
 
1
+ """CortexLab Dashboard - Home Page with Data Management."""
2
 
3
  import streamlit as st
4
+ import numpy as np
5
 
6
+ from session import init_session, data_summary_widget, show_analysis_log, upload_npy_widget
7
+ from utils import make_roi_indices
8
+
9
+ st.set_page_config(page_title="CortexLab Dashboard", page_icon="🧠", layout="wide", initial_sidebar_state="expanded")
10
+ init_session()
 
11
 
12
  st.title("CortexLab Dashboard")
13
+ st.markdown("**Research-grade analysis toolkit for multimodal fMRI brain encoding**")
14
 
15
+ # --- Data Source ---
16
  st.divider()
17
+ st.subheader("Data Configuration")
18
+
19
+ col_src, col_params = st.columns([1, 2])
20
+
21
+ with col_src:
22
+ source = st.radio("Data source", ["Synthetic (realistic)", "Upload your data"], index=0)
23
+ st.session_state["data_source"] = "synthetic" if "Synthetic" in source else "uploaded"
24
+
25
+ with col_params:
26
+ if st.session_state["data_source"] == "synthetic":
27
+ c1, c2, c3, c4 = st.columns(4)
28
+ st.session_state["stimulus_type"] = c1.selectbox("Stimulus type", ["visual", "auditory", "language", "multimodal"])
29
+ st.session_state["n_timepoints"] = c2.slider("Duration (TRs)", 30, 200, 80)
30
+ st.session_state["tr_seconds"] = c3.slider("TR (seconds)", 0.5, 2.0, 1.0, 0.1)
31
+ st.session_state["seed"] = c4.number_input("Seed", value=42, min_value=0)
32
+
33
+ # Generate on config change
34
+ roi_indices, n_vertices = make_roi_indices()
35
+ st.session_state["roi_indices"] = roi_indices
36
+ st.session_state["n_vertices"] = n_vertices
37
+
38
+ from synthetic import generate_realistic_predictions
39
+ predictions = generate_realistic_predictions(
40
+ st.session_state["n_timepoints"], roi_indices,
41
+ st.session_state["stimulus_type"], st.session_state["tr_seconds"],
42
+ seed=st.session_state["seed"],
43
+ )
44
+ st.session_state["brain_predictions"] = predictions
45
+ else:
46
+ uploaded = upload_npy_widget("Upload brain predictions (.npy, shape: timepoints x vertices)", "upload_predictions")
47
+ if uploaded is not None:
48
+ st.session_state["brain_predictions"] = uploaded
49
+ roi_indices, _ = make_roi_indices()
50
+ st.session_state["roi_indices"] = roi_indices
51
+
52
+ # --- Data Summary ---
53
+ roi_indices = st.session_state.get("roi_indices")
54
+ predictions = st.session_state.get("brain_predictions")
55
+ if predictions is not None and roi_indices is not None:
56
+ data_summary_widget(predictions, roi_indices)
57
+
58
+ # Show HRF-convolved signal preview
59
+ with st.expander("Data Preview", expanded=False):
60
+ import plotly.graph_objects as go
61
+ from utils import ROI_GROUPS
62
 
63
+ fig = go.Figure()
64
+ t = np.arange(predictions.shape[0]) * st.session_state.get("tr_seconds", 1.0)
65
+ colors = {"Visual": "#00D2FF", "Auditory": "#FF6B6B", "Language": "#A29BFE", "Executive": "#FFEAA7"}
66
+ for group, rois in ROI_GROUPS.items():
67
+ vals = []
68
+ for roi in rois:
69
+ if roi in roi_indices:
70
+ verts = roi_indices[roi]
71
+ valid = verts[verts < predictions.shape[1]]
72
+ if len(valid) > 0:
73
+ vals.append(np.abs(predictions[:, valid]).mean(axis=1))
74
+ if vals:
75
+ mean_tc = np.mean(vals, axis=0)
76
+ fig.add_trace(go.Scatter(x=t, y=mean_tc, name=group, line=dict(color=colors.get(group, "#888"), width=2)))
77
+
78
+ fig.update_layout(
79
+ xaxis_title="Time (seconds)", yaxis_title="Mean |activation|",
80
+ height=300, template="plotly_dark",
81
+ legend=dict(orientation="h", yanchor="bottom", y=1.02),
82
+ )
83
+ st.plotly_chart(fig, use_container_width=True)
84
+ st.caption("Mean absolute activation per functional group. Note the hemodynamic response shape and modality-specific activation patterns.")
85
+
86
+ # --- Navigation ---
87
+ st.divider()
88
  col1, col2 = st.columns(2)
89
 
90
  with col1:
91
  st.subheader("Analysis Tools")
92
  st.page_link("pages/1_Brain_Alignment.py", label="Brain Alignment Benchmark", icon="🎯")
93
+ st.caption("RSA, CKA, Procrustes with permutation tests, bootstrap CIs, FDR correction, noise ceiling, and RDM visualization")
94
 
95
  st.page_link("pages/2_Cognitive_Load.py", label="Cognitive Load Scorer", icon="📊")
96
+ st.caption("Timeline with confidence bands, dimension correlation, per-ROI breakdown, comparison mode")
97
 
98
  with col2:
99
  st.subheader("Advanced Analysis")
100
  st.page_link("pages/3_Temporal_Dynamics.py", label="Temporal Dynamics", icon="⏱️")
101
+ st.caption("Raw timecourses, peak latency hierarchy, optimal lag analysis, cross-ROI lag matrix")
102
 
103
  st.page_link("pages/4_Connectivity.py", label="ROI Connectivity", icon="🔗")
104
+ st.caption("Partial correlation, modularity, betweenness centrality, dendrogram, network graph")
105
 
106
+ # --- Analysis Log ---
107
+ show_analysis_log()
108
 
109
+ st.divider()
110
+ st.caption("[GitHub](https://github.com/siddhant-rajhans/cortexlab) | [HuggingFace](https://huggingface.co/SID2000/cortexlab) | [Dashboard Repo](https://github.com/siddhant-rajhans/cortexlab-dashboard)")
 
 
 
 
 
 
 
 
 
 
 
 
pages/1_Brain_Alignment.py CHANGED
@@ -1,4 +1,4 @@
1
- """Brain Alignment Benchmark page."""
2
 
3
  import streamlit as st
4
  import numpy as np
@@ -6,135 +6,245 @@ import pandas as pd
6
  import plotly.graph_objects as go
7
  import plotly.express as px
8
 
 
9
  from utils import (
10
- ALIGNMENT_METHODS,
11
- make_roi_indices,
12
- generate_brain_predictions,
13
- generate_model_features,
14
- permutation_test,
15
  )
 
16
 
17
  st.set_page_config(page_title="Brain Alignment", page_icon="🎯", layout="wide")
 
 
 
18
  st.title("🎯 Brain Alignment Benchmark")
19
- st.markdown("Compare how well AI model representations align with predicted brain responses.")
20
 
21
- # --- Sidebar Controls ---
22
  with st.sidebar:
23
  st.header("Configuration")
24
- n_stimuli = st.slider("Number of stimuli", 10, 200, 50)
25
- seed = st.number_input("Random seed", value=42, min_value=0)
26
-
27
- st.subheader("Models to compare")
28
- models_config = {
29
- "CLIP ViT-L/14": st.checkbox("CLIP ViT-L/14", value=True),
30
- "DINOv2 ViT-S": st.checkbox("DINOv2 ViT-S", value=True),
31
- "V-JEPA2 ViT-G": st.checkbox("V-JEPA2 ViT-G", value=True),
32
- "LLaMA 3.2-3B": st.checkbox("LLaMA 3.2-3B", value=False),
33
- }
34
- model_dims = {
35
- "CLIP ViT-L/14": 768,
36
- "DINOv2 ViT-S": 384,
37
- "V-JEPA2 ViT-G": 1024,
38
- "LLaMA 3.2-3B": 3072,
39
  }
40
- selected_models = [m for m, checked in models_config.items() if checked]
41
 
 
42
  methods = st.multiselect("Methods", ["RSA", "CKA", "Procrustes"], default=["RSA", "CKA"])
43
- run_stats = st.checkbox("Run permutation test", value=False)
44
- n_perm = st.slider("Permutations", 50, 1000, 200) if run_stats else 0
 
45
 
46
- if not selected_models or not methods:
47
- st.warning("Select at least one model and one method.")
48
  st.stop()
49
 
50
  # --- Generate Data ---
51
  roi_indices, n_vertices = make_roi_indices()
52
- brain_pred = generate_brain_predictions(n_stimuli, n_vertices, seed)
53
 
54
  model_features = {}
55
- for i, name in enumerate(selected_models):
56
- model_features[name] = generate_model_features(n_stimuli, model_dims[name], seed + i + 1)
 
 
57
 
58
  # --- Run Benchmark ---
59
- with st.spinner("Computing alignment scores..."):
60
  results = []
 
 
61
  for model_name, features in model_features.items():
62
  for method_name in methods:
63
  score_fn = ALIGNMENT_METHODS[method_name]
64
- score = score_fn(features, brain_pred)
65
- row = {"Model": model_name, "Method": method_name, "Score": score}
66
-
67
- if run_stats:
68
- _, p_val = permutation_test(features, brain_pred, score_fn, n_perm, seed)
69
- row["p-value"] = p_val
70
- row["Significant"] = "Yes" if p_val < 0.05 else "No"
71
-
72
- results.append(row)
73
-
74
- df = pd.DataFrame(results)
75
-
76
- # --- Display Results ---
77
- col1, col2 = st.columns([2, 1])
78
-
79
- with col1:
80
- st.subheader("Alignment Scores")
81
- fig = px.bar(
82
- df,
83
- x="Model",
84
- y="Score",
85
- color="Method",
86
- barmode="group",
87
- color_discrete_sequence=px.colors.qualitative.Set2,
88
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  fig.update_layout(
90
- yaxis_title="Alignment Score",
91
- height=450,
92
- template="plotly_dark",
93
  )
94
  st.plotly_chart(fig, use_container_width=True)
95
 
96
- with col2:
97
- st.subheader("Results Table")
98
  display_df = df.copy()
99
- display_df["Score"] = display_df["Score"].map(lambda x: f"{x:.4f}")
100
- if "p-value" in display_df.columns:
101
- display_df["p-value"] = display_df["p-value"].map(lambda x: f"{x:.4f}")
102
  st.dataframe(display_df, use_container_width=True, hide_index=True)
103
-
104
- # --- Per-ROI Analysis ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  st.divider()
106
- st.subheader("Per-ROI Alignment (RSA)")
107
-
108
- if "RSA" in methods and len(selected_models) >= 1:
109
- from utils import ROI_GROUPS, rsa_score
110
-
111
- roi_data = []
112
- for model_name, features in model_features.items():
113
- for group_name, rois in ROI_GROUPS.items():
114
- group_scores = []
115
- for roi in rois:
116
- if roi in roi_indices:
117
- verts = roi_indices[roi]
118
- valid = verts[verts < brain_pred.shape[1]]
119
- if len(valid) >= 2:
120
- s = rsa_score(features, brain_pred[:, valid])
121
- group_scores.append(s)
122
- if group_scores:
123
- roi_data.append({
124
- "Model": model_name,
125
- "Region": group_name,
126
- "RSA Score": float(np.mean(group_scores)),
127
- })
128
-
129
- if roi_data:
130
- roi_df = pd.DataFrame(roi_data)
131
- fig2 = px.bar(
132
- roi_df,
133
- x="Region",
134
- y="RSA Score",
135
- color="Model",
136
- barmode="group",
137
- color_discrete_sequence=px.colors.qualitative.Pastel,
138
- )
139
- fig2.update_layout(height=400, template="plotly_dark")
140
- st.plotly_chart(fig2, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Brain Alignment Benchmark - Research Grade."""
2
 
3
  import streamlit as st
4
  import numpy as np
 
6
  import plotly.graph_objects as go
7
  import plotly.express as px
8
 
9
+ from session import init_session, log_analysis, carry_rois, download_csv_button, show_analysis_log
10
  from utils import (
11
+ ALIGNMENT_METHODS, ROI_GROUPS, make_roi_indices,
12
+ permutation_test, bootstrap_ci, fdr_correction, noise_ceiling,
13
+ compute_rdm,
 
 
14
  )
15
+ from synthetic import generate_realistic_predictions, generate_correlated_features
16
 
17
  st.set_page_config(page_title="Brain Alignment", page_icon="🎯", layout="wide")
18
+ init_session()
19
+ show_analysis_log()
20
+
21
  st.title("🎯 Brain Alignment Benchmark")
22
+ st.markdown("Score how well AI model representations align with predicted brain responses, with full statistical testing.")
23
 
24
+ # --- Sidebar ---
25
  with st.sidebar:
26
  st.header("Configuration")
27
+ stimulus = st.selectbox("Stimulus type", ["visual", "auditory", "language", "multimodal"],
28
+ index=["visual", "auditory", "language", "multimodal"].index(st.session_state.get("stimulus_type", "visual")))
29
+ n_timepoints = st.slider("Timepoints", 30, 200, st.session_state.get("n_timepoints", 80))
30
+ seed = st.number_input("Seed", value=st.session_state.get("seed", 42), min_value=0)
31
+
32
+ st.subheader("Models")
33
+ model_configs = {
34
+ "CLIP ViT-L/14": {"dim": 768, "alignment": st.slider("CLIP alignment", 0.0, 1.0, 0.6, 0.05)},
35
+ "DINOv2 ViT-S": {"dim": 384, "alignment": st.slider("DINOv2 alignment", 0.0, 1.0, 0.3, 0.05)},
36
+ "V-JEPA2 ViT-G": {"dim": 1024, "alignment": st.slider("V-JEPA2 alignment", 0.0, 1.0, 0.8, 0.05)},
 
 
 
 
 
37
  }
 
38
 
39
+ st.subheader("Methods & Statistics")
40
  methods = st.multiselect("Methods", ["RSA", "CKA", "Procrustes"], default=["RSA", "CKA"])
41
+ n_perm = st.slider("Permutations", 50, 1000, 200)
42
+ n_boot = st.slider("Bootstrap samples", 100, 2000, 500)
43
+ apply_fdr = st.checkbox("Apply FDR correction", value=True)
44
 
45
+ if not methods:
46
+ st.warning("Select at least one method.")
47
  st.stop()
48
 
49
  # --- Generate Data ---
50
  roi_indices, n_vertices = make_roi_indices()
51
+ brain_pred = generate_realistic_predictions(n_timepoints, roi_indices, stimulus, seed=seed)
52
 
53
  model_features = {}
54
+ for i, (name, cfg) in enumerate(model_configs.items()):
55
+ model_features[name] = generate_correlated_features(
56
+ brain_pred, cfg["alignment"], cfg["dim"], seed=seed + i + 1
57
+ )
58
 
59
  # --- Run Benchmark ---
60
+ with st.spinner("Computing alignment scores with statistical testing..."):
61
  results = []
62
+ null_distributions = {}
63
+
64
  for model_name, features in model_features.items():
65
  for method_name in methods:
66
  score_fn = ALIGNMENT_METHODS[method_name]
67
+ observed, p_val, null_dist = permutation_test(features, brain_pred, score_fn, n_perm, seed)
68
+ point, ci_lo, ci_hi = bootstrap_ci(features, brain_pred, score_fn, n_boot, seed=seed)
69
+ null_distributions[f"{model_name}_{method_name}"] = null_dist
70
+
71
+ results.append({
72
+ "Model": model_name,
73
+ "Method": method_name,
74
+ "Score": observed,
75
+ "CI Lower": ci_lo,
76
+ "CI Upper": ci_hi,
77
+ "p-value": p_val,
78
+ })
79
+
80
+ df = pd.DataFrame(results)
81
+ log_analysis(f"Brain alignment: {len(model_features)} models x {len(methods)} methods")
82
+
83
+ # --- Noise Ceiling ---
84
+ ceiling_scores = {}
85
+ for method_name in methods:
86
+ score_fn = ALIGNMENT_METHODS[method_name]
87
+ ceil_mean, ceil_std = noise_ceiling(brain_pred, score_fn, seed=seed)
88
+ ceiling_scores[method_name] = ceil_mean
89
+
90
+ # --- Display: Alignment Scores with CIs ---
91
+ st.subheader("Alignment Scores")
92
+
93
+ col_chart, col_table = st.columns([2, 1])
94
+
95
+ with col_chart:
96
+ fig = go.Figure()
97
+ method_colors = {"RSA": "#00D2FF", "CKA": "#FF6B6B", "Procrustes": "#A29BFE"}
98
+ x_positions = list(model_configs.keys())
99
+
100
+ for method_name in methods:
101
+ method_df = df[df["Method"] == method_name]
102
+ fig.add_trace(go.Bar(
103
+ name=method_name,
104
+ x=method_df["Model"],
105
+ y=method_df["Score"],
106
+ error_y=dict(
107
+ type="data",
108
+ symmetric=False,
109
+ array=(method_df["CI Upper"] - method_df["Score"]).tolist(),
110
+ arrayminus=(method_df["Score"] - method_df["CI Lower"]).tolist(),
111
+ ),
112
+ marker_color=method_colors.get(method_name, "#888"),
113
+ ))
114
+ # Noise ceiling line
115
+ if method_name in ceiling_scores:
116
+ fig.add_hline(
117
+ y=ceiling_scores[method_name],
118
+ line_dash="dash", line_color=method_colors.get(method_name, "#888"),
119
+ opacity=0.4,
120
+ annotation_text=f"{method_name} ceiling",
121
+ annotation_position="top right",
122
+ )
123
+
124
  fig.update_layout(
125
+ barmode="group", yaxis_title="Alignment Score",
126
+ height=450, template="plotly_dark",
127
+ legend=dict(orientation="h", yanchor="bottom", y=1.02),
128
  )
129
  st.plotly_chart(fig, use_container_width=True)
130
 
131
+ with col_table:
132
+ st.subheader("Results")
133
  display_df = df.copy()
134
+ for col in ["Score", "CI Lower", "CI Upper", "p-value"]:
135
+ display_df[col] = display_df[col].map(lambda x: f"{x:.4f}")
 
136
  st.dataframe(display_df, use_container_width=True, hide_index=True)
137
+ download_csv_button(df, "brain_alignment_results.csv")
138
+
139
+ # --- Null Distribution ---
140
+ with st.expander("Null Distributions (Permutation Tests)", expanded=False):
141
+ st.markdown("The histogram shows the distribution of scores under the null hypothesis (no alignment). "
142
+ "The red line marks the observed score. If it falls far to the right, alignment is significant.")
143
+ cols = st.columns(min(len(null_distributions), 3))
144
+ for i, (key, null_dist) in enumerate(null_distributions.items()):
145
+ model_name, method_name = key.rsplit("_", 1)
146
+ row = df[(df["Model"] == model_name) & (df["Method"] == method_name)].iloc[0]
147
+ with cols[i % len(cols)]:
148
+ fig_null = go.Figure()
149
+ fig_null.add_trace(go.Histogram(x=null_dist, nbinsx=30, marker_color="rgba(100,100,100,0.6)", name="Null"))
150
+ fig_null.add_vline(x=row["Score"], line_color="red", line_width=2, annotation_text=f"Observed")
151
+ fig_null.update_layout(
152
+ title=f"{model_name} ({method_name})",
153
+ xaxis_title="Score", yaxis_title="Count",
154
+ height=250, template="plotly_dark", showlegend=False,
155
+ margin=dict(t=40, b=30, l=30, r=10),
156
+ )
157
+ st.plotly_chart(fig_null, use_container_width=True)
158
+ st.caption(f"p = {row['p-value']:.4f}")
159
+
160
+ # --- RDM Visualization ---
161
+ with st.expander("Representational Dissimilarity Matrices", expanded=False):
162
+ st.markdown("RDMs show pairwise dissimilarity between stimulus representations. "
163
+ "Similar RDM structure between model and brain indicates representational alignment.")
164
+ rdm_model_name = st.selectbox("Model for RDM", list(model_features.keys()))
165
+ col_brain, col_model = st.columns(2)
166
+
167
+ brain_rdm = compute_rdm(brain_pred)
168
+ model_rdm = compute_rdm(model_features[rdm_model_name])
169
+
170
+ with col_brain:
171
+ fig_rdm = go.Figure(go.Heatmap(z=brain_rdm, colorscale="Viridis", colorbar=dict(title="Dissimilarity")))
172
+ fig_rdm.update_layout(title="Brain RDM", height=350, template="plotly_dark", xaxis_title="Stimulus", yaxis_title="Stimulus")
173
+ st.plotly_chart(fig_rdm, use_container_width=True)
174
+
175
+ with col_model:
176
+ fig_rdm2 = go.Figure(go.Heatmap(z=model_rdm, colorscale="Viridis", colorbar=dict(title="Dissimilarity")))
177
+ fig_rdm2.update_layout(title=f"{rdm_model_name} RDM", height=350, template="plotly_dark", xaxis_title="Stimulus", yaxis_title="Stimulus")
178
+ st.plotly_chart(fig_rdm2, use_container_width=True)
179
+
180
+ # --- Per-ROI Analysis with FDR ---
181
  st.divider()
182
+ st.subheader("Per-ROI Alignment")
183
+
184
+ roi_method = st.selectbox("Method for ROI analysis", methods, key="roi_method")
185
+ score_fn = ALIGNMENT_METHODS[roi_method]
186
+
187
+ roi_data = []
188
+ roi_p_values = []
189
+ top_model = df[df["Method"] == roi_method].sort_values("Score", ascending=False).iloc[0]["Model"]
190
+ features = model_features[top_model]
191
+
192
+ for group_name, rois in ROI_GROUPS.items():
193
+ for roi in rois:
194
+ if roi in roi_indices:
195
+ verts = roi_indices[roi]
196
+ valid = verts[verts < brain_pred.shape[1]]
197
+ if len(valid) >= 2:
198
+ s = score_fn(features, brain_pred[:, valid])
199
+ _, p, _ = permutation_test(features, brain_pred[:, valid], score_fn, n_perm=50, seed=seed)
200
+ roi_data.append({"ROI": roi, "Group": group_name, "Score": s, "p-value": p})
201
+ roi_p_values.append(p)
202
+
203
+ if roi_data:
204
+ roi_df = pd.DataFrame(roi_data)
205
+ if apply_fdr and len(roi_p_values) > 1:
206
+ corrected_p, significant = fdr_correction(roi_p_values)
207
+ roi_df["FDR p-value"] = corrected_p
208
+ roi_df["Significant"] = significant
209
+ roi_df["Label"] = roi_df.apply(lambda r: f"{r['ROI']} *" if r["Significant"] else r["ROI"], axis=1)
210
+ else:
211
+ roi_df["Label"] = roi_df["ROI"]
212
+ roi_df["Significant"] = roi_df["p-value"] < 0.05
213
+
214
+ group_colors = {"Visual": "#00D2FF", "Auditory": "#FF6B6B", "Language": "#A29BFE", "Executive": "#FFEAA7"}
215
+ fig_roi = px.bar(roi_df, x="Label", y="Score", color="Group",
216
+ color_discrete_map=group_colors)
217
+ fig_roi.update_layout(height=400, template="plotly_dark", xaxis_tickangle=45)
218
+ st.plotly_chart(fig_roi, use_container_width=True)
219
+ st.caption(f"Model: {top_model} | * = significant after FDR correction (q < 0.05)" if apply_fdr else f"Model: {top_model}")
220
+
221
+ # Carry ROIs button
222
+ sig_rois = roi_df[roi_df["Significant"]]["ROI"].tolist() if "Significant" in roi_df.columns else []
223
+ if sig_rois:
224
+ if st.button(f"Carry {len(sig_rois)} significant ROIs to other pages"):
225
+ carry_rois(sig_rois, "Temporal Dynamics / Connectivity")
226
+ st.success(f"Carried {len(sig_rois)} ROIs: {', '.join(sig_rois[:5])}{'...' if len(sig_rois) > 5 else ''}")
227
+
228
+ # --- Methodology ---
229
+ with st.expander("Methodology", expanded=False):
230
+ st.markdown("""
231
+ **Representational Similarity Analysis (RSA)** compares the geometry of two representation
232
+ spaces by computing pairwise dissimilarity matrices (RDMs) and correlating their upper triangles
233
+ via Spearman rank correlation. Range: [-1, 1]. Values > 0.1 are typically meaningful.
234
+ *Kriegeskorte et al., 2008, Frontiers in Systems Neuroscience.*
235
+
236
+ **Centered Kernel Alignment (CKA)** measures similarity between representations using
237
+ HSIC (Hilbert-Schmidt Independence Criterion) normalized by self-similarities. Invariant to
238
+ orthogonal transformations and isotropic scaling. Range: [0, 1].
239
+ *Kornblith et al., 2019, ICML.*
240
+
241
+ **Procrustes** finds the optimal rotation mapping one space onto another and measures
242
+ residual distance. Score = 1 - normalized Procrustes distance. Range: [0, 1].
243
+ *Ding et al., 2021, NeurIPS.*
244
+
245
+ **Noise ceiling** estimates the maximum achievable alignment score given the noise in the
246
+ brain data, computed via split-half reliability.
247
+
248
+ **FDR correction** (Benjamini-Hochberg) controls the false discovery rate when testing
249
+ multiple ROIs simultaneously.
250
+ """)
pages/2_Cognitive_Load.py CHANGED
@@ -1,4 +1,4 @@
1
- """Cognitive Load Scorer page."""
2
 
3
  import streamlit as st
4
  import numpy as np
@@ -6,126 +6,208 @@ import pandas as pd
6
  import plotly.graph_objects as go
7
  import plotly.express as px
8
 
9
- from utils import (
10
- make_roi_indices,
11
- generate_brain_predictions,
12
- score_cognitive_load,
13
- COGNITIVE_DIMENSIONS,
14
- )
15
 
16
  st.set_page_config(page_title="Cognitive Load", page_icon="📊", layout="wide")
 
 
 
17
  st.title("📊 Cognitive Load Scorer")
18
- st.markdown("Predict cognitive demand of media content from brain activation patterns.")
19
 
20
  # --- Sidebar ---
21
  with st.sidebar:
22
  st.header("Configuration")
23
- n_timepoints = st.slider("Duration (TRs)", 20, 200, 60)
24
- tr_seconds = st.slider("TR duration (seconds)", 0.5, 2.0, 1.0, 0.1)
25
- seed = st.number_input("Random seed", value=42, min_value=0)
26
-
27
- st.subheader("Simulate content type")
28
- content_type = st.selectbox(
29
- "Content profile",
30
- ["Balanced", "Visual-heavy", "Audio-heavy", "Language-heavy", "Low engagement"],
31
- )
32
 
33
- # --- Generate Data ---
34
- roi_indices, n_vertices = make_roi_indices()
35
- predictions = generate_brain_predictions(n_timepoints, n_vertices, seed)
36
 
37
- # Apply content profile by amplifying relevant ROIs
38
- multipliers = {"Balanced": {}, "Visual-heavy": {}, "Audio-heavy": {}, "Language-heavy": {}, "Low engagement": {}}
39
- if content_type == "Visual-heavy":
40
- for roi in COGNITIVE_DIMENSIONS["Visual Complexity"]:
41
- if roi in roi_indices:
42
- predictions[:, roi_indices[roi]] *= 3.0
43
- elif content_type == "Audio-heavy":
44
- for roi in COGNITIVE_DIMENSIONS["Auditory Demand"]:
45
- if roi in roi_indices:
46
- predictions[:, roi_indices[roi]] *= 3.0
47
- elif content_type == "Language-heavy":
48
- for roi in COGNITIVE_DIMENSIONS["Language Processing"]:
49
- if roi in roi_indices:
50
- predictions[:, roi_indices[roi]] *= 3.0
51
- elif content_type == "Low engagement":
52
- predictions *= 0.2
53
 
54
- # --- Score ---
 
 
55
  averages, timeline = score_cognitive_load(predictions, roi_indices, tr_seconds)
 
 
 
 
 
56
 
57
- # --- Display ---
58
- col1, col2, col3, col4, col5 = st.columns(5)
59
  dims = ["Overall", "Visual Complexity", "Auditory Demand", "Language Processing", "Executive Load"]
60
- cols = [col1, col2, col3, col4, col5]
61
  for col, dim in zip(cols, dims):
62
  val = averages.get(dim, 0.0)
63
- col.metric(dim, f"{val:.2f}", delta=None)
 
 
 
 
 
64
 
65
  st.divider()
66
 
67
- # --- Timeline ---
68
  st.subheader("Cognitive Load Timeline")
 
 
 
 
69
  timeline_df = pd.DataFrame(timeline)
 
70
 
71
  fig = go.Figure()
72
- colors = {"Visual Complexity": "#00D2FF", "Auditory Demand": "#FF6B6B", "Language Processing": "#A29BFE", "Executive Load": "#FFEAA7"}
73
- for dim, color in colors.items():
74
- fig.add_trace(go.Scatter(
75
- x=timeline_df["time"],
76
- y=timeline_df[dim],
77
- name=dim,
78
- line=dict(color=color, width=2),
79
- mode="lines",
80
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  fig.update_layout(
83
- xaxis_title="Time (seconds)",
84
- yaxis_title="Cognitive Load (normalized)",
85
- yaxis_range=[0, 1.05],
86
- height=400,
87
- template="plotly_dark",
88
  legend=dict(orientation="h", yanchor="bottom", y=1.02),
89
  )
90
  st.plotly_chart(fig, use_container_width=True)
91
 
92
- # --- Dimension Breakdown ---
93
  st.divider()
94
  col1, col2 = st.columns(2)
95
 
96
  with col1:
97
- st.subheader("Dimension Breakdown")
98
- dim_data = {k: v for k, v in averages.items() if k != "Overall"}
99
- fig2 = go.Figure(go.Bar(
100
- x=list(dim_data.values()),
101
- y=list(dim_data.keys()),
102
- orientation="h",
103
- marker_color=list(colors.values()),
 
 
 
 
 
 
 
104
  ))
105
- fig2.update_layout(
106
- xaxis_title="Score",
107
- xaxis_range=[0, 1],
108
- height=300,
109
- template="plotly_dark",
110
- )
111
- st.plotly_chart(fig2, use_container_width=True)
112
 
113
  with col2:
114
- st.subheader("Radar Chart")
 
115
  categories = list(dim_data.keys())
116
- values = list(dim_data.values()) + [list(dim_data.values())[0]] # close the polygon
117
-
118
- fig3 = go.Figure(go.Scatterpolar(
119
- r=values,
120
- theta=categories + [categories[0]],
121
- fill="toself",
122
- fillcolor="rgba(108, 92, 231, 0.3)",
123
- line=dict(color="#6C5CE7"),
124
  ))
125
- fig3.update_layout(
 
 
 
 
 
 
 
 
 
126
  polar=dict(radialaxis=dict(visible=True, range=[0, 1])),
127
- height=350,
128
- template="plotly_dark",
129
- showlegend=False,
130
  )
131
- st.plotly_chart(fig3, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cognitive Load Scorer - Research Grade."""
2
 
3
  import streamlit as st
4
  import numpy as np
 
6
  import plotly.graph_objects as go
7
  import plotly.express as px
8
 
9
+ from session import init_session, log_analysis, download_csv_button, show_analysis_log
10
+ from utils import make_roi_indices, score_cognitive_load, COGNITIVE_DIMENSIONS, ROI_GROUPS
11
+ from synthetic import generate_realistic_predictions
 
 
 
12
 
13
  st.set_page_config(page_title="Cognitive Load", page_icon="📊", layout="wide")
14
+ init_session()
15
+ show_analysis_log()
16
+
17
  st.title("📊 Cognitive Load Scorer")
18
+ st.markdown("Predict cognitive demand from brain activation patterns across four neurocognitive dimensions.")
19
 
20
  # --- Sidebar ---
21
  with st.sidebar:
22
  st.header("Configuration")
23
+ n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
24
+ tr_seconds = st.slider("TR (seconds)", 0.5, 2.0, 1.0, 0.1)
25
+ seed = st.number_input("Seed", value=42, min_value=0)
 
 
 
 
 
 
26
 
27
+ st.subheader("Stimulus")
28
+ stim_type = st.selectbox("Primary stimulus", ["visual", "auditory", "language", "multimodal"])
 
29
 
30
+ st.subheader("Comparison Mode")
31
+ compare = st.checkbox("Compare two stimulus types", value=False)
32
+ if compare:
33
+ stim_type_2 = st.selectbox("Second stimulus", [s for s in ["visual", "auditory", "language", "multimodal"] if s != stim_type])
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # --- Generate Data ---
36
+ roi_indices, n_vertices = make_roi_indices()
37
+ predictions = generate_realistic_predictions(n_timepoints, roi_indices, stim_type, tr_seconds, seed=seed)
38
  averages, timeline = score_cognitive_load(predictions, roi_indices, tr_seconds)
39
+ log_analysis(f"Cognitive load: {stim_type}, {n_timepoints} TRs")
40
+
41
+ if compare:
42
+ predictions_2 = generate_realistic_predictions(n_timepoints, roi_indices, stim_type_2, tr_seconds, seed=seed + 100)
43
+ averages_2, timeline_2 = score_cognitive_load(predictions_2, roi_indices, tr_seconds)
44
 
45
+ # --- Metric Cards ---
 
46
  dims = ["Overall", "Visual Complexity", "Auditory Demand", "Language Processing", "Executive Load"]
47
+ cols = st.columns(5)
48
  for col, dim in zip(cols, dims):
49
  val = averages.get(dim, 0.0)
50
+ if compare:
51
+ val_2 = averages_2.get(dim, 0.0)
52
+ delta = val - val_2
53
+ col.metric(dim, f"{val:.2f}", delta=f"{delta:+.2f} vs {stim_type_2}", delta_color="normal")
54
+ else:
55
+ col.metric(dim, f"{val:.2f}")
56
 
57
  st.divider()
58
 
59
+ # --- Timeline with Confidence Bands ---
60
  st.subheader("Cognitive Load Timeline")
61
+
62
+ if compare:
63
+ st.markdown(f"**Solid lines**: {stim_type} | **Dashed lines**: {stim_type_2}")
64
+
65
  timeline_df = pd.DataFrame(timeline)
66
+ dim_colors = {"Visual Complexity": "#00D2FF", "Auditory Demand": "#FF6B6B", "Language Processing": "#A29BFE", "Executive Load": "#FFEAA7"}
67
 
68
  fig = go.Figure()
69
+ for dim, color in dim_colors.items():
70
+ y = timeline_df[dim].values
71
+
72
+ # Bootstrap confidence band (resample vertices within dimension ROIs)
73
+ rng = np.random.default_rng(seed)
74
+ dim_rois = COGNITIVE_DIMENSIONS.get(dim, [])
75
+ dim_vertices = []
76
+ for roi in dim_rois:
77
+ if roi in roi_indices:
78
+ valid = roi_indices[roi]
79
+ dim_vertices.extend(valid[valid < predictions.shape[1]])
80
+
81
+ if dim_vertices:
82
+ boot_scores = []
83
+ for _ in range(50):
84
+ sample_verts = rng.choice(dim_vertices, size=max(1, len(dim_vertices) // 2), replace=True)
85
+ boot_tc = np.abs(predictions[:, sample_verts]).mean(axis=1)
86
+ baseline = max(np.median(np.abs(predictions)), 1e-8)
87
+ boot_scores.append(np.clip(boot_tc / baseline, 0, 1))
88
+ boot_arr = np.array(boot_scores)
89
+ ci_lo = np.percentile(boot_arr, 2.5, axis=0)
90
+ ci_hi = np.percentile(boot_arr, 97.5, axis=0)
91
+ t_axis = timeline_df["time"].values
92
+
93
+ fig.add_trace(go.Scatter(x=t_axis, y=ci_hi, mode="lines", line=dict(width=0), showlegend=False))
94
+ fig.add_trace(go.Scatter(x=t_axis, y=ci_lo, mode="lines", line=dict(width=0),
95
+ fill="tonexty", fillcolor=color.replace(")", ",0.15)").replace("rgb", "rgba").replace("#", "rgba(") if color.startswith("#") else color,
96
+ showlegend=False))
97
+
98
+ fig.add_trace(go.Scatter(x=timeline_df["time"], y=y, name=dim, line=dict(color=color, width=2)))
99
+
100
+ if compare:
101
+ timeline_df_2 = pd.DataFrame(timeline_2)
102
+ fig.add_trace(go.Scatter(x=timeline_df_2["time"], y=timeline_df_2[dim].values,
103
+ name=f"{dim} ({stim_type_2})", line=dict(color=color, width=1.5, dash="dash")))
104
 
105
  fig.update_layout(
106
+ xaxis_title="Time (seconds)", yaxis_title="Cognitive Load (normalized)",
107
+ yaxis_range=[0, 1.05], height=450, template="plotly_dark",
 
 
 
108
  legend=dict(orientation="h", yanchor="bottom", y=1.02),
109
  )
110
  st.plotly_chart(fig, use_container_width=True)
111
 
112
+ # --- Dimension Correlation + Radar ---
113
  st.divider()
114
  col1, col2 = st.columns(2)
115
 
116
  with col1:
117
+ st.subheader("Dimension Correlation")
118
+ st.markdown("How do the four cognitive dimensions co-vary over time?")
119
+ dim_timeseries = {}
120
+ for dim in dim_colors:
121
+ dim_timeseries[dim] = timeline_df[dim].values
122
+ dim_arr = np.array(list(dim_timeseries.values()))
123
+ corr = np.corrcoef(dim_arr)
124
+ dim_names = list(dim_colors.keys())
125
+
126
+ fig_corr = go.Figure(go.Heatmap(
127
+ z=corr, x=dim_names, y=dim_names,
128
+ colorscale="RdBu_r", zmid=0, zmin=-1, zmax=1,
129
+ colorbar=dict(title="r"),
130
+ text=np.round(corr, 2), texttemplate="%{text}",
131
  ))
132
+ fig_corr.update_layout(height=350, template="plotly_dark", xaxis_tickangle=30)
133
+ st.plotly_chart(fig_corr, use_container_width=True)
 
 
 
 
 
134
 
135
  with col2:
136
+ st.subheader("Dimension Profile")
137
+ dim_data = {k: v for k, v in averages.items() if k != "Overall"}
138
  categories = list(dim_data.keys())
139
+ values = list(dim_data.values()) + [list(dim_data.values())[0]]
140
+
141
+ fig_radar = go.Figure()
142
+ fig_radar.add_trace(go.Scatterpolar(
143
+ r=values, theta=categories + [categories[0]],
144
+ fill="toself", fillcolor="rgba(108, 92, 231, 0.3)",
145
+ line=dict(color="#6C5CE7"), name=stim_type,
 
146
  ))
147
+ if compare:
148
+ dim_data_2 = {k: v for k, v in averages_2.items() if k != "Overall"}
149
+ values_2 = list(dim_data_2.values()) + [list(dim_data_2.values())[0]]
150
+ fig_radar.add_trace(go.Scatterpolar(
151
+ r=values_2, theta=categories + [categories[0]],
152
+ fill="toself", fillcolor="rgba(255,107,107,0.2)",
153
+ line=dict(color="#FF6B6B", dash="dash"), name=stim_type_2,
154
+ ))
155
+
156
+ fig_radar.update_layout(
157
  polar=dict(radialaxis=dict(visible=True, range=[0, 1])),
158
+ height=350, template="plotly_dark",
 
 
159
  )
160
+ st.plotly_chart(fig_radar, use_container_width=True)
161
+
162
+ # --- Per-ROI Activation Breakdown ---
163
+ st.divider()
164
+ st.subheader("Per-ROI Activation Within Each Dimension")
165
+ selected_dim = st.selectbox("Dimension", list(dim_colors.keys()))
166
+ dim_rois = COGNITIVE_DIMENSIONS.get(selected_dim, [])
167
+
168
+ roi_activations = []
169
+ for roi in dim_rois:
170
+ if roi in roi_indices:
171
+ verts = roi_indices[roi]
172
+ valid = verts[verts < predictions.shape[1]]
173
+ if len(valid) > 0:
174
+ act = float(np.abs(predictions[:, valid]).mean())
175
+ roi_activations.append({"ROI": roi, "Mean Activation": act})
176
+
177
+ if roi_activations:
178
+ roi_act_df = pd.DataFrame(roi_activations).sort_values("Mean Activation", ascending=False)
179
+ fig_roi = go.Figure(go.Bar(
180
+ x=roi_act_df["Mean Activation"], y=roi_act_df["ROI"],
181
+ orientation="h", marker_color=dim_colors[selected_dim],
182
+ ))
183
+ fig_roi.update_layout(height=max(250, len(roi_activations) * 25), template="plotly_dark",
184
+ yaxis=dict(autorange="reversed"), xaxis_title="Mean |activation|")
185
+ st.plotly_chart(fig_roi, use_container_width=True)
186
+ download_csv_button(roi_act_df, f"cognitive_load_{selected_dim}_rois.csv")
187
+
188
+ # --- Methodology ---
189
+ with st.expander("Methodology", expanded=False):
190
+ st.markdown("""
191
+ **Cognitive Load Scoring** maps predicted fMRI activations onto four neurocognitive dimensions
192
+ using HCP MMP1.0 ROI groupings:
193
+
194
+ - **Visual Complexity**: V1-V4, MT, MST, FFC, VVC (ventral & dorsal visual streams)
195
+ - **Auditory Demand**: A1, belt areas, STS (auditory cortex + association areas)
196
+ - **Language Processing**: Areas 44/45 (Broca's), TPOJ (Wernicke's), STV, PSL (perisylvian language network)
197
+ - **Executive Load**: dlPFC (area 46), ACC, FEF (frontoparietal control network)
198
+
199
+ Each dimension score is the mean absolute activation across its ROIs, normalized by the
200
+ median activation across all vertices (baseline). Scores are clipped to [0, 1].
201
+
202
+ **Confidence bands** on the timeline are computed via bootstrap resampling of vertices
203
+ within each dimension's ROI group (50 resamples, 95% CI).
204
+
205
+ **Limitations**: The ROI-to-dimension mapping is based on established functional neuroanatomy
206
+ but is not exhaustive. Cognitive load is a multidimensional construct that cannot be fully
207
+ captured by fMRI activation alone. These scores should be interpreted as relative measures,
208
+ not absolute cognitive load values.
209
+
210
+ **References**:
211
+ - Glasser et al., 2016, *Nature* (HCP MMP1.0 parcellation)
212
+ - Sweller, 1988, *Cognitive Science* (Cognitive Load Theory)
213
+ """)
pages/3_Temporal_Dynamics.py CHANGED
@@ -1,4 +1,4 @@
1
- """Temporal Dynamics page."""
2
 
3
  import streamlit as st
4
  import numpy as np
@@ -6,32 +6,39 @@ import pandas as pd
6
  import plotly.graph_objects as go
7
  from plotly.subplots import make_subplots
8
 
9
- from utils import (
10
- make_roi_indices,
11
- generate_brain_predictions,
12
- generate_model_features,
13
- peak_latency,
14
- temporal_correlation,
15
- decompose_response,
16
- ROI_GROUPS,
17
- ALL_ROIS,
18
- )
19
 
20
  st.set_page_config(page_title="Temporal Dynamics", page_icon="⏱️", layout="wide")
 
 
 
21
  st.title("⏱️ Temporal Dynamics")
22
- st.markdown("Analyze how brain responses evolve over time.")
23
 
24
  # --- Sidebar ---
25
  with st.sidebar:
26
  st.header("Configuration")
 
 
27
  n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
28
- tr_seconds = st.slider("TR duration (seconds)", 0.5, 2.0, 1.0, 0.1)
29
- seed = st.number_input("Random seed", value=42, min_value=0)
30
 
31
  st.subheader("ROI Selection")
32
- selected_group = st.selectbox("Region group", list(ROI_GROUPS.keys()))
33
- available_rois = ROI_GROUPS[selected_group]
34
- selected_rois = st.multiselect("ROIs to analyze", available_rois, default=available_rois[:3])
 
 
 
 
 
 
 
 
 
35
 
36
  max_lag = st.slider("Max correlation lag (TRs)", 5, 30, 15)
37
  cutoff = st.slider("Decomposition cutoff (seconds)", 1.0, 10.0, 4.0, 0.5)
@@ -42,78 +49,209 @@ if not selected_rois:
42
 
43
  # --- Generate Data ---
44
  roi_indices, n_vertices = make_roi_indices()
45
- predictions = generate_brain_predictions(n_timepoints, n_vertices, seed)
46
- features = generate_model_features(n_timepoints, 64, seed + 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # --- Peak Latency ---
49
- st.subheader("Peak Response Latency")
50
  latency_data = []
51
  for roi in selected_rois:
52
- lat = peak_latency(predictions, roi_indices, roi, tr_seconds)
53
- latency_data.append({"ROI": roi, "Peak Latency (s)": lat})
 
 
 
 
 
 
 
54
 
55
- lat_df = pd.DataFrame(latency_data)
56
- col1, col2 = st.columns([2, 1])
57
 
 
58
  with col1:
59
- fig = go.Figure(go.Bar(
60
- x=lat_df["ROI"],
61
- y=lat_df["Peak Latency (s)"],
62
- marker_color="#6C5CE7",
63
  ))
64
- fig.update_layout(
65
- yaxis_title="Time to peak (seconds)",
66
- height=350,
67
- template="plotly_dark",
68
  )
69
- st.plotly_chart(fig, use_container_width=True)
70
 
71
  with col2:
72
- st.dataframe(lat_df, use_container_width=True, hide_index=True)
 
73
 
74
- # --- Lag Correlation ---
75
  st.divider()
76
  st.subheader("Temporal Correlation (Brain vs Model Features)")
 
 
77
 
78
  lags = np.arange(-max_lag, max_lag + 1) * tr_seconds
79
- fig2 = go.Figure()
80
- colors = ["#00D2FF", "#FF6B6B", "#A29BFE", "#FFEAA7", "#55EFC4", "#FD79A8"]
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  for i, roi in enumerate(selected_rois):
83
- corr = temporal_correlation(predictions, features, roi_indices, roi, max_lag)
84
- fig2.add_trace(go.Scatter(
85
- x=lags,
86
- y=corr,
87
- name=roi,
88
- line=dict(color=colors[i % len(colors)], width=2),
89
- ))
 
90
 
91
- fig2.add_vline(x=0, line_dash="dash", line_color="gray", opacity=0.5)
92
- fig2.update_layout(
93
- xaxis_title="Lag (seconds)",
94
- yaxis_title="Pearson Correlation",
95
- height=400,
96
- template="plotly_dark",
97
  legend=dict(orientation="h", yanchor="bottom", y=1.02),
98
  )
99
- st.plotly_chart(fig2, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # --- Sustained vs Transient ---
102
  st.divider()
103
  st.subheader("Sustained vs Transient Decomposition")
 
104
 
105
  roi_for_decomp = st.selectbox("ROI for decomposition", selected_rois)
106
  sustained, transient = decompose_response(predictions, roi_indices, roi_for_decomp, cutoff, tr_seconds)
107
- time_axis = np.arange(len(sustained)) * tr_seconds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- fig3 = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.08,
110
- subplot_titles=("Sustained Component", "Transient Component"))
 
111
 
112
- fig3.add_trace(go.Scatter(x=time_axis, y=sustained, name="Sustained",
113
- line=dict(color="#6C5CE7", width=2)), row=1, col=1)
114
- fig3.add_trace(go.Scatter(x=time_axis, y=transient, name="Transient",
115
- line=dict(color="#FF6B6B", width=1.5)), row=2, col=1)
116
 
117
- fig3.update_xaxes(title_text="Time (seconds)", row=2, col=1)
118
- fig3.update_layout(height=500, template="plotly_dark", showlegend=False)
119
- st.plotly_chart(fig3, use_container_width=True)
 
 
1
+ """Temporal Dynamics - Research Grade."""
2
 
3
  import streamlit as st
4
  import numpy as np
 
6
  import plotly.graph_objects as go
7
  from plotly.subplots import make_subplots
8
 
9
+ from session import init_session, log_analysis, get_carried_rois, download_csv_button, show_analysis_log
10
+ from utils import make_roi_indices, peak_latency, temporal_correlation, decompose_response, ROI_GROUPS, ALL_ROIS
11
+ from synthetic import generate_realistic_predictions, generate_correlated_features
 
 
 
 
 
 
 
12
 
13
  st.set_page_config(page_title="Temporal Dynamics", page_icon="⏱️", layout="wide")
14
+ init_session()
15
+ show_analysis_log()
16
+
17
  st.title("⏱️ Temporal Dynamics")
18
+ st.markdown("Analyze how brain responses evolve over time, including processing hierarchy and temporal coupling with model features.")
19
 
20
  # --- Sidebar ---
21
  with st.sidebar:
22
  st.header("Configuration")
23
+ stim_type = st.selectbox("Stimulus type", ["visual", "auditory", "language", "multimodal"],
24
+ index=["visual", "auditory", "language", "multimodal"].index(st.session_state.get("stimulus_type", "visual")))
25
  n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
26
+ tr_seconds = st.slider("TR (seconds)", 0.5, 2.0, 1.0, 0.1)
27
+ seed = st.number_input("Seed", value=42, min_value=0)
28
 
29
  st.subheader("ROI Selection")
30
+ carried = get_carried_rois()
31
+ use_carried = False
32
+ if carried:
33
+ use_carried = st.checkbox(f"Use {len(carried)} ROIs from Brain Alignment", value=True)
34
+
35
+ if use_carried and carried:
36
+ selected_rois = carried
37
+ st.caption(f"Using: {', '.join(selected_rois[:5])}{'...' if len(selected_rois) > 5 else ''}")
38
+ else:
39
+ selected_group = st.selectbox("Region group", list(ROI_GROUPS.keys()))
40
+ available_rois = ROI_GROUPS[selected_group]
41
+ selected_rois = st.multiselect("ROIs to analyze", available_rois, default=available_rois[:4])
42
 
43
  max_lag = st.slider("Max correlation lag (TRs)", 5, 30, 15)
44
  cutoff = st.slider("Decomposition cutoff (seconds)", 1.0, 10.0, 4.0, 0.5)
 
49
 
50
  # --- Generate Data ---
51
  roi_indices, n_vertices = make_roi_indices()
52
+ predictions = generate_realistic_predictions(n_timepoints, roi_indices, stim_type, tr_seconds, seed=seed)
53
+ features = generate_correlated_features(predictions, alignment_strength=0.5, feature_dim=64, seed=seed + 1)
54
+ log_analysis(f"Temporal dynamics: {stim_type}, {len(selected_rois)} ROIs")
55
+
56
+ time_axis = np.arange(n_timepoints) * tr_seconds
57
+ colors = ["#00D2FF", "#FF6B6B", "#A29BFE", "#FFEAA7", "#55EFC4", "#FD79A8", "#74B9FF", "#E17055"]
58
+
59
+ # --- Raw ROI Timecourses ---
60
+ st.subheader("Raw ROI Timecourses")
61
+ st.markdown("Mean absolute activation over time for each selected ROI. Note the hemodynamic response shape after stimulus events.")
62
+
63
+ fig_raw = go.Figure()
64
+ for i, roi in enumerate(selected_rois):
65
+ if roi in roi_indices:
66
+ verts = roi_indices[roi]
67
+ valid = verts[verts < predictions.shape[1]]
68
+ if len(valid) > 0:
69
+ tc = np.abs(predictions[:, valid]).mean(axis=1)
70
+ fig_raw.add_trace(go.Scatter(
71
+ x=time_axis, y=tc, name=roi,
72
+ line=dict(color=colors[i % len(colors)], width=2),
73
+ ))
74
+
75
+ fig_raw.update_layout(
76
+ xaxis_title="Time (seconds)", yaxis_title="Mean |activation|",
77
+ height=400, template="plotly_dark",
78
+ legend=dict(orientation="h", yanchor="bottom", y=1.02),
79
+ )
80
+ st.plotly_chart(fig_raw, use_container_width=True)
81
+
82
+ # --- Peak Latency (sorted = processing hierarchy) ---
83
+ st.divider()
84
+ st.subheader("Peak Response Latency (Processing Hierarchy)")
85
+ st.markdown("ROIs sorted by peak latency reveal the cortical processing hierarchy: early sensory areas respond first, association cortex later.")
86
 
 
 
87
  latency_data = []
88
  for roi in selected_rois:
89
+ if roi in roi_indices:
90
+ lat = peak_latency(predictions, roi_indices, roi, tr_seconds)
91
+ # Determine functional group
92
+ group = "Other"
93
+ for g, rois in ROI_GROUPS.items():
94
+ if roi in rois:
95
+ group = g
96
+ break
97
+ latency_data.append({"ROI": roi, "Peak Latency (s)": lat, "Group": group})
98
 
99
+ lat_df = pd.DataFrame(latency_data).sort_values("Peak Latency (s)")
100
+ group_colors = {"Visual": "#00D2FF", "Auditory": "#FF6B6B", "Language": "#A29BFE", "Executive": "#FFEAA7", "Other": "#888"}
101
 
102
+ col1, col2 = st.columns([2, 1])
103
  with col1:
104
+ fig_lat = go.Figure(go.Bar(
105
+ x=lat_df["Peak Latency (s)"], y=lat_df["ROI"],
106
+ orientation="h",
107
+ marker_color=[group_colors.get(g, "#888") for g in lat_df["Group"]],
108
  ))
109
+ fig_lat.update_layout(
110
+ xaxis_title="Time to peak (seconds)", height=max(250, len(selected_rois) * 30),
111
+ template="plotly_dark", yaxis=dict(autorange="reversed"),
 
112
  )
113
+ st.plotly_chart(fig_lat, use_container_width=True)
114
 
115
  with col2:
116
+ st.dataframe(lat_df[["ROI", "Peak Latency (s)", "Group"]], use_container_width=True, hide_index=True)
117
+ download_csv_button(lat_df, "peak_latencies.csv")
118
 
119
+ # --- Lag Correlation with Significance ---
120
  st.divider()
121
  st.subheader("Temporal Correlation (Brain vs Model Features)")
122
+ st.markdown("Pearson correlation at different time lags. The peak indicates optimal temporal alignment. "
123
+ "Gray band shows 95% null range from shuffled data.")
124
 
125
  lags = np.arange(-max_lag, max_lag + 1) * tr_seconds
126
+ fig_corr = go.Figure()
 
127
 
128
+ # Null band (shuffle features, compute correlation envelope)
129
+ rng = np.random.default_rng(seed)
130
+ null_corrs = []
131
+ for _ in range(50):
132
+ shuffled = features[rng.permutation(len(features))]
133
+ for roi in selected_rois[:1]: # Use first ROI for null band
134
+ if roi in roi_indices:
135
+ nc = temporal_correlation(predictions, shuffled, roi_indices, roi, max_lag)
136
+ null_corrs.append(nc)
137
+ if null_corrs:
138
+ null_arr = np.array(null_corrs)
139
+ null_hi = np.percentile(null_arr, 97.5, axis=0)
140
+ null_lo = np.percentile(null_arr, 2.5, axis=0)
141
+ fig_corr.add_trace(go.Scatter(x=lags, y=null_hi, mode="lines", line=dict(width=0), showlegend=False))
142
+ fig_corr.add_trace(go.Scatter(x=lags, y=null_lo, mode="lines", line=dict(width=0),
143
+ fill="tonexty", fillcolor="rgba(150,150,150,0.2)",
144
+ name="95% null range"))
145
+
146
+ # Actual correlations
147
+ optimal_lags = []
148
  for i, roi in enumerate(selected_rois):
149
+ if roi in roi_indices:
150
+ corr = temporal_correlation(predictions, features, roi_indices, roi, max_lag)
151
+ fig_corr.add_trace(go.Scatter(
152
+ x=lags, y=corr, name=roi,
153
+ line=dict(color=colors[i % len(colors)], width=2),
154
+ ))
155
+ opt_idx = np.argmax(np.abs(corr))
156
+ optimal_lags.append({"ROI": roi, "Optimal Lag (s)": lags[opt_idx], "Max |r|": float(np.abs(corr[opt_idx]))})
157
 
158
+ fig_corr.add_vline(x=0, line_dash="dash", line_color="gray", opacity=0.5)
159
+ fig_corr.update_layout(
160
+ xaxis_title="Lag (seconds)", yaxis_title="Pearson Correlation",
161
+ height=400, template="plotly_dark",
 
 
162
  legend=dict(orientation="h", yanchor="bottom", y=1.02),
163
  )
164
+ st.plotly_chart(fig_corr, use_container_width=True)
165
+
166
+ # --- Optimal Lag Summary ---
167
+ if optimal_lags:
168
+ st.subheader("Optimal Lag Summary")
169
+ opt_df = pd.DataFrame(optimal_lags).sort_values("Max |r|", ascending=False)
170
+ st.dataframe(opt_df, use_container_width=True, hide_index=True)
171
+ download_csv_button(opt_df, "optimal_lags.csv")
172
+
173
+ # --- Cross-ROI Lag Matrix ---
174
+ if len(selected_rois) >= 2:
175
+ st.divider()
176
+ st.subheader("Cross-ROI Lag Matrix")
177
+ st.markdown("Optimal lag between each pair of ROIs. Positive values mean the row ROI leads the column ROI.")
178
+
179
+ n_rois = len(selected_rois)
180
+ lag_matrix = np.zeros((n_rois, n_rois))
181
+ for i, roi_a in enumerate(selected_rois):
182
+ if roi_a not in roi_indices:
183
+ continue
184
+ verts_a = roi_indices[roi_a]
185
+ valid_a = verts_a[verts_a < predictions.shape[1]]
186
+ if len(valid_a) == 0:
187
+ continue
188
+ tc_a = np.abs(predictions[:, valid_a]).mean(axis=1)
189
+ for j, roi_b in enumerate(selected_rois):
190
+ if i == j or roi_b not in roi_indices:
191
+ continue
192
+ verts_b = roi_indices[roi_b]
193
+ valid_b = verts_b[verts_b < predictions.shape[1]]
194
+ if len(valid_b) == 0:
195
+ continue
196
+ tc_b = np.abs(predictions[:, valid_b]).mean(axis=1)
197
+ # Cross-correlation to find optimal lag
198
+ corrs_ab = temporal_correlation(predictions, tc_b, roi_indices, roi_a, max_lag)
199
+ opt_idx = np.argmax(np.abs(corrs_ab))
200
+ lag_matrix[i, j] = lags[opt_idx]
201
+
202
+ fig_lagmat = go.Figure(go.Heatmap(
203
+ z=lag_matrix, x=selected_rois, y=selected_rois,
204
+ colorscale="RdBu_r", zmid=0,
205
+ colorbar=dict(title="Lag (s)"),
206
+ text=np.round(lag_matrix, 1), texttemplate="%{text}",
207
+ ))
208
+ fig_lagmat.update_layout(height=400, template="plotly_dark")
209
+ st.plotly_chart(fig_lagmat, use_container_width=True)
210
 
211
  # --- Sustained vs Transient ---
212
  st.divider()
213
  st.subheader("Sustained vs Transient Decomposition")
214
+ st.markdown("Moving-average filter separates slow sustained responses from fast transient spikes.")
215
 
216
  roi_for_decomp = st.selectbox("ROI for decomposition", selected_rois)
217
  sustained, transient = decompose_response(predictions, roi_indices, roi_for_decomp, cutoff, tr_seconds)
218
+ original = sustained + transient
219
+
220
+ fig_decomp = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.06,
221
+ subplot_titles=("Original Signal", "Sustained Component", "Transient Component"))
222
+
223
+ fig_decomp.add_trace(go.Scatter(x=time_axis, y=original, line=dict(color="#888", width=1.5)), row=1, col=1)
224
+ fig_decomp.add_trace(go.Scatter(x=time_axis, y=sustained, line=dict(color="#6C5CE7", width=2)), row=2, col=1)
225
+ fig_decomp.add_trace(go.Scatter(x=time_axis, y=transient, line=dict(color="#FF6B6B", width=1.5)), row=3, col=1)
226
+
227
+ fig_decomp.update_xaxes(title_text="Time (seconds)", row=3, col=1)
228
+ fig_decomp.update_layout(height=550, template="plotly_dark", showlegend=False)
229
+ st.plotly_chart(fig_decomp, use_container_width=True)
230
+
231
+ # --- Methodology ---
232
+ with st.expander("Methodology", expanded=False):
233
+ st.markdown("""
234
+ **Peak Latency** is the time at which mean absolute activation reaches its maximum
235
+ within an ROI. In real fMRI, early sensory cortex (V1, A1) peaks at ~5-6s post-stimulus
236
+ due to the hemodynamic response, while association cortex (dlPFC, angular gyrus) peaks
237
+ ~1-3s later reflecting higher-order processing.
238
+
239
+ **Temporal Correlation** computes Pearson correlation between the ROI timecourse and model
240
+ feature timecourse at each lag in ``[-max_lag, +max_lag]`` TRs. The lag at maximum absolute
241
+ correlation reveals the temporal offset at which model and brain are best aligned.
242
+
243
+ **Null significance band** is estimated by shuffling the model features 50 times and
244
+ computing the lag correlation each time. The 95% envelope of these null correlations
245
+ provides a significance threshold.
246
 
247
+ **Sustained vs Transient Decomposition** uses a moving-average filter with the specified
248
+ cutoff period. The sustained component captures slow, maintained responses (e.g., block
249
+ design activations), while the transient component captures fast, event-related responses.
250
 
251
+ **Cross-ROI Lag Matrix** shows the optimal temporal offset between every pair of ROIs,
252
+ revealing directional information flow (positive lag = row ROI leads column ROI).
 
 
253
 
254
+ **References**:
255
+ - Boynton et al., 1996, *J Neuroscience* (hemodynamic response function)
256
+ - Friston et al., 1998, *NeuroImage* (temporal basis functions in fMRI)
257
+ """)
pages/4_Connectivity.py CHANGED
@@ -1,4 +1,4 @@
1
- """ROI Connectivity page."""
2
 
3
  import streamlit as st
4
  import numpy as np
@@ -6,157 +6,278 @@ import pandas as pd
6
  import plotly.graph_objects as go
7
  import plotly.express as px
8
 
 
9
  from utils import (
10
- make_roi_indices,
11
- generate_brain_predictions,
12
- compute_connectivity,
13
- cluster_rois,
14
- graph_metrics,
15
  ROI_GROUPS,
16
  )
 
17
 
18
  st.set_page_config(page_title="ROI Connectivity", page_icon="🔗", layout="wide")
 
 
 
19
  st.title("🔗 ROI Connectivity Analysis")
20
- st.markdown("Functional connectivity between brain regions from predicted responses.")
21
 
22
  # --- Sidebar ---
23
  with st.sidebar:
24
  st.header("Configuration")
 
25
  n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
26
- seed = st.number_input("Random seed", value=42, min_value=0)
 
 
 
27
  n_clusters = st.slider("Number of clusters", 2, 8, 4)
28
  threshold = st.slider("Edge threshold", 0.1, 0.8, 0.3, 0.05)
 
 
 
 
 
 
 
29
 
30
  # --- Generate Data ---
31
  roi_indices, n_vertices = make_roi_indices()
32
- predictions = generate_brain_predictions(n_timepoints, n_vertices, seed)
33
-
34
- # --- Correlation Matrix ---
35
- corr_matrix, roi_names = compute_connectivity(predictions, roi_indices)
36
-
37
- st.subheader("Correlation Matrix")
38
- fig = go.Figure(go.Heatmap(
39
- z=corr_matrix,
40
- x=roi_names,
41
- y=roi_names,
42
- colorscale="RdBu_r",
43
- zmid=0,
44
- zmin=-1,
45
- zmax=1,
46
- colorbar=dict(title="Correlation"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  ))
48
- fig.update_layout(
49
- height=600,
50
- width=700,
51
- template="plotly_dark",
 
 
 
 
 
 
 
 
 
 
 
52
  xaxis=dict(tickangle=45, tickfont=dict(size=8)),
53
  yaxis=dict(tickfont=dict(size=8)),
54
  )
55
- st.plotly_chart(fig, use_container_width=True)
 
56
 
57
- # --- Network Clusters ---
58
  st.divider()
59
- col1, col2 = st.columns(2)
 
 
 
 
 
 
 
60
 
61
- with col1:
62
- st.subheader("Functional Network Clusters")
63
- clusters, labels = cluster_rois(corr_matrix, roi_names, n_clusters)
 
64
 
65
- cluster_data = []
66
- for cid, rois in sorted(clusters.items()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  for roi in rois:
68
- cluster_data.append({"Cluster": f"Network {cid}", "ROI": roi})
69
-
70
- cluster_df = pd.DataFrame(cluster_data)
71
- fig2 = px.bar(
72
- cluster_df.groupby("Cluster").size().reset_index(name="Count"),
73
- x="Cluster",
74
- y="Count",
75
- color="Cluster",
76
- color_discrete_sequence=px.colors.qualitative.Set2,
77
- )
78
- fig2.update_layout(height=350, template="plotly_dark", showlegend=False)
79
- st.plotly_chart(fig2, use_container_width=True)
80
-
81
- for cid, rois in sorted(clusters.items()):
82
- st.markdown(f"**Network {cid}:** {', '.join(rois)}")
83
-
84
- with col2:
85
- st.subheader("Degree Centrality")
86
- degrees = graph_metrics(corr_matrix, roi_names, threshold)
87
-
88
- # Sort by degree
89
- sorted_degrees = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
90
- deg_df = pd.DataFrame(sorted_degrees, columns=["ROI", "Degree Centrality"])
91
-
92
- fig3 = go.Figure(go.Bar(
93
- x=deg_df["Degree Centrality"],
94
- y=deg_df["ROI"],
95
- orientation="h",
96
- marker_color="#6C5CE7",
97
- ))
98
- fig3.update_layout(
99
- xaxis_title="Degree Centrality",
100
- xaxis_range=[0, 1],
101
- height=600,
102
- template="plotly_dark",
103
- yaxis=dict(autorange="reversed", tickfont=dict(size=9)),
104
- )
105
- st.plotly_chart(fig3, use_container_width=True)
106
 
107
- # --- Network Graph ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  st.divider()
109
- st.subheader("Network Graph")
110
-
111
- try:
112
- import networkx as nx
113
-
114
- G = nx.Graph()
115
- for name in roi_names:
116
- G.add_node(name)
117
-
118
- for i in range(len(roi_names)):
119
- for j in range(i + 1, len(roi_names)):
120
- if abs(corr_matrix[i, j]) > threshold:
121
- G.add_edge(roi_names[i], roi_names[j], weight=abs(corr_matrix[i, j]))
122
-
123
- pos = nx.spring_layout(G, seed=seed, k=2.0)
124
-
125
- # Cluster colors
126
- color_map = px.colors.qualitative.Set2
127
- node_colors = [color_map[(labels[i] - 1) % len(color_map)] for i in range(len(roi_names))]
128
-
129
- edge_x, edge_y = [], []
130
- for u, v in G.edges():
131
- x0, y0 = pos[u]
132
- x1, y1 = pos[v]
133
- edge_x.extend([x0, x1, None])
134
- edge_y.extend([y0, y1, None])
135
-
136
- node_x = [pos[n][0] for n in roi_names]
137
- node_y = [pos[n][1] for n in roi_names]
138
-
139
- fig4 = go.Figure()
140
- fig4.add_trace(go.Scatter(
141
- x=edge_x, y=edge_y, mode="lines",
142
- line=dict(width=0.5, color="rgba(150,150,150,0.3)"),
143
- hoverinfo="none",
144
- ))
145
- fig4.add_trace(go.Scatter(
146
- x=node_x, y=node_y, mode="markers+text",
147
- marker=dict(size=12, color=node_colors, line=dict(width=1, color="white")),
148
- text=roi_names,
149
- textposition="top center",
150
- textfont=dict(size=8),
151
- hoverinfo="text",
152
- ))
153
- fig4.update_layout(
154
- height=500,
155
- template="plotly_dark",
156
- showlegend=False,
157
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
158
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
159
  )
160
- st.plotly_chart(fig4, use_container_width=True)
161
- except ImportError:
162
- st.info("Install `networkx` for the network graph visualization: `pip install networkx`")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ROI Connectivity - Research Grade."""
2
 
3
  import streamlit as st
4
  import numpy as np
 
6
  import plotly.graph_objects as go
7
  import plotly.express as px
8
 
9
+ from session import init_session, log_analysis, get_carried_rois, download_csv_button, show_analysis_log
10
  from utils import (
11
+ make_roi_indices, compute_connectivity, cluster_rois, graph_metrics,
12
+ partial_correlation, betweenness_centrality, modularity_score,
 
 
 
13
  ROI_GROUPS,
14
  )
15
+ from synthetic import generate_realistic_predictions
16
 
17
  st.set_page_config(page_title="ROI Connectivity", page_icon="🔗", layout="wide")
18
+ init_session()
19
+ show_analysis_log()
20
+
21
  st.title("🔗 ROI Connectivity Analysis")
22
+ st.markdown("Functional connectivity between brain regions: correlation structure, network organization, and graph topology.")
23
 
24
  # --- Sidebar ---
25
  with st.sidebar:
26
  st.header("Configuration")
27
+ stim_type = st.selectbox("Stimulus type", ["visual", "auditory", "language", "multimodal"])
28
  n_timepoints = st.slider("Duration (TRs)", 30, 200, 80)
29
+ tr_seconds = st.slider("TR (seconds)", 0.5, 2.0, 1.0, 0.1)
30
+ seed = st.number_input("Seed", value=42, min_value=0)
31
+
32
+ st.subheader("Analysis Parameters")
33
  n_clusters = st.slider("Number of clusters", 2, 8, 4)
34
  threshold = st.slider("Edge threshold", 0.1, 0.8, 0.3, 0.05)
35
+ use_partial = st.checkbox("Use partial correlation", value=False,
36
+ help="Control for shared mean signal across all ROIs")
37
+
38
+ carried = get_carried_rois()
39
+ use_carried = False
40
+ if carried:
41
+ use_carried = st.checkbox(f"Filter to {len(carried)} carried ROIs", value=False)
42
 
43
  # --- Generate Data ---
44
  roi_indices, n_vertices = make_roi_indices()
45
+ predictions = generate_realistic_predictions(n_timepoints, roi_indices, stim_type, tr_seconds, seed=seed)
46
+ log_analysis(f"Connectivity: {stim_type}, partial={use_partial}")
47
+
48
+ # Filter ROIs if carrying from alignment
49
+ active_indices = roi_indices
50
+ if use_carried and carried:
51
+ active_indices = {k: v for k, v in roi_indices.items() if k in carried}
52
+
53
+ # --- Compute Connectivity ---
54
+ if use_partial:
55
+ corr_matrix, roi_names = partial_correlation(predictions, active_indices)
56
+ corr_label = "Partial Correlation"
57
+ else:
58
+ corr_matrix, roi_names = compute_connectivity(predictions, active_indices)
59
+ corr_label = "Pearson Correlation"
60
+
61
+ n_rois = len(roi_names)
62
+
63
+ # --- Correlation Matrix with Cluster Boundaries ---
64
+ st.subheader(f"{corr_label} Matrix")
65
+
66
+ clusters, labels = cluster_rois(corr_matrix, roi_names, n_clusters)
67
+
68
+ # Sort ROIs by cluster for block-diagonal structure
69
+ sorted_idx = np.argsort(labels)
70
+ sorted_corr = corr_matrix[np.ix_(sorted_idx, sorted_idx)]
71
+ sorted_names = [roi_names[i] for i in sorted_idx]
72
+ sorted_labels = labels[sorted_idx]
73
+
74
+ fig_corr = go.Figure(go.Heatmap(
75
+ z=sorted_corr, x=sorted_names, y=sorted_names,
76
+ colorscale="RdBu_r", zmid=0, zmin=-1, zmax=1,
77
+ colorbar=dict(title="r"),
78
  ))
79
+
80
+ # Add cluster boundary lines
81
+ boundaries = []
82
+ for i in range(1, len(sorted_labels)):
83
+ if sorted_labels[i] != sorted_labels[i - 1]:
84
+ boundaries.append(i - 0.5)
85
+
86
+ for b in boundaries:
87
+ fig_corr.add_shape(type="line", x0=b, x1=b, y0=-0.5, y1=n_rois - 0.5,
88
+ line=dict(color="white", width=1.5, dash="dot"))
89
+ fig_corr.add_shape(type="line", x0=-0.5, x1=n_rois - 0.5, y0=b, y1=b,
90
+ line=dict(color="white", width=1.5, dash="dot"))
91
+
92
+ fig_corr.update_layout(
93
+ height=550, template="plotly_dark",
94
  xaxis=dict(tickangle=45, tickfont=dict(size=8)),
95
  yaxis=dict(tickfont=dict(size=8)),
96
  )
97
+ st.plotly_chart(fig_corr, use_container_width=True)
98
+ st.caption(f"White dotted lines indicate cluster boundaries ({n_clusters} clusters)")
99
 
100
+ # --- Dendrogram ---
101
  st.divider()
102
+ col_dendro, col_clusters = st.columns([1, 1])
103
+
104
+ with col_dendro:
105
+ st.subheader("Hierarchical Clustering Dendrogram")
106
+ from scipy.cluster.hierarchy import linkage, dendrogram
107
+ import matplotlib.pyplot as plt
108
+ import matplotlib
109
+ matplotlib.use("Agg")
110
 
111
+ dist = 1.0 - np.abs(corr_matrix)
112
+ np.fill_diagonal(dist, 0.0)
113
+ condensed = [dist[i, j] for i in range(n_rois) for j in range(i + 1, n_rois)]
114
+ Z = linkage(condensed, method="average")
115
 
116
+ fig_dendro, ax = plt.subplots(figsize=(8, 4))
117
+ ax.set_facecolor("#0E1117")
118
+ fig_dendro.patch.set_facecolor("#0E1117")
119
+ dendrogram(Z, labels=roi_names, leaf_rotation=90, leaf_font_size=7, ax=ax,
120
+ color_threshold=Z[-n_clusters + 1, 2] if n_clusters < n_rois else 0)
121
+ ax.tick_params(colors="white")
122
+ ax.set_ylabel("Distance (1 - |r|)", color="white")
123
+ for spine in ax.spines.values():
124
+ spine.set_color("white")
125
+ st.pyplot(fig_dendro)
126
+ plt.close()
127
+
128
+ with col_clusters:
129
+ st.subheader("Network Clusters")
130
+ mod_q = modularity_score(corr_matrix, labels)
131
+ st.metric("Modularity (Q)", f"{mod_q:.3f}",
132
+ help="Newman's modularity. Higher = stronger community structure. Q > 0.3 is typically considered meaningful.")
133
+
134
+ for cid in sorted(clusters.keys()):
135
+ rois = clusters[cid]
136
+ # Identify dominant functional group
137
+ group_counts = {}
138
  for roi in rois:
139
+ for g, g_rois in ROI_GROUPS.items():
140
+ if roi in g_rois:
141
+ group_counts[g] = group_counts.get(g, 0) + 1
142
+ dominant = max(group_counts, key=group_counts.get) if group_counts else "Mixed"
143
+ st.markdown(f"**Network {cid}** ({dominant}): {', '.join(rois)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ # --- Centrality Comparison ---
146
+ st.divider()
147
+ st.subheader("Centrality Analysis")
148
+
149
+ col_deg, col_btw = st.columns(2)
150
+
151
+ degrees = graph_metrics(corr_matrix, roi_names, threshold)
152
+ btw = betweenness_centrality(corr_matrix, roi_names, threshold)
153
+
154
+ with col_deg:
155
+ st.markdown("**Degree Centrality** - fraction of ROIs connected to each node")
156
+ deg_df = pd.DataFrame(sorted(degrees.items(), key=lambda x: x[1], reverse=True), columns=["ROI", "Degree"])
157
+ fig_deg = go.Figure(go.Bar(x=deg_df["Degree"], y=deg_df["ROI"], orientation="h", marker_color="#6C5CE7"))
158
+ fig_deg.update_layout(xaxis_range=[0, 1], height=max(300, n_rois * 20), template="plotly_dark",
159
+ yaxis=dict(autorange="reversed", tickfont=dict(size=9)))
160
+ st.plotly_chart(fig_deg, use_container_width=True)
161
+
162
+ with col_btw:
163
+ st.markdown("**Betweenness Centrality** - how often a node lies on shortest paths between others")
164
+ btw_df = pd.DataFrame(sorted(btw.items(), key=lambda x: x[1], reverse=True), columns=["ROI", "Betweenness"])
165
+ fig_btw = go.Figure(go.Bar(x=btw_df["Betweenness"], y=btw_df["ROI"], orientation="h", marker_color="#FF6B6B"))
166
+ fig_btw.update_layout(height=max(300, n_rois * 20), template="plotly_dark",
167
+ yaxis=dict(autorange="reversed", tickfont=dict(size=9)))
168
+ st.plotly_chart(fig_btw, use_container_width=True)
169
+
170
+ # Combined table
171
+ centrality_df = pd.merge(deg_df, btw_df, on="ROI")
172
+ download_csv_button(centrality_df, "centrality_metrics.csv")
173
+
174
+ # --- Edge Weight Distribution ---
175
  st.divider()
176
+ col_dist, col_graph = st.columns([1, 2])
177
+
178
+ with col_dist:
179
+ st.subheader("Edge Weight Distribution")
180
+ upper_tri = corr_matrix[np.triu_indices(n_rois, k=1)]
181
+ fig_hist = go.Figure(go.Histogram(x=upper_tri, nbinsx=40, marker_color="rgba(108,92,231,0.7)"))
182
+ fig_hist.add_vline(x=threshold, line_color="red", line_dash="dash", annotation_text="Threshold")
183
+ fig_hist.add_vline(x=-threshold, line_color="red", line_dash="dash")
184
+ fig_hist.update_layout(
185
+ xaxis_title="Correlation", yaxis_title="Count",
186
+ height=350, template="plotly_dark",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
+ st.plotly_chart(fig_hist, use_container_width=True)
189
+ n_edges = np.sum(np.abs(upper_tri) > threshold)
190
+ max_edges = n_rois * (n_rois - 1) // 2
191
+ st.caption(f"{n_edges}/{max_edges} edges above threshold ({100 * n_edges / max(max_edges, 1):.1f}% density)")
192
+
193
+ # --- Network Graph ---
194
+ with col_graph:
195
+ st.subheader("Network Graph")
196
+ try:
197
+ import networkx as nx
198
+
199
+ G = nx.Graph()
200
+ for name in roi_names:
201
+ G.add_node(name)
202
+ for i in range(n_rois):
203
+ for j in range(i + 1, n_rois):
204
+ w = abs(corr_matrix[i, j])
205
+ if w > threshold:
206
+ G.add_edge(roi_names[i], roi_names[j], weight=w)
207
+
208
+ pos = nx.spring_layout(G, seed=seed, k=2.5)
209
+ color_map = px.colors.qualitative.Set2
210
+ node_colors = [color_map[(labels[i] - 1) % len(color_map)] for i in range(n_rois)]
211
+
212
+ # Edges with width proportional to weight
213
+ for u, v, data in G.edges(data=True):
214
+ x0, y0 = pos[u]
215
+ x1, y1 = pos[v]
216
+ fig_graph = go.Figure() if not hasattr(st, '_graph_fig') else st._graph_fig
217
+
218
+ fig_net = go.Figure()
219
+ for u, v, d in G.edges(data=True):
220
+ x0, y0 = pos[u]
221
+ x1, y1 = pos[v]
222
+ w = d.get("weight", 0.3)
223
+ fig_net.add_trace(go.Scatter(
224
+ x=[x0, x1, None], y=[y0, y1, None], mode="lines",
225
+ line=dict(width=w * 3, color=f"rgba(150,150,150,{min(w, 0.8)})"),
226
+ hoverinfo="none", showlegend=False,
227
+ ))
228
+
229
+ # Node sizes by degree
230
+ max_deg = max(degrees.values()) if degrees else 1
231
+ node_sizes = [8 + 20 * degrees.get(name, 0) / max(max_deg, 0.01) for name in roi_names]
232
+ node_x = [pos[n][0] for n in roi_names]
233
+ node_y = [pos[n][1] for n in roi_names]
234
+
235
+ fig_net.add_trace(go.Scatter(
236
+ x=node_x, y=node_y, mode="markers+text",
237
+ marker=dict(size=node_sizes, color=node_colors, line=dict(width=1, color="white")),
238
+ text=roi_names, textposition="top center", textfont=dict(size=7, color="white"),
239
+ hovertext=[f"{name}<br>Degree: {degrees.get(name, 0):.2f}<br>Betweenness: {btw.get(name, 0):.3f}" for name in roi_names],
240
+ hoverinfo="text", showlegend=False,
241
+ ))
242
+
243
+ fig_net.update_layout(
244
+ height=450, template="plotly_dark",
245
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
246
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
247
+ )
248
+ st.plotly_chart(fig_net, use_container_width=True)
249
+ except ImportError:
250
+ st.info("Install `networkx` for graph visualization: `pip install networkx`")
251
+
252
+ # --- Methodology ---
253
+ with st.expander("Methodology", expanded=False):
254
+ st.markdown("""
255
+ **Functional Connectivity** is computed as pairwise Pearson correlation between ROI
256
+ timecourses (mean activation across vertices within each ROI).
257
+
258
+ **Partial Correlation** controls for the shared mean signal by computing the precision
259
+ matrix (inverse covariance) and normalizing. This removes indirect correlations mediated
260
+ by a common driver.
261
+
262
+ **Hierarchical Clustering** uses agglomerative clustering with average linkage on a
263
+ distance matrix defined as ``1 - |correlation|``. The dendrogram shows the hierarchical
264
+ merging of ROIs into networks.
265
+
266
+ **Modularity (Q)** quantifies how strongly the network divides into communities compared
267
+ to a random network with the same degree distribution. Q > 0.3 typically indicates
268
+ meaningful community structure. (Newman, 2006, *PNAS*)
269
+
270
+ **Degree Centrality** is the fraction of other nodes each node is connected to (above
271
+ the correlation threshold). High degree = hub region.
272
+
273
+ **Betweenness Centrality** counts how often a node lies on the shortest path between
274
+ other node pairs. High betweenness = bridge between communities.
275
+
276
+ **Edge Weight Distribution** shows the histogram of all pairwise correlations. The
277
+ threshold (red line) determines which connections are retained for graph analysis.
278
+
279
+ **References**:
280
+ - Rubinov & Sporns, 2010, *NeuroImage* (graph metrics for brain networks)
281
+ - Newman, 2006, *PNAS* (modularity in networks)
282
+ - Smith et al., 2011, *NeuroImage* (partial correlation for fMRI connectivity)
283
+ """)
session.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared session state management and data I/O utilities.
2
+
3
+ Manages cross-page state (selected ROIs, predictions, analysis log)
4
+ and provides upload/download widgets.
5
+ """
6
+
7
+ import io
8
+ import json
9
+ from datetime import datetime
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import streamlit as st
14
+
15
+
16
+ def init_session():
17
+ """Initialize session state with defaults. Safe to call multiple times."""
18
+ defaults = {
19
+ "brain_predictions": None,
20
+ "model_features": {},
21
+ "roi_indices": None,
22
+ "n_vertices": 0,
23
+ "selected_rois": [],
24
+ "data_source": "synthetic",
25
+ "stimulus_type": "visual",
26
+ "tr_seconds": 1.0,
27
+ "n_timepoints": 80,
28
+ "seed": 42,
29
+ "analysis_log": [],
30
+ "carry_rois": [], # ROIs carried from another page
31
+ }
32
+ for key, value in defaults.items():
33
+ if key not in st.session_state:
34
+ st.session_state[key] = value
35
+
36
+
37
+ def log_analysis(description):
38
+ """Append an entry to the analysis log."""
39
+ timestamp = datetime.now().strftime("%H:%M:%S")
40
+ entry = f"[{timestamp}] {description}"
41
+ if "analysis_log" not in st.session_state:
42
+ st.session_state["analysis_log"] = []
43
+ st.session_state["analysis_log"].append(entry)
44
+
45
+
46
+ def carry_rois(rois, target_page=""):
47
+ """Store selected ROIs for cross-page workflow."""
48
+ st.session_state["carry_rois"] = list(rois)
49
+ log_analysis(f"Carried {len(rois)} ROIs to {target_page}")
50
+
51
+
52
+ def get_carried_rois():
53
+ """Retrieve ROIs carried from another page."""
54
+ return st.session_state.get("carry_rois", [])
55
+
56
+
57
+ def get_or_generate_data(roi_indices):
58
+ """Get brain predictions from session or generate new synthetic data."""
59
+ from synthetic import generate_realistic_predictions
60
+
61
+ params_key = (
62
+ st.session_state.get("n_timepoints", 80),
63
+ st.session_state.get("stimulus_type", "visual"),
64
+ st.session_state.get("seed", 42),
65
+ )
66
+
67
+ # Check if we need to regenerate
68
+ if (
69
+ st.session_state.get("brain_predictions") is None
70
+ or st.session_state.get("_data_params") != params_key
71
+ or st.session_state.get("data_source") == "synthetic"
72
+ ):
73
+ if st.session_state.get("data_source") == "uploaded" and st.session_state.get("brain_predictions") is not None:
74
+ return st.session_state["brain_predictions"]
75
+
76
+ predictions = generate_realistic_predictions(
77
+ n_timepoints=st.session_state["n_timepoints"],
78
+ roi_indices=roi_indices,
79
+ stimulus_type=st.session_state["stimulus_type"],
80
+ tr_seconds=st.session_state["tr_seconds"],
81
+ seed=st.session_state["seed"],
82
+ )
83
+ st.session_state["brain_predictions"] = predictions
84
+ st.session_state["_data_params"] = params_key
85
+
86
+ return st.session_state["brain_predictions"]
87
+
88
+
89
+ def upload_npy_widget(label, key):
90
+ """File uploader for .npy arrays with validation."""
91
+ uploaded = st.file_uploader(label, type=["npy"], key=key)
92
+ if uploaded is not None:
93
+ try:
94
+ data = np.load(io.BytesIO(uploaded.read()))
95
+ st.success(f"Loaded: shape {data.shape}, dtype {data.dtype}")
96
+ return data
97
+ except Exception as e:
98
+ st.error(f"Failed to load file: {e}")
99
+ return None
100
+
101
+
102
+ def download_csv_button(df, filename, label="Download CSV"):
103
+ """Download button for a pandas DataFrame as CSV."""
104
+ csv = df.to_csv(index=False)
105
+ st.download_button(label, csv, filename, "text/csv")
106
+
107
+
108
+ def download_json_button(data, filename, label="Download JSON"):
109
+ """Download button for a dict as JSON."""
110
+ json_str = json.dumps(data, indent=2, default=str)
111
+ st.download_button(label, json_str, filename, "application/json")
112
+
113
+
114
+ def show_analysis_log():
115
+ """Display the analysis log in the sidebar."""
116
+ log = st.session_state.get("analysis_log", [])
117
+ if log:
118
+ with st.sidebar:
119
+ with st.expander("Analysis Log", expanded=False):
120
+ for entry in reversed(log[-20:]):
121
+ st.caption(entry)
122
+
123
+
124
+ def data_summary_widget(predictions, roi_indices):
125
+ """Show a summary of the current data."""
126
+ if predictions is None:
127
+ st.info("No data loaded. Generate synthetic data or upload your own.")
128
+ return
129
+
130
+ col1, col2, col3, col4 = st.columns(4)
131
+ col1.metric("Timepoints", predictions.shape[0])
132
+ col2.metric("Vertices", predictions.shape[1])
133
+ col3.metric("ROIs", len(roi_indices))
134
+ col4.metric("Source", st.session_state.get("data_source", "synthetic").title())
synthetic.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Biologically realistic synthetic fMRI data generation.
2
+
3
+ Generates data with hemodynamic response convolution, modality-specific
4
+ activation patterns, spatial autocorrelation, temporal noise structure,
5
+ and scanner drift - mimicking real fMRI recordings.
6
+ """
7
+
8
+ import numpy as np
9
+
10
+ # --- Hemodynamic Response Function ---
11
+
12
+ def generate_hrf(tr_seconds=1.0, duration=30.0):
13
+ """Canonical double-gamma hemodynamic response function.
14
+
15
+ Models the BOLD signal: a positive peak at ~5-6s followed by a
16
+ smaller negative undershoot at ~15s.
17
+ """
18
+ t = np.arange(0, duration, tr_seconds)
19
+ # Double gamma parameters (SPM canonical)
20
+ a1, b1 = 6.0, 1.0 # positive peak
21
+ a2, b2 = 16.0, 1.0 # undershoot
22
+ c = 1.0 / 6.0 # undershoot ratio
23
+
24
+ from scipy.stats import gamma as gamma_dist
25
+ h = gamma_dist.pdf(t, a1, scale=b1) - c * gamma_dist.pdf(t, a2, scale=b2)
26
+ h = h / np.max(np.abs(h)) # normalize to [-1, 1]
27
+ return h
28
+
29
+
30
+ def generate_stimulus_events(n_timepoints, tr_seconds=1.0, n_events=5, seed=42):
31
+ """Generate random stimulus onset times as a binary event train.
32
+
33
+ Returns a (n_timepoints,) array with 1s at stimulus onsets.
34
+ Events are spaced at least 8 seconds apart.
35
+ """
36
+ rng = np.random.default_rng(seed)
37
+ total_seconds = n_timepoints * tr_seconds
38
+ min_gap = 8.0 # minimum inter-stimulus interval
39
+
40
+ events = np.zeros(n_timepoints)
41
+ onsets = []
42
+ attempts = 0
43
+ while len(onsets) < n_events and attempts < 1000:
44
+ t = rng.uniform(2.0, total_seconds - 10.0)
45
+ if all(abs(t - o) > min_gap for o in onsets):
46
+ onsets.append(t)
47
+ attempts += 1
48
+
49
+ for onset in onsets:
50
+ idx = int(onset / tr_seconds)
51
+ if 0 <= idx < n_timepoints:
52
+ events[idx] = 1.0
53
+
54
+ return events
55
+
56
+
57
+ # --- Modality-Specific Activation Weights ---
58
+
59
+ # Weight for each ROI given a stimulus modality (0 = no response, 1 = maximum)
60
+ MODALITY_WEIGHTS = {
61
+ "visual": {
62
+ # Strong visual cortex activation
63
+ "V1": 1.0, "V2": 0.95, "V3": 0.85, "V4": 0.8,
64
+ "MT": 0.75, "MST": 0.7, "FFC": 0.65, "VVC": 0.6,
65
+ # Weak cross-modal
66
+ "A1": 0.05, "LBelt": 0.04, "MBelt": 0.03, "PBelt": 0.03, "A4": 0.02, "A5": 0.02,
67
+ # Minimal language
68
+ "44": 0.08, "45": 0.07, "IFJa": 0.06, "IFJp": 0.05,
69
+ "TPOJ1": 0.1, "TPOJ2": 0.08, "STV": 0.07, "PSL": 0.06,
70
+ # Moderate executive (attention)
71
+ "46": 0.3, "9-46d": 0.25, "8Av": 0.35, "8Ad": 0.3,
72
+ "FEF": 0.4, "p32pr": 0.15, "a32pr": 0.12,
73
+ },
74
+ "auditory": {
75
+ "V1": 0.03, "V2": 0.03, "V3": 0.02, "V4": 0.02,
76
+ "MT": 0.02, "MST": 0.01, "FFC": 0.01, "VVC": 0.01,
77
+ "A1": 1.0, "LBelt": 0.95, "MBelt": 0.9, "PBelt": 0.85, "A4": 0.75, "A5": 0.7,
78
+ "44": 0.15, "45": 0.12, "IFJa": 0.1, "IFJp": 0.08,
79
+ "TPOJ1": 0.25, "TPOJ2": 0.2, "STV": 0.3, "PSL": 0.2,
80
+ "46": 0.2, "9-46d": 0.15, "8Av": 0.12, "8Ad": 0.1,
81
+ "FEF": 0.08, "p32pr": 0.1, "a32pr": 0.08,
82
+ },
83
+ "language": {
84
+ "V1": 0.05, "V2": 0.04, "V3": 0.03, "V4": 0.03,
85
+ "MT": 0.02, "MST": 0.02, "FFC": 0.1, "VVC": 0.08,
86
+ "A1": 0.3, "LBelt": 0.25, "MBelt": 0.2, "PBelt": 0.15, "A4": 0.2, "A5": 0.15,
87
+ "44": 1.0, "45": 0.95, "IFJa": 0.85, "IFJp": 0.8,
88
+ "TPOJ1": 0.9, "TPOJ2": 0.85, "STV": 0.75, "PSL": 0.7,
89
+ "46": 0.5, "9-46d": 0.45, "8Av": 0.3, "8Ad": 0.25,
90
+ "FEF": 0.15, "p32pr": 0.35, "a32pr": 0.3,
91
+ },
92
+ "multimodal": {
93
+ "V1": 0.7, "V2": 0.65, "V3": 0.55, "V4": 0.5,
94
+ "MT": 0.5, "MST": 0.45, "FFC": 0.4, "VVC": 0.35,
95
+ "A1": 0.7, "LBelt": 0.65, "MBelt": 0.55, "PBelt": 0.5, "A4": 0.45, "A5": 0.4,
96
+ "44": 0.65, "45": 0.6, "IFJa": 0.5, "IFJp": 0.45,
97
+ "TPOJ1": 0.6, "TPOJ2": 0.55, "STV": 0.5, "PSL": 0.45,
98
+ "46": 0.4, "9-46d": 0.35, "8Av": 0.3, "8Ad": 0.25,
99
+ "FEF": 0.3, "p32pr": 0.25, "a32pr": 0.2,
100
+ },
101
+ }
102
+
103
+
104
+ def generate_realistic_predictions(
105
+ n_timepoints,
106
+ roi_indices,
107
+ stimulus_type="visual",
108
+ tr_seconds=1.0,
109
+ n_events=5,
110
+ snr=2.0,
111
+ seed=42,
112
+ ):
113
+ """Generate biologically realistic fMRI-like predictions.
114
+
115
+ Parameters
116
+ ----------
117
+ n_timepoints : int
118
+ Number of TRs.
119
+ roi_indices : dict[str, np.ndarray]
120
+ ROI name -> vertex indices mapping.
121
+ stimulus_type : str
122
+ One of "visual", "auditory", "language", "multimodal".
123
+ tr_seconds : float
124
+ Repetition time in seconds.
125
+ n_events : int
126
+ Number of stimulus events.
127
+ snr : float
128
+ Signal-to-noise ratio (higher = cleaner signal).
129
+ seed : int
130
+ Random seed.
131
+ """
132
+ rng = np.random.default_rng(seed)
133
+ n_vertices = max(max(v) for v in roi_indices.values()) + 1
134
+ predictions = np.zeros((n_timepoints, n_vertices))
135
+
136
+ # 1. Generate stimulus-evoked signal
137
+ events = generate_stimulus_events(n_timepoints, tr_seconds, n_events, seed)
138
+ hrf = generate_hrf(tr_seconds)
139
+
140
+ # Convolve events with HRF
141
+ bold_signal = np.convolve(events, hrf)[:n_timepoints]
142
+
143
+ # 2. Apply modality-specific weights per ROI
144
+ weights = MODALITY_WEIGHTS.get(stimulus_type, MODALITY_WEIGHTS["multimodal"])
145
+ for roi_name, vertices in roi_indices.items():
146
+ w = weights.get(roi_name, 0.1)
147
+ # Add per-ROI latency jitter (higher-order areas respond later)
148
+ latency_shift = 0
149
+ if roi_name in ["44", "45", "IFJa", "IFJp", "46", "9-46d"]:
150
+ latency_shift = int(2.0 / tr_seconds) # ~2s later for association cortex
151
+ elif roi_name in ["TPOJ1", "TPOJ2", "STV", "PSL"]:
152
+ latency_shift = int(1.5 / tr_seconds)
153
+
154
+ shifted = np.roll(bold_signal, latency_shift) * w
155
+ # Add per-vertex variation within ROI
156
+ for v in vertices:
157
+ if v < n_vertices:
158
+ vertex_scale = 0.8 + 0.4 * rng.random()
159
+ predictions[:, v] = shifted * vertex_scale
160
+
161
+ # 3. Add temporal autocorrelation (AR(1) noise)
162
+ ar_coeff = 0.5
163
+ noise = rng.standard_normal(predictions.shape)
164
+ for t in range(1, n_timepoints):
165
+ noise[t] += ar_coeff * noise[t - 1]
166
+
167
+ # 4. Add scanner drift (low-frequency sinusoidal)
168
+ t_axis = np.arange(n_timepoints) * tr_seconds
169
+ drift = 0.1 * np.sin(2 * np.pi * t_axis / (n_timepoints * tr_seconds * 0.8))
170
+ drift = drift[:, np.newaxis]
171
+
172
+ # 5. Combine signal + noise + drift
173
+ signal_power = np.std(predictions[predictions != 0]) if np.any(predictions != 0) else 1.0
174
+ noise_power = signal_power / max(snr, 0.1)
175
+ predictions = predictions + noise * noise_power + drift
176
+
177
+ # 6. Spatial smoothing (average with neighbors within same ROI)
178
+ for roi_name, vertices in roi_indices.items():
179
+ valid = vertices[vertices < n_vertices]
180
+ if len(valid) > 1:
181
+ roi_data = predictions[:, valid].copy()
182
+ kernel = np.ones(min(3, len(valid))) / min(3, len(valid))
183
+ for t in range(n_timepoints):
184
+ predictions[t, valid] = np.convolve(roi_data[t], kernel, mode="same")
185
+
186
+ return predictions
187
+
188
+
189
+ def generate_correlated_features(
190
+ brain_predictions,
191
+ alignment_strength=0.5,
192
+ feature_dim=512,
193
+ seed=42,
194
+ ):
195
+ """Generate model features with controllable correlation to brain data.
196
+
197
+ Parameters
198
+ ----------
199
+ brain_predictions : np.ndarray
200
+ Brain data of shape (n_stimuli, n_vertices).
201
+ alignment_strength : float
202
+ 0.0 = random features, 1.0 = perfectly correlated with brain.
203
+ feature_dim : int
204
+ Output feature dimensionality.
205
+ seed : int
206
+ Random seed.
207
+
208
+ Returns
209
+ -------
210
+ np.ndarray
211
+ Features of shape (n_stimuli, feature_dim).
212
+ """
213
+ rng = np.random.default_rng(seed)
214
+ n_stimuli = brain_predictions.shape[0]
215
+
216
+ # Project brain data to feature_dim via random projection
217
+ n_vertices = brain_predictions.shape[1]
218
+ projection = rng.standard_normal((n_vertices, feature_dim)) / np.sqrt(n_vertices)
219
+ brain_projected = brain_predictions @ projection
220
+
221
+ # Generate random features
222
+ random_features = rng.standard_normal((n_stimuli, feature_dim))
223
+
224
+ # Mix: strength controls brain-alignment vs randomness
225
+ strength = np.clip(alignment_strength, 0.0, 1.0)
226
+ features = strength * brain_projected + (1 - strength) * random_features
227
+
228
+ # Standardize
229
+ features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)
230
+ return features
utils.py CHANGED
@@ -79,13 +79,114 @@ ALIGNMENT_METHODS = {"RSA": rsa_score, "CKA": cka_score, "Procrustes": procruste
79
 
80
 
81
  def permutation_test(model_feat, brain_pred, method_fn, n_perm=500, seed=42):
 
82
  rng = np.random.default_rng(seed)
83
  observed = method_fn(model_feat, brain_pred)
84
- count = sum(
85
- 1 for _ in range(n_perm)
86
- if method_fn(model_feat[rng.permutation(len(model_feat))], brain_pred) >= observed
87
- )
88
- return observed, (count + 1) / (n_perm + 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  # --- Cognitive Load ---
 
79
 
80
 
81
  def permutation_test(model_feat, brain_pred, method_fn, n_perm=500, seed=42):
82
+ """Returns (observed_score, p_value, null_distribution)."""
83
  rng = np.random.default_rng(seed)
84
  observed = method_fn(model_feat, brain_pred)
85
+ null_dist = []
86
+ for _ in range(n_perm):
87
+ perm_score = method_fn(model_feat[rng.permutation(len(model_feat))], brain_pred)
88
+ null_dist.append(perm_score)
89
+ null_dist = np.array(null_dist)
90
+ count = np.sum(null_dist >= observed)
91
+ p_value = (count + 1) / (n_perm + 1)
92
+ return observed, p_value, null_dist
93
+
94
+
95
+ def bootstrap_ci(model_feat, brain_pred, method_fn, n_boot=500, confidence=0.95, seed=42):
96
+ """Returns (point_estimate, ci_lower, ci_upper)."""
97
+ rng = np.random.default_rng(seed)
98
+ n = model_feat.shape[0]
99
+ point = method_fn(model_feat, brain_pred)
100
+ scores = []
101
+ for _ in range(n_boot):
102
+ idx = rng.choice(n, size=n, replace=True)
103
+ scores.append(method_fn(model_feat[idx], brain_pred[idx]))
104
+ scores = np.array(scores)
105
+ alpha = 1 - confidence
106
+ return point, float(np.percentile(scores, 100 * alpha / 2)), float(np.percentile(scores, 100 * (1 - alpha / 2)))
107
+
108
+
109
+ def fdr_correction(p_values, alpha=0.05):
110
+ """Benjamini-Hochberg FDR correction. Returns corrected p-values and significance mask."""
111
+ p = np.array(p_values)
112
+ n = len(p)
113
+ sorted_idx = np.argsort(p)
114
+ sorted_p = p[sorted_idx]
115
+ corrected = np.empty(n)
116
+ corrected[sorted_idx[-1]] = sorted_p[-1]
117
+ for i in range(n - 2, -1, -1):
118
+ corrected[sorted_idx[i]] = min(corrected[sorted_idx[i + 1]], sorted_p[i] * n / (i + 1))
119
+ return corrected, corrected < alpha
120
+
121
+
122
+ def noise_ceiling(brain_pred, method_fn, n_splits=20, seed=42):
123
+ """Estimate noise ceiling via split-half reliability."""
124
+ rng = np.random.default_rng(seed)
125
+ n = brain_pred.shape[0]
126
+ scores = []
127
+ for _ in range(n_splits):
128
+ idx = rng.permutation(n)
129
+ half = n // 2
130
+ s = method_fn(brain_pred[idx[:half]], brain_pred[idx[half:half * 2]])
131
+ scores.append(s)
132
+ return float(np.mean(scores)), float(np.std(scores))
133
+
134
+
135
+ def partial_correlation(predictions, roi_indices):
136
+ """Compute partial correlation matrix (correlation controlling for mean signal)."""
137
+ names = list(roi_indices.keys())
138
+ n = len(names)
139
+ T = predictions.shape[0]
140
+ timecourses = np.zeros((n, T))
141
+ for i, name in enumerate(names):
142
+ verts = roi_indices[name]
143
+ valid = verts[verts < predictions.shape[1]]
144
+ if len(valid) > 0:
145
+ timecourses[i] = predictions[:, valid].mean(axis=1)
146
+
147
+ # Partial correlation via precision matrix
148
+ cov = np.cov(timecourses)
149
+ try:
150
+ prec = np.linalg.inv(cov + 1e-6 * np.eye(n))
151
+ d = np.sqrt(np.diag(prec))
152
+ d[d == 0] = 1
153
+ partial = -prec / np.outer(d, d)
154
+ np.fill_diagonal(partial, 1.0)
155
+ except np.linalg.LinAlgError:
156
+ partial = np.eye(n)
157
+ return np.nan_to_num(partial, nan=0.0), names
158
+
159
+
160
+ def betweenness_centrality(corr_matrix, roi_names, threshold=0.3):
161
+ """Compute betweenness centrality from thresholded connectivity."""
162
+ import networkx as nx
163
+ n = corr_matrix.shape[0]
164
+ G = nx.Graph()
165
+ for i, name in enumerate(roi_names):
166
+ G.add_node(name)
167
+ for i in range(n):
168
+ for j in range(i + 1, n):
169
+ if abs(corr_matrix[i, j]) > threshold:
170
+ G.add_edge(roi_names[i], roi_names[j], weight=abs(corr_matrix[i, j]))
171
+ bc = nx.betweenness_centrality(G)
172
+ return {name: bc.get(name, 0.0) for name in roi_names}
173
+
174
+
175
+ def modularity_score(corr_matrix, labels):
176
+ """Compute Newman's modularity Q for a given partition."""
177
+ n = corr_matrix.shape[0]
178
+ adj = np.abs(corr_matrix).copy()
179
+ np.fill_diagonal(adj, 0)
180
+ m = adj.sum() / 2
181
+ if m == 0:
182
+ return 0.0
183
+ Q = 0.0
184
+ k = adj.sum(axis=1)
185
+ for i in range(n):
186
+ for j in range(n):
187
+ if labels[i] == labels[j]:
188
+ Q += adj[i, j] - k[i] * k[j] / (2 * m)
189
+ return float(Q / (2 * m))
190
 
191
 
192
  # --- Cognitive Load ---