MatteoFasulo commited on
Commit
e56e6bf
·
verified ·
1 Parent(s): b4c56ea

Delete scripts/HMC.py

Browse files
Files changed (1) hide show
  1. scripts/HMC.py +0 -370
scripts/HMC.py DELETED
@@ -1,370 +0,0 @@
1
- import os
2
- from typing import Optional, Tuple
3
-
4
- import h5py
5
- import mne
6
- import numpy as np
7
- from joblib import Parallel, delayed
8
- from mne.io import read_raw_edf
9
-
10
-
11
- def process_single_recording(
12
- raw_fn: str,
13
- scoring_fn: str,
14
- data_path: str,
15
- channel: str,
16
- start_at: int,
17
- duration_sec: int,
18
- l_freq: float,
19
- h_freq: float,
20
- sfreq: int,
21
- mains: int,
22
- window_size: int,
23
- stride: int,
24
- mapping: dict,
25
- verbose: bool = False,
26
- ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], str]:
27
- """
28
- Process a single recording file and return windows and labels.
29
-
30
- Returns:
31
- Tuple of (windows, labels, filename) or (None, None, filename) if processing fails
32
- """
33
- try:
34
- if verbose:
35
- print(f"Processing: {raw_fn}")
36
-
37
- full_path_raw = os.path.join(data_path, raw_fn)
38
- full_path_score = os.path.join(data_path, scoring_fn)
39
-
40
- # Load and preprocess
41
- raw = read_raw_edf(full_path_raw, preload=True, verbose=False)
42
- annotation = mne.read_annotations(full_path_score)
43
- raw.set_annotations(annotation, emit_warning=False)
44
-
45
- # Crop
46
- end_at = start_at + duration_sec
47
- if end_at > raw.times[-1]:
48
- end_at = raw.times[-1] - (raw.times[-1] % 30.0)
49
- raw = raw.crop(tmin=start_at, tmax=end_at)
50
-
51
- # Pick channel
52
- if channel not in raw.ch_names:
53
- print(f"Warning: Channel {channel} not found in {raw_fn}, skipping")
54
- return None, None, raw_fn
55
- raw = raw.pick([channel])
56
-
57
- # Filter (bandpass) with safe h_freq clipping to Nyquist
58
- nyq = raw.info["sfreq"] / 2.0
59
- h_freq_adj = h_freq if h_freq is not None and h_freq < nyq else None
60
- raw = raw.filter(
61
- l_freq=l_freq, h_freq=h_freq_adj, fir_design="firwin", verbose=False
62
- )
63
-
64
- # Notch at mains harmonics (e.g. 50,100,150 or 60,120,180) but only those < Nyquist
65
- mains_freqs = [mains * i for i in (1, 2, 3)]
66
- mains_freqs = [f for f in mains_freqs if f < nyq]
67
- if len(mains_freqs) > 0:
68
- # use raw.notch_filter which handles multiple notch freqs
69
- raw.notch_filter(freqs=mains_freqs, picks=[channel], verbose=False)
70
-
71
- # Resample to target sampling rate (upsample or downsample)
72
- if raw.info["sfreq"] != sfreq:
73
- raw = raw.resample(sfreq, npad="auto")
74
-
75
- # Create 30s epochs
76
- events, event_id = mne.events_from_annotations(raw, chunk_duration=30.0)
77
- tmax = 30.0 - 1.0 / raw.info["sfreq"]
78
- epochs = mne.Epochs(
79
- raw=raw,
80
- events=events,
81
- event_id=event_id,
82
- tmin=0.0,
83
- tmax=tmax,
84
- baseline=None,
85
- verbose=False,
86
- )
87
-
88
- epochs_data = epochs.get_data() # (n_epochs, 1, n_times)
89
- labels = []
90
- for ann in epochs.get_annotations_per_epoch():
91
- labels.append(mapping[str(ann[0][2])])
92
-
93
- n_epochs, _, n_times = epochs_data.shape
94
- if n_times < window_size:
95
- print(f"Warning: Not enough samples in {raw_fn}, skipping")
96
- return None, None, raw_fn
97
-
98
- # Sliding window
99
- windows = []
100
- labels_win = []
101
- for i in range(n_epochs):
102
- for start in range(0, n_times - window_size + 1, stride):
103
- windows.append(epochs_data[i, 0, start : start + window_size])
104
- labels_win.append(labels[i])
105
-
106
- if len(windows) > 0:
107
- windows = np.stack(windows) # (n_windows, window_size)
108
- windows = windows[:, np.newaxis, :] # (n_windows, 1, window_size)
109
-
110
- if verbose:
111
- print(f" {raw_fn}: Generated {len(windows)} windows")
112
-
113
- return (
114
- windows.astype(np.float32),
115
- np.array(labels_win, dtype=np.int32),
116
- raw_fn,
117
- )
118
- else:
119
- return None, None, raw_fn
120
-
121
- except Exception as e:
122
- print(f"Error processing {raw_fn}: {e}")
123
- return None, None, raw_fn
124
-
125
-
126
- def convert_hmc_to_h5(
127
- data_path: str,
128
- save_path: str,
129
- channel: str = "EMG chin",
130
- start_at: int = 15 * 60,
131
- duration_sec: int = 6 * 60 * 60,
132
- l_freq: float = 5.0,
133
- h_freq: float = 250.0,
134
- sfreq: int = 100,
135
- mains: int = 50,
136
- window_size: int = 1000,
137
- stride: int = 1000,
138
- n_jobs: int = -1,
139
- verbose: bool = True,
140
- ):
141
- """
142
- Convert HMC EMG dataset to HDF5 format compatible with EMGDataset.
143
- Uses joblib for parallel processing of individual recordings.
144
-
145
- Args:
146
- data_path: Root directory containing EDF files
147
- save_path: Directory to save HDF5 files
148
- channel: EMG channel name to extract
149
- start_at: Start time in seconds
150
- duration_sec: Duration to extract in seconds
151
- l_freq: Low-pass filter frequency
152
- h_freq: High-pass filter frequency
153
- sfreq: Target sampling frequency
154
- window_size: Window size for segmentation
155
- stride: Stride for sliding window
156
- n_jobs: Number of parallel jobs (-1 for all cores, 1 for sequential)
157
- verbose: Print progress
158
- """
159
-
160
- mapping = {
161
- "Sleep stage W": 0,
162
- "Sleep stage N1": 1,
163
- "Sleep stage N2": 2,
164
- "Sleep stage N3": 3,
165
- "Sleep stage R": 4,
166
- "Lights off@@EEG F4-A1": 0,
167
- }
168
-
169
- os.makedirs(save_path, exist_ok=True)
170
-
171
- # Discover record file pairs
172
- files = os.listdir(data_path)
173
- raw_files = [
174
- f
175
- for f in files
176
- if f.lower().endswith(".edf") and "sleepscoring" not in f.lower()
177
- ]
178
-
179
- records = []
180
- for raw_fn in raw_files:
181
- base = os.path.splitext(raw_fn)[0]
182
- scoring_fn = base + "_sleepscoring.edf"
183
- if scoring_fn in files:
184
- records.append((raw_fn, scoring_fn))
185
- elif verbose:
186
- print(f"Warning: scoring file missing for {raw_fn}")
187
-
188
- if len(records) == 0:
189
- print("No valid record pairs found!")
190
- return
191
-
192
- print(f"Found {len(records)} recording pairs")
193
- print(f"Using {n_jobs} parallel jobs" if n_jobs != 1 else "Running sequentially")
194
-
195
- # Initialize data containers for each split
196
- datasets = {
197
- "train": {"data": [], "label": []},
198
- "val": {"data": [], "label": []},
199
- "test": {"data": [], "label": []},
200
- }
201
-
202
- # Create mapping from filename to split
203
- def get_split(filename):
204
- # Extract subject number from filename
205
- # Example: "SN001.edf" -> 1
206
- import re
207
-
208
- match = re.search(r"(\d+)", filename)
209
- # Version 1.1: recordings SN014, SN064, and SN135 were removed after it was detected that these recordings contained erroneous (and unfixable) signal data.
210
- train_subjects = range(1, 101) # Subjects 1-100 for training
211
- val_subjects = range(101, 127) # Subjects 101-126 for validation
212
- test_subjects = range(127, 155) # Subjects 127-154 for testing
213
- if match:
214
- subj_num = int(match.group(1))
215
- if subj_num in train_subjects:
216
- return "train"
217
- elif subj_num in val_subjects:
218
- return "val"
219
- elif subj_num in test_subjects:
220
- return "test"
221
- else:
222
- return "train" # default to train
223
- return "train" # default
224
-
225
- # Process recordings in parallel
226
- print(f"\nProcessing {len(records)} recordings...")
227
- results = Parallel(n_jobs=n_jobs, verbose=10 if verbose else 0)(
228
- delayed(process_single_recording)(
229
- raw_fn=raw_fn,
230
- scoring_fn=scoring_fn,
231
- data_path=data_path,
232
- channel=channel,
233
- start_at=start_at,
234
- duration_sec=duration_sec,
235
- l_freq=l_freq,
236
- h_freq=h_freq,
237
- sfreq=sfreq,
238
- mains=mains,
239
- window_size=window_size,
240
- stride=stride,
241
- mapping=mapping,
242
- verbose=False,
243
- )
244
- for raw_fn, scoring_fn in records
245
- )
246
-
247
- # Collect results into splits
248
- processed_count = 0
249
- failed_count = 0
250
-
251
- for windows, labels, raw_fn in results:
252
- if windows is not None and labels is not None:
253
- # Determine which split this recording belongs to
254
- split_key = get_split(raw_fn)
255
-
256
- datasets[split_key]["data"].append(windows)
257
- datasets[split_key]["label"].append(labels)
258
- processed_count += 1
259
-
260
- if verbose:
261
- print(f"✓ {raw_fn}: {len(windows)} windows -> {split_key}")
262
- else:
263
- failed_count += 1
264
- if verbose:
265
- print(f"✗ {raw_fn}: Failed")
266
-
267
- print(f"\nProcessing complete: {processed_count} successful, {failed_count} failed")
268
-
269
- # Concatenate and save
270
- for split_name, split_data in datasets.items():
271
- if len(split_data["data"]) == 0:
272
- print(f"Warning: No data for {split_name} split")
273
- continue
274
-
275
- print(f"\nPreparing {split_name} split...")
276
- X = np.concatenate(split_data["data"], axis=0) # (N, 1, window_size)
277
- y = np.concatenate(split_data["label"], axis=0) # (N,)
278
-
279
- h5_path = os.path.join(save_path, f"{split_name}.h5")
280
- with h5py.File(h5_path, "w") as f:
281
- f.create_dataset(
282
- "data", data=X, dtype=np.float32, compression="gzip", compression_opts=4
283
- )
284
- f.create_dataset(
285
- "label", data=y, dtype=np.int32, compression="gzip", compression_opts=4
286
- )
287
-
288
- uniq, cnt = np.unique(y, return_counts=True)
289
- label_dist = dict(zip(uniq.tolist(), cnt.tolist()))
290
-
291
- print(f"{split_name.upper()}:")
292
- print(f" Shape: X={X.shape}, y={y.shape}")
293
- print(f" Label distribution: {label_dist}")
294
- print(f" Saved to: {h5_path}")
295
- print(f" File size: {os.path.getsize(h5_path) / (1024**2):.2f} MB")
296
-
297
-
298
- if __name__ == "__main__":
299
- import argparse
300
-
301
- parser = argparse.ArgumentParser(
302
- description="Convert HMC EMG dataset to HDF5 format with parallel processing"
303
- )
304
- parser.add_argument(
305
- "--data_dir",
306
- type=str,
307
- required=True,
308
- help="Root directory containing HMC EDF files",
309
- )
310
- parser.add_argument(
311
- "--save_dir", type=str, required=True, help="Directory to save HDF5 files"
312
- )
313
- parser.add_argument(
314
- "--channel", type=str, default="EMG chin", help="EMG channel name"
315
- )
316
- parser.add_argument(
317
- "--start_at", type=int, default=900, help="Start time in seconds"
318
- )
319
- parser.add_argument(
320
- "--duration_sec", type=int, default=21600, help="Duration in seconds"
321
- )
322
- parser.add_argument(
323
- "--l_freq", type=float, default=5.0, help="Low-pass filter frequency"
324
- )
325
- parser.add_argument(
326
- "--h_freq", type=float, default=200.0, help="High-pass filter frequency"
327
- )
328
- parser.add_argument(
329
- "--sfreq", type=int, default=500, help="Target sampling frequency (Hz)"
330
- )
331
- parser.add_argument(
332
- "--mains",
333
- type=int,
334
- default=50,
335
- choices=[50, 60],
336
- help="Mains frequency for notch (50 or 60 Hz)",
337
- )
338
- parser.add_argument(
339
- "--window_size", type=int, default=1000, help="Window size for segmentation"
340
- )
341
- parser.add_argument(
342
- "--stride", type=int, default=1000, help="Stride for sliding window"
343
- )
344
- parser.add_argument(
345
- "--n_jobs",
346
- type=int,
347
- default=-1,
348
- help="Number of parallel jobs (-1 for all cores)",
349
- )
350
- parser.add_argument(
351
- "--verbose", action="store_true", help="Print detailed progress"
352
- )
353
-
354
- args = parser.parse_args()
355
-
356
- convert_hmc_to_h5(
357
- data_path=args.data_dir,
358
- save_path=args.save_dir,
359
- channel=args.channel,
360
- start_at=args.start_at,
361
- duration_sec=args.duration_sec,
362
- l_freq=args.l_freq,
363
- h_freq=args.h_freq,
364
- mains=args.mains,
365
- sfreq=args.sfreq,
366
- window_size=args.window_size,
367
- stride=args.stride,
368
- n_jobs=args.n_jobs,
369
- verbose=args.verbose,
370
- )