JayLacoma commited on
Commit
70d5249
Β·
verified Β·
1 Parent(s): 5f9047d

Upload roi_connectivity.py

Browse files
Files changed (1) hide show
  1. roi_connectivity.py +177 -0
roi_connectivity.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # roi_connectivity.py
2
+
3
+
4
+ import numpy as np
5
+ import mne
6
+ import json
7
+ import pandas as pd
8
+ from mne_connectivity import spectral_connectivity_epochs
9
+ from nilearn import datasets
10
+
11
+ # ───────────────────────────────────────
12
+ # 1. Utility: ROI definitions
13
+ # ───────────────────────────────────────
14
+
15
+ def get_difumo_names():
16
+ try:
17
+ atlas = datasets.fetch_atlas_difumo(dimension=512, resolution_mm=2)
18
+ return atlas.labels['difumo_names'].astype(str).tolist()
19
+ except Exception:
20
+ return [f"Component_{i}" for i in range(512)]
21
+
22
+ def define_motor_cognitive_regions():
23
+ Motor_M1 = [40, 86, 198, 268, 305, 437, 458, 465]
24
+ Motor_SMA_Premotor = [17, 18, 288, 291, 296, 297, 302, 305, 314, 315, 335, 375, 379, 448]
25
+ Motor_Medial = [101, 102, 388, 409, 498]
26
+ Thalamus = [70, 73, 297, 334, 414, 420]
27
+ Basal_Ganglia = [30, 53, 224, 260, 405, 422, 109, 110, 315, 331, 467, 479, 55, 71, 307, 223]
28
+ Cerebellum_Motor = [43, 47, 83, 84, 127, 183, 220, 221, 295, 304, 310, 311, 374, 378, 381, 403, 441, 490, 491]
29
+ Somatosensory = [44, 131, 210, 411, 413, 436]
30
+ Executive_Control = [3, 85, 104, 148, 184, 337, 377, 446, 447, 506, 507]
31
+ Interoception = [2, 387, 358, 389, 165, 469]
32
+ Error_Monitoring = [185, 219, 326, 473, 492]
33
+ return sorted(set(
34
+ Motor_M1 + Motor_SMA_Premotor + Motor_Medial + Thalamus +
35
+ Basal_Ganglia + Cerebellum_Motor + Somatosensory +
36
+ Executive_Control + Interoception + Error_Monitoring
37
+ ))
38
+
39
+ def get_band_freqs(band_name):
40
+ bands = {
41
+ "Theta": (4, 8),
42
+ "Alpha": (8, 12),
43
+ "Low_Beta": (13, 20),
44
+ "High_Beta": (20, 30),
45
+ "Low_Gamma": (30, 60),
46
+ "High_Gamma": (60, 120)
47
+ }
48
+ if band_name not in bands:
49
+ raise ValueError(f"Unknown band: {band_name}. Options: {list(bands.keys())}")
50
+ return bands[band_name]
51
+
52
+ # ───────────────────────────────────────
53
+ # 2. Epoch creation functions
54
+ # ───────────────────────────────────────
55
+
56
+ def create_task_epochs(
57
+ data_file,
58
+ events_file,
59
+ event_id_file,
60
+ condition,
61
+ tmin=0.0,
62
+ tmax=1.5,
63
+ sfreq=500.0
64
+ ):
65
+ """Create epochs from event markers."""
66
+ data = np.load(data_file)
67
+ if data.shape[0] > data.shape[1]:
68
+ data = data.T
69
+
70
+ events = mne.read_events(events_file)
71
+ with open(event_id_file, 'r') as f:
72
+ event_id = json.load(f)
73
+
74
+ ch_names = [f'C{i}' for i in range(data.shape[0])]
75
+ info = mne.create_info(ch_names, sfreq=sfreq, ch_types='misc')
76
+ raw = mne.io.RawArray(data, info, verbose=False)
77
+
78
+ epochs = mne.Epochs(
79
+ raw, events, {condition: event_id[condition]},
80
+ tmin=tmin, tmax=tmax, baseline=None,
81
+ preload=True, verbose=False, event_repeated='drop'
82
+ )
83
+ return epochs
84
+
85
+ def create_rest_epochs(
86
+ data_file,
87
+ duration=2.5,
88
+ sfreq=500.0
89
+ ):
90
+ """Create fixed-length epochs from continuous data."""
91
+ data = np.load(data_file)
92
+ if data.shape[0] > data.shape[1]:
93
+ data = data.T
94
+
95
+ ch_names = [f'C{i}' for i in range(data.shape[0])]
96
+ info = mne.create_info(ch_names, sfreq=sfreq, ch_types='misc')
97
+ raw = mne.io.RawArray(data, info, verbose=False)
98
+
99
+ events = mne.make_fixed_length_events(raw, duration=duration)
100
+ epochs = mne.Epochs(
101
+ raw, events, tmin=0, tmax=duration,
102
+ baseline=None, preload=True, verbose=False
103
+ )
104
+ return epochs
105
+
106
+ # ───────────────────────────────────────
107
+ # 3. Connectivity function
108
+ # ───────────────────────────────────────
109
+
110
+ def compute_roi_connectivity_matrix(
111
+ epochs,
112
+ band_name="Low_Beta",
113
+ method='wpli2_debiased',
114
+ sfreq=500.0
115
+ ):
116
+ """
117
+ Compute a single ROI Γ— ROI connectivity matrix from MNE Epochs object.
118
+
119
+ Parameters:
120
+ - epochs: mne.Epochs instance (already loaded and preprocessed)
121
+ - band_name: e.g., "Alpha", "Low_Beta"
122
+ - method: connectivity method (default: 'wpli2_debiased')
123
+
124
+ Returns:
125
+ - conn_df: pandas DataFrame (n_roi Γ— n_roi) with DiFuMo ROI names as labels
126
+ """
127
+ # Get ROI info
128
+ all_names = get_difumo_names()
129
+ selected_indices = define_motor_cognitive_regions()
130
+ roi_names = [all_names[i] for i in selected_indices]
131
+
132
+ # Extract data for selected ROIs
133
+ epoch_data = epochs.get_data()[:, selected_indices, :] # (n_epochs, n_roi, n_times)
134
+
135
+ # Get frequency range
136
+ fmin, fmax = get_band_freqs(band_name)
137
+
138
+ # Compute connectivity
139
+ con = spectral_connectivity_epochs(
140
+ data=epoch_data,
141
+ method=method,
142
+ mode='multitaper',
143
+ sfreq=sfreq,
144
+ fmin=fmin,
145
+ fmax=fmax,
146
+ faverage=True,
147
+ verbose=False
148
+ )
149
+ matrix = con.get_data(output='dense').squeeze()
150
+ matrix = (matrix + matrix.T) / 2
151
+ np.fill_diagonal(matrix, 0)
152
+
153
+ return pd.DataFrame(matrix, index=roi_names, columns=roi_names)
154
+
155
+
156
+ '''
157
+ # Task
158
+ epochs = create_task_epochs(
159
+ data_file=r"sub-02_task\difumo_time_courses.npy",
160
+ events_file=r"sub-02_task\sub-02_events_mne_binary-eve.fif",
161
+ event_id_file=r"sub-02_task\sub-02_event_id_binary.json",
162
+ condition="InPhase"
163
+ )
164
+
165
+ conn_matrix = compute_roi_connectivity_matrix(epochs, band_name="Alpha")
166
+ conn_matrix.to_csv("sub-02_InPhase_Alpha_matrix.csv")
167
+
168
+
169
+ # Rest
170
+ epochs = create_rest_epochs(
171
+ data_file=r"sub-02_rest\difumo_time_courses.npy",
172
+ duration=2.5
173
+ )
174
+
175
+ conn_matrix = compute_roi_connectivity_matrix(epochs, band_name="Low_Beta")
176
+
177
+ '''