MatteoFasulo commited on
Commit
ca8e271
Β·
verified Β·
1 Parent(s): a8ec78f

Upload 9 files

Browse files
Files changed (9) hide show
  1. scripts/HMC.py +370 -0
  2. scripts/README.md +129 -0
  3. scripts/db5.py +213 -0
  4. scripts/db6.py +186 -0
  5. scripts/db7.py +184 -0
  6. scripts/db8.py +203 -0
  7. scripts/emg2pose.py +149 -0
  8. scripts/epn.py +194 -0
  9. scripts/uci.py +229 -0
scripts/HMC.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Tuple
3
+
4
+ import h5py
5
+ import mne
6
+ import numpy as np
7
+ from joblib import Parallel, delayed
8
+ from mne.io import read_raw_edf
9
+
10
+
11
+ def process_single_recording(
12
+ raw_fn: str,
13
+ scoring_fn: str,
14
+ data_path: str,
15
+ channel: str,
16
+ start_at: int,
17
+ duration_sec: int,
18
+ l_freq: float,
19
+ h_freq: float,
20
+ sfreq: int,
21
+ mains: int,
22
+ window_size: int,
23
+ stride: int,
24
+ mapping: dict,
25
+ verbose: bool = False,
26
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], str]:
27
+ """
28
+ Process a single recording file and return windows and labels.
29
+
30
+ Returns:
31
+ Tuple of (windows, labels, filename) or (None, None, filename) if processing fails
32
+ """
33
+ try:
34
+ if verbose:
35
+ print(f"Processing: {raw_fn}")
36
+
37
+ full_path_raw = os.path.join(data_path, raw_fn)
38
+ full_path_score = os.path.join(data_path, scoring_fn)
39
+
40
+ # Load and preprocess
41
+ raw = read_raw_edf(full_path_raw, preload=True, verbose=False)
42
+ annotation = mne.read_annotations(full_path_score)
43
+ raw.set_annotations(annotation, emit_warning=False)
44
+
45
+ # Crop
46
+ end_at = start_at + duration_sec
47
+ if end_at > raw.times[-1]:
48
+ end_at = raw.times[-1] - (raw.times[-1] % 30.0)
49
+ raw = raw.crop(tmin=start_at, tmax=end_at)
50
+
51
+ # Pick channel
52
+ if channel not in raw.ch_names:
53
+ print(f"Warning: Channel {channel} not found in {raw_fn}, skipping")
54
+ return None, None, raw_fn
55
+ raw = raw.pick([channel])
56
+
57
+ # Filter (bandpass) with safe h_freq clipping to Nyquist
58
+ nyq = raw.info["sfreq"] / 2.0
59
+ h_freq_adj = h_freq if h_freq is not None and h_freq < nyq else None
60
+ raw = raw.filter(
61
+ l_freq=l_freq, h_freq=h_freq_adj, fir_design="firwin", verbose=False
62
+ )
63
+
64
+ # Notch at mains harmonics (e.g. 50,100,150 or 60,120,180) but only those < Nyquist
65
+ mains_freqs = [mains * i for i in (1, 2, 3)]
66
+ mains_freqs = [f for f in mains_freqs if f < nyq]
67
+ if len(mains_freqs) > 0:
68
+ # use raw.notch_filter which handles multiple notch freqs
69
+ raw.notch_filter(freqs=mains_freqs, picks=[channel], verbose=False)
70
+
71
+ # Resample to target sampling rate (upsample or downsample)
72
+ if raw.info["sfreq"] != sfreq:
73
+ raw = raw.resample(sfreq, npad="auto")
74
+
75
+ # Create 30s epochs
76
+ events, event_id = mne.events_from_annotations(raw, chunk_duration=30.0)
77
+ tmax = 30.0 - 1.0 / raw.info["sfreq"]
78
+ epochs = mne.Epochs(
79
+ raw=raw,
80
+ events=events,
81
+ event_id=event_id,
82
+ tmin=0.0,
83
+ tmax=tmax,
84
+ baseline=None,
85
+ verbose=False,
86
+ )
87
+
88
+ epochs_data = epochs.get_data() # (n_epochs, 1, n_times)
89
+ labels = []
90
+ for ann in epochs.get_annotations_per_epoch():
91
+ labels.append(mapping[str(ann[0][2])])
92
+
93
+ n_epochs, _, n_times = epochs_data.shape
94
+ if n_times < window_size:
95
+ print(f"Warning: Not enough samples in {raw_fn}, skipping")
96
+ return None, None, raw_fn
97
+
98
+ # Sliding window
99
+ windows = []
100
+ labels_win = []
101
+ for i in range(n_epochs):
102
+ for start in range(0, n_times - window_size + 1, stride):
103
+ windows.append(epochs_data[i, 0, start : start + window_size])
104
+ labels_win.append(labels[i])
105
+
106
+ if len(windows) > 0:
107
+ windows = np.stack(windows) # (n_windows, window_size)
108
+ windows = windows[:, np.newaxis, :] # (n_windows, 1, window_size)
109
+
110
+ if verbose:
111
+ print(f" {raw_fn}: Generated {len(windows)} windows")
112
+
113
+ return (
114
+ windows.astype(np.float32),
115
+ np.array(labels_win, dtype=np.int32),
116
+ raw_fn,
117
+ )
118
+ else:
119
+ return None, None, raw_fn
120
+
121
+ except Exception as e:
122
+ print(f"Error processing {raw_fn}: {e}")
123
+ return None, None, raw_fn
124
+
125
+
126
+ def convert_hmc_to_h5(
127
+ data_path: str,
128
+ save_path: str,
129
+ channel: str = "EMG chin",
130
+ start_at: int = 15 * 60,
131
+ duration_sec: int = 6 * 60 * 60,
132
+ l_freq: float = 5.0,
133
+ h_freq: float = 250.0,
134
+ sfreq: int = 100,
135
+ mains: int = 50,
136
+ window_size: int = 1000,
137
+ stride: int = 1000,
138
+ n_jobs: int = -1,
139
+ verbose: bool = True,
140
+ ):
141
+ """
142
+ Convert HMC EMG dataset to HDF5 format compatible with EMGDataset.
143
+ Uses joblib for parallel processing of individual recordings.
144
+
145
+ Args:
146
+ data_path: Root directory containing EDF files
147
+ save_path: Directory to save HDF5 files
148
+ channel: EMG channel name to extract
149
+ start_at: Start time in seconds
150
+ duration_sec: Duration to extract in seconds
151
+ l_freq: Low-pass filter frequency
152
+ h_freq: High-pass filter frequency
153
+ sfreq: Target sampling frequency
154
+ window_size: Window size for segmentation
155
+ stride: Stride for sliding window
156
+ n_jobs: Number of parallel jobs (-1 for all cores, 1 for sequential)
157
+ verbose: Print progress
158
+ """
159
+
160
+ mapping = {
161
+ "Sleep stage W": 0,
162
+ "Sleep stage N1": 1,
163
+ "Sleep stage N2": 2,
164
+ "Sleep stage N3": 3,
165
+ "Sleep stage R": 4,
166
+ "Lights off@@EEG F4-A1": 0,
167
+ }
168
+
169
+ os.makedirs(save_path, exist_ok=True)
170
+
171
+ # Discover record file pairs
172
+ files = os.listdir(data_path)
173
+ raw_files = [
174
+ f
175
+ for f in files
176
+ if f.lower().endswith(".edf") and "sleepscoring" not in f.lower()
177
+ ]
178
+
179
+ records = []
180
+ for raw_fn in raw_files:
181
+ base = os.path.splitext(raw_fn)[0]
182
+ scoring_fn = base + "_sleepscoring.edf"
183
+ if scoring_fn in files:
184
+ records.append((raw_fn, scoring_fn))
185
+ elif verbose:
186
+ print(f"Warning: scoring file missing for {raw_fn}")
187
+
188
+ if len(records) == 0:
189
+ print("No valid record pairs found!")
190
+ return
191
+
192
+ print(f"Found {len(records)} recording pairs")
193
+ print(f"Using {n_jobs} parallel jobs" if n_jobs != 1 else "Running sequentially")
194
+
195
+ # Initialize data containers for each split
196
+ datasets = {
197
+ "train": {"data": [], "label": []},
198
+ "val": {"data": [], "label": []},
199
+ "test": {"data": [], "label": []},
200
+ }
201
+
202
+ # Create mapping from filename to split
203
+ def get_split(filename):
204
+ # Extract subject number from filename
205
+ # Example: "SN001.edf" -> 1
206
+ import re
207
+
208
+ match = re.search(r"(\d+)", filename)
209
+ # Version 1.1: recordings SN014, SN064, and SN135 were removed after it was detected that these recordings contained erroneous (and unfixable) signal data.
210
+ train_subjects = range(1, 101) # Subjects 1-100 for training
211
+ val_subjects = range(101, 127) # Subjects 101-126 for validation
212
+ test_subjects = range(127, 155) # Subjects 127-154 for testing
213
+ if match:
214
+ subj_num = int(match.group(1))
215
+ if subj_num in train_subjects:
216
+ return "train"
217
+ elif subj_num in val_subjects:
218
+ return "val"
219
+ elif subj_num in test_subjects:
220
+ return "test"
221
+ else:
222
+ return "train" # default to train
223
+ return "train" # default
224
+
225
+ # Process recordings in parallel
226
+ print(f"\nProcessing {len(records)} recordings...")
227
+ results = Parallel(n_jobs=n_jobs, verbose=10 if verbose else 0)(
228
+ delayed(process_single_recording)(
229
+ raw_fn=raw_fn,
230
+ scoring_fn=scoring_fn,
231
+ data_path=data_path,
232
+ channel=channel,
233
+ start_at=start_at,
234
+ duration_sec=duration_sec,
235
+ l_freq=l_freq,
236
+ h_freq=h_freq,
237
+ sfreq=sfreq,
238
+ mains=mains,
239
+ window_size=window_size,
240
+ stride=stride,
241
+ mapping=mapping,
242
+ verbose=False,
243
+ )
244
+ for raw_fn, scoring_fn in records
245
+ )
246
+
247
+ # Collect results into splits
248
+ processed_count = 0
249
+ failed_count = 0
250
+
251
+ for windows, labels, raw_fn in results:
252
+ if windows is not None and labels is not None:
253
+ # Determine which split this recording belongs to
254
+ split_key = get_split(raw_fn)
255
+
256
+ datasets[split_key]["data"].append(windows)
257
+ datasets[split_key]["label"].append(labels)
258
+ processed_count += 1
259
+
260
+ if verbose:
261
+ print(f"βœ“ {raw_fn}: {len(windows)} windows -> {split_key}")
262
+ else:
263
+ failed_count += 1
264
+ if verbose:
265
+ print(f"βœ— {raw_fn}: Failed")
266
+
267
+ print(f"\nProcessing complete: {processed_count} successful, {failed_count} failed")
268
+
269
+ # Concatenate and save
270
+ for split_name, split_data in datasets.items():
271
+ if len(split_data["data"]) == 0:
272
+ print(f"Warning: No data for {split_name} split")
273
+ continue
274
+
275
+ print(f"\nPreparing {split_name} split...")
276
+ X = np.concatenate(split_data["data"], axis=0) # (N, 1, window_size)
277
+ y = np.concatenate(split_data["label"], axis=0) # (N,)
278
+
279
+ h5_path = os.path.join(save_path, f"{split_name}.h5")
280
+ with h5py.File(h5_path, "w") as f:
281
+ f.create_dataset(
282
+ "data", data=X, dtype=np.float32, compression="gzip", compression_opts=4
283
+ )
284
+ f.create_dataset(
285
+ "label", data=y, dtype=np.int32, compression="gzip", compression_opts=4
286
+ )
287
+
288
+ uniq, cnt = np.unique(y, return_counts=True)
289
+ label_dist = dict(zip(uniq.tolist(), cnt.tolist()))
290
+
291
+ print(f"{split_name.upper()}:")
292
+ print(f" Shape: X={X.shape}, y={y.shape}")
293
+ print(f" Label distribution: {label_dist}")
294
+ print(f" Saved to: {h5_path}")
295
+ print(f" File size: {os.path.getsize(h5_path) / (1024**2):.2f} MB")
296
+
297
+
298
+ if __name__ == "__main__":
299
+ import argparse
300
+
301
+ parser = argparse.ArgumentParser(
302
+ description="Convert HMC EMG dataset to HDF5 format with parallel processing"
303
+ )
304
+ parser.add_argument(
305
+ "--data_dir",
306
+ type=str,
307
+ required=True,
308
+ help="Root directory containing HMC EDF files",
309
+ )
310
+ parser.add_argument(
311
+ "--save_dir", type=str, required=True, help="Directory to save HDF5 files"
312
+ )
313
+ parser.add_argument(
314
+ "--channel", type=str, default="EMG chin", help="EMG channel name"
315
+ )
316
+ parser.add_argument(
317
+ "--start_at", type=int, default=900, help="Start time in seconds"
318
+ )
319
+ parser.add_argument(
320
+ "--duration_sec", type=int, default=21600, help="Duration in seconds"
321
+ )
322
+ parser.add_argument(
323
+ "--l_freq", type=float, default=5.0, help="Low-pass filter frequency"
324
+ )
325
+ parser.add_argument(
326
+ "--h_freq", type=float, default=200.0, help="High-pass filter frequency"
327
+ )
328
+ parser.add_argument(
329
+ "--sfreq", type=int, default=500, help="Target sampling frequency (Hz)"
330
+ )
331
+ parser.add_argument(
332
+ "--mains",
333
+ type=int,
334
+ default=50,
335
+ choices=[50, 60],
336
+ help="Mains frequency for notch (50 or 60 Hz)",
337
+ )
338
+ parser.add_argument(
339
+ "--window_size", type=int, default=1000, help="Window size for segmentation"
340
+ )
341
+ parser.add_argument(
342
+ "--stride", type=int, default=1000, help="Stride for sliding window"
343
+ )
344
+ parser.add_argument(
345
+ "--n_jobs",
346
+ type=int,
347
+ default=-1,
348
+ help="Number of parallel jobs (-1 for all cores)",
349
+ )
350
+ parser.add_argument(
351
+ "--verbose", action="store_true", help="Print detailed progress"
352
+ )
353
+
354
+ args = parser.parse_args()
355
+
356
+ convert_hmc_to_h5(
357
+ data_path=args.data_dir,
358
+ save_path=args.save_dir,
359
+ channel=args.channel,
360
+ start_at=args.start_at,
361
+ duration_sec=args.duration_sec,
362
+ l_freq=args.l_freq,
363
+ h_freq=args.h_freq,
364
+ mains=args.mains,
365
+ sfreq=args.sfreq,
366
+ window_size=args.window_size,
367
+ stride=args.stride,
368
+ n_jobs=args.n_jobs,
369
+ verbose=args.verbose,
370
+ )
scripts/README.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset Preparation Commands
2
+
3
+ ## Overview
4
+
5
+ This document provides the commands to prepare various EMG datasets for pretraining and downstream tasks. Each dataset preparation script takes in raw data, processes it into overlapping windows, and saves the processed data in HDF5 format for efficient loading during model training.
6
+
7
+ Remember to add the flag `--download_data` if the dataset is not downloaded yet.
8
+
9
+ ## Pretraining Datasets
10
+
11
+ For the pretraining:
12
+
13
+ ### emg2pose
14
+
15
+ ```bash
16
+ python scripts/emg2pose.py \
17
+ --data_dir $SCRATCH/datasets/emg2pose_data/ \
18
+ --save_dir $SCRATCH/datasets/emg2pose_data/h5/ \
19
+ --window_size 1000 \
20
+ --stride 500
21
+ ```
22
+
23
+ ### Ninapro DB6
24
+
25
+ ```bash
26
+ python scripts/db6.py \
27
+ --data_dir $SCRATCH/datasets/ninapro/DB6/ \
28
+ --save_dir $SCRATCH/datasets/ninapro/DB6/h5/ \
29
+ --window_size 1000 \
30
+ --stride 500
31
+ ```
32
+
33
+ ### Ninapro DB7
34
+
35
+ ```bash
36
+ python scripts/db7.py \
37
+ --data_dir $SCRATCH/datasets/ninapro/DB7/ \
38
+ --save_dir $SCRATCH/datasets/ninapro/DB7/h5/ \
39
+ --window_size 1000 \
40
+ --stride 500
41
+ ```
42
+
43
+ ---
44
+
45
+ ## Downstream Datasets
46
+
47
+ For the downstream tasks:
48
+
49
+ ### Ninapro DB5 (200 ms, 25% overlap)
50
+
51
+ ```bash
52
+ python scripts/db5.py \
53
+ --data_dir $SCRATCH/datasets/ninapro/DB5/ \
54
+ --save_dir $SCRATCH/datasets/ninapro/DB5/h5/ \
55
+ --window_size 200 \
56
+ --stride 50
57
+ ```
58
+
59
+ ### Ninapro DB5 (1000 ms, 25% overlap)
60
+
61
+ ```bash
62
+ python scripts/db5.py \
63
+ --data_dir $SCRATCH/datasets/ninapro/DB5/ \
64
+ --save_dir $SCRATCH/datasets/ninapro/DB5/h5/ \
65
+ --window_size 1000 \
66
+ --stride 250
67
+ ```
68
+
69
+ ### EMG-EPN612 (200 ms)
70
+
71
+ ```bash
72
+ python scripts/epn.py \
73
+ --data_dir $SCRATCH/datasets/EPN612/ \
74
+ --source_training $SCRATCH/datasets/EPN612/trainingJSON/ \
75
+ --source_testing $SCRATCH/datasets/EPN612/testingJSON/ \
76
+ --dest_dir $SCRATCH/datasets/EPN612/h5/ \
77
+ --window_size 200
78
+ ```
79
+
80
+ ### EMG-EPN612 (1000 ms)
81
+
82
+ ```bash
83
+ python scripts/epn.py \
84
+ --data_dir $SCRATCH/datasets/EPN612/ \
85
+ --source_training $SCRATCH/datasets/EPN612/trainingJSON/ \
86
+ --source_testing $SCRATCH/datasets/EPN612/testingJSON/ \
87
+ --dest_dir $SCRATCH/datasets/EPN612/h5/ \
88
+ --window_size 1000
89
+ ```
90
+
91
+ ### UCI EMG (200 ms, 25% overlap)
92
+
93
+ ```bash
94
+ python scripts/uci.py \
95
+ --data_dir $SCRATCH/datasets/UCI_EMG/EMG_data_for_gestures-master/ \
96
+ --save_dir $SCRATCH/datasets/UCI_EMG/EMG_data_for_gestures-master/h5/ \
97
+ --window_size 200 \
98
+ --stride 50
99
+ ```
100
+
101
+ ### UCI EMG (1000 ms, 25% overlap)
102
+
103
+ ```bash
104
+ python scripts/uci.py \
105
+ --data_dir $SCRATCH/datasets/UCI_EMG/EMG_data_for_gestures-master/ \
106
+ --save_dir $SCRATCH/datasets/UCI_EMG/EMG_data_for_gestures-master/h5/ \
107
+ --window_size 1000 \
108
+ --stride 250
109
+ ```
110
+
111
+ ### Ninapro DB8 (200 ms, no overlap)
112
+
113
+ ```bash
114
+ python scripts/db8.py \
115
+ --data_dir $SCRATCH/datasets/ninapro/DB8/ \
116
+ --save_dir $SCRATCH/datasets/ninapro/DB8/h5/ \
117
+ --window_size 200 \
118
+ --stride 200
119
+ ```
120
+
121
+ ### Ninapro DB8 (1000 ms, no overlap)
122
+
123
+ ```bash
124
+ python scripts/db8.py \
125
+ --data_dir $SCRATCH/datasets/ninapro/DB8/ \
126
+ --save_dir $SCRATCH/datasets/ninapro/DB8/h5/ \
127
+ --window_size 1000 \
128
+ --stride 1000
129
+ ```
scripts/db5.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import scipy.io
7
+ import scipy.signal as signal
8
+ from scipy.signal import iirnotch
9
+
10
+
11
+ # ==== Data augmentation functions ====
12
+ def random_amplitude_scale(sig, scale_range=(0.9, 1.1)):
13
+ scale = np.random.uniform(*scale_range)
14
+ return sig * scale
15
+
16
+
17
+ def random_time_jitter(sig, jitter_ratio=0.01):
18
+ T, D = sig.shape
19
+ std_ch = np.std(sig, axis=0)
20
+ noise = np.random.randn(T, D) * (jitter_ratio * std_ch)
21
+ return sig + noise
22
+
23
+
24
+ def random_channel_dropout(sig, dropout_prob=0.05):
25
+ T, D = sig.shape
26
+ mask = np.random.rand(D) < dropout_prob
27
+ sig[:, mask] = 0.0
28
+ return sig
29
+
30
+
31
+ def augment_one_sample(seg):
32
+ out = seg.copy()
33
+ out = random_amplitude_scale(out, (0.9, 1.1))
34
+ out = random_time_jitter(out, 0.01)
35
+ out = random_channel_dropout(out, 0.05)
36
+ return out
37
+
38
+
39
+ def augment_train_data(data, labels, factor=3):
40
+ if factor <= 0 or data.shape[0] == 0:
41
+ return data, labels
42
+ aug_segs = [data]
43
+ aug_lbls = [labels]
44
+ N = data.shape[0]
45
+ for i in range(N):
46
+ seg = data[i] # [window_size, n_ch]
47
+ lab = labels[i]
48
+ for _ in range(factor):
49
+ aug_segs.append(augment_one_sample(seg)[None, ...])
50
+ aug_lbls.append([lab])
51
+ new_data = np.concatenate(aug_segs, axis=0)
52
+ new_labels = np.concatenate(aug_lbls, axis=0).ravel()
53
+ return new_data, new_labels
54
+
55
+
56
+ # ==== Filter functions (operate at original fs=200) ====
57
+ def notch_filter(data, notch_freq=50.0, Q=30.0, fs=200.0):
58
+ b, a = iirnotch(notch_freq, Q, fs)
59
+ out = np.zeros_like(data)
60
+ for ch in range(data.shape[1]):
61
+ out[:, ch] = signal.filtfilt(b, a, data[:, ch])
62
+ return out
63
+
64
+
65
+ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
66
+ nyq = 0.5 * fs
67
+ low = lowcut / nyq
68
+ high = highcut / nyq
69
+ b, a = signal.butter(order, [low, high], btype="bandpass")
70
+ out = np.zeros_like(emg)
71
+ for c in range(emg.shape[1]):
72
+ out[:, c] = signal.filtfilt(b, a, emg[:, c])
73
+ return out
74
+
75
+
76
+ # ==== Window segmentation ====
77
+ def process_emg_features(emg, label, rerep, window_size=1024, stride=512):
78
+ segs, lbls, reps = [], [], []
79
+ N = len(label)
80
+ for start in range(0, N, stride):
81
+ end = start + window_size
82
+ if end > N:
83
+ cut = emg[start:N]
84
+ pad = np.zeros((end - N, emg.shape[1]))
85
+ win = np.vstack([cut, pad])
86
+ else:
87
+ win = emg[start:end]
88
+
89
+ segs.append(win)
90
+ lbls.append(label[start])
91
+ reps.append(rerep[start])
92
+ return np.array(segs), np.array(lbls), np.array(reps)
93
+
94
+
95
+ # ==== Main pipeline ====
96
+ def main():
97
+ import argparse
98
+
99
+ args = argparse.ArgumentParser(description="Process EMG data from DB5.")
100
+ args.add_argument("--download_data", action="store_true")
101
+ args.add_argument("--data_dir", type=str)
102
+ args.add_argument("--save_dir", type=str)
103
+ args.add_argument(
104
+ "--window_size", type=int, help="Size of the sliding window for segmentation."
105
+ )
106
+ args.add_argument(
107
+ "--stride", type=int, help="Stride for the sliding window segmentation."
108
+ )
109
+ args = args.parse_args()
110
+
111
+ data_dir = args.data_dir
112
+ save_dir = args.save_dir
113
+ os.makedirs(save_dir, exist_ok=True)
114
+
115
+ # download data if requested
116
+ if args.download_data:
117
+ # https://ninapro.hevs.ch/instructions/DB5.html
118
+ len_data = range(1, 11) # 1–10
119
+ base_url = "https://ninapro.hevs.ch/files/DB5_Preproc/"
120
+ # download and unzip
121
+ for i in len_data:
122
+ url = f"{base_url}s{i}.zip"
123
+ os.system(f"wget -P {data_dir} {url}")
124
+ os.system(f"unzip -o {data_dir}/s{i}.zip -d {data_dir}")
125
+ os.system(f"rm {data_dir}/s{i}.zip")
126
+ print(f"Downloaded and unzipped subject {i}\n{data_dir}/s{i}.zip")
127
+ sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
128
+
129
+ fs = 200.0 # original sampling rate
130
+ window_size, stride = args.window_size, args.stride
131
+ train_reps = [1, 3, 4, 6]
132
+ val_reps = [2]
133
+ test_reps = [5]
134
+
135
+ all_data = {"train": [], "val": [], "test": []}
136
+ all_lbls = {"train": [], "val": [], "test": []}
137
+
138
+ for subj in sorted(os.listdir(data_dir)):
139
+ subj_path = os.path.join(data_dir, subj)
140
+ if not os.path.isdir(subj_path):
141
+ continue
142
+ print(f"Processing subject {subj}...")
143
+ for mat in sorted(os.listdir(subj_path)):
144
+ if not mat.endswith(".mat"):
145
+ continue
146
+ dd = scipy.io.loadmat(os.path.join(subj_path, mat))
147
+ emg = dd["emg"] # [N,16]
148
+ label = dd["restimulus"].ravel().astype(int)
149
+ rerep = dd["rerepetition"].ravel().astype(int)
150
+
151
+ # label shift by exercise
152
+ if "E2" in mat:
153
+ label = np.where(label != 0, label + 12, 0)
154
+ elif "E3" in mat:
155
+ label = np.where(label != 0, label + 29, 0)
156
+
157
+ # filtering at original 200 Hz
158
+ emg_filt = bandpass_filter_emg(emg, 20, 90, fs=fs)
159
+ emg_filt = notch_filter(emg_filt, 50, 30, fs=fs)
160
+
161
+ # z-score
162
+ mu = emg_filt.mean(axis=0)
163
+ sd = emg_filt.std(axis=0, ddof=1)
164
+ sd[sd == 0] = 1.0
165
+ emg_z = (emg_filt - mu) / sd
166
+
167
+ # segment
168
+ segs, lbls, reps = process_emg_features(
169
+ emg_z, label, rerep, window_size, stride
170
+ )
171
+
172
+ # split by repetition index
173
+ for seg, lab, rp in zip(segs, lbls, reps):
174
+ if rp in train_reps:
175
+ all_data["train"].append(seg)
176
+ all_lbls["train"].append(lab)
177
+ elif rp in val_reps:
178
+ all_data["val"].append(seg)
179
+ all_lbls["val"].append(lab)
180
+ elif rp in test_reps:
181
+ all_data["test"].append(seg)
182
+ all_lbls["test"].append(lab)
183
+
184
+ # stack, augment train, transpose, save, and print stats
185
+ stats = {}
186
+ for split in ["train", "val", "test"]:
187
+ X = np.stack(all_data[split], axis=0) # [N, window_size, ch]
188
+ y = np.array(all_lbls[split], dtype=int)
189
+
190
+ if split == "train":
191
+ X, y = augment_train_data(X, y, factor=3)
192
+
193
+ # transpose to [N, ch, window_size]
194
+ X = X.transpose(0, 2, 1)
195
+
196
+ # save
197
+ with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as hf:
198
+ hf.create_dataset("data", data=X)
199
+ hf.create_dataset("label", data=y)
200
+
201
+ # compute stats
202
+ uniq, cnt = np.unique(y, return_counts=True)
203
+ stats[split] = (X.shape, dict(zip(uniq.tolist(), cnt.tolist())))
204
+
205
+ # print stats
206
+ for split, (shape, dist) in stats.items():
207
+ print(f"\n{split} β†’ X={shape}, label distribution:")
208
+ for lab, count in dist.items():
209
+ print(f" label {lab}: {count} samples")
210
+
211
+
212
+ if __name__ == "__main__":
213
+ main()
scripts/db6.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import scipy.io
7
+ import scipy.signal as signal
8
+ from scipy.signal import iirnotch
9
+
10
+
11
+ # ─────────────── Filtering ──────────────────
12
+ def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
13
+ """Notch-filter every channel independently."""
14
+ b, a = iirnotch(notch_freq, Q, fs)
15
+ out = np.zeros_like(data)
16
+ for ch in range(data.shape[1]):
17
+ out[:, ch] = signal.filtfilt(b, a, data[:, ch])
18
+ return out
19
+
20
+
21
+ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
22
+ nyq = 0.5 * fs
23
+ b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
24
+ out = np.zeros_like(emg)
25
+ for ch in range(emg.shape[1]):
26
+ out[:, ch] = signal.filtfilt(b, a, emg[:, ch])
27
+ return out
28
+
29
+
30
+ # ─────────────── Sliding window ──────────────
31
+ def sliding_window_segment(emg, label, rerepetition, window_size, stride):
32
+ """
33
+ Segment EMG with a sliding window.
34
+ Use the frame at the window centre as the segment label / repetition index.
35
+ """
36
+ segments, labels, reps = [], [], []
37
+ n_samples = len(label)
38
+
39
+ for start in range(0, n_samples - window_size + 1, stride):
40
+ end = start + window_size
41
+ emg_segment = emg[start:end] # (win, ch)
42
+ centre_idx = (start + end) // 2
43
+ segments.append(emg_segment)
44
+ labels.append(label[centre_idx])
45
+ reps.append(rerepetition[centre_idx])
46
+
47
+ return np.array(segments), np.array(labels), np.array(reps)
48
+
49
+
50
+ # ─────────────── Main pipeline ───────────────
51
+ def main():
52
+ import argparse
53
+
54
+ args = argparse.ArgumentParser(description="Process EMG data from DB6.")
55
+ args.add_argument("--download_data", action="store_true")
56
+ args.add_argument("--data_dir", type=str)
57
+ args.add_argument("--save_dir", type=str)
58
+ args.add_argument(
59
+ "--window_size", type=int, help="Size of the sliding window for segmentation."
60
+ )
61
+ args.add_argument(
62
+ "--stride", type=int, help="Stride for the sliding window segmentation."
63
+ )
64
+ args = args.parse_args()
65
+ data_dir = args.data_dir # input folder with .mat files
66
+ save_dir = args.save_dir # output folder for .h5 files
67
+ os.makedirs(save_dir, exist_ok=True)
68
+
69
+ # download data if requested
70
+ if args.download_data:
71
+ # https://ninapro.hevs.ch/instructions/DB6.html
72
+ len_data = range(1, 11) # 1–10
73
+ base_url = "https://ninapro.hevs.ch/files/DB6_Preproc/"
74
+ # download and unzip
75
+ for i in len_data:
76
+ url_a = f"{base_url}DB6_s{i}_a.zip"
77
+ url_b = f"{base_url}DB6_s{i}_b.zip"
78
+ os.system(f"wget -P {data_dir} {url_a}")
79
+ os.system(f"wget -P {data_dir} {url_b}")
80
+ os.system(f"unzip -o {data_dir}/DB6_s{i}_a.zip -d {data_dir}")
81
+ os.system(f"unzip -o {data_dir}/DB6_s{i}_b.zip -d {data_dir}")
82
+ os.system(f"rm {data_dir}/DB6_s{i}_a.zip {data_dir}/DB6_s{i}_b.zip")
83
+ print(
84
+ f"Downloaded and unzipped subject {i}\n{data_dir}/DB6_s{i}_a.zip and {data_dir}/DB6_s{i}_b.zip"
85
+ )
86
+ sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
87
+
88
+ fs = 2000.0
89
+ window_size, stride = args.window_size, args.stride
90
+
91
+ train_reps = list(range(1, 9)) # 1–8
92
+ val_reps = [9, 10] # 9–10
93
+ test_reps = [11, 12] # 11–12
94
+
95
+ splits = {
96
+ "train": {"data": [], "label": []},
97
+ "val": {"data": [], "label": []},
98
+ "test": {"data": [], "label": []},
99
+ }
100
+
101
+ # iterate subjects
102
+ for subj in sorted(os.listdir(data_dir)):
103
+ subj_path = os.path.join(data_dir, subj)
104
+ if not os.path.isdir(subj_path):
105
+ continue
106
+ print(f"Processing subject {subj} ...")
107
+
108
+ subj_seg, subj_lbl, subj_rep = [], [], []
109
+
110
+ # iterate .mat files
111
+ for mat_file in sorted(os.listdir(subj_path)):
112
+ if not mat_file.endswith(".mat"):
113
+ continue
114
+ mat_path = os.path.join(subj_path, mat_file)
115
+ mat = scipy.io.loadmat(mat_path)
116
+
117
+ emg = mat["emg"] # (N, 16)
118
+ label = mat["restimulus"].ravel()
119
+ rerep = mat["rerepetition"].ravel()
120
+
121
+ # drop empty channels (index 8, 9 β†’ 0-based)
122
+ emg = np.delete(emg, [8, 9], axis=1) # now (N, 14)
123
+
124
+ # filtering
125
+ emg = bandpass_filter_emg(emg, 20, 450, fs=fs)
126
+ emg = notch_filter(emg, 50, 30, fs=fs)
127
+
128
+ # z-score per channel
129
+ mu = emg.mean(axis=0)
130
+ sd = emg.std(axis=0, ddof=1)
131
+ sd[sd == 0] = 1.0
132
+ emg = (emg - mu) / sd
133
+
134
+ # windowing
135
+ seg, lbl, rep = sliding_window_segment(
136
+ emg, label, rerep, window_size, stride
137
+ )
138
+ subj_seg.append(seg)
139
+ subj_lbl.append(lbl)
140
+ subj_rep.append(rep)
141
+
142
+ if not subj_seg:
143
+ continue
144
+
145
+ seg = np.concatenate(subj_seg, axis=0) # (M, win, 14)
146
+ lbl = np.concatenate(subj_lbl)
147
+ rep = np.concatenate(subj_rep)
148
+
149
+ # split by repetition id
150
+ for split_name, mask in (
151
+ ("train", np.isin(rep, train_reps)),
152
+ ("val", np.isin(rep, val_reps)),
153
+ ("test", np.isin(rep, test_reps)),
154
+ ):
155
+ X = seg[mask].transpose(0, 2, 1) # (N, 14, 1024)
156
+ y = lbl[mask]
157
+ splits[split_name]["data"].append(X)
158
+ splits[split_name]["label"].append(y)
159
+
160
+ # concatenate, save, and report
161
+ for split in ["train", "val", "test"]:
162
+ X = (
163
+ np.concatenate(splits[split]["data"], axis=0)
164
+ if splits[split]["data"]
165
+ else np.empty((0, 14, window_size))
166
+ )
167
+ y = (
168
+ np.concatenate(splits[split]["label"], axis=0)
169
+ if splits[split]["label"]
170
+ else np.empty((0,), dtype=int)
171
+ )
172
+
173
+ with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as f:
174
+ f.create_dataset("data", data=X.astype(np.float32))
175
+ f.create_dataset("label", data=y.astype(np.int64))
176
+
177
+ uniq, cnt = np.unique(y, return_counts=True)
178
+ print(f"\n{split.upper()} β†’ X={X.shape}, label distribution:")
179
+ for u, c in zip(uniq, cnt):
180
+ print(f" label {u}: {c} samples")
181
+
182
+ print("\nSaved: train.h5, val.h5, test.h5")
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main()
scripts/db7.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import scipy.io
7
+ import scipy.signal as signal
8
+ from scipy.signal import iirnotch
9
+
10
+
11
+ # ─────────────── Filtering ──────────────────
12
+ def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
13
+ """Notch-filter every channel independently."""
14
+ b, a = iirnotch(notch_freq, Q, fs)
15
+ out = np.zeros_like(data)
16
+ for ch in range(data.shape[1]):
17
+ out[:, ch] = signal.filtfilt(b, a, data[:, ch])
18
+ return out
19
+
20
+
21
+ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
22
+ nyq = 0.5 * fs
23
+ b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
24
+ out = np.zeros_like(emg)
25
+ for ch in range(emg.shape[1]):
26
+ out[:, ch] = signal.filtfilt(b, a, emg[:, ch])
27
+ return out
28
+
29
+
30
+ # ─────────────── Sliding window ──────────────
31
+ def sliding_window_segment(emg, label, rerepetition, window_size, stride):
32
+ """
33
+ Segment EMG with a sliding window.
34
+ Use the frame at the window centre as the segment label / repetition index.
35
+ """
36
+ segments, labels, reps = [], [], []
37
+ n_samples = len(label)
38
+
39
+ for start in range(0, n_samples - window_size + 1, stride):
40
+ end = start + window_size
41
+ emg_segment = emg[start:end] # (win, ch)
42
+ centre_idx = (start + end) // 2
43
+ segments.append(emg_segment)
44
+ labels.append(label[centre_idx])
45
+ reps.append(rerepetition[centre_idx])
46
+
47
+ return np.array(segments), np.array(labels), np.array(reps)
48
+
49
+
50
+ # ─────────────── Main pipeline ───────────────
51
+ def main():
52
+ import argparse
53
+
54
+ args = argparse.ArgumentParser(description="Process EMG data from DB7.")
55
+ args.add_argument("--download_data", action="store_true")
56
+ args.add_argument("--data_dir", type=str)
57
+ args.add_argument("--save_dir", type=str)
58
+ args.add_argument(
59
+ "--window_size",
60
+ type=int,
61
+ default=256,
62
+ help="Size of the sliding window for segmentation.",
63
+ )
64
+ args.add_argument(
65
+ "--stride",
66
+ type=int,
67
+ default=128,
68
+ help="Stride for the sliding window segmentation.",
69
+ )
70
+ args = args.parse_args()
71
+ data_dir = args.data_dir # input folder with .mat files
72
+ save_dir = args.save_dir # output folder for .h5 files
73
+ os.makedirs(save_dir, exist_ok=True)
74
+
75
+ # download data if requested
76
+ if args.download_data:
77
+ # https://ninapro.hevs.ch/instructions/DB7.html
78
+ len_data = range(1, 23) # 1–22
79
+ base_url = "https://ninapro.hevs.ch/files/DB7_Preproc/"
80
+ # download and unzip
81
+ for i in len_data:
82
+ url = f"{base_url}Subject_{i}.zip"
83
+ os.system(f"wget -P {data_dir} {url}")
84
+ os.system(f"unzip -o {data_dir}/Subject_{i}.zip -d {data_dir}/Subject_{i}")
85
+ os.system(f"rm {data_dir}/Subject_{i}.zip")
86
+ print(f"Downloaded and unzipped subject {i}\n{data_dir}/Subject_{i}.zip")
87
+ sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
88
+
89
+ fs = 2000.0
90
+ window_size, stride = args.window_size, args.stride
91
+
92
+ train_reps = [1, 2, 3, 4] # 1–4
93
+ val_reps = [5] # 5
94
+ test_reps = [6] # 6
95
+
96
+ splits = {
97
+ "train": {"data": [], "label": []},
98
+ "val": {"data": [], "label": []},
99
+ "test": {"data": [], "label": []},
100
+ }
101
+
102
+ # iterate subjects
103
+ for subj in sorted(os.listdir(data_dir)):
104
+ subj_path = os.path.join(data_dir, subj)
105
+ if not os.path.isdir(subj_path):
106
+ continue
107
+ print(f"Processing subject {subj} ...")
108
+
109
+ subj_seg, subj_lbl, subj_rep = [], [], []
110
+
111
+ # iterate .mat files
112
+ for mat_file in sorted(os.listdir(subj_path)):
113
+ if not mat_file.endswith(".mat"):
114
+ continue
115
+ mat_path = os.path.join(subj_path, mat_file)
116
+ mat = scipy.io.loadmat(mat_path)
117
+
118
+ emg = mat["emg"] # (N, 16)
119
+ label = mat["restimulus"].ravel()
120
+ rerep = mat["rerepetition"].ravel()
121
+
122
+ # filtering
123
+ emg = bandpass_filter_emg(emg, 20.0, 450.0, fs=fs)
124
+ emg = notch_filter(emg, 50.0, 30.0, fs=fs)
125
+
126
+ # z-score per channel
127
+ mu = emg.mean(axis=0)
128
+ sd = emg.std(axis=0, ddof=1)
129
+ sd[sd == 0] = 1.0
130
+ emg = (emg - mu) / sd
131
+
132
+ # windowing
133
+ seg, lbl, rep = sliding_window_segment(
134
+ emg, label, rerep, window_size, stride
135
+ )
136
+ subj_seg.append(seg)
137
+ subj_lbl.append(lbl)
138
+ subj_rep.append(rep)
139
+
140
+ if not subj_seg:
141
+ continue
142
+
143
+ seg = np.concatenate(subj_seg, axis=0) # (M, win, 14)
144
+ lbl = np.concatenate(subj_lbl)
145
+ rep = np.concatenate(subj_rep)
146
+
147
+ # split by repetition id
148
+ for split_name, mask in (
149
+ ("train", np.isin(rep, train_reps)),
150
+ ("val", np.isin(rep, val_reps)),
151
+ ("test", np.isin(rep, test_reps)),
152
+ ):
153
+ X = seg[mask].transpose(0, 2, 1) # (N, 14, 1024)
154
+ y = lbl[mask]
155
+ splits[split_name]["data"].append(X)
156
+ splits[split_name]["label"].append(y)
157
+
158
+ # concatenate, save, and report
159
+ for split in ["train", "val", "test"]:
160
+ X = (
161
+ np.concatenate(splits[split]["data"], axis=0)
162
+ if splits[split]["data"]
163
+ else np.empty((0, 14, window_size))
164
+ )
165
+ y = (
166
+ np.concatenate(splits[split]["label"], axis=0)
167
+ if splits[split]["label"]
168
+ else np.empty((0,), dtype=int)
169
+ )
170
+
171
+ with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as f:
172
+ f.create_dataset("data", data=X.astype(np.float32))
173
+ f.create_dataset("label", data=y.astype(np.int64))
174
+
175
+ uniq, cnt = np.unique(y, return_counts=True)
176
+ print(f"\n{split.upper()} β†’ X={X.shape}, label distribution:")
177
+ for u, c in zip(uniq, cnt):
178
+ print(f" label {u}: {c} samples")
179
+
180
+ print("\nSaved: train.h5, val.h5, test.h5")
181
+
182
+
183
+ if __name__ == "__main__":
184
+ main()
scripts/db8.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import scipy.io
7
+ import scipy.signal as signal
8
+ from joblib import Parallel, delayed
9
+ from scipy.signal import iirnotch
10
+ from tqdm import tqdm
11
+
12
+ _MATRIX_DOF2DOA_TRANSPOSED = np.array(
13
+ # https://www.frontiersin.org/articles/10.3389/fnins.2019.00891/full
14
+ # Open supplemental data > Data Sheet 1.PDF >
15
+ # > SUPPLEMENTARY METHODS > Eqn. S2
16
+ # https://www.frontiersin.org/articles/file/downloadfile/461612_supplementary-materials_datasheets_1_pdf/octet-stream/Data%20Sheet%201.PDF/1/461612
17
+ [
18
+ [+0.6390, +0.0000, +0.0000, +0.0000, +0.0000],
19
+ [+0.3830, +0.0000, +0.0000, +0.0000, +0.0000],
20
+ [+0.0000, +1.0000, +0.0000, +0.0000, +0.0000],
21
+ [-0.6390, +0.0000, +0.0000, +0.0000, +0.0000],
22
+ [+0.0000, +0.0000, +0.4000, +0.0000, +0.0000],
23
+ [+0.0000, +0.0000, +0.6000, +0.0000, +0.0000],
24
+ [+0.0000, +0.0000, +0.0000, +0.4000, +0.0000],
25
+ [+0.0000, +0.0000, +0.0000, +0.6000, +0.0000],
26
+ [+0.0000, +0.0000, +0.0000, +0.0000, +0.0000],
27
+ [+0.0000, +0.0000, +0.0000, +0.0000, +0.1667],
28
+ [+0.0000, +0.0000, +0.0000, +0.0000, +0.3333],
29
+ [+0.0000, +0.0000, +0.0000, +0.0000, +0.0000],
30
+ [+0.0000, +0.0000, +0.0000, +0.0000, +0.1667],
31
+ [+0.0000, +0.0000, +0.0000, +0.0000, +0.3333],
32
+ [+0.0000, +0.0000, +0.0000, +0.0000, +0.0000],
33
+ [+0.0000, +0.0000, +0.0000, +0.0000, +0.0000],
34
+ [-0.1900, +0.0000, +0.0000, +0.0000, +0.0000],
35
+ [+0.0000, +0.0000, +0.0000, +0.0000, +0.0000],
36
+ ],
37
+ dtype=np.float32,
38
+ )
39
+
40
+ MATRIX_DOF2DOA = _MATRIX_DOF2DOA_TRANSPOSED.T
41
+
42
+
43
+ # ─────────────── Filtering ──────────────────
44
+ def notch_filter(data, notch_freq=50.0, Q=30.0, fs=1111.0):
45
+ """Notch-filter every channel independently."""
46
+ b, a = iirnotch(notch_freq, Q, fs)
47
+ out = np.zeros_like(data)
48
+ for ch in range(data.shape[1]):
49
+ out[:, ch] = signal.filtfilt(b, a, data[:, ch])
50
+ return out
51
+
52
+
53
+ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
54
+ nyq = 0.5 * fs
55
+ b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
56
+ out = np.zeros_like(emg)
57
+ for ch in range(emg.shape[1]):
58
+ out[:, ch] = signal.filtfilt(b, a, emg[:, ch])
59
+ return out
60
+
61
+
62
+ # ─────────────── Sliding window ──────────────
63
+ def sliding_window_segment(emg, label, window_size, stride):
64
+ """
65
+ Segment EMG with a sliding window.
66
+ Use the frame at the window centre as the segment label / repetition index.
67
+ """
68
+ segments, labels = [], []
69
+ n_samples = len(label)
70
+
71
+ for start in range(0, n_samples - window_size + 1, stride):
72
+ end = start + window_size
73
+ emg_segment = emg[start:end] # (win, ch)
74
+ label_segment = label[start:end] # (win, ch)
75
+ segments.append(emg_segment)
76
+ labels.append(label_segment)
77
+
78
+ return np.array(segments), np.array(labels)
79
+
80
+
81
+ # ─────────────── Main pipeline ───────────────
82
+ def process_mat_file(mat_path, window_size, stride, fs):
83
+ """
84
+ Load one .mat file, filter out NaNs, filter & normalize EMG, map DoF→DoA,
85
+ segment, and return (split, segs, labels).
86
+ """
87
+ mat = scipy.io.loadmat(mat_path)
88
+ emg = mat["emg"] # (T, 16)
89
+ label = mat["glove"] # (T, DoF)
90
+
91
+ # 1) Drop timesteps with any NaNs in glove data
92
+ valid = ~np.isnan(label).any(axis=1)
93
+ emg = emg[valid]
94
+ label = label[valid]
95
+
96
+ # 3) Z-score per channel
97
+ mu = emg.mean(axis=0)
98
+ sd = emg.std(axis=0, ddof=1)
99
+ sd[sd == 0] = 1.0
100
+ emg = (emg - mu) / sd
101
+
102
+ # 4) DoF β†’ DoA
103
+ y_doa = (MATRIX_DOF2DOA @ label.T).T
104
+
105
+ # 5) Windowing
106
+ segs, labs = sliding_window_segment(emg, y_doa, window_size, stride)
107
+
108
+ # 6) Determine split
109
+ fname = os.path.basename(mat_path)
110
+ if "_A1" in fname:
111
+ split = "train"
112
+ elif "_A2" in fname:
113
+ split = "val"
114
+ elif "_A3" in fname:
115
+ split = "test"
116
+ else:
117
+ return None # skip
118
+
119
+ return split, segs, labs
120
+
121
+
122
+ def main():
123
+ import argparse
124
+
125
+ args = argparse.ArgumentParser(description="Process EMG data from DB8.")
126
+ args.add_argument("--download_data", action="store_true")
127
+ args.add_argument("--data_dir", type=str, required=True)
128
+ args.add_argument("--save_dir", type=str, required=True)
129
+ args.add_argument(
130
+ "--window_size", type=int, help="Size of the sliding window for segmentation."
131
+ )
132
+ args.add_argument(
133
+ "--stride", type=int, help="Stride for the sliding window segmentation."
134
+ )
135
+ args.add_argument(
136
+ "--n_jobs", type=int, default=-1, help="Number of parallel jobs to run."
137
+ )
138
+ args = args.parse_args()
139
+ data_dir = args.data_dir # input folder with .mat files
140
+ os.makedirs(args.save_dir, exist_ok=True)
141
+
142
+ # download data if requested
143
+ if args.download_data:
144
+ # https://ninapro.hevs.ch/instructions/DB8.html
145
+ len_data = range(1, 13) # 1–12
146
+ base_url = "https://ninapro.hevs.ch/files/DB8/"
147
+ # download and unzip
148
+ for i in len_data:
149
+ url_a = f"{base_url}S{i}_E1_A1.mat"
150
+ url_b = f"{base_url}S{i}_E1_A2.mat"
151
+ url_c = f"{base_url}S{i}_E1_A3.mat"
152
+ os.system(f"wget -P {data_dir} {url_a}")
153
+ os.system(f"wget -P {data_dir} {url_b}")
154
+ os.system(f"wget -P {data_dir} {url_c}")
155
+ print(
156
+ f"Downloaded subject {i}\n{data_dir}/S{i}_E1_A1.mat and {data_dir}/S{i}_E1_A2.mat and {data_dir}/S{i}_E1_A3.mat"
157
+ )
158
+ sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
159
+
160
+ fs = 2000.0 # Hz
161
+
162
+ # collect all .mat paths
163
+ mat_paths = [
164
+ os.path.join(args.data_dir, f)
165
+ for f in sorted(os.listdir(args.data_dir))
166
+ if f.endswith(".mat")
167
+ ]
168
+
169
+ # run in parallel
170
+ results = Parallel(n_jobs=min(os.cpu_count(), args.n_jobs), verbose=5)(
171
+ delayed(process_mat_file)(mp, args.window_size, args.stride, fs)
172
+ for mp in mat_paths
173
+ )
174
+
175
+ # aggregate
176
+ splits = {k: {"data": [], "label": []} for k in ("train", "val", "test")}
177
+ for out in tqdm(results, desc="Processing files", unit="file"):
178
+ if out is None:
179
+ continue
180
+ split, segs, labs = out
181
+ splits[split]["data"].append(segs)
182
+ splits[split]["label"].append(labs)
183
+
184
+ # concatenate + save + stats
185
+ for split, d in tqdm(splits.items(), desc="Saving splits", unit="split"):
186
+ if not d["data"]:
187
+ continue
188
+
189
+ X = np.concatenate(d["data"], axis=0)
190
+ y = np.concatenate(d["label"], axis=0)
191
+
192
+ # transpose to [N, ch, window_size]
193
+ X = X.transpose(0, 2, 1)
194
+
195
+ print(f"Split: {split}, X shape: {X.shape}, y shape: {y.shape}")
196
+ # save
197
+ with h5py.File(os.path.join(args.save_dir, f"{split}.h5"), "w") as hf:
198
+ hf.create_dataset("data", data=X)
199
+ hf.create_dataset("label", data=y)
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
scripts/emg2pose.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scipy.io
8
+ import scipy.signal as signal
9
+ from joblib import Parallel, delayed
10
+ from scipy.signal import iirnotch
11
+ from tqdm import tqdm
12
+
13
+
14
+ # ==== Filter functions (operate at original fs=2000) ====
15
+ def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
16
+ b, a = iirnotch(notch_freq, Q, fs)
17
+ out = np.zeros_like(data)
18
+ for ch in range(data.shape[1]):
19
+ out[:, ch] = signal.filtfilt(b, a, data[:, ch])
20
+ return out
21
+
22
+
23
+ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
24
+ nyq = 0.5 * fs
25
+ low = lowcut / nyq
26
+ high = highcut / nyq
27
+ b, a = signal.butter(order, [low, high], btype="bandpass")
28
+ out = np.zeros_like(emg)
29
+ for c in range(emg.shape[1]):
30
+ out[:, c] = signal.filtfilt(b, a, emg[:, c])
31
+ return out
32
+
33
+
34
+ # ==== Window segmentation ====
35
+ def process_emg_features(emg, window_size=1000, stride=500):
36
+ segs, lbls = [], []
37
+ N = len(emg)
38
+ for start in range(0, N, stride):
39
+ end = start + window_size
40
+ if end > N: # skip the last segment if it is not complete
41
+ continue
42
+ win = emg[start:end]
43
+ segs.append(win)
44
+ return np.array(segs)
45
+
46
+
47
+ def process_one_recording(file_path, fs=2000.0, window_size=1000, stride=500):
48
+ """
49
+ Process a single recording file to extract EMG features and labels
50
+ as to be used in the main pipeline with parallel processing.
51
+ """
52
+ with h5py.File(file_path, "r") as f:
53
+ grp = f["emg2pose"]
54
+ data = grp["timeseries"]
55
+ emg = data["emg"][:].astype(np.float32)
56
+
57
+ # ==== Preprocessing EMG data ====
58
+ emg_filt = bandpass_filter_emg(emg, 20, 450, fs=fs)
59
+ emg_filt = notch_filter(emg_filt, 50, 30, fs=fs)
60
+
61
+ # z-score
62
+ mu = emg_filt.mean(axis=0)
63
+ sd = emg_filt.std(axis=0, ddof=1)
64
+ sd[sd == 0] = 1.0
65
+ emg_z = (emg_filt - mu) / sd
66
+
67
+ # segment
68
+ segs = process_emg_features(emg_z, window_size, stride)
69
+
70
+ return segs
71
+
72
+
73
+ # ==== Main pipeline ====
74
+ def main():
75
+ import argparse
76
+
77
+ args = argparse.ArgumentParser(description="Process EMG data from DB5.")
78
+ args.add_argument("--data_dir", type=str)
79
+ args.add_argument("--save_dir", type=str)
80
+ args.add_argument(
81
+ "--window_size", type=int, help="Size of the sliding window for segmentation."
82
+ )
83
+ args.add_argument(
84
+ "--stride", type=int, help="Stride for the sliding window segmentation."
85
+ )
86
+ args.add_argument(
87
+ "--subsample", type=float, default=1.0, help="Whether to subsample the data"
88
+ )
89
+ args.add_argument(
90
+ "--n_jobs",
91
+ type=int,
92
+ default=-1,
93
+ help="Number of parallel jobs to run. -1 means using all available cores.",
94
+ )
95
+ args.add_argument(
96
+ "--seed", type=int, default=42, help="Random seed for reproducibility."
97
+ )
98
+ args = args.parse_args()
99
+
100
+ data_dir = args.data_dir
101
+ save_dir = args.save_dir
102
+ os.makedirs(save_dir, exist_ok=True)
103
+
104
+ fs = 2000.0 # original sampling rate
105
+ window_size, stride = args.window_size, args.stride
106
+
107
+ df = pd.read_csv(os.path.join(data_dir, "metadata.csv"))
108
+ df = df.groupby("split").apply(
109
+ lambda x: (
110
+ x.sample(frac=args.subsample, random_state=args.seed)
111
+ if args.subsample < 1.0
112
+ else x
113
+ )
114
+ )
115
+ df.reset_index(drop=True, inplace=True)
116
+
117
+ splits = {}
118
+ for split, df_ in df.groupby("split"):
119
+ sessions = list(df_.filename)
120
+ splits[split] = [
121
+ Path(data_dir).expanduser().joinpath(f"{session}.hdf5")
122
+ for session in sessions
123
+ ]
124
+
125
+ all_data = {"train": [], "val": [], "test": []}
126
+
127
+ for split, files in splits.items():
128
+ # Here we use joblib to parallelize the file processing, each file is processed independently as the task is embarrassingly parallel. We scale the processing across all available CPU cores since the number of files is around 25k (with training being 17k).
129
+ results = Parallel(n_jobs=args.n_jobs)(
130
+ delayed(process_one_recording)(file_path, fs, window_size, stride)
131
+ for file_path in tqdm(files, desc=f"Processing {split} files")
132
+ )
133
+ # Collect results
134
+ for segs in tqdm(results, desc=f"Collecting {split} data"):
135
+ all_data[split].append(segs)
136
+
137
+ # stack, augment train, transpose, save, and print stats
138
+ X = np.concatenate(all_data[split], axis=0) # [N, window_size, ch]
139
+
140
+ # transpose to [N, ch, window_size]
141
+ X = X.transpose(0, 2, 1)
142
+
143
+ # save
144
+ with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as hf:
145
+ hf.create_dataset("data", data=X)
146
+
147
+
148
+ if __name__ == "__main__":
149
+ main()
scripts/epn.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ import sys
5
+
6
+ import h5py
7
+ import numpy as np
8
+ import scipy.signal as signal
9
+ from joblib import Parallel, delayed
10
+ from scipy.signal import iirnotch
11
+ from tqdm.auto import tqdm
12
+
13
+ # Sampling frequency and EMG channels
14
+ tfs, n_ch = 200.0, 8
15
+
16
+ # Gesture label mapping
17
+ gesture_map = {
18
+ "noGesture": 0,
19
+ "waveIn": 1,
20
+ "waveOut": 2,
21
+ "pinch": 3,
22
+ "open": 4,
23
+ "fist": 5,
24
+ "notProvided": 6,
25
+ }
26
+
27
+
28
+ # Filtering utilities
29
+ def bandpass_filter_emg(emg, low=20.0, high=90.0, fs=tfs, order=4):
30
+ nyq = 0.5 * fs
31
+ b, a = signal.butter(order, [low / nyq, high / nyq], btype="bandpass")
32
+ return signal.filtfilt(b, a, emg, axis=1)
33
+
34
+
35
+ def notch_filter_emg(emg, notch=50.0, Q=30.0, fs=tfs):
36
+ w0 = notch / (0.5 * fs)
37
+ b, a = iirnotch(w0, Q)
38
+ return signal.filtfilt(b, a, emg, axis=1)
39
+
40
+
41
+ # Normalization helpers
42
+ def zscore_per_channel(emg):
43
+ mean = emg.mean(axis=1, keepdims=True)
44
+ std = emg.std(axis=1, ddof=1, keepdims=True)
45
+ std[std == 0] = 1.0
46
+ return (emg - mean) / std
47
+
48
+
49
+ def adjust_length(x, max_len):
50
+ n_ch, seq_len = x.shape
51
+ if seq_len >= max_len:
52
+ return x[:, :max_len]
53
+ pad = np.zeros((n_ch, max_len - seq_len), dtype=x.dtype)
54
+ return np.concatenate([x, pad], axis=1)
55
+
56
+
57
+ # Single-sample processing
58
+ def extract_emg_signal(sample, seq_len):
59
+ emg = np.stack([v for v in sample["emg"].values()], dtype=np.float32) / 128.0
60
+ emg = bandpass_filter_emg(emg, 20.0, 90.0)
61
+ emg = notch_filter_emg(emg, 50.0, 30.0)
62
+ emg = zscore_per_channel(emg)
63
+ emg = adjust_length(emg, seq_len)
64
+ label = gesture_map.get(sample.get("gestureName", "notProvided"), 6)
65
+ return emg, label
66
+
67
+
68
+ # Process one user JSON for train/validation
69
+ def process_user_training(path, seq_len):
70
+ train_X, train_y, val_X, val_y = [], [], [], []
71
+ with open(path, "r", encoding="utf-8") as f:
72
+ data = json.load(f)
73
+ for sample in data.get("trainingSamples", {}).values():
74
+ emg, lbl = extract_emg_signal(sample, seq_len)
75
+ if lbl != 6:
76
+ train_X.append(emg)
77
+ train_y.append(lbl)
78
+ for sample in data.get("testingSamples", {}).values():
79
+ emg, lbl = extract_emg_signal(sample, seq_len)
80
+ if lbl != 6:
81
+ val_X.append(emg)
82
+ val_y.append(lbl)
83
+ return train_X, train_y, val_X, val_y
84
+
85
+
86
+ # Process one user JSON for testing split
87
+ def process_user_testing(path, seq_len):
88
+ train_X, train_y, test_X, test_y = [], [], [], []
89
+ with open(path, "r", encoding="utf-8") as f:
90
+ data = json.load(f)
91
+ buckets = {g: [] for g in gesture_map}
92
+ for sample in data.get("trainingSamples", {}).values():
93
+ buckets.setdefault(sample.get("gestureName", "notProvided"), []).append(sample)
94
+ for samples in buckets.values():
95
+ for i, sample in enumerate(samples):
96
+ emg, lbl = extract_emg_signal(sample, seq_len)
97
+ if lbl == 6:
98
+ continue
99
+ if i < 10:
100
+ train_X.append(emg)
101
+ train_y.append(lbl)
102
+ else:
103
+ test_X.append(emg)
104
+ test_y.append(lbl)
105
+ return train_X, train_y, test_X, test_y
106
+
107
+
108
+ # Save to HDF5
109
+ def save_h5(path, data, labels):
110
+ with h5py.File(path, "w") as f:
111
+ f.create_dataset("data", data=np.asarray(data, np.float32))
112
+ f.create_dataset("label", data=np.asarray(labels, np.int64))
113
+
114
+
115
+ # Main parallelized pipeline
116
+ def main():
117
+ import argparse
118
+
119
+ parser = argparse.ArgumentParser()
120
+ parser.add_argument("--download_data", action="store_true")
121
+ parser.add_argument("--data_dir", type=str, required=True)
122
+ parser.add_argument("--source_training", required=True)
123
+ parser.add_argument("--source_testing", required=True)
124
+ parser.add_argument("--dest_dir", required=True)
125
+ parser.add_argument("--window_size", type=int, required=True)
126
+ parser.add_argument("--n_jobs", type=int, default=-1)
127
+ args = parser.parse_args()
128
+ data_dir = args.data_dir
129
+ os.makedirs(args.dest_dir, exist_ok=True)
130
+
131
+ # download data if requested
132
+ if args.download_data:
133
+ # https://zenodo.org/records/4421500
134
+ url = "https://zenodo.org/records/4421500/files/EMG-EPN612%20Dataset.zip?download=1"
135
+ os.system(f"wget -O {data_dir}/EMG-EPN612_Dataset.zip {url}")
136
+ os.system(f"unzip -o {data_dir}/EMG-EPN612_Dataset.zip -d {data_dir}")
137
+ # move the contents one level up
138
+ os.system(rf"mv {data_dir}/EMG-EPN612\ Dataset/* {data_dir}/")
139
+ os.system(f"rmdir {data_dir}/EMG-EPN612_Dataset")
140
+ # clean up zip file
141
+ os.system(f"rm {data_dir}/EMG-EPN612_Dataset.zip")
142
+ print(f"Downloaded and unzipped dataset\n{data_dir}/EMG-EPN612_Dataset.zip")
143
+ sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
144
+
145
+ seq_len = args.window_size
146
+ train_X, train_y, val_X, val_y, test_X, test_y = [], [], [], [], [], []
147
+
148
+ paths = glob.glob(os.path.join(args.source_training, "user*", "user*.json"))
149
+
150
+ # Parallel process training JSONs
151
+ results = Parallel(n_jobs=args.n_jobs)(
152
+ delayed(process_user_training)(p, seq_len)
153
+ for p in tqdm(paths, desc="Training files")
154
+ )
155
+ for tX, ty, vX, vy in results:
156
+ train_X.extend(tX)
157
+ train_y.extend(ty)
158
+ val_X.extend(vX)
159
+ val_y.extend(vy)
160
+
161
+ # Parallel process testing JSONs
162
+ test_results = Parallel(n_jobs=args.n_jobs)(
163
+ delayed(process_user_testing)(p, seq_len)
164
+ for p in tqdm(
165
+ glob.glob(os.path.join(args.source_testing, "user*", "user*.json")),
166
+ desc="Testing files",
167
+ )
168
+ )
169
+ for tX, ty, teX, tey in test_results:
170
+ train_X.extend(tX)
171
+ train_y.extend(ty)
172
+ test_X.extend(teX)
173
+ test_y.extend(tey)
174
+
175
+ # Save datasets
176
+ save_h5(os.path.join(args.dest_dir, "train.h5"), train_X, train_y)
177
+ save_h5(os.path.join(args.dest_dir, "val.h5"), val_X, val_y)
178
+ save_h5(os.path.join(args.dest_dir, "test.h5"), test_X, test_y)
179
+
180
+ # Print distributions
181
+ for split, X, y in [
182
+ ("Train", train_X, train_y),
183
+ ("Val", val_X, val_y),
184
+ ("Test", test_X, test_y),
185
+ ]:
186
+ arr = np.array(y)
187
+ uniq, cnt = np.unique(arr, return_counts=True)
188
+ uniq = [i.item() for i in uniq]
189
+ cnt = [i.item() for i in cnt]
190
+ print(f"{split} β†’ total={len(y)}, classes={{}}".format(dict(zip(uniq, cnt))))
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()
scripts/uci.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ import h5py
6
+ import numpy as np
7
+ import scipy.signal as signal
8
+ from scipy.signal import iirnotch
9
+
10
+
11
+ # ─────────────────────────────────────────────
12
+ # Filtering utilities
13
+ # ─────────────────────────────────────────────
14
+ def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
15
+ nyq = 0.5 * fs
16
+ b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
17
+ return signal.filtfilt(b, a, emg, axis=0)
18
+
19
+
20
+ def notch_filter_emg(emg, notch_freq=50.0, Q=30.0, fs=200.0):
21
+ b, a = iirnotch(notch_freq / (0.5 * fs), Q)
22
+ return signal.filtfilt(b, a, emg, axis=0)
23
+
24
+
25
+ # ─────────────────────────────────────────────
26
+ # Core I/O + preprocessing helpers
27
+ # ─────────────────────────────────────────────
28
+ def read_emg_txt(txt_path):
29
+ """
30
+ Read a txt file with columns: time ch1 … ch8 class.
31
+ Return float32 array of shape (N, 10).
32
+ """
33
+ data = []
34
+ with open(txt_path, "r") as f:
35
+ for line in f.readlines()[1:]: # skip header
36
+ cols = line.strip().split()
37
+ if len(cols) == 10:
38
+ data.append(list(map(float, cols)))
39
+ return np.asarray(data, dtype=np.float32)
40
+
41
+
42
+ def preprocess_emg(arr, fs=200.0, remove_class0=True):
43
+ """
44
+ 1) optional removal of class-0 rows
45
+ 2) band-pass β†’ notch β†’ Z-score (on 8 channels)
46
+ """
47
+ if remove_class0:
48
+ arr = arr[arr[:, -1] >= 1]
49
+ if arr.size == 0:
50
+ return arr
51
+
52
+ emg = arr[:, 1:9] # (N, 8)
53
+ emg = bandpass_filter_emg(emg, 20, 90, fs)
54
+ emg = notch_filter_emg(emg, 50, 30, fs)
55
+
56
+ mu = emg.mean(axis=0)
57
+ sd = emg.std(axis=0, ddof=1)
58
+ sd[sd == 0] = 1.0
59
+ emg = (emg - mu) / sd
60
+
61
+ arr[:, 1:9] = emg
62
+ return arr
63
+
64
+
65
+ def find_label_runs(arr):
66
+ """Group consecutive rows with identical class labels."""
67
+ runs = []
68
+ if arr.size == 0:
69
+ return runs
70
+ curr_lbl = int(arr[0, -1])
71
+ start = 0
72
+ for i in range(1, len(arr)):
73
+ lbl = int(arr[i, -1])
74
+ if lbl != curr_lbl:
75
+ runs.append((curr_lbl, arr[start:i]))
76
+ curr_lbl, start = lbl, i
77
+ runs.append((curr_lbl, arr[start:]))
78
+ return runs
79
+
80
+
81
+ def sliding_window_majority(seg_arr, window_size=1000, stride=500):
82
+ segs, labs = [], []
83
+ for start in range(0, len(seg_arr) - window_size + 1, stride):
84
+ win = seg_arr[start : start + window_size]
85
+ maj = np.argmax(np.bincount(win[:, -1].astype(int)))
86
+ segs.append(win[:, 1:9]) # keep 8-channel EMG
87
+ labs.append(maj)
88
+ return np.asarray(segs, dtype=np.float32), np.asarray(labs, dtype=np.int32)
89
+
90
+
91
+ def users_with_gesture(
92
+ data_root, gesture_id, subj_range=range(1, 37), return_counts=False
93
+ ):
94
+ found = {}
95
+ for subj in subj_range:
96
+ subj_dir = os.path.join(data_root, f"{subj:02d}")
97
+ if not os.path.isdir(subj_dir):
98
+ continue
99
+ count = 0
100
+ for fname in os.listdir(subj_dir):
101
+ if not fname.endswith(".txt"):
102
+ continue
103
+ txt_path = os.path.join(subj_dir, fname)
104
+ try:
105
+ arr = read_emg_txt(txt_path)
106
+ except Exception:
107
+ # skip files we can't parse
108
+ continue
109
+ if arr.size == 0:
110
+ continue
111
+ # last column is class label (as float). Compare as int.
112
+ if np.any(arr[:, -1].astype(int) == int(gesture_id)):
113
+ # count occurrences (rows) of that gesture in this file
114
+ count += int((arr[:, -1].astype(int) == int(gesture_id)).sum())
115
+ if count > 0:
116
+ found[subj] = count
117
+
118
+ if return_counts:
119
+ return found # dict subj -> count
120
+ else:
121
+ return sorted(found.keys())
122
+
123
+
124
+ # ─────────────────────────────────────────────
125
+ # Safe concatenation utilities
126
+ # ─────────────────────────────────────────────
127
+ def concat_data(lst): # lst of (N,256,8)
128
+ return np.concatenate(lst, axis=0) if lst else np.empty((0, 1000, 8), np.float32)
129
+
130
+
131
+ def concat_label(lst):
132
+ return np.concatenate(lst, axis=0) if lst else np.empty((0,), np.int32)
133
+
134
+
135
+ # ─────────────────────────────────────────────
136
+ # Main
137
+ # ──────────────────────────���──────────────────
138
+ if __name__ == "__main__":
139
+ import argparse
140
+
141
+ arg = argparse.ArgumentParser(description="Convert UCI EMG dataset to h5 format.")
142
+ arg.add_argument("--download_data", action="store_true")
143
+ arg.add_argument(
144
+ "--data_dir",
145
+ type=str,
146
+ required=True,
147
+ help="Root directory of the UCI EMG dataset",
148
+ )
149
+ arg.add_argument(
150
+ "--save_dir",
151
+ type=str,
152
+ required=True,
153
+ help="Directory to save the output h5 files",
154
+ )
155
+ arg.add_argument("--window_size", type=int, help="Window size for sliding window")
156
+ arg.add_argument("--stride", type=int, help="Stride for sliding window")
157
+ args = arg.parse_args()
158
+
159
+ data_root = args.data_dir
160
+ save_root = args.save_dir
161
+ os.makedirs(save_root, exist_ok=True)
162
+
163
+ # download data if requested
164
+ if args.download_data:
165
+ # https://archive.ics.uci.edu/dataset/481/emg+data+for+gestures
166
+ base_url = (
167
+ "https://archive.ics.uci.edu/static/public/481/emg+data+for+gestures.zip"
168
+ )
169
+ os.system(f"wget -O {data_root}/emg_gestures.zip '{base_url}'")
170
+ os.system(f"unzip -o {data_root}/emg_gestures.zip -d {Path(data_root).parent}")
171
+ os.system(f"rm {data_root}/emg_gestures.zip")
172
+ print("Dataset downloaded and cleaned up.")
173
+ sys.exit("Rerun without --download_data.")
174
+
175
+ fs = 200.0 # sampling rate of MYO bracelet
176
+ window_size, stride = args.window_size, args.stride
177
+
178
+ split_map = {
179
+ "train": list(range(1, 25)), # 1–24
180
+ "val": list(range(25, 31)), # 25–30
181
+ "test": list(range(31, 37)), # 31–36
182
+ }
183
+ # remove users that performed gesture 7
184
+ gesture_id = 7
185
+ gesture7_users = users_with_gesture(data_root, gesture_id)
186
+ print(f"Users that performed gesture {gesture_id}:", gesture7_users)
187
+
188
+ keep_subjs = []
189
+ for k in split_map:
190
+ split_map[k] = [u for u in split_map[k] if u not in gesture7_users]
191
+ keep_subjs.extend(split_map[k])
192
+ print("Updated split map after removing gesture-7 users:", keep_subjs)
193
+
194
+ datasets = {k: {"data": [], "label": []} for k in split_map}
195
+
196
+ for subj in keep_subjs:
197
+ subj_dir = os.path.join(data_root, f"{subj:02d}")
198
+ if not os.path.isdir(subj_dir):
199
+ continue
200
+ split_key = next(k for k, v in split_map.items() if subj in v)
201
+
202
+ for fname in sorted(os.listdir(subj_dir)):
203
+ if not fname.endswith(".txt"):
204
+ continue
205
+ arr = read_emg_txt(os.path.join(subj_dir, fname))
206
+ arr = preprocess_emg(arr, fs)
207
+
208
+ for lbl, seg_arr in find_label_runs(arr):
209
+ segs, labs = sliding_window_majority(seg_arr, window_size, stride)
210
+ if segs.size:
211
+ datasets[split_key]["data"].append(segs)
212
+ datasets[split_key]["label"].append(labs - 1)
213
+
214
+ # concatenate, transpose & save
215
+ for split in ["train", "val", "test"]:
216
+ X = concat_data(datasets[split]["data"]) # (N,256,8)
217
+ y = concat_label(datasets[split]["label"])
218
+ X = X.transpose(0, 2, 1) # (N,8,256)
219
+
220
+ with h5py.File(os.path.join(save_root, f"{split}.h5"), "w") as f:
221
+ f.create_dataset("data", data=X.astype(np.float32))
222
+ f.create_dataset("label", data=y.astype(np.int32))
223
+ uniq, cnt = np.unique(y, return_counts=True)
224
+ print(
225
+ f"{split.upper():5} β†’ X={X.shape}, label dist:",
226
+ dict(zip(uniq.tolist(), cnt.tolist())),
227
+ )
228
+
229
+ print("\nAll splits saved to:", save_root)