JayLacoma commited on
Commit
2e99ef5
·
verified ·
1 Parent(s): 041409d

Upload 2 files

Browse files
Files changed (2) hide show
  1. ghw280_from_egig.gpsc +284 -0
  2. ica_xtra.py +978 -0
ghw280_from_egig.gpsc ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FidNz 103 203 122
2
+ FidT9 25 98 112
3
+ FidT10 176 101 114
4
+ E1 99 194 201
5
+ E2 107 203 186
6
+ E3 114 191 204
7
+ E4 99 181 216
8
+ E5 82 189 203
9
+ E6 89 202 185
10
+ E7 100 208 169
11
+ E8 117 207 170
12
+ E9 125 198 190
13
+ E10 130 185 206
14
+ E11 115 178 218
15
+ E12 100 167 226
16
+ E13 82 176 216
17
+ E14 66 181 204
18
+ E15 70 197 186
19
+ E16 80 206 171
20
+ E17 91 207 149
21
+ E18 109 207 149
22
+ E19 127 205 156
23
+ E20 136 202 174
24
+ E21 142 190 193
25
+ E22 145 173 209
26
+ E23 131 167 221
27
+ E24 118 161 229
28
+ E25 102 149 237
29
+ E26 83 159 229
30
+ E27 66 163 219
31
+ E28 52 169 204
32
+ E29 55 185 188
33
+ E30 60 197 167
34
+ E31 74 204 152
35
+ E32 103 206 129
36
+ E33 155 189 171
37
+ E34 132 147 232
38
+ E35 68 145 231
39
+ E36 46 183 166
40
+ E37 164 160 105
41
+ E38 168 145 85
42
+ E39 171 145 105
43
+ E40 173 149 125
44
+ E41 167 166 129
45
+ E42 149 169 73
46
+ E43 153 151 62
47
+ E44 163 136 68
48
+ E45 174 125 91
49
+ E46 175 124 109
50
+ E47 177 131 127
51
+ E48 178 143 144
52
+ E49 171 161 147
53
+ E50 163 175 149
54
+ E51 147 195 154
55
+ E52 136 180 93
56
+ E53 158 120 53
57
+ E54 179 114 128
58
+ E55 182 125 144
59
+ E56 178 128 199
60
+ E57 178 143 181
61
+ E58 185 125 181
62
+ E59 182 112 203
63
+ E60 169 130 215
64
+ E61 171 145 199
65
+ E62 169 160 187
66
+ E63 176 152 163
67
+ E64 183 135 161
68
+ E65 187 117 161
69
+ E66 187 108 184
70
+ E67 184 95 205
71
+ E68 174 108 219
72
+ E69 157 121 229
73
+ E70 157 144 220
74
+ E71 159 159 207
75
+ E72 158 176 191
76
+ E73 167 169 167
77
+ E74 185 106 144
78
+ E75 189 97 166
79
+ E76 187 87 184
80
+ E77 184 78 207
81
+ E78 176 90 223
82
+ E79 160 100 234
83
+ E80 144 116 238
84
+ E81 146 136 230
85
+ E82 145 156 222
86
+ E83 187 80 167
87
+ E84 160 79 238
88
+ E85 164 58 102
89
+ E86 177 109 95
90
+ E87 173 98 79
91
+ E88 170 83 78
92
+ E89 163 68 85
93
+ E90 156 45 102
94
+ E91 169 57 120
95
+ E92 177 67 135
96
+ E93 169 120 76
97
+ E94 164 103 63
98
+ E95 159 69 67
99
+ E96 155 53 83
100
+ E97 145 34 102
101
+ E98 158 40 122
102
+ E99 169 48 135
103
+ E100 174 51 153
104
+ E101 181 69 151
105
+ E102 148 48 66
106
+ E103 157 32 136
107
+ E104 155 29 206
108
+ E105 168 43 208
109
+ E106 165 34 190
110
+ E107 150 22 191
111
+ E108 140 24 211
112
+ E109 152 38 224
113
+ E110 164 52 225
114
+ E111 177 58 209
115
+ E112 176 49 189
116
+ E113 171 41 171
117
+ E114 158 27 171
118
+ E115 143 17 173
119
+ E116 133 16 194
120
+ E117 123 18 212
121
+ E118 132 31 228
122
+ E119 142 46 238
123
+ E120 153 60 239
124
+ E121 173 71 224
125
+ E122 183 68 187
126
+ E123 179 59 170
127
+ E124 165 36 154
128
+ E125 150 22 156
129
+ E126 132 11 157
130
+ E127 123 10 175
131
+ E128 114 12 196
132
+ E129 102 16 212
133
+ E130 110 25 227
134
+ E131 121 38 239
135
+ E132 129 56 247
136
+ E133 103 7 176
137
+ E134 99 37 240
138
+ E135 103 19 104
139
+ E136 112 14 121
140
+ E137 120 19 101
141
+ E138 104 21 89
142
+ E139 86 20 102
143
+ E140 94 16 121
144
+ E141 103 8 138
145
+ E142 123 12 137
146
+ E143 128 17 120
147
+ E144 133 24 102
148
+ E145 127 27 80
149
+ E146 104 23 76
150
+ E147 82 25 81
151
+ E148 70 26 99
152
+ E149 74 18 118
153
+ E150 82 11 137
154
+ E151 89 5 155
155
+ E152 112 7 156
156
+ E153 141 21 137
157
+ E154 144 27 121
158
+ E155 145 40 81
159
+ E156 131 32 64
160
+ E157 106 26 59
161
+ E158 80 30 57
162
+ E159 65 34 80
163
+ E160 55 37 97
164
+ E161 60 26 116
165
+ E162 62 18 135
166
+ E163 69 10 154
167
+ E164 59 45 61
168
+ E165 48 29 133
169
+ E166 88 79 253
170
+ E167 84 99 250
171
+ E168 102 113 250
172
+ E169 116 102 250
173
+ E170 108 81 254
174
+ E171 77 65 251
175
+ E172 70 82 250
176
+ E173 72 107 246
177
+ E174 85 121 245
178
+ E175 102 131 244
179
+ E176 118 121 245
180
+ E177 132 108 245
181
+ E178 125 85 251
182
+ E179 118 68 252
183
+ E180 108 50 247
184
+ E181 88 50 247
185
+ E182 68 51 243
186
+ E183 58 68 246
187
+ E184 53 88 243
188
+ E185 57 110 238
189
+ E186 70 127 239
190
+ E187 85 142 239
191
+ E188 119 141 239
192
+ E189 131 127 240
193
+ E190 144 94 245
194
+ E191 142 73 247
195
+ E192 37 72 233
196
+ E193 48 27 204
197
+ E194 51 34 220
198
+ E195 64 20 206
199
+ E196 54 19 185
200
+ E197 37 30 184
201
+ E198 35 38 204
202
+ E199 37 48 220
203
+ E200 59 41 234
204
+ E201 68 28 225
205
+ E202 82 17 210
206
+ E203 72 12 189
207
+ E204 61 14 171
208
+ E205 43 25 171
209
+ E206 31 39 166
210
+ E207 26 45 187
211
+ E208 23 54 205
212
+ E209 27 65 220
213
+ E210 47 55 235
214
+ E211 78 39 240
215
+ E212 89 25 228
216
+ E213 93 10 194
217
+ E214 81 8 175
218
+ E215 52 20 153
219
+ E216 40 31 149
220
+ E217 29 47 150
221
+ E218 22 55 166
222
+ E219 18 61 184
223
+ E220 18 70 203
224
+ E221 17 73 161
225
+ E222 32 65 102
226
+ E223 26 57 132
227
+ E224 34 53 115
228
+ E225 44 49 96
229
+ E226 40 64 84
230
+ E227 34 80 76
231
+ E228 30 94 73
232
+ E229 26 101 93
233
+ E230 27 116 108
234
+ E231 23 106 125
235
+ E232 17 100 142
236
+ E233 22 60 147
237
+ E234 35 43 134
238
+ E235 46 39 114
239
+ E236 50 49 79
240
+ E237 47 61 60
241
+ E238 44 100 54
242
+ E239 36 111 69
243
+ E240 30 119 89
244
+ E241 26 128 123
245
+ E242 48 119 50
246
+ E243 24 126 195
247
+ E244 30 139 197
248
+ E245 29 120 210
249
+ E246 18 107 196
250
+ E247 17 116 177
251
+ E248 22 133 179
252
+ E249 30 150 182
253
+ E250 39 153 202
254
+ E251 39 135 214
255
+ E252 42 114 227
256
+ E253 26 102 216
257
+ E254 16 89 199
258
+ E255 14 99 178
259
+ E256 15 111 158
260
+ E257 20 131 159
261
+ E258 27 147 161
262
+ E259 34 162 163
263
+ E260 39 168 186
264
+ E261 53 150 219
265
+ E262 54 130 229
266
+ E263 37 93 230
267
+ E264 24 82 220
268
+ E265 15 81 182
269
+ E266 14 90 159
270
+ E267 21 118 140
271
+ E268 26 137 137
272
+ E269 32 154 142
273
+ E270 39 157 99
274
+ E271 35 138 83
275
+ E272 30 138 104
276
+ E273 29 146 122
277
+ E274 37 163 127
278
+ E275 53 168 72
279
+ E276 50 149 60
280
+ E277 40 129 67
281
+ E278 41 172 147
282
+ E279 53 190 149
283
+ E280 67 181 90
284
+ Cz 101 96 253
ica_xtra.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ica_xtra.py
2
+
3
+ # Standard library
4
+ from pathlib import Path
5
+ import logging
6
+ from typing import Dict, List, Tuple, Optional, Union
7
+
8
+ # Third-party scientific stack
9
+ import numpy as np
10
+ from scipy.stats import median_abs_deviation, kurtosis
11
+ import matplotlib
12
+ import matplotlib.pyplot as plt
13
+
14
+ # MNE and related
15
+ import mne
16
+ from mne import io
17
+ from mne.io import constants
18
+ from mne_icalabel import label_components
19
+
20
+ # Configure MNE
21
+ mne.set_log_level('WARNING')
22
+
23
+ # ============================================================================
24
+ # CONSTANTS
25
+ # ============================================================================
26
+
27
+ CHANNEL_RENAME_MAP = {**{str(i): f'E{i}' for i in range(1, 281)}, 'REF CZ': 'Cz'}
28
+ EXPECTED_EEG_CHANNELS = {f"E{i}" for i in range(1, 281)} | {"Cz"}
29
+ PROTECTED_CHANNELS = {'E31', 'E19', 'E41', 'E274', 'E227', 'E229', 'E280', 'E52'}
30
+
31
+ # Artifact detection channels
32
+ VVEOG = ('E31', 'E19')
33
+ HEOG = ('E41', 'E274')
34
+ ECG = ('E227', 'E229')
35
+ EMG_CHS = ['E280', 'E52']
36
+ FRONTAL_CHS = ['E31', 'E19']
37
+
38
+ # ICLabel thresholds
39
+ ICALABEL_THRESHOLDS = {
40
+ 'eye blink': 0.80,
41
+ 'heart beat': 0.80,
42
+ 'muscle artifact': 0.75,
43
+ 'line noise': 0.80,
44
+ 'channel noise': 0.80
45
+ }
46
+
47
+
48
+ # ============================================================================
49
+ # UTILITY FUNCTIONS
50
+ # ============================================================================
51
+
52
+ def setup_logging(subject: str, output_path: Path, log_to_file: bool = True) -> Path:
53
+ """Setup logging and return log file path."""
54
+ log_file = output_path / f"{subject}_preproc_log.txt"
55
+ if log_to_file:
56
+ log(f"Initialized preprocessing for {subject}", log_file, log_to_file)
57
+ return log_file
58
+
59
+
60
+ def log(msg: str, log_file: Path, log_to_file: bool = True, detail: str = "normal"):
61
+ """Log message to file and optionally console."""
62
+ if log_to_file:
63
+ with open(log_file, 'a') as f:
64
+ f.write(f"{msg}\n")
65
+ if detail == "normal":
66
+ print(msg)
67
+
68
+
69
+ def parse_gpsc(filepath: Path) -> List[Tuple[str, float, float, float]]:
70
+ """Parse GPSC file efficiently."""
71
+ channels = []
72
+ with open(filepath, 'r') as f:
73
+ for line in f:
74
+ parts = line.strip().split()
75
+ if len(parts) >= 4:
76
+ try:
77
+ name = parts[0]
78
+ x, y, z = map(float, parts[1:4])
79
+ channels.append((name, x, y, z))
80
+ except ValueError:
81
+ continue
82
+ return channels
83
+
84
+
85
+ # ============================================================================
86
+ # DATA LOADING AND PREPARATION
87
+ # ============================================================================
88
+
89
+ def load_raw_data(input_path: Path, input_format: str, log_file: Path,
90
+ log_to_file: bool = True) -> mne.io.Raw:
91
+ """Load raw data from MFF or FIF format."""
92
+ if input_format == "mff":
93
+ log("Loading raw data from .mff...", log_file, log_to_file)
94
+ raw = mne.io.read_raw_egi(str(input_path), preload=True)
95
+ elif input_format == "fif":
96
+ log(f"Loading raw data from .fif: {input_path}", log_file, log_to_file)
97
+ if not input_path.is_file() or input_path.suffix != '.fif':
98
+ raise ValueError(f"Invalid .fif file: {input_path}")
99
+ raw = mne.io.read_raw_fif(str(input_path), preload=True)
100
+ else:
101
+ raise ValueError("input_format must be 'mff' or 'fif'")
102
+
103
+ return raw
104
+
105
+
106
+ def apply_channel_renaming(raw: mne.io.Raw, log_file: Path,
107
+ log_to_file: bool = True) -> mne.io.Raw:
108
+ """Apply channel renaming."""
109
+ log("Applying channel renaming...", log_file, log_to_file)
110
+
111
+ existing_map = {old: new for old, new in CHANNEL_RENAME_MAP.items()
112
+ if old in raw.ch_names}
113
+ if existing_map:
114
+ raw.rename_channels(existing_map)
115
+ log(f"Renamed {len(existing_map)} channels.", log_file, log_to_file)
116
+
117
+ return raw
118
+
119
+
120
+ def apply_montage(raw: mne.io.Raw, gpsc_file: Path, log_file: Path,
121
+ log_to_file: bool = True) -> mne.io.Raw:
122
+ """Apply GPS montage from GPSC file."""
123
+ channels = parse_gpsc(gpsc_file)
124
+ if not channels:
125
+ raise ValueError("No valid channels in .gpsc file")
126
+
127
+ # Normalize positions
128
+ gpsc_array = np.array([ch[1:4] for ch in channels])
129
+ mean_pos = gpsc_array.mean(axis=0)
130
+ log(f"Original mean position (mm): {mean_pos}", log_file, log_to_file)
131
+
132
+ ch_pos = {
133
+ ch[0]: np.array([ch[1] - mean_pos[0], ch[2] - mean_pos[1], ch[3] - mean_pos[2]]) / 1000.0
134
+ for ch in channels
135
+ }
136
+
137
+ # Create montage with fiducials
138
+ montage = mne.channels.make_dig_montage(
139
+ ch_pos=ch_pos,
140
+ nasion=ch_pos.get('FidNz'),
141
+ lpa=ch_pos.get('FidT9'),
142
+ rpa=ch_pos.get('FidT10'),
143
+ coord_frame='head'
144
+ )
145
+ raw.set_montage(montage, on_missing='warn')
146
+ log("Montage applied.", log_file, log_to_file)
147
+
148
+ return raw
149
+
150
+
151
+ # ============================================================================
152
+ # FILTERING FUNCTIONS
153
+ # ============================================================================
154
+
155
+ def apply_highpass_filter(raw: mne.io.Raw, l_freq: float, log_file: Path,
156
+ log_to_file: bool = True) -> mne.io.Raw:
157
+ """Apply highpass filter."""
158
+ log(f"Applying highpass filter at {l_freq} Hz...", log_file, log_to_file)
159
+ return raw.copy().filter(
160
+ l_freq=l_freq, h_freq=None, picks=['eeg'],
161
+ method='fir', phase='zero', fir_window='hamming',
162
+ fir_design='firwin', n_jobs=-1
163
+ )
164
+
165
+
166
+ def apply_lowpass_filter(raw: mne.io.Raw, h_freq: float, log_file: Path,
167
+ log_to_file: bool = True) -> mne.io.Raw:
168
+ """Apply lowpass filter."""
169
+ log(f"Applying lowpass filter at {h_freq} Hz...", log_file, log_to_file)
170
+ return raw.copy().filter(
171
+ l_freq=None, h_freq=h_freq, picks=['eeg'],
172
+ method='fir', phase='zero', fir_window='hamming',
173
+ fir_design='firwin', n_jobs=-1
174
+ )
175
+
176
+
177
+
178
+ def apply_notch_filter(raw: mne.io.Raw, line_freq: float, log_file: Path,
179
+ log_to_file: bool = True, max_freq: float = 100.0) -> mne.io.Raw:
180
+ """Apply notch filter up to max_freq (default: 100 Hz)."""
181
+ nyquist = raw.info["sfreq"] / 2
182
+ upper = min(nyquist, max_freq)
183
+ notch_freqs = np.arange(line_freq, upper + line_freq, line_freq)
184
+ notch_freqs = notch_freqs[notch_freqs < upper]
185
+
186
+ if len(notch_freqs) > 0:
187
+ log(f"Applying notch filter at: {notch_freqs}", log_file, log_to_file)
188
+ return raw.copy().notch_filter(
189
+ freqs=notch_freqs, picks='eeg', method='spectrum_fit',
190
+ filter_length='auto', mt_bandwidth=1.0, p_value=0.05
191
+ )
192
+ return raw
193
+
194
+ def filter_data(raw: mne.io.Raw, apply_highpass: bool, apply_lowpass: bool,
195
+ apply_notch: bool, l_freq: float, h_freq: float,
196
+ line_freq: float, log_file: Path, log_to_file: bool = True) -> mne.io.Raw:
197
+ """Apply all selected filters in sequence."""
198
+ log("Applying filters...", log_file, log_to_file)
199
+
200
+ filtered_raw = raw.copy()
201
+
202
+ if apply_highpass and l_freq is not None:
203
+ filtered_raw = apply_highpass_filter(filtered_raw, l_freq, log_file, log_to_file)
204
+
205
+ if apply_lowpass and h_freq is not None:
206
+ filtered_raw = apply_lowpass_filter(filtered_raw, h_freq, log_file, log_to_file)
207
+
208
+ if apply_notch:
209
+ filtered_raw = apply_notch_filter(filtered_raw, line_freq, log_file, log_to_file)
210
+
211
+ if not (apply_highpass or apply_lowpass or apply_notch):
212
+ log("No filters applied (all filter types disabled).", log_file, log_to_file)
213
+
214
+
215
+ # Check Cz for flat signal
216
+ if 'Cz' in filtered_raw.ch_names and 'Cz' not in filtered_raw.info['bads']:
217
+ if np.std(filtered_raw.get_data(picks=['Cz'])[0]) < 1e-6:
218
+ filtered_raw.info['bads'].append('Cz')
219
+ log("Marked Cz as bad (flat signal).", log_file, log_to_file)
220
+
221
+ return filtered_raw
222
+
223
+
224
+ # ============================================================================
225
+ # BAD CHANNEL DETECTION
226
+ # ============================================================================
227
+
228
+ def detect_bad_channels(raw: mne.io.Raw, subject: str, output_path: Path,
229
+ plot: bool, log_file: Path, log_to_file: bool = True,
230
+ mad_threshold: float = 5.0,
231
+ min_amplitude_uv: float = 0.1,
232
+ protected_channels: Optional[set] = None) -> mne.io.Raw:
233
+ """
234
+ Detect bad channels using MAD-based outlier detection.
235
+
236
+ Args:
237
+ raw: The raw EEG data.
238
+ subject: Subject identifier for logging and file naming.
239
+ output_path: Path to save output files and plots.
240
+ plot: Whether to generate and save diagnostic plots.
241
+ log_file: Path to the log file.
242
+ log_to_file: Whether to write logs to the file.
243
+ mad_threshold: Threshold for MAD-based outlier detection.
244
+ min_amplitude_uv: Minimum amplitude threshold for flat channel detection (µV).
245
+ protected_channels: Optional set of channels to protect from being marked as bad.
246
+ If None, defaults to the global PROTECTED_CHANNELS constant.
247
+
248
+ Returns:
249
+ The input `raw` object with updated `info['bads']`.
250
+ """
251
+ # Determine which channels to protect
252
+ if protected_channels is None:
253
+ protected_chs_to_use = PROTECTED_CHANNELS
254
+ else:
255
+ protected_chs_to_use = protected_channels
256
+
257
+ log(f"Detecting bad channels (flat < {min_amplitude_uv} µV, noisy Z > {mad_threshold})...",
258
+ log_file, log_to_file)
259
+
260
+ raw_eeg = raw.copy().pick_types(eeg=True)
261
+ available_chs = set(raw_eeg.ch_names)
262
+ # Use the potentially overridden protected channels set
263
+ protected_chs = protected_chs_to_use & available_chs
264
+ # Channels to analyze are those not in the protected set
265
+ eeg_chs = [ch for ch in raw_eeg.ch_names if ch not in protected_chs]
266
+
267
+ if not eeg_chs:
268
+ log("No EEG channels available for detection.", log_file, log_to_file)
269
+ return raw
270
+
271
+ # Get data in µV
272
+ raw_for_detection = raw_eeg.copy().pick(eeg_chs)
273
+ data_uv = np.nan_to_num(raw_for_detection.get_data() * 1e6, nan=0.0, posinf=0.0, neginf=0.0)
274
+
275
+ # Compute features
276
+ variance = np.var(data_uv, axis=1)
277
+ amplitude = np.ptp(data_uv, axis=1)
278
+
279
+ # Detect flat channels
280
+ flat_mask = amplitude < min_amplitude_uv
281
+ flat_channels = [raw_for_detection.ch_names[i] for i in np.where(flat_mask)[0]]
282
+
283
+ # Detect noisy channels using MAD
284
+ noisy_mask = np.zeros(len(amplitude), dtype=bool)
285
+ for feat in [variance, amplitude]:
286
+ mad = median_abs_deviation(feat, scale='normal', nan_policy='omit')
287
+ if not np.isnan(mad) and mad > 1e-12:
288
+ z_scores = (feat - np.nanmedian(feat)) / mad
289
+ noisy_mask |= (z_scores > mad_threshold)
290
+
291
+ noisy_channels = [raw_for_detection.ch_names[i] for i in np.where(noisy_mask)[0]]
292
+
293
+ # Combine and update bad channels
294
+ detected_bads = sorted(set(flat_channels + noisy_channels))
295
+ current_bads = set(raw.info['bads'])
296
+ new_bads = [ch for ch in detected_bads if ch not in current_bads]
297
+ raw.info['bads'] = sorted(current_bads | set(detected_bads))
298
+
299
+ # Log results
300
+ if protected_chs:
301
+ log(f"Protected: {sorted(protected_chs)}", log_file, log_to_file)
302
+ log(f"Detected bad channels: {detected_bads}", log_file, log_to_file)
303
+
304
+ # Save plots if enabled
305
+ if plot and detected_bads:
306
+ try:
307
+ montage = raw.get_montage()
308
+ if montage is None:
309
+ log("⚠️ No montage found — skipping topomap.", log_file, log_to_file)
310
+ else:
311
+ ch_pos = {}
312
+ for d in montage.dig:
313
+ if d['kind'] == mne.io.constants.FIFF.FIFFV_POINT_EEG:
314
+ ch_name = montage.ch_names[d['ident'] - 1]
315
+ ch_pos[ch_name] = d['r']
316
+
317
+ plot_bad_channels_topomap(ch_pos, detected_bads, subject,
318
+ output_path, log_file, log_to_file)
319
+ plot_bad_channels_timeseries(raw, detected_bads, subject,
320
+ output_path, log_file, log_to_file)
321
+
322
+ except Exception as e:
323
+ log(f"⚠️ Failed to generate/save bad channel plots: {e}", log_file, log_to_file)
324
+
325
+ return raw
326
+
327
+
328
+ def plot_bad_channels_topomap(ch_pos: Dict[str, np.ndarray], bad_channels: List[str],
329
+ subject: str, output_path: Path, log_file: Path,
330
+ log_to_file: bool = True):
331
+ """Plot and save topomap of bad channels."""
332
+ if not bad_channels:
333
+ return
334
+
335
+ valid_bads = [ch for ch in bad_channels if ch in ch_pos]
336
+ if not valid_bads:
337
+ log("⚠️ No bad channels with valid positions.", log_file, log_to_file)
338
+ return
339
+
340
+ # Unique colors
341
+ cmap = plt.get_cmap('viridis')
342
+ colors = cmap(np.linspace(0, 1, len(valid_bads)))
343
+ color_dict = {ch: colors[i] for i, ch in enumerate(valid_bads)}
344
+
345
+ # Get and scale positions
346
+ pos = np.array([ch_pos[ch][:2] for ch in valid_bads])
347
+ max_radius = np.max(np.sqrt(np.sum(pos**2, axis=1)))
348
+ pos_scaled = (pos / max_radius * 0.1 if max_radius > 0 else pos)
349
+ pos_scaled[:, 1] -= 0.02
350
+
351
+ # Plot
352
+ fig, ax = plt.subplots(figsize=(8, 6))
353
+ for i, (x, y) in enumerate(pos_scaled):
354
+ ch = valid_bads[i]
355
+ color = color_dict[ch]
356
+ ax.plot(x, y, 's', markersize=15, color=color, alpha=1.0)
357
+ ax.text(x, y + 0.01, ch, fontsize=10, ha='center', va='bottom', color=color)
358
+
359
+ # Add head outline
360
+ try:
361
+ mne.viz.plot_topomap(np.zeros(len(pos_scaled)), pos_scaled, axes=ax,
362
+ show=False, sphere=0.1, outlines='head')
363
+ except Exception as e:
364
+ log(f"Topomap background failed: {e}", log_file, log_to_file, detail="debug")
365
+
366
+ ax.set_title('Detected Bad Channels (Topomap)', fontsize=12, pad=20)
367
+ ax.set_xlim(-0.12, 0.12)
368
+ ax.set_ylim(-0.12, 0.12)
369
+ ax.set_aspect('equal')
370
+ ax.grid(True, alpha=0.3)
371
+
372
+ plt.tight_layout()
373
+ fig_path = output_path / "plots" / f"{subject}_bad_channels_topomap.png"
374
+ fig.savefig(fig_path, dpi=150, bbox_inches='tight')
375
+ plt.close(fig)
376
+ log(f"🖼️ Bad channels topomap saved: {fig_path}", log_file, log_to_file)
377
+
378
+
379
+ def plot_bad_channels_timeseries(raw: mne.io.Raw, bad_channels: List[str],
380
+ subject: str, output_path: Path, log_file: Path,
381
+ log_to_file: bool = True):
382
+ """Plot full-duration time series of bad channels."""
383
+ if not bad_channels:
384
+ return
385
+
386
+ valid_bads = [ch for ch in bad_channels if ch in raw.ch_names]
387
+ if not valid_bads:
388
+ return
389
+
390
+ # Get full data
391
+ data, times = raw[valid_bads, :]
392
+
393
+ # Unique colors
394
+ cmap = plt.get_cmap('viridis')
395
+ colors = cmap(np.linspace(0, 1, len(valid_bads)))
396
+ color_dict = {ch: colors[i] for i, ch in enumerate(valid_bads)}
397
+
398
+ # Plot
399
+ n_ch = len(valid_bads)
400
+ fig_height = min(2.2 * n_ch, 40)
401
+ fig, axes = plt.subplots(n_ch, 1, figsize=(14, fig_height), sharex=True)
402
+ if n_ch == 1:
403
+ axes = [axes]
404
+
405
+ for i, (ch, ax) in enumerate(zip(valid_bads, axes)):
406
+ color = color_dict[ch]
407
+ ax.plot(times, data[i, :] * 1e6, color=color, linewidth=1)
408
+ ax.set_ylabel(f'{ch}\n(µV)', fontsize=10)
409
+ ax.grid(True, alpha=0.3)
410
+
411
+ mean_val = np.mean(data[i, :]) * 1e6
412
+ std_val = np.std(data[i, :]) * 1e6
413
+ ax.text(0.02, 0.98, f'μ={mean_val:.1f}, σ={std_val:.1f}',
414
+ transform=ax.transAxes, va='top', fontsize=9,
415
+ bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
416
+
417
+ axes[-1].set_xlabel('Time (s)', fontsize=12)
418
+ total_duration = times[-1]
419
+ plt.suptitle(f'Bad Channels — Full Time Series ({total_duration:.1f}s)',
420
+ fontsize=14, weight='bold')
421
+ plt.tight_layout(rect=[0, 0, 1, 0.96])
422
+
423
+ fig_path = output_path / "plots" / f"{subject}_bad_channels_timeseries.png"
424
+ fig.savefig(fig_path, dpi=120, bbox_inches='tight')
425
+ plt.close(fig)
426
+ log(f"📈 Bad channels full time series saved: {fig_path}", log_file, log_to_file)
427
+
428
+
429
+ # ============================================================================
430
+ # ICA FUNCTIONS
431
+ # ============================================================================
432
+
433
+ def create_bipolar_channels(raw: mne.io.Raw, log_file: Path,
434
+ log_to_file: bool = True,
435
+ use_artifact_detection_channels: bool = True) -> mne.io.Raw:
436
+ """Create bipolar reference channels for artifact detection."""
437
+ if not use_artifact_detection_channels:
438
+ log("Skipping creation of bipolar channels (artifact detection disabled).", log_file, log_to_file)
439
+ return raw
440
+
441
+ bipolar_specs = [
442
+ (VVEOG, 'vVEOG', 'eog', "blink detection"),
443
+ (HEOG, 'BLINK_H', 'eog', "horizontal eye movement"),
444
+ (ECG, 'ECG_BIO', 'ecg', "cardiac artifact")
445
+ ]
446
+
447
+ for (anode, cathode), name, ch_type, desc in bipolar_specs:
448
+ if anode in raw.ch_names and cathode in raw.ch_names:
449
+ raw = mne.set_bipolar_reference(
450
+ raw, anode=anode, cathode=cathode,
451
+ ch_name=name, drop_refs=False
452
+ ).set_channel_types({name: ch_type})
453
+ log(f"Created {name} ({anode}-{cathode}) for {desc}",
454
+ log_file, log_to_file, detail="debug")
455
+ else:
456
+ log(f"Skipping {name} - channels {anode}, {cathode} not found.", log_file, log_to_file, detail="debug")
457
+
458
+ return raw
459
+
460
+ def detect_artifact_components(ica: mne.preprocessing.ICA, raw: mne.io.Raw,
461
+ log_file: Path, log_to_file: bool = True,
462
+ use_artifact_detection_channels: bool = True) -> Dict[str, List[int]]:
463
+ """Detect artifact components using multiple methods."""
464
+ results = {
465
+ 'blink': [], 'horizontal': [], 'ecg': [],
466
+ 'muscle': [], 'frontal_lf': [], 'line_noise': [],
467
+ 'icalabel': [], 'extreme': []
468
+ }
469
+
470
+ if use_artifact_detection_channels:
471
+ # Blink detection
472
+ if 'vVEOG' in raw.ch_names:
473
+ try:
474
+ idx, _ = ica.find_bads_eog(raw, ch_name='vVEOG', measure='zscore', threshold=3.0)
475
+ results['blink'] = [int(i) for i in idx]
476
+ log(f"Blink: {results['blink']}", log_file, log_to_file, detail="debug")
477
+ except Exception as e:
478
+ log(f"Blink detection failed: {e}", log_file, log_to_file, detail="debug")
479
+ else:
480
+ log("vVEOG channel not found, skipping blink detection.", log_file, log_to_file, detail="debug")
481
+
482
+ # Horizontal eye movement
483
+ if 'BLINK_H' in raw.ch_names:
484
+ try:
485
+ idx, _ = ica.find_bads_eog(raw, ch_name='BLINK_H', measure='zscore', threshold=3.0)
486
+ results['horizontal'] = [int(i) for i in idx]
487
+ log(f"Horizontal: {results['horizontal']}", log_file, log_to_file, detail="debug")
488
+ except Exception as e:
489
+ log(f"Horizontal detection failed: {e}", log_file, log_to_file, detail="debug")
490
+ else:
491
+ log("BLINK_H channel not found, skipping horizontal eye movement detection.", log_file, log_to_file, detail="debug")
492
+
493
+ # ECG detection
494
+ if 'ECG_BIO' in raw.ch_names:
495
+ try:
496
+ idx, _ = ica.find_bads_ecg(raw, ch_name='ECG_BIO', method='correlation',
497
+ measure='zscore', threshold=3.0)
498
+ results['ecg'] = [int(i) for i in idx]
499
+ log(f"ECG: {results['ecg']}", log_file, log_to_file, detail="debug")
500
+ except Exception as e:
501
+ log(f"ECG detection failed: {e}", log_file, log_to_file, detail="debug")
502
+ else:
503
+ log("ECG_BIO channel not found, skipping ECG detection.", log_file, log_to_file, detail="debug")
504
+
505
+ # Muscle artifacts
506
+ for ch in EMG_CHS:
507
+ if ch in raw.ch_names:
508
+ try:
509
+ idx, _ = ica.find_bads_eog(raw, ch_name=ch, measure='zscore',
510
+ l_freq=30, h_freq=100, threshold=3.0)
511
+ results['muscle'].extend([int(i) for i in idx])
512
+ except Exception as e:
513
+ log(f"EMG detection failed for {ch}: {e}", log_file, log_to_file, detail="debug")
514
+ else:
515
+ log(f"EMG channel {ch} not found, skipping.", log_file, log_to_file, detail="debug")
516
+ results['muscle'] = list(set(results['muscle']))
517
+
518
+ # Frontal low-frequency artifacts
519
+ for ch in FRONTAL_CHS:
520
+ if ch in raw.ch_names:
521
+ try:
522
+ idx, _ = ica.find_bads_eog(raw, ch_name=ch, measure='zscore',
523
+ l_freq=1.0, h_freq=10.0, threshold=3.5)
524
+ results['frontal_lf'].extend([int(i) for i in idx])
525
+ except Exception as e:
526
+ log(f"Frontal LF failed for {ch}: {e}", log_file, log_to_file, detail="debug")
527
+ else:
528
+ log(f"Frontal channel {ch} not found, skipping.", log_file, log_to_file, detail="debug")
529
+ results['frontal_lf'] = list(set(results['frontal_lf']))
530
+
531
+ # Line noise detection (also conditional on the flag)
532
+ if use_artifact_detection_channels:
533
+ try:
534
+ sfreq = raw.info['sfreq']
535
+ src_data = ica.get_sources(raw).get_data()
536
+ for i in range(ica.n_components_):
537
+ psd, freqs = mne.time_frequency.psd_array_welch(
538
+ src_data[i], sfreq=sfreq, fmin=1, fmax=100, verbose=False
539
+ )
540
+ line_band = (freqs >= 58) & (freqs <= 62)
541
+ ref_band = (freqs >= 1) & (freqs <= 100)
542
+ flank_band = ((freqs >= 50) & (freqs < 58)) | ((freqs > 62) & (freqs <= 70))
543
+
544
+ ref_mean = psd[ref_band].mean()
545
+ flank_mean = psd[flank_band].mean()
546
+
547
+ if ref_mean > 0 and flank_mean > 0:
548
+ line_ratio = psd[line_band].mean() / ref_mean
549
+ peak_prominence = psd[line_band].max() / flank_mean
550
+
551
+ if line_ratio > 0.8 and peak_prominence > 5.0:
552
+ results['line_noise'].append(i)
553
+ except Exception as e:
554
+ log(f"Line noise detection failed: {e}", log_file, log_to_file, detail="debug")
555
+
556
+ return results
557
+
558
+ def run_icalabel(ica: mne.preprocessing.ICA, raw: mne.io.Raw,
559
+ excluded: List[int], log_file: Path,
560
+ log_to_file: bool = True) -> Tuple[List[int], Dict]:
561
+ """Run ICLabel classification on EEG channels only."""
562
+ try:
563
+ # ✅ CRITICAL FIX: Use ONLY EEG channels
564
+ raw_eeg = raw.copy().pick_types(eeg=True, exclude=[])
565
+
566
+ labels_dict = label_components(raw_eeg, ica, method="iclabel")
567
+ labels = labels_dict["labels"]
568
+ probas = labels_dict["y_pred_proba"]
569
+
570
+ new_excluded = []
571
+ label_info = {}
572
+ for i, (label, prob) in enumerate(zip(labels, probas)):
573
+ lbl = label.lower().strip()
574
+ if lbl in ICALABEL_THRESHOLDS and prob > ICALABEL_THRESHOLDS[lbl]:
575
+ if i not in excluded:
576
+ new_excluded.append(i)
577
+ label_info[i] = (label, prob)
578
+
579
+ if new_excluded:
580
+ info_strs = [f"C{i}({label}: {prob.max():.2f})"
581
+ for i, (label, prob) in label_info.items()]
582
+ log(f"ICLabel added {len(new_excluded)}: {', '.join(info_strs)}",
583
+ log_file, log_to_file)
584
+ return new_excluded, label_info
585
+ except Exception as e:
586
+ log(f"ICLabel failed: {e}", log_file, log_to_file)
587
+ # Optional: uncomment to debug
588
+ # import traceback
589
+ # log(f"Full traceback:\n{traceback.format_exc()}", log_file, log_to_file, detail="debug")
590
+ return [], {}
591
+
592
+ def detect_extreme_components(ica: mne.preprocessing.ICA, raw: mne.io.Raw,
593
+ excluded: List[int], log_file: Path,
594
+ log_to_file: bool = True) -> List[int]:
595
+ """Detect components with extreme signal characteristics."""
596
+ src_data = ica.get_sources(raw).get_data()
597
+ extreme = []
598
+
599
+ for i in range(ica.n_components_):
600
+ if i in excluded:
601
+ continue
602
+
603
+ x = src_data[i]
604
+ var = np.var(x)
605
+ kurt = kurtosis(x)
606
+ ptp = np.ptp(x)
607
+
608
+ if var < 1e-14 or kurt > 10000 or ptp > 10000:
609
+ extreme.append(i)
610
+ log(f"Excluded C{i} via signal metrics", log_file, log_to_file, detail="debug")
611
+
612
+ return extreme
613
+
614
+ def log_ica_summary(ica: mne.preprocessing.ICA, results: Dict[str, List[int]],
615
+ icalabel_info: Dict, log_file: Path, log_to_file: bool = True):
616
+ """Log ICA artifact rejection summary."""
617
+ log("\n" + "━" * 60, log_file, log_to_file)
618
+ log("🧩 ICA ARTIFACT REJECTION SUMMARY", log_file, log_to_file)
619
+ log("━" * 60, log_file, log_to_file)
620
+ log(f"{'Total components':<18} {ica.n_components_}", log_file, log_to_file)
621
+ log(f"{'Excluded':<18} {len(ica.exclude)}", log_file, log_to_file)
622
+ log("", log_file, log_to_file)
623
+
624
+ labels = {
625
+ 'blink': 'Blink',
626
+ 'horizontal': 'Horizontal eye',
627
+ 'ecg': 'ECG',
628
+ 'muscle': 'Muscle',
629
+ 'frontal_lf': 'Frontal LF',
630
+ 'line_noise': 'Line noise',
631
+ 'extreme': 'Signal metrics'
632
+ }
633
+
634
+ for key, label in labels.items():
635
+ log(f"{label:<18} {sorted(results[key])}", log_file, log_to_file)
636
+
637
+ if icalabel_info:
638
+ info_str = ", ".join([f"C{i}({lbl}: {prob.max():.2f})"
639
+ for i, (lbl, prob) in icalabel_info.items()])
640
+ log(f"{'ICLabel':<18} {info_str}", log_file, log_to_file)
641
+ else:
642
+ log(f"{'ICLabel':<18} []", log_file, log_to_file)
643
+
644
+ log(f"\n🔧 Final exclude list: {sorted(ica.exclude)}", log_file, log_to_file)
645
+ log("━" * 60, log_file, log_to_file)
646
+
647
+ def save_ica_plots(ica: mne.preprocessing.ICA, subject: str, output_path: Path,
648
+ log_file: Path, log_to_file: bool = True, cmap: str = 'plasma'):
649
+ """Save ICA component plots."""
650
+ try:
651
+ fig_components = ica.plot_components(cmap=cmap, show=False)
652
+ if not isinstance(fig_components, list):
653
+ fig_components = [fig_components]
654
+
655
+ for i, fig in enumerate(fig_components):
656
+ fig_path = output_path / "plots" / f"{subject}_ica_components_page{i}.png"
657
+ fig.savefig(fig_path, dpi=150, bbox_inches='tight')
658
+ plt.close(fig)
659
+
660
+ log(f"🖼️ Saved {len(fig_components)} ICA component page(s)", log_file, log_to_file)
661
+ except Exception as e:
662
+ log(f"⚠️ Failed to save ICA plots: {e}", log_file, log_to_file)
663
+
664
+ def run_automatic_ica_cleaning(eeg_data: mne.io.Raw, subject: str, output_path: Path,
665
+ plot: bool, log_file: Path, log_to_file: bool = True,
666
+ n_components: float = 0.99,
667
+ random_state: int = 99,
668
+ use_artifact_detection_channels: bool = True) -> Tuple[mne.io.Raw, Dict]:
669
+ """Run complete ICA artifact detection and removal pipeline."""
670
+ log("Running automatic ICA cleaning...", log_file, log_to_file)
671
+
672
+ # Work on a copy of the full EEG data (all channels are EEG)
673
+ raw_for_ica = eeg_data.copy()
674
+
675
+ # Create bipolar reference channels (vVEOG, ECG_BIO, etc.) if enabled
676
+ raw_for_ica = create_bipolar_channels(
677
+ raw_for_ica, log_file, log_to_file,
678
+ use_artifact_detection_channels=use_artifact_detection_channels
679
+ )
680
+
681
+ # Fit ICA on the full dataset (MNE automatically excludes channels in info['bads'])
682
+ log("Fitting ICA with Extended Infomax...", log_file, log_to_file, detail="debug")
683
+ ica = mne.preprocessing.ICA(
684
+ n_components=n_components,
685
+ random_state=random_state,
686
+ method='picard',
687
+ fit_params=dict(ortho=False, extended=True),
688
+ max_iter='auto'
689
+ )
690
+ ica.fit(raw_for_ica)
691
+ log(f"ICA fitted with {ica.n_components_} components", log_file, log_to_file, detail="debug")
692
+
693
+ # Detect artifacts
694
+ detection_results = detect_artifact_components(
695
+ ica, raw_for_ica, log_file, log_to_file,
696
+ use_artifact_detection_channels=use_artifact_detection_channels
697
+ )
698
+
699
+ # Build exclude list
700
+ ica.exclude = []
701
+ if use_artifact_detection_channels:
702
+ for key in ['blink', 'horizontal', 'ecg', 'muscle', 'frontal_lf', 'line_noise']:
703
+ ica.exclude.extend(detection_results[key])
704
+
705
+ # ICLabel and extreme
706
+ icalabel_excluded, icalabel_info = run_icalabel(ica, raw_for_ica, ica.exclude, log_file, log_to_file)
707
+ ica.exclude.extend(icalabel_excluded)
708
+ detection_results['icalabel'] = icalabel_excluded
709
+
710
+ extreme = detect_extreme_components(ica, raw_for_ica, ica.exclude, log_file, log_to_file)
711
+ ica.exclude.extend(extreme)
712
+ detection_results['extreme'] = extreme
713
+
714
+ # ✅ CRITICAL FIX: Remove duplicates (e.g., component 0 flagged by blink + frontal_lf)
715
+ ica.exclude = sorted(set(ica.exclude))
716
+
717
+ # Apply ICA to the ORIGINAL full data (preserves all 281 channels)
718
+ cleaned_data = ica.apply(eeg_data.copy())
719
+
720
+ # Log and save
721
+ log_ica_summary(ica, detection_results, icalabel_info, log_file, log_to_file)
722
+ if plot and ica.exclude:
723
+ save_ica_plots(ica, subject, output_path, log_file, log_to_file)
724
+
725
+ ica_object = {
726
+ 'ica_model': ica,
727
+ 'original_data': eeg_data,
728
+ 'filtered_data': raw_for_ica,
729
+ 'auto_excluded': ica.exclude.copy(),
730
+ 'detection_results': {k: sorted(v) for k, v in detection_results.items()},
731
+ 'icalabel_info': icalabel_info,
732
+ 'parameters': {
733
+ 'n_components': n_components,
734
+ 'random_state': random_state,
735
+ 'use_artifact_detection_channels': use_artifact_detection_channels
736
+ }
737
+ }
738
+
739
+ log("✅ ICA cleaning complete", log_file, log_to_file)
740
+ return cleaned_data, ica_object
741
+
742
+ EXPECTED_EEG_COUNT = 281 # E1-E280 + Cz
743
+
744
+ def verify_channel_count(raw: mne.io.Raw, log_file: Path, log_to_file: bool = True):
745
+ eeg_chs = raw.copy().pick_types(eeg=True).ch_names
746
+ count = len(eeg_chs)
747
+ if count != EXPECTED_EEG_COUNT:
748
+ log(f"⚠️ WARNING: Expected {EXPECTED_EEG_COUNT} EEG channels, found {count}", log_file, log_to_file)
749
+ log(f"Missing: {sorted(EXPECTED_EEG_CHANNELS - set(eeg_chs))}", log_file, log_to_file)
750
+ else:
751
+ log("✅ EEG channel count verified: 281", log_file, log_to_file)
752
+
753
+ def plot_psd_comparison(raw_filtered: mne.io.Raw, cleaned_data: mne.io.Raw, subject: str,
754
+ output_path: Path, log_file: Path, log_to_file: bool = True):
755
+ """Plot PSD comparison before and after ICA."""
756
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10), sharex=True)
757
+
758
+ raw_filtered.compute_psd(fmax=120, picks='eeg', exclude='bads').plot(axes=ax1, show=False)
759
+ ax1.set_title('Before ICA', fontsize=12)
760
+ ax1.set_xlabel('')
761
+
762
+ cleaned_data.compute_psd(fmax=120, picks='eeg', exclude='bads').plot(axes=ax2, show=False)
763
+ ax2.set_title('After ICA', fontsize=12)
764
+
765
+ fig.suptitle('Power Spectral Density: Before vs. After ICA', fontsize=16)
766
+ plt.subplots_adjust(top=0.94, hspace=0.3)
767
+
768
+ fig_path = output_path / "plots" / f"{subject}_psd_comparison.png"
769
+ fig.savefig(fig_path, dpi=150, bbox_inches='tight')
770
+ plt.close(fig)
771
+ log(f"📊 PSD comparison saved to: {fig_path}", log_file, log_to_file)
772
+
773
+
774
+ def save_data(cleaned_data: mne.io.Raw, subject: str, output_path: Path,
775
+ log_file: Path, log_to_file: bool = True):
776
+ """Save cleaned data to FIF file."""
777
+ sub_id = subject.replace('sub-', '')
778
+ fname = f"sub-{sub_id}_eeg_ica_cleaned_raw.fif"
779
+ full_path = output_path / fname
780
+ cleaned_data.save(str(full_path), overwrite=True)
781
+ log(f"Cleaned data saved to: {full_path}", log_file, log_to_file)
782
+
783
+
784
+ def interpolate_bads(raw: mne.io.Raw, log_file: Path, log_to_file: bool = True) -> mne.io.Raw:
785
+ """Interpolate bad channels if any exist."""
786
+ bads = raw.info['bads']
787
+ if bads:
788
+ log(f"Interpolating bad channels: {bads}", log_file, log_to_file)
789
+ raw.interpolate_bads(reset_bads=True)
790
+ else:
791
+ log("No bad channels to interpolate", log_file, log_to_file)
792
+ return raw
793
+
794
+
795
+
796
+ # ============================================================================
797
+ # MAIN PIPELINE FUNCTION
798
+ # ============================================================================
799
+
800
+ def run_preprocessing_pipeline(
801
+ subject: str,
802
+ input_path: str,
803
+ gpsc_file: str,
804
+ base_output_path: str,
805
+ plot: bool = True,
806
+ random_state: int = 99,
807
+ log_to_file: bool = True,
808
+ apply_highpass: bool = True,
809
+ apply_lowpass: bool = True,
810
+ apply_notch: bool = True,
811
+ l_freq: float = 1.0,
812
+ h_freq: float = 100.0,
813
+ line_freq: float = 60.0,
814
+ input_format: str = "mff",
815
+ append_subject_to_output: bool = True,
816
+ use_artifact_detection_channels: bool = True,
817
+ pre_ica_mad_threshold: float = 5.0,
818
+ post_ica_mad_threshold: float = 5.0,
819
+ interpolate_before_ica: bool = False # ← Now correctly placed with comma above
820
+ ):
821
+ """
822
+ Execute the complete EEG preprocessing pipeline using standalone functions.
823
+
824
+ This function replicates the functionality of the EEGICAProcessor.run() method.
825
+
826
+ Parameters
827
+ ----------
828
+ subject : str
829
+ Subject identifier (e.g., 'sub-001').
830
+ input_path : str
831
+ Path to raw EEG file (.mff or .fif).
832
+ gpsc_file : str
833
+ Path to .gpsc channel location file.
834
+ base_output_path : str
835
+ Base directory for saving outputs.
836
+ plot : bool, optional
837
+ Whether to generate diagnostic plots (default: True).
838
+ random_state : int, optional
839
+ Random seed for ICA reproducibility (default: 99).
840
+ log_to_file : bool, optional
841
+ Whether to write logs to file (default: True).
842
+ apply_highpass, apply_lowpass, apply_notch : bool, optional
843
+ Whether to apply respective filters (default: True).
844
+ l_freq, h_freq : float, optional
845
+ Highpass and lowpass cutoff frequencies (default: 1.0, 100.0 Hz).
846
+ line_freq : float, optional
847
+ Line noise frequency for notch filtering (default: 60.0 Hz).
848
+ input_format : str, optional
849
+ Format of input data: 'mff' or 'fif' (default: 'mff').
850
+ append_subject_to_output : bool, optional
851
+ Whether to create subject-specific subfolder (default: True).
852
+ use_artifact_detection_channels : bool, optional
853
+ Whether to create bipolar EOG/ECG channels (default: True).
854
+ pre_ica_mad_threshold : float, optional
855
+ MAD threshold for bad channel detection before ICA (default: 5.0).
856
+ post_ica_mad_threshold : float, optional
857
+ MAD threshold after ICA (default: 5.0).
858
+ interpolate_before_ica : bool, optional
859
+ If False, skip interpolation before ICA; bad channels are excluded from CAR/ICA
860
+ but remain in the data array. Final output always interpolates all bads post-ICA.
861
+ (default: True)
862
+ """
863
+ # Setup output directories
864
+ if append_subject_to_output:
865
+ output_path = Path(base_output_path) / subject
866
+ else:
867
+ output_path = Path(base_output_path)
868
+ output_path.mkdir(parents=True, exist_ok=True)
869
+ (output_path / "plots").mkdir(exist_ok=True)
870
+
871
+ # Setup logging
872
+ log_file = setup_logging(subject, output_path, log_to_file)
873
+ log(f"Pre-ICA interpolation: {'ENABLED' if interpolate_before_ica else 'DISABLED'}", log_file, log_to_file)
874
+
875
+ # --- Pipeline Execution ---
876
+ log("🔄 Starting preprocessing...", log_file, log_to_file)
877
+
878
+ # 1. Load and prepare data
879
+ log("🔧 Loading data...", log_file, log_to_file)
880
+ raw = load_raw_data(Path(input_path), input_format, log_file, log_to_file)
881
+ raw = apply_channel_renaming(raw, log_file, log_to_file)
882
+ raw = apply_montage(raw, Path(gpsc_file), log_file, log_to_file)
883
+ log("✅ Loading data complete", log_file, log_to_file)
884
+
885
+ # 2. Filter data
886
+ log("🔧 Filtering...", log_file, log_to_file)
887
+ raw_filtered = filter_data(
888
+ raw, apply_highpass, apply_lowpass, apply_notch,
889
+ l_freq, h_freq, line_freq, log_file, log_to_file
890
+ )
891
+ log("✅ Filtering complete", log_file, log_to_file)
892
+
893
+ # 3. Detect bad channels (before ICA, WITH protection) — HIGHER threshold
894
+ log(f"🔧 Detecting bad channels (before ICA, with protection, threshold={pre_ica_mad_threshold})...", log_file, log_to_file)
895
+ raw_filtered = detect_bad_channels(
896
+ raw_filtered, subject, output_path, plot, log_file, log_to_file,
897
+ protected_channels=PROTECTED_CHANNELS,
898
+ mad_threshold=pre_ica_mad_threshold
899
+ )
900
+ log("✅ Detecting bad channels (before ICA) complete", log_file, log_to_file)
901
+
902
+ # 4. Apply CAR (Common Average Reference)
903
+ log("🔧 Applying CAR...", log_file, log_to_file)
904
+ raw_filtered = raw_filtered.set_eeg_reference('average', verbose=False)
905
+ log("✅ Applying CAR complete", log_file, log_to_file)
906
+
907
+ # 5. Interpolate bad channels BEFORE ICA — only if requested
908
+ if interpolate_before_ica:
909
+ log("🔧 Interpolating bad channels (before ICA)...", log_file, log_to_file)
910
+ raw_filtered = interpolate_bads(raw_filtered, log_file, log_to_file)
911
+ log("✅ Interpolating bad channels (before ICA) complete", log_file, log_to_file)
912
+ else:
913
+ log("⚠️ Skipping pre-ICA interpolation — bad channels will be excluded from CAR and ICA fitting", log_file, log_to_file)
914
+
915
+ # 6. Run ICA cleaning
916
+ log("🔧 Running ICA...", log_file, log_to_file)
917
+ cleaned_data, ica_obj = run_automatic_ica_cleaning(
918
+ raw_filtered, subject, output_path, plot, log_file, log_to_file,
919
+ n_components=0.99,
920
+ random_state=random_state,
921
+ use_artifact_detection_channels=use_artifact_detection_channels
922
+ )
923
+ log("✅ Running ICA complete", log_file, log_to_file)
924
+
925
+ # 7. Detect bad channels AFTER ICA — more sensitive, no protection
926
+ log(f"🔧 Detecting bad channels (after ICA, NO protection, threshold={post_ica_mad_threshold})...", log_file, log_to_file)
927
+ post_ica_subject_name = f"{subject}_post_ica"
928
+ cleaned_data_post_ica = detect_bad_channels(
929
+ cleaned_data, post_ica_subject_name, output_path, plot, log_file, log_to_file,
930
+ protected_channels=set(), # No protection
931
+ mad_threshold=post_ica_mad_threshold
932
+ )
933
+ log("✅ Detecting bad channels (after ICA) complete", log_file, log_to_file)
934
+
935
+ # 8. Interpolate ALL bad channels AFTER ICA (to restore full 280-channel set)
936
+ log("🔧 Interpolating bad channels (after ICA)...", log_file, log_to_file)
937
+ cleaned_data_interpolated = interpolate_bads(cleaned_data_post_ica, log_file, log_to_file)
938
+ log("✅ Interpolating bad channels (after ICA) complete", log_file, log_to_file)
939
+
940
+ # 9. Plot PSD comparison
941
+ log("🔧 Plotting PSD (before vs. after ICA)...", log_file, log_to_file)
942
+ plot_psd_comparison(raw_filtered, cleaned_data_interpolated, subject, output_path, log_file, log_to_file)
943
+ log("✅ Plotting PSD complete", log_file, log_to_file)
944
+
945
+ # 10. Save final cleaned data
946
+ log("🔧 Saving final data...", log_file, log_to_file)
947
+ save_data(cleaned_data_interpolated, subject, output_path, log_file, log_to_file)
948
+ log("✅ Saving data complete", log_file, log_to_file)
949
+ verify_channel_count(cleaned_data_interpolated, log_file, log_to_file)
950
+
951
+ log("✅ FULL PREPROCESSING COMPLETE\n", log_file, log_to_file)
952
+
953
+
954
+ # ============================================================================
955
+ # EXAMPLE USAGE
956
+ # ============================================================================
957
+
958
+ if __name__ == "__main__":
959
+ # Example usage - replace these paths and parameters with your actual data
960
+ SUBJECT_ID = "sub-001"
961
+ INPUT_PATH = "/path/to/your/data/sub-001.mff" # or .fif
962
+ GPSC_FILE = "/path/to/your/channels.gpsc"
963
+ BASE_OUTPUT_PATH = "/path/to/output/directory"
964
+ INPUT_FORMAT = "mff" # or "fif"
965
+
966
+ # Run the full pipeline — with pre-ICA interpolation DISABLED
967
+ run_preprocessing_pipeline(
968
+ subject=SUBJECT_ID,
969
+ input_path=INPUT_PATH,
970
+ gpsc_file=GPSC_FILE,
971
+ base_output_path=BASE_OUTPUT_PATH,
972
+ input_format=INPUT_FORMAT,
973
+ use_artifact_detection_channels=True,
974
+ pre_ica_mad_threshold=5.0,
975
+ post_ica_mad_threshold=5.0,
976
+ interpolate_before_ica=False
977
+ )
978
+ print(f"Pipeline completed for subject {SUBJECT_ID}. Check output in {BASE_OUTPUT_PATH}/{SUBJECT_ID}")