Upload 9 files
Browse files- scripts/HMC.py +370 -0
- scripts/README.md +129 -0
- scripts/db5.py +213 -0
- scripts/db6.py +186 -0
- scripts/db7.py +184 -0
- scripts/db8.py +203 -0
- scripts/emg2pose.py +149 -0
- scripts/epn.py +194 -0
- scripts/uci.py +229 -0
scripts/HMC.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import h5py
|
| 5 |
+
import mne
|
| 6 |
+
import numpy as np
|
| 7 |
+
from joblib import Parallel, delayed
|
| 8 |
+
from mne.io import read_raw_edf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def process_single_recording(
|
| 12 |
+
raw_fn: str,
|
| 13 |
+
scoring_fn: str,
|
| 14 |
+
data_path: str,
|
| 15 |
+
channel: str,
|
| 16 |
+
start_at: int,
|
| 17 |
+
duration_sec: int,
|
| 18 |
+
l_freq: float,
|
| 19 |
+
h_freq: float,
|
| 20 |
+
sfreq: int,
|
| 21 |
+
mains: int,
|
| 22 |
+
window_size: int,
|
| 23 |
+
stride: int,
|
| 24 |
+
mapping: dict,
|
| 25 |
+
verbose: bool = False,
|
| 26 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], str]:
|
| 27 |
+
"""
|
| 28 |
+
Process a single recording file and return windows and labels.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Tuple of (windows, labels, filename) or (None, None, filename) if processing fails
|
| 32 |
+
"""
|
| 33 |
+
try:
|
| 34 |
+
if verbose:
|
| 35 |
+
print(f"Processing: {raw_fn}")
|
| 36 |
+
|
| 37 |
+
full_path_raw = os.path.join(data_path, raw_fn)
|
| 38 |
+
full_path_score = os.path.join(data_path, scoring_fn)
|
| 39 |
+
|
| 40 |
+
# Load and preprocess
|
| 41 |
+
raw = read_raw_edf(full_path_raw, preload=True, verbose=False)
|
| 42 |
+
annotation = mne.read_annotations(full_path_score)
|
| 43 |
+
raw.set_annotations(annotation, emit_warning=False)
|
| 44 |
+
|
| 45 |
+
# Crop
|
| 46 |
+
end_at = start_at + duration_sec
|
| 47 |
+
if end_at > raw.times[-1]:
|
| 48 |
+
end_at = raw.times[-1] - (raw.times[-1] % 30.0)
|
| 49 |
+
raw = raw.crop(tmin=start_at, tmax=end_at)
|
| 50 |
+
|
| 51 |
+
# Pick channel
|
| 52 |
+
if channel not in raw.ch_names:
|
| 53 |
+
print(f"Warning: Channel {channel} not found in {raw_fn}, skipping")
|
| 54 |
+
return None, None, raw_fn
|
| 55 |
+
raw = raw.pick([channel])
|
| 56 |
+
|
| 57 |
+
# Filter (bandpass) with safe h_freq clipping to Nyquist
|
| 58 |
+
nyq = raw.info["sfreq"] / 2.0
|
| 59 |
+
h_freq_adj = h_freq if h_freq is not None and h_freq < nyq else None
|
| 60 |
+
raw = raw.filter(
|
| 61 |
+
l_freq=l_freq, h_freq=h_freq_adj, fir_design="firwin", verbose=False
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Notch at mains harmonics (e.g. 50,100,150 or 60,120,180) but only those < Nyquist
|
| 65 |
+
mains_freqs = [mains * i for i in (1, 2, 3)]
|
| 66 |
+
mains_freqs = [f for f in mains_freqs if f < nyq]
|
| 67 |
+
if len(mains_freqs) > 0:
|
| 68 |
+
# use raw.notch_filter which handles multiple notch freqs
|
| 69 |
+
raw.notch_filter(freqs=mains_freqs, picks=[channel], verbose=False)
|
| 70 |
+
|
| 71 |
+
# Resample to target sampling rate (upsample or downsample)
|
| 72 |
+
if raw.info["sfreq"] != sfreq:
|
| 73 |
+
raw = raw.resample(sfreq, npad="auto")
|
| 74 |
+
|
| 75 |
+
# Create 30s epochs
|
| 76 |
+
events, event_id = mne.events_from_annotations(raw, chunk_duration=30.0)
|
| 77 |
+
tmax = 30.0 - 1.0 / raw.info["sfreq"]
|
| 78 |
+
epochs = mne.Epochs(
|
| 79 |
+
raw=raw,
|
| 80 |
+
events=events,
|
| 81 |
+
event_id=event_id,
|
| 82 |
+
tmin=0.0,
|
| 83 |
+
tmax=tmax,
|
| 84 |
+
baseline=None,
|
| 85 |
+
verbose=False,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
epochs_data = epochs.get_data() # (n_epochs, 1, n_times)
|
| 89 |
+
labels = []
|
| 90 |
+
for ann in epochs.get_annotations_per_epoch():
|
| 91 |
+
labels.append(mapping[str(ann[0][2])])
|
| 92 |
+
|
| 93 |
+
n_epochs, _, n_times = epochs_data.shape
|
| 94 |
+
if n_times < window_size:
|
| 95 |
+
print(f"Warning: Not enough samples in {raw_fn}, skipping")
|
| 96 |
+
return None, None, raw_fn
|
| 97 |
+
|
| 98 |
+
# Sliding window
|
| 99 |
+
windows = []
|
| 100 |
+
labels_win = []
|
| 101 |
+
for i in range(n_epochs):
|
| 102 |
+
for start in range(0, n_times - window_size + 1, stride):
|
| 103 |
+
windows.append(epochs_data[i, 0, start : start + window_size])
|
| 104 |
+
labels_win.append(labels[i])
|
| 105 |
+
|
| 106 |
+
if len(windows) > 0:
|
| 107 |
+
windows = np.stack(windows) # (n_windows, window_size)
|
| 108 |
+
windows = windows[:, np.newaxis, :] # (n_windows, 1, window_size)
|
| 109 |
+
|
| 110 |
+
if verbose:
|
| 111 |
+
print(f" {raw_fn}: Generated {len(windows)} windows")
|
| 112 |
+
|
| 113 |
+
return (
|
| 114 |
+
windows.astype(np.float32),
|
| 115 |
+
np.array(labels_win, dtype=np.int32),
|
| 116 |
+
raw_fn,
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
return None, None, raw_fn
|
| 120 |
+
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"Error processing {raw_fn}: {e}")
|
| 123 |
+
return None, None, raw_fn
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def convert_hmc_to_h5(
|
| 127 |
+
data_path: str,
|
| 128 |
+
save_path: str,
|
| 129 |
+
channel: str = "EMG chin",
|
| 130 |
+
start_at: int = 15 * 60,
|
| 131 |
+
duration_sec: int = 6 * 60 * 60,
|
| 132 |
+
l_freq: float = 5.0,
|
| 133 |
+
h_freq: float = 250.0,
|
| 134 |
+
sfreq: int = 100,
|
| 135 |
+
mains: int = 50,
|
| 136 |
+
window_size: int = 1000,
|
| 137 |
+
stride: int = 1000,
|
| 138 |
+
n_jobs: int = -1,
|
| 139 |
+
verbose: bool = True,
|
| 140 |
+
):
|
| 141 |
+
"""
|
| 142 |
+
Convert HMC EMG dataset to HDF5 format compatible with EMGDataset.
|
| 143 |
+
Uses joblib for parallel processing of individual recordings.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
data_path: Root directory containing EDF files
|
| 147 |
+
save_path: Directory to save HDF5 files
|
| 148 |
+
channel: EMG channel name to extract
|
| 149 |
+
start_at: Start time in seconds
|
| 150 |
+
duration_sec: Duration to extract in seconds
|
| 151 |
+
l_freq: Low-pass filter frequency
|
| 152 |
+
h_freq: High-pass filter frequency
|
| 153 |
+
sfreq: Target sampling frequency
|
| 154 |
+
window_size: Window size for segmentation
|
| 155 |
+
stride: Stride for sliding window
|
| 156 |
+
n_jobs: Number of parallel jobs (-1 for all cores, 1 for sequential)
|
| 157 |
+
verbose: Print progress
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
mapping = {
|
| 161 |
+
"Sleep stage W": 0,
|
| 162 |
+
"Sleep stage N1": 1,
|
| 163 |
+
"Sleep stage N2": 2,
|
| 164 |
+
"Sleep stage N3": 3,
|
| 165 |
+
"Sleep stage R": 4,
|
| 166 |
+
"Lights off@@EEG F4-A1": 0,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
os.makedirs(save_path, exist_ok=True)
|
| 170 |
+
|
| 171 |
+
# Discover record file pairs
|
| 172 |
+
files = os.listdir(data_path)
|
| 173 |
+
raw_files = [
|
| 174 |
+
f
|
| 175 |
+
for f in files
|
| 176 |
+
if f.lower().endswith(".edf") and "sleepscoring" not in f.lower()
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
records = []
|
| 180 |
+
for raw_fn in raw_files:
|
| 181 |
+
base = os.path.splitext(raw_fn)[0]
|
| 182 |
+
scoring_fn = base + "_sleepscoring.edf"
|
| 183 |
+
if scoring_fn in files:
|
| 184 |
+
records.append((raw_fn, scoring_fn))
|
| 185 |
+
elif verbose:
|
| 186 |
+
print(f"Warning: scoring file missing for {raw_fn}")
|
| 187 |
+
|
| 188 |
+
if len(records) == 0:
|
| 189 |
+
print("No valid record pairs found!")
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
print(f"Found {len(records)} recording pairs")
|
| 193 |
+
print(f"Using {n_jobs} parallel jobs" if n_jobs != 1 else "Running sequentially")
|
| 194 |
+
|
| 195 |
+
# Initialize data containers for each split
|
| 196 |
+
datasets = {
|
| 197 |
+
"train": {"data": [], "label": []},
|
| 198 |
+
"val": {"data": [], "label": []},
|
| 199 |
+
"test": {"data": [], "label": []},
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
# Create mapping from filename to split
|
| 203 |
+
def get_split(filename):
|
| 204 |
+
# Extract subject number from filename
|
| 205 |
+
# Example: "SN001.edf" -> 1
|
| 206 |
+
import re
|
| 207 |
+
|
| 208 |
+
match = re.search(r"(\d+)", filename)
|
| 209 |
+
# Version 1.1: recordings SN014, SN064, and SN135 were removed after it was detected that these recordings contained erroneous (and unfixable) signal data.
|
| 210 |
+
train_subjects = range(1, 101) # Subjects 1-100 for training
|
| 211 |
+
val_subjects = range(101, 127) # Subjects 101-126 for validation
|
| 212 |
+
test_subjects = range(127, 155) # Subjects 127-154 for testing
|
| 213 |
+
if match:
|
| 214 |
+
subj_num = int(match.group(1))
|
| 215 |
+
if subj_num in train_subjects:
|
| 216 |
+
return "train"
|
| 217 |
+
elif subj_num in val_subjects:
|
| 218 |
+
return "val"
|
| 219 |
+
elif subj_num in test_subjects:
|
| 220 |
+
return "test"
|
| 221 |
+
else:
|
| 222 |
+
return "train" # default to train
|
| 223 |
+
return "train" # default
|
| 224 |
+
|
| 225 |
+
# Process recordings in parallel
|
| 226 |
+
print(f"\nProcessing {len(records)} recordings...")
|
| 227 |
+
results = Parallel(n_jobs=n_jobs, verbose=10 if verbose else 0)(
|
| 228 |
+
delayed(process_single_recording)(
|
| 229 |
+
raw_fn=raw_fn,
|
| 230 |
+
scoring_fn=scoring_fn,
|
| 231 |
+
data_path=data_path,
|
| 232 |
+
channel=channel,
|
| 233 |
+
start_at=start_at,
|
| 234 |
+
duration_sec=duration_sec,
|
| 235 |
+
l_freq=l_freq,
|
| 236 |
+
h_freq=h_freq,
|
| 237 |
+
sfreq=sfreq,
|
| 238 |
+
mains=mains,
|
| 239 |
+
window_size=window_size,
|
| 240 |
+
stride=stride,
|
| 241 |
+
mapping=mapping,
|
| 242 |
+
verbose=False,
|
| 243 |
+
)
|
| 244 |
+
for raw_fn, scoring_fn in records
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Collect results into splits
|
| 248 |
+
processed_count = 0
|
| 249 |
+
failed_count = 0
|
| 250 |
+
|
| 251 |
+
for windows, labels, raw_fn in results:
|
| 252 |
+
if windows is not None and labels is not None:
|
| 253 |
+
# Determine which split this recording belongs to
|
| 254 |
+
split_key = get_split(raw_fn)
|
| 255 |
+
|
| 256 |
+
datasets[split_key]["data"].append(windows)
|
| 257 |
+
datasets[split_key]["label"].append(labels)
|
| 258 |
+
processed_count += 1
|
| 259 |
+
|
| 260 |
+
if verbose:
|
| 261 |
+
print(f"β {raw_fn}: {len(windows)} windows -> {split_key}")
|
| 262 |
+
else:
|
| 263 |
+
failed_count += 1
|
| 264 |
+
if verbose:
|
| 265 |
+
print(f"β {raw_fn}: Failed")
|
| 266 |
+
|
| 267 |
+
print(f"\nProcessing complete: {processed_count} successful, {failed_count} failed")
|
| 268 |
+
|
| 269 |
+
# Concatenate and save
|
| 270 |
+
for split_name, split_data in datasets.items():
|
| 271 |
+
if len(split_data["data"]) == 0:
|
| 272 |
+
print(f"Warning: No data for {split_name} split")
|
| 273 |
+
continue
|
| 274 |
+
|
| 275 |
+
print(f"\nPreparing {split_name} split...")
|
| 276 |
+
X = np.concatenate(split_data["data"], axis=0) # (N, 1, window_size)
|
| 277 |
+
y = np.concatenate(split_data["label"], axis=0) # (N,)
|
| 278 |
+
|
| 279 |
+
h5_path = os.path.join(save_path, f"{split_name}.h5")
|
| 280 |
+
with h5py.File(h5_path, "w") as f:
|
| 281 |
+
f.create_dataset(
|
| 282 |
+
"data", data=X, dtype=np.float32, compression="gzip", compression_opts=4
|
| 283 |
+
)
|
| 284 |
+
f.create_dataset(
|
| 285 |
+
"label", data=y, dtype=np.int32, compression="gzip", compression_opts=4
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
uniq, cnt = np.unique(y, return_counts=True)
|
| 289 |
+
label_dist = dict(zip(uniq.tolist(), cnt.tolist()))
|
| 290 |
+
|
| 291 |
+
print(f"{split_name.upper()}:")
|
| 292 |
+
print(f" Shape: X={X.shape}, y={y.shape}")
|
| 293 |
+
print(f" Label distribution: {label_dist}")
|
| 294 |
+
print(f" Saved to: {h5_path}")
|
| 295 |
+
print(f" File size: {os.path.getsize(h5_path) / (1024**2):.2f} MB")
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
if __name__ == "__main__":
|
| 299 |
+
import argparse
|
| 300 |
+
|
| 301 |
+
parser = argparse.ArgumentParser(
|
| 302 |
+
description="Convert HMC EMG dataset to HDF5 format with parallel processing"
|
| 303 |
+
)
|
| 304 |
+
parser.add_argument(
|
| 305 |
+
"--data_dir",
|
| 306 |
+
type=str,
|
| 307 |
+
required=True,
|
| 308 |
+
help="Root directory containing HMC EDF files",
|
| 309 |
+
)
|
| 310 |
+
parser.add_argument(
|
| 311 |
+
"--save_dir", type=str, required=True, help="Directory to save HDF5 files"
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--channel", type=str, default="EMG chin", help="EMG channel name"
|
| 315 |
+
)
|
| 316 |
+
parser.add_argument(
|
| 317 |
+
"--start_at", type=int, default=900, help="Start time in seconds"
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--duration_sec", type=int, default=21600, help="Duration in seconds"
|
| 321 |
+
)
|
| 322 |
+
parser.add_argument(
|
| 323 |
+
"--l_freq", type=float, default=5.0, help="Low-pass filter frequency"
|
| 324 |
+
)
|
| 325 |
+
parser.add_argument(
|
| 326 |
+
"--h_freq", type=float, default=200.0, help="High-pass filter frequency"
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--sfreq", type=int, default=500, help="Target sampling frequency (Hz)"
|
| 330 |
+
)
|
| 331 |
+
parser.add_argument(
|
| 332 |
+
"--mains",
|
| 333 |
+
type=int,
|
| 334 |
+
default=50,
|
| 335 |
+
choices=[50, 60],
|
| 336 |
+
help="Mains frequency for notch (50 or 60 Hz)",
|
| 337 |
+
)
|
| 338 |
+
parser.add_argument(
|
| 339 |
+
"--window_size", type=int, default=1000, help="Window size for segmentation"
|
| 340 |
+
)
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--stride", type=int, default=1000, help="Stride for sliding window"
|
| 343 |
+
)
|
| 344 |
+
parser.add_argument(
|
| 345 |
+
"--n_jobs",
|
| 346 |
+
type=int,
|
| 347 |
+
default=-1,
|
| 348 |
+
help="Number of parallel jobs (-1 for all cores)",
|
| 349 |
+
)
|
| 350 |
+
parser.add_argument(
|
| 351 |
+
"--verbose", action="store_true", help="Print detailed progress"
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
args = parser.parse_args()
|
| 355 |
+
|
| 356 |
+
convert_hmc_to_h5(
|
| 357 |
+
data_path=args.data_dir,
|
| 358 |
+
save_path=args.save_dir,
|
| 359 |
+
channel=args.channel,
|
| 360 |
+
start_at=args.start_at,
|
| 361 |
+
duration_sec=args.duration_sec,
|
| 362 |
+
l_freq=args.l_freq,
|
| 363 |
+
h_freq=args.h_freq,
|
| 364 |
+
mains=args.mains,
|
| 365 |
+
sfreq=args.sfreq,
|
| 366 |
+
window_size=args.window_size,
|
| 367 |
+
stride=args.stride,
|
| 368 |
+
n_jobs=args.n_jobs,
|
| 369 |
+
verbose=args.verbose,
|
| 370 |
+
)
|
scripts/README.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataset Preparation Commands
|
| 2 |
+
|
| 3 |
+
## Overview
|
| 4 |
+
|
| 5 |
+
This document provides the commands to prepare various EMG datasets for pretraining and downstream tasks. Each dataset preparation script takes in raw data, processes it into overlapping windows, and saves the processed data in HDF5 format for efficient loading during model training.
|
| 6 |
+
|
| 7 |
+
Remember to add the flag `--download_data` if the dataset is not downloaded yet.
|
| 8 |
+
|
| 9 |
+
## Pretraining Datasets
|
| 10 |
+
|
| 11 |
+
For the pretraining:
|
| 12 |
+
|
| 13 |
+
### emg2pose
|
| 14 |
+
|
| 15 |
+
```bash
|
| 16 |
+
python scripts/emg2pose.py \
|
| 17 |
+
--data_dir $SCRATCH/datasets/emg2pose_data/ \
|
| 18 |
+
--save_dir $SCRATCH/datasets/emg2pose_data/h5/ \
|
| 19 |
+
--window_size 1000 \
|
| 20 |
+
--stride 500
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
### Ninapro DB6
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
python scripts/db6.py \
|
| 27 |
+
--data_dir $SCRATCH/datasets/ninapro/DB6/ \
|
| 28 |
+
--save_dir $SCRATCH/datasets/ninapro/DB6/h5/ \
|
| 29 |
+
--window_size 1000 \
|
| 30 |
+
--stride 500
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### Ninapro DB7
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
python scripts/db7.py \
|
| 37 |
+
--data_dir $SCRATCH/datasets/ninapro/DB7/ \
|
| 38 |
+
--save_dir $SCRATCH/datasets/ninapro/DB7/h5/ \
|
| 39 |
+
--window_size 1000 \
|
| 40 |
+
--stride 500
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## Downstream Datasets
|
| 46 |
+
|
| 47 |
+
For the downstream tasks:
|
| 48 |
+
|
| 49 |
+
### Ninapro DB5 (200 ms, 25% overlap)
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
python scripts/db5.py \
|
| 53 |
+
--data_dir $SCRATCH/datasets/ninapro/DB5/ \
|
| 54 |
+
--save_dir $SCRATCH/datasets/ninapro/DB5/h5/ \
|
| 55 |
+
--window_size 200 \
|
| 56 |
+
--stride 50
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### Ninapro DB5 (1000 ms, 25% overlap)
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
python scripts/db5.py \
|
| 63 |
+
--data_dir $SCRATCH/datasets/ninapro/DB5/ \
|
| 64 |
+
--save_dir $SCRATCH/datasets/ninapro/DB5/h5/ \
|
| 65 |
+
--window_size 1000 \
|
| 66 |
+
--stride 250
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### EMG-EPN612 (200 ms)
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
python scripts/epn.py \
|
| 73 |
+
--data_dir $SCRATCH/datasets/EPN612/ \
|
| 74 |
+
--source_training $SCRATCH/datasets/EPN612/trainingJSON/ \
|
| 75 |
+
--source_testing $SCRATCH/datasets/EPN612/testingJSON/ \
|
| 76 |
+
--dest_dir $SCRATCH/datasets/EPN612/h5/ \
|
| 77 |
+
--window_size 200
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### EMG-EPN612 (1000 ms)
|
| 81 |
+
|
| 82 |
+
```bash
|
| 83 |
+
python scripts/epn.py \
|
| 84 |
+
--data_dir $SCRATCH/datasets/EPN612/ \
|
| 85 |
+
--source_training $SCRATCH/datasets/EPN612/trainingJSON/ \
|
| 86 |
+
--source_testing $SCRATCH/datasets/EPN612/testingJSON/ \
|
| 87 |
+
--dest_dir $SCRATCH/datasets/EPN612/h5/ \
|
| 88 |
+
--window_size 1000
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### UCI EMG (200 ms, 25% overlap)
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
python scripts/uci.py \
|
| 95 |
+
--data_dir $SCRATCH/datasets/UCI_EMG/EMG_data_for_gestures-master/ \
|
| 96 |
+
--save_dir $SCRATCH/datasets/UCI_EMG/EMG_data_for_gestures-master/h5/ \
|
| 97 |
+
--window_size 200 \
|
| 98 |
+
--stride 50
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
### UCI EMG (1000 ms, 25% overlap)
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
python scripts/uci.py \
|
| 105 |
+
--data_dir $SCRATCH/datasets/UCI_EMG/EMG_data_for_gestures-master/ \
|
| 106 |
+
--save_dir $SCRATCH/datasets/UCI_EMG/EMG_data_for_gestures-master/h5/ \
|
| 107 |
+
--window_size 1000 \
|
| 108 |
+
--stride 250
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
### Ninapro DB8 (200 ms, no overlap)
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
python scripts/db8.py \
|
| 115 |
+
--data_dir $SCRATCH/datasets/ninapro/DB8/ \
|
| 116 |
+
--save_dir $SCRATCH/datasets/ninapro/DB8/h5/ \
|
| 117 |
+
--window_size 200 \
|
| 118 |
+
--stride 200
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
### Ninapro DB8 (1000 ms, no overlap)
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
python scripts/db8.py \
|
| 125 |
+
--data_dir $SCRATCH/datasets/ninapro/DB8/ \
|
| 126 |
+
--save_dir $SCRATCH/datasets/ninapro/DB8/h5/ \
|
| 127 |
+
--window_size 1000 \
|
| 128 |
+
--stride 1000
|
| 129 |
+
```
|
scripts/db5.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import h5py
|
| 5 |
+
import numpy as np
|
| 6 |
+
import scipy.io
|
| 7 |
+
import scipy.signal as signal
|
| 8 |
+
from scipy.signal import iirnotch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ==== Data augmentation functions ====
|
| 12 |
+
def random_amplitude_scale(sig, scale_range=(0.9, 1.1)):
|
| 13 |
+
scale = np.random.uniform(*scale_range)
|
| 14 |
+
return sig * scale
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def random_time_jitter(sig, jitter_ratio=0.01):
|
| 18 |
+
T, D = sig.shape
|
| 19 |
+
std_ch = np.std(sig, axis=0)
|
| 20 |
+
noise = np.random.randn(T, D) * (jitter_ratio * std_ch)
|
| 21 |
+
return sig + noise
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def random_channel_dropout(sig, dropout_prob=0.05):
|
| 25 |
+
T, D = sig.shape
|
| 26 |
+
mask = np.random.rand(D) < dropout_prob
|
| 27 |
+
sig[:, mask] = 0.0
|
| 28 |
+
return sig
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def augment_one_sample(seg):
|
| 32 |
+
out = seg.copy()
|
| 33 |
+
out = random_amplitude_scale(out, (0.9, 1.1))
|
| 34 |
+
out = random_time_jitter(out, 0.01)
|
| 35 |
+
out = random_channel_dropout(out, 0.05)
|
| 36 |
+
return out
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def augment_train_data(data, labels, factor=3):
|
| 40 |
+
if factor <= 0 or data.shape[0] == 0:
|
| 41 |
+
return data, labels
|
| 42 |
+
aug_segs = [data]
|
| 43 |
+
aug_lbls = [labels]
|
| 44 |
+
N = data.shape[0]
|
| 45 |
+
for i in range(N):
|
| 46 |
+
seg = data[i] # [window_size, n_ch]
|
| 47 |
+
lab = labels[i]
|
| 48 |
+
for _ in range(factor):
|
| 49 |
+
aug_segs.append(augment_one_sample(seg)[None, ...])
|
| 50 |
+
aug_lbls.append([lab])
|
| 51 |
+
new_data = np.concatenate(aug_segs, axis=0)
|
| 52 |
+
new_labels = np.concatenate(aug_lbls, axis=0).ravel()
|
| 53 |
+
return new_data, new_labels
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ==== Filter functions (operate at original fs=200) ====
|
| 57 |
+
def notch_filter(data, notch_freq=50.0, Q=30.0, fs=200.0):
|
| 58 |
+
b, a = iirnotch(notch_freq, Q, fs)
|
| 59 |
+
out = np.zeros_like(data)
|
| 60 |
+
for ch in range(data.shape[1]):
|
| 61 |
+
out[:, ch] = signal.filtfilt(b, a, data[:, ch])
|
| 62 |
+
return out
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
|
| 66 |
+
nyq = 0.5 * fs
|
| 67 |
+
low = lowcut / nyq
|
| 68 |
+
high = highcut / nyq
|
| 69 |
+
b, a = signal.butter(order, [low, high], btype="bandpass")
|
| 70 |
+
out = np.zeros_like(emg)
|
| 71 |
+
for c in range(emg.shape[1]):
|
| 72 |
+
out[:, c] = signal.filtfilt(b, a, emg[:, c])
|
| 73 |
+
return out
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ==== Window segmentation ====
|
| 77 |
+
def process_emg_features(emg, label, rerep, window_size=1024, stride=512):
|
| 78 |
+
segs, lbls, reps = [], [], []
|
| 79 |
+
N = len(label)
|
| 80 |
+
for start in range(0, N, stride):
|
| 81 |
+
end = start + window_size
|
| 82 |
+
if end > N:
|
| 83 |
+
cut = emg[start:N]
|
| 84 |
+
pad = np.zeros((end - N, emg.shape[1]))
|
| 85 |
+
win = np.vstack([cut, pad])
|
| 86 |
+
else:
|
| 87 |
+
win = emg[start:end]
|
| 88 |
+
|
| 89 |
+
segs.append(win)
|
| 90 |
+
lbls.append(label[start])
|
| 91 |
+
reps.append(rerep[start])
|
| 92 |
+
return np.array(segs), np.array(lbls), np.array(reps)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ==== Main pipeline ====
|
| 96 |
+
def main():
|
| 97 |
+
import argparse
|
| 98 |
+
|
| 99 |
+
args = argparse.ArgumentParser(description="Process EMG data from DB5.")
|
| 100 |
+
args.add_argument("--download_data", action="store_true")
|
| 101 |
+
args.add_argument("--data_dir", type=str)
|
| 102 |
+
args.add_argument("--save_dir", type=str)
|
| 103 |
+
args.add_argument(
|
| 104 |
+
"--window_size", type=int, help="Size of the sliding window for segmentation."
|
| 105 |
+
)
|
| 106 |
+
args.add_argument(
|
| 107 |
+
"--stride", type=int, help="Stride for the sliding window segmentation."
|
| 108 |
+
)
|
| 109 |
+
args = args.parse_args()
|
| 110 |
+
|
| 111 |
+
data_dir = args.data_dir
|
| 112 |
+
save_dir = args.save_dir
|
| 113 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 114 |
+
|
| 115 |
+
# download data if requested
|
| 116 |
+
if args.download_data:
|
| 117 |
+
# https://ninapro.hevs.ch/instructions/DB5.html
|
| 118 |
+
len_data = range(1, 11) # 1β10
|
| 119 |
+
base_url = "https://ninapro.hevs.ch/files/DB5_Preproc/"
|
| 120 |
+
# download and unzip
|
| 121 |
+
for i in len_data:
|
| 122 |
+
url = f"{base_url}s{i}.zip"
|
| 123 |
+
os.system(f"wget -P {data_dir} {url}")
|
| 124 |
+
os.system(f"unzip -o {data_dir}/s{i}.zip -d {data_dir}")
|
| 125 |
+
os.system(f"rm {data_dir}/s{i}.zip")
|
| 126 |
+
print(f"Downloaded and unzipped subject {i}\n{data_dir}/s{i}.zip")
|
| 127 |
+
sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
|
| 128 |
+
|
| 129 |
+
fs = 200.0 # original sampling rate
|
| 130 |
+
window_size, stride = args.window_size, args.stride
|
| 131 |
+
train_reps = [1, 3, 4, 6]
|
| 132 |
+
val_reps = [2]
|
| 133 |
+
test_reps = [5]
|
| 134 |
+
|
| 135 |
+
all_data = {"train": [], "val": [], "test": []}
|
| 136 |
+
all_lbls = {"train": [], "val": [], "test": []}
|
| 137 |
+
|
| 138 |
+
for subj in sorted(os.listdir(data_dir)):
|
| 139 |
+
subj_path = os.path.join(data_dir, subj)
|
| 140 |
+
if not os.path.isdir(subj_path):
|
| 141 |
+
continue
|
| 142 |
+
print(f"Processing subject {subj}...")
|
| 143 |
+
for mat in sorted(os.listdir(subj_path)):
|
| 144 |
+
if not mat.endswith(".mat"):
|
| 145 |
+
continue
|
| 146 |
+
dd = scipy.io.loadmat(os.path.join(subj_path, mat))
|
| 147 |
+
emg = dd["emg"] # [N,16]
|
| 148 |
+
label = dd["restimulus"].ravel().astype(int)
|
| 149 |
+
rerep = dd["rerepetition"].ravel().astype(int)
|
| 150 |
+
|
| 151 |
+
# label shift by exercise
|
| 152 |
+
if "E2" in mat:
|
| 153 |
+
label = np.where(label != 0, label + 12, 0)
|
| 154 |
+
elif "E3" in mat:
|
| 155 |
+
label = np.where(label != 0, label + 29, 0)
|
| 156 |
+
|
| 157 |
+
# filtering at original 200 Hz
|
| 158 |
+
emg_filt = bandpass_filter_emg(emg, 20, 90, fs=fs)
|
| 159 |
+
emg_filt = notch_filter(emg_filt, 50, 30, fs=fs)
|
| 160 |
+
|
| 161 |
+
# z-score
|
| 162 |
+
mu = emg_filt.mean(axis=0)
|
| 163 |
+
sd = emg_filt.std(axis=0, ddof=1)
|
| 164 |
+
sd[sd == 0] = 1.0
|
| 165 |
+
emg_z = (emg_filt - mu) / sd
|
| 166 |
+
|
| 167 |
+
# segment
|
| 168 |
+
segs, lbls, reps = process_emg_features(
|
| 169 |
+
emg_z, label, rerep, window_size, stride
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# split by repetition index
|
| 173 |
+
for seg, lab, rp in zip(segs, lbls, reps):
|
| 174 |
+
if rp in train_reps:
|
| 175 |
+
all_data["train"].append(seg)
|
| 176 |
+
all_lbls["train"].append(lab)
|
| 177 |
+
elif rp in val_reps:
|
| 178 |
+
all_data["val"].append(seg)
|
| 179 |
+
all_lbls["val"].append(lab)
|
| 180 |
+
elif rp in test_reps:
|
| 181 |
+
all_data["test"].append(seg)
|
| 182 |
+
all_lbls["test"].append(lab)
|
| 183 |
+
|
| 184 |
+
# stack, augment train, transpose, save, and print stats
|
| 185 |
+
stats = {}
|
| 186 |
+
for split in ["train", "val", "test"]:
|
| 187 |
+
X = np.stack(all_data[split], axis=0) # [N, window_size, ch]
|
| 188 |
+
y = np.array(all_lbls[split], dtype=int)
|
| 189 |
+
|
| 190 |
+
if split == "train":
|
| 191 |
+
X, y = augment_train_data(X, y, factor=3)
|
| 192 |
+
|
| 193 |
+
# transpose to [N, ch, window_size]
|
| 194 |
+
X = X.transpose(0, 2, 1)
|
| 195 |
+
|
| 196 |
+
# save
|
| 197 |
+
with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as hf:
|
| 198 |
+
hf.create_dataset("data", data=X)
|
| 199 |
+
hf.create_dataset("label", data=y)
|
| 200 |
+
|
| 201 |
+
# compute stats
|
| 202 |
+
uniq, cnt = np.unique(y, return_counts=True)
|
| 203 |
+
stats[split] = (X.shape, dict(zip(uniq.tolist(), cnt.tolist())))
|
| 204 |
+
|
| 205 |
+
# print stats
|
| 206 |
+
for split, (shape, dist) in stats.items():
|
| 207 |
+
print(f"\n{split} β X={shape}, label distribution:")
|
| 208 |
+
for lab, count in dist.items():
|
| 209 |
+
print(f" label {lab}: {count} samples")
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
main()
|
scripts/db6.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import h5py
|
| 5 |
+
import numpy as np
|
| 6 |
+
import scipy.io
|
| 7 |
+
import scipy.signal as signal
|
| 8 |
+
from scipy.signal import iirnotch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# βββββββββββββββ Filtering ββββββββββββββββββ
|
| 12 |
+
def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
|
| 13 |
+
"""Notch-filter every channel independently."""
|
| 14 |
+
b, a = iirnotch(notch_freq, Q, fs)
|
| 15 |
+
out = np.zeros_like(data)
|
| 16 |
+
for ch in range(data.shape[1]):
|
| 17 |
+
out[:, ch] = signal.filtfilt(b, a, data[:, ch])
|
| 18 |
+
return out
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
|
| 22 |
+
nyq = 0.5 * fs
|
| 23 |
+
b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
|
| 24 |
+
out = np.zeros_like(emg)
|
| 25 |
+
for ch in range(emg.shape[1]):
|
| 26 |
+
out[:, ch] = signal.filtfilt(b, a, emg[:, ch])
|
| 27 |
+
return out
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# βββββββββββββββ Sliding window ββββββββββββββ
|
| 31 |
+
def sliding_window_segment(emg, label, rerepetition, window_size, stride):
|
| 32 |
+
"""
|
| 33 |
+
Segment EMG with a sliding window.
|
| 34 |
+
Use the frame at the window centre as the segment label / repetition index.
|
| 35 |
+
"""
|
| 36 |
+
segments, labels, reps = [], [], []
|
| 37 |
+
n_samples = len(label)
|
| 38 |
+
|
| 39 |
+
for start in range(0, n_samples - window_size + 1, stride):
|
| 40 |
+
end = start + window_size
|
| 41 |
+
emg_segment = emg[start:end] # (win, ch)
|
| 42 |
+
centre_idx = (start + end) // 2
|
| 43 |
+
segments.append(emg_segment)
|
| 44 |
+
labels.append(label[centre_idx])
|
| 45 |
+
reps.append(rerepetition[centre_idx])
|
| 46 |
+
|
| 47 |
+
return np.array(segments), np.array(labels), np.array(reps)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# βββββββββββββββ Main pipeline βββββββββββββββ
|
| 51 |
+
def main():
|
| 52 |
+
import argparse
|
| 53 |
+
|
| 54 |
+
args = argparse.ArgumentParser(description="Process EMG data from DB6.")
|
| 55 |
+
args.add_argument("--download_data", action="store_true")
|
| 56 |
+
args.add_argument("--data_dir", type=str)
|
| 57 |
+
args.add_argument("--save_dir", type=str)
|
| 58 |
+
args.add_argument(
|
| 59 |
+
"--window_size", type=int, help="Size of the sliding window for segmentation."
|
| 60 |
+
)
|
| 61 |
+
args.add_argument(
|
| 62 |
+
"--stride", type=int, help="Stride for the sliding window segmentation."
|
| 63 |
+
)
|
| 64 |
+
args = args.parse_args()
|
| 65 |
+
data_dir = args.data_dir # input folder with .mat files
|
| 66 |
+
save_dir = args.save_dir # output folder for .h5 files
|
| 67 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 68 |
+
|
| 69 |
+
# download data if requested
|
| 70 |
+
if args.download_data:
|
| 71 |
+
# https://ninapro.hevs.ch/instructions/DB6.html
|
| 72 |
+
len_data = range(1, 11) # 1β10
|
| 73 |
+
base_url = "https://ninapro.hevs.ch/files/DB6_Preproc/"
|
| 74 |
+
# download and unzip
|
| 75 |
+
for i in len_data:
|
| 76 |
+
url_a = f"{base_url}DB6_s{i}_a.zip"
|
| 77 |
+
url_b = f"{base_url}DB6_s{i}_b.zip"
|
| 78 |
+
os.system(f"wget -P {data_dir} {url_a}")
|
| 79 |
+
os.system(f"wget -P {data_dir} {url_b}")
|
| 80 |
+
os.system(f"unzip -o {data_dir}/DB6_s{i}_a.zip -d {data_dir}")
|
| 81 |
+
os.system(f"unzip -o {data_dir}/DB6_s{i}_b.zip -d {data_dir}")
|
| 82 |
+
os.system(f"rm {data_dir}/DB6_s{i}_a.zip {data_dir}/DB6_s{i}_b.zip")
|
| 83 |
+
print(
|
| 84 |
+
f"Downloaded and unzipped subject {i}\n{data_dir}/DB6_s{i}_a.zip and {data_dir}/DB6_s{i}_b.zip"
|
| 85 |
+
)
|
| 86 |
+
sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
|
| 87 |
+
|
| 88 |
+
fs = 2000.0
|
| 89 |
+
window_size, stride = args.window_size, args.stride
|
| 90 |
+
|
| 91 |
+
train_reps = list(range(1, 9)) # 1β8
|
| 92 |
+
val_reps = [9, 10] # 9β10
|
| 93 |
+
test_reps = [11, 12] # 11β12
|
| 94 |
+
|
| 95 |
+
splits = {
|
| 96 |
+
"train": {"data": [], "label": []},
|
| 97 |
+
"val": {"data": [], "label": []},
|
| 98 |
+
"test": {"data": [], "label": []},
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
# iterate subjects
|
| 102 |
+
for subj in sorted(os.listdir(data_dir)):
|
| 103 |
+
subj_path = os.path.join(data_dir, subj)
|
| 104 |
+
if not os.path.isdir(subj_path):
|
| 105 |
+
continue
|
| 106 |
+
print(f"Processing subject {subj} ...")
|
| 107 |
+
|
| 108 |
+
subj_seg, subj_lbl, subj_rep = [], [], []
|
| 109 |
+
|
| 110 |
+
# iterate .mat files
|
| 111 |
+
for mat_file in sorted(os.listdir(subj_path)):
|
| 112 |
+
if not mat_file.endswith(".mat"):
|
| 113 |
+
continue
|
| 114 |
+
mat_path = os.path.join(subj_path, mat_file)
|
| 115 |
+
mat = scipy.io.loadmat(mat_path)
|
| 116 |
+
|
| 117 |
+
emg = mat["emg"] # (N, 16)
|
| 118 |
+
label = mat["restimulus"].ravel()
|
| 119 |
+
rerep = mat["rerepetition"].ravel()
|
| 120 |
+
|
| 121 |
+
# drop empty channels (index 8, 9 β 0-based)
|
| 122 |
+
emg = np.delete(emg, [8, 9], axis=1) # now (N, 14)
|
| 123 |
+
|
| 124 |
+
# filtering
|
| 125 |
+
emg = bandpass_filter_emg(emg, 20, 450, fs=fs)
|
| 126 |
+
emg = notch_filter(emg, 50, 30, fs=fs)
|
| 127 |
+
|
| 128 |
+
# z-score per channel
|
| 129 |
+
mu = emg.mean(axis=0)
|
| 130 |
+
sd = emg.std(axis=0, ddof=1)
|
| 131 |
+
sd[sd == 0] = 1.0
|
| 132 |
+
emg = (emg - mu) / sd
|
| 133 |
+
|
| 134 |
+
# windowing
|
| 135 |
+
seg, lbl, rep = sliding_window_segment(
|
| 136 |
+
emg, label, rerep, window_size, stride
|
| 137 |
+
)
|
| 138 |
+
subj_seg.append(seg)
|
| 139 |
+
subj_lbl.append(lbl)
|
| 140 |
+
subj_rep.append(rep)
|
| 141 |
+
|
| 142 |
+
if not subj_seg:
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
seg = np.concatenate(subj_seg, axis=0) # (M, win, 14)
|
| 146 |
+
lbl = np.concatenate(subj_lbl)
|
| 147 |
+
rep = np.concatenate(subj_rep)
|
| 148 |
+
|
| 149 |
+
# split by repetition id
|
| 150 |
+
for split_name, mask in (
|
| 151 |
+
("train", np.isin(rep, train_reps)),
|
| 152 |
+
("val", np.isin(rep, val_reps)),
|
| 153 |
+
("test", np.isin(rep, test_reps)),
|
| 154 |
+
):
|
| 155 |
+
X = seg[mask].transpose(0, 2, 1) # (N, 14, 1024)
|
| 156 |
+
y = lbl[mask]
|
| 157 |
+
splits[split_name]["data"].append(X)
|
| 158 |
+
splits[split_name]["label"].append(y)
|
| 159 |
+
|
| 160 |
+
# concatenate, save, and report
|
| 161 |
+
for split in ["train", "val", "test"]:
|
| 162 |
+
X = (
|
| 163 |
+
np.concatenate(splits[split]["data"], axis=0)
|
| 164 |
+
if splits[split]["data"]
|
| 165 |
+
else np.empty((0, 14, window_size))
|
| 166 |
+
)
|
| 167 |
+
y = (
|
| 168 |
+
np.concatenate(splits[split]["label"], axis=0)
|
| 169 |
+
if splits[split]["label"]
|
| 170 |
+
else np.empty((0,), dtype=int)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as f:
|
| 174 |
+
f.create_dataset("data", data=X.astype(np.float32))
|
| 175 |
+
f.create_dataset("label", data=y.astype(np.int64))
|
| 176 |
+
|
| 177 |
+
uniq, cnt = np.unique(y, return_counts=True)
|
| 178 |
+
print(f"\n{split.upper()} β X={X.shape}, label distribution:")
|
| 179 |
+
for u, c in zip(uniq, cnt):
|
| 180 |
+
print(f" label {u}: {c} samples")
|
| 181 |
+
|
| 182 |
+
print("\nSaved: train.h5, val.h5, test.h5")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
main()
|
scripts/db7.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import h5py
|
| 5 |
+
import numpy as np
|
| 6 |
+
import scipy.io
|
| 7 |
+
import scipy.signal as signal
|
| 8 |
+
from scipy.signal import iirnotch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# βββββββββββββββ Filtering ββββββββββββββββββ
|
| 12 |
+
def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
|
| 13 |
+
"""Notch-filter every channel independently."""
|
| 14 |
+
b, a = iirnotch(notch_freq, Q, fs)
|
| 15 |
+
out = np.zeros_like(data)
|
| 16 |
+
for ch in range(data.shape[1]):
|
| 17 |
+
out[:, ch] = signal.filtfilt(b, a, data[:, ch])
|
| 18 |
+
return out
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
|
| 22 |
+
nyq = 0.5 * fs
|
| 23 |
+
b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
|
| 24 |
+
out = np.zeros_like(emg)
|
| 25 |
+
for ch in range(emg.shape[1]):
|
| 26 |
+
out[:, ch] = signal.filtfilt(b, a, emg[:, ch])
|
| 27 |
+
return out
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# βββββββββββββββ Sliding window ββββββββββββββ
|
| 31 |
+
def sliding_window_segment(emg, label, rerepetition, window_size, stride):
|
| 32 |
+
"""
|
| 33 |
+
Segment EMG with a sliding window.
|
| 34 |
+
Use the frame at the window centre as the segment label / repetition index.
|
| 35 |
+
"""
|
| 36 |
+
segments, labels, reps = [], [], []
|
| 37 |
+
n_samples = len(label)
|
| 38 |
+
|
| 39 |
+
for start in range(0, n_samples - window_size + 1, stride):
|
| 40 |
+
end = start + window_size
|
| 41 |
+
emg_segment = emg[start:end] # (win, ch)
|
| 42 |
+
centre_idx = (start + end) // 2
|
| 43 |
+
segments.append(emg_segment)
|
| 44 |
+
labels.append(label[centre_idx])
|
| 45 |
+
reps.append(rerepetition[centre_idx])
|
| 46 |
+
|
| 47 |
+
return np.array(segments), np.array(labels), np.array(reps)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# βββββββββββββββ Main pipeline βββββββββββββββ
|
| 51 |
+
def main():
|
| 52 |
+
import argparse
|
| 53 |
+
|
| 54 |
+
args = argparse.ArgumentParser(description="Process EMG data from DB7.")
|
| 55 |
+
args.add_argument("--download_data", action="store_true")
|
| 56 |
+
args.add_argument("--data_dir", type=str)
|
| 57 |
+
args.add_argument("--save_dir", type=str)
|
| 58 |
+
args.add_argument(
|
| 59 |
+
"--window_size",
|
| 60 |
+
type=int,
|
| 61 |
+
default=256,
|
| 62 |
+
help="Size of the sliding window for segmentation.",
|
| 63 |
+
)
|
| 64 |
+
args.add_argument(
|
| 65 |
+
"--stride",
|
| 66 |
+
type=int,
|
| 67 |
+
default=128,
|
| 68 |
+
help="Stride for the sliding window segmentation.",
|
| 69 |
+
)
|
| 70 |
+
args = args.parse_args()
|
| 71 |
+
data_dir = args.data_dir # input folder with .mat files
|
| 72 |
+
save_dir = args.save_dir # output folder for .h5 files
|
| 73 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 74 |
+
|
| 75 |
+
# download data if requested
|
| 76 |
+
if args.download_data:
|
| 77 |
+
# https://ninapro.hevs.ch/instructions/DB7.html
|
| 78 |
+
len_data = range(1, 23) # 1β22
|
| 79 |
+
base_url = "https://ninapro.hevs.ch/files/DB7_Preproc/"
|
| 80 |
+
# download and unzip
|
| 81 |
+
for i in len_data:
|
| 82 |
+
url = f"{base_url}Subject_{i}.zip"
|
| 83 |
+
os.system(f"wget -P {data_dir} {url}")
|
| 84 |
+
os.system(f"unzip -o {data_dir}/Subject_{i}.zip -d {data_dir}/Subject_{i}")
|
| 85 |
+
os.system(f"rm {data_dir}/Subject_{i}.zip")
|
| 86 |
+
print(f"Downloaded and unzipped subject {i}\n{data_dir}/Subject_{i}.zip")
|
| 87 |
+
sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
|
| 88 |
+
|
| 89 |
+
fs = 2000.0
|
| 90 |
+
window_size, stride = args.window_size, args.stride
|
| 91 |
+
|
| 92 |
+
train_reps = [1, 2, 3, 4] # 1β4
|
| 93 |
+
val_reps = [5] # 5
|
| 94 |
+
test_reps = [6] # 6
|
| 95 |
+
|
| 96 |
+
splits = {
|
| 97 |
+
"train": {"data": [], "label": []},
|
| 98 |
+
"val": {"data": [], "label": []},
|
| 99 |
+
"test": {"data": [], "label": []},
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# iterate subjects
|
| 103 |
+
for subj in sorted(os.listdir(data_dir)):
|
| 104 |
+
subj_path = os.path.join(data_dir, subj)
|
| 105 |
+
if not os.path.isdir(subj_path):
|
| 106 |
+
continue
|
| 107 |
+
print(f"Processing subject {subj} ...")
|
| 108 |
+
|
| 109 |
+
subj_seg, subj_lbl, subj_rep = [], [], []
|
| 110 |
+
|
| 111 |
+
# iterate .mat files
|
| 112 |
+
for mat_file in sorted(os.listdir(subj_path)):
|
| 113 |
+
if not mat_file.endswith(".mat"):
|
| 114 |
+
continue
|
| 115 |
+
mat_path = os.path.join(subj_path, mat_file)
|
| 116 |
+
mat = scipy.io.loadmat(mat_path)
|
| 117 |
+
|
| 118 |
+
emg = mat["emg"] # (N, 16)
|
| 119 |
+
label = mat["restimulus"].ravel()
|
| 120 |
+
rerep = mat["rerepetition"].ravel()
|
| 121 |
+
|
| 122 |
+
# filtering
|
| 123 |
+
emg = bandpass_filter_emg(emg, 20.0, 450.0, fs=fs)
|
| 124 |
+
emg = notch_filter(emg, 50.0, 30.0, fs=fs)
|
| 125 |
+
|
| 126 |
+
# z-score per channel
|
| 127 |
+
mu = emg.mean(axis=0)
|
| 128 |
+
sd = emg.std(axis=0, ddof=1)
|
| 129 |
+
sd[sd == 0] = 1.0
|
| 130 |
+
emg = (emg - mu) / sd
|
| 131 |
+
|
| 132 |
+
# windowing
|
| 133 |
+
seg, lbl, rep = sliding_window_segment(
|
| 134 |
+
emg, label, rerep, window_size, stride
|
| 135 |
+
)
|
| 136 |
+
subj_seg.append(seg)
|
| 137 |
+
subj_lbl.append(lbl)
|
| 138 |
+
subj_rep.append(rep)
|
| 139 |
+
|
| 140 |
+
if not subj_seg:
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
seg = np.concatenate(subj_seg, axis=0) # (M, win, 14)
|
| 144 |
+
lbl = np.concatenate(subj_lbl)
|
| 145 |
+
rep = np.concatenate(subj_rep)
|
| 146 |
+
|
| 147 |
+
# split by repetition id
|
| 148 |
+
for split_name, mask in (
|
| 149 |
+
("train", np.isin(rep, train_reps)),
|
| 150 |
+
("val", np.isin(rep, val_reps)),
|
| 151 |
+
("test", np.isin(rep, test_reps)),
|
| 152 |
+
):
|
| 153 |
+
X = seg[mask].transpose(0, 2, 1) # (N, 14, 1024)
|
| 154 |
+
y = lbl[mask]
|
| 155 |
+
splits[split_name]["data"].append(X)
|
| 156 |
+
splits[split_name]["label"].append(y)
|
| 157 |
+
|
| 158 |
+
# concatenate, save, and report
|
| 159 |
+
for split in ["train", "val", "test"]:
|
| 160 |
+
X = (
|
| 161 |
+
np.concatenate(splits[split]["data"], axis=0)
|
| 162 |
+
if splits[split]["data"]
|
| 163 |
+
else np.empty((0, 14, window_size))
|
| 164 |
+
)
|
| 165 |
+
y = (
|
| 166 |
+
np.concatenate(splits[split]["label"], axis=0)
|
| 167 |
+
if splits[split]["label"]
|
| 168 |
+
else np.empty((0,), dtype=int)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as f:
|
| 172 |
+
f.create_dataset("data", data=X.astype(np.float32))
|
| 173 |
+
f.create_dataset("label", data=y.astype(np.int64))
|
| 174 |
+
|
| 175 |
+
uniq, cnt = np.unique(y, return_counts=True)
|
| 176 |
+
print(f"\n{split.upper()} β X={X.shape}, label distribution:")
|
| 177 |
+
for u, c in zip(uniq, cnt):
|
| 178 |
+
print(f" label {u}: {c} samples")
|
| 179 |
+
|
| 180 |
+
print("\nSaved: train.h5, val.h5, test.h5")
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
main()
|
scripts/db8.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import h5py
|
| 5 |
+
import numpy as np
|
| 6 |
+
import scipy.io
|
| 7 |
+
import scipy.signal as signal
|
| 8 |
+
from joblib import Parallel, delayed
|
| 9 |
+
from scipy.signal import iirnotch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
_MATRIX_DOF2DOA_TRANSPOSED = np.array(
|
| 13 |
+
# https://www.frontiersin.org/articles/10.3389/fnins.2019.00891/full
|
| 14 |
+
# Open supplemental data > Data Sheet 1.PDF >
|
| 15 |
+
# > SUPPLEMENTARY METHODS > Eqn. S2
|
| 16 |
+
# https://www.frontiersin.org/articles/file/downloadfile/461612_supplementary-materials_datasheets_1_pdf/octet-stream/Data%20Sheet%201.PDF/1/461612
|
| 17 |
+
[
|
| 18 |
+
[+0.6390, +0.0000, +0.0000, +0.0000, +0.0000],
|
| 19 |
+
[+0.3830, +0.0000, +0.0000, +0.0000, +0.0000],
|
| 20 |
+
[+0.0000, +1.0000, +0.0000, +0.0000, +0.0000],
|
| 21 |
+
[-0.6390, +0.0000, +0.0000, +0.0000, +0.0000],
|
| 22 |
+
[+0.0000, +0.0000, +0.4000, +0.0000, +0.0000],
|
| 23 |
+
[+0.0000, +0.0000, +0.6000, +0.0000, +0.0000],
|
| 24 |
+
[+0.0000, +0.0000, +0.0000, +0.4000, +0.0000],
|
| 25 |
+
[+0.0000, +0.0000, +0.0000, +0.6000, +0.0000],
|
| 26 |
+
[+0.0000, +0.0000, +0.0000, +0.0000, +0.0000],
|
| 27 |
+
[+0.0000, +0.0000, +0.0000, +0.0000, +0.1667],
|
| 28 |
+
[+0.0000, +0.0000, +0.0000, +0.0000, +0.3333],
|
| 29 |
+
[+0.0000, +0.0000, +0.0000, +0.0000, +0.0000],
|
| 30 |
+
[+0.0000, +0.0000, +0.0000, +0.0000, +0.1667],
|
| 31 |
+
[+0.0000, +0.0000, +0.0000, +0.0000, +0.3333],
|
| 32 |
+
[+0.0000, +0.0000, +0.0000, +0.0000, +0.0000],
|
| 33 |
+
[+0.0000, +0.0000, +0.0000, +0.0000, +0.0000],
|
| 34 |
+
[-0.1900, +0.0000, +0.0000, +0.0000, +0.0000],
|
| 35 |
+
[+0.0000, +0.0000, +0.0000, +0.0000, +0.0000],
|
| 36 |
+
],
|
| 37 |
+
dtype=np.float32,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
MATRIX_DOF2DOA = _MATRIX_DOF2DOA_TRANSPOSED.T
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# βββββββββββββββ Filtering ββββββββββββββββββ
|
| 44 |
+
def notch_filter(data, notch_freq=50.0, Q=30.0, fs=1111.0):
|
| 45 |
+
"""Notch-filter every channel independently."""
|
| 46 |
+
b, a = iirnotch(notch_freq, Q, fs)
|
| 47 |
+
out = np.zeros_like(data)
|
| 48 |
+
for ch in range(data.shape[1]):
|
| 49 |
+
out[:, ch] = signal.filtfilt(b, a, data[:, ch])
|
| 50 |
+
return out
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
|
| 54 |
+
nyq = 0.5 * fs
|
| 55 |
+
b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
|
| 56 |
+
out = np.zeros_like(emg)
|
| 57 |
+
for ch in range(emg.shape[1]):
|
| 58 |
+
out[:, ch] = signal.filtfilt(b, a, emg[:, ch])
|
| 59 |
+
return out
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# βββββββββββββββ Sliding window ββββββββββββββ
|
| 63 |
+
def sliding_window_segment(emg, label, window_size, stride):
|
| 64 |
+
"""
|
| 65 |
+
Segment EMG with a sliding window.
|
| 66 |
+
Use the frame at the window centre as the segment label / repetition index.
|
| 67 |
+
"""
|
| 68 |
+
segments, labels = [], []
|
| 69 |
+
n_samples = len(label)
|
| 70 |
+
|
| 71 |
+
for start in range(0, n_samples - window_size + 1, stride):
|
| 72 |
+
end = start + window_size
|
| 73 |
+
emg_segment = emg[start:end] # (win, ch)
|
| 74 |
+
label_segment = label[start:end] # (win, ch)
|
| 75 |
+
segments.append(emg_segment)
|
| 76 |
+
labels.append(label_segment)
|
| 77 |
+
|
| 78 |
+
return np.array(segments), np.array(labels)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# βββββββββββββββ Main pipeline βββββββββββββββ
|
| 82 |
+
def process_mat_file(mat_path, window_size, stride, fs):
|
| 83 |
+
"""
|
| 84 |
+
Load one .mat file, filter out NaNs, filter & normalize EMG, map DoFβDoA,
|
| 85 |
+
segment, and return (split, segs, labels).
|
| 86 |
+
"""
|
| 87 |
+
mat = scipy.io.loadmat(mat_path)
|
| 88 |
+
emg = mat["emg"] # (T, 16)
|
| 89 |
+
label = mat["glove"] # (T, DoF)
|
| 90 |
+
|
| 91 |
+
# 1) Drop timesteps with any NaNs in glove data
|
| 92 |
+
valid = ~np.isnan(label).any(axis=1)
|
| 93 |
+
emg = emg[valid]
|
| 94 |
+
label = label[valid]
|
| 95 |
+
|
| 96 |
+
# 3) Z-score per channel
|
| 97 |
+
mu = emg.mean(axis=0)
|
| 98 |
+
sd = emg.std(axis=0, ddof=1)
|
| 99 |
+
sd[sd == 0] = 1.0
|
| 100 |
+
emg = (emg - mu) / sd
|
| 101 |
+
|
| 102 |
+
# 4) DoF β DoA
|
| 103 |
+
y_doa = (MATRIX_DOF2DOA @ label.T).T
|
| 104 |
+
|
| 105 |
+
# 5) Windowing
|
| 106 |
+
segs, labs = sliding_window_segment(emg, y_doa, window_size, stride)
|
| 107 |
+
|
| 108 |
+
# 6) Determine split
|
| 109 |
+
fname = os.path.basename(mat_path)
|
| 110 |
+
if "_A1" in fname:
|
| 111 |
+
split = "train"
|
| 112 |
+
elif "_A2" in fname:
|
| 113 |
+
split = "val"
|
| 114 |
+
elif "_A3" in fname:
|
| 115 |
+
split = "test"
|
| 116 |
+
else:
|
| 117 |
+
return None # skip
|
| 118 |
+
|
| 119 |
+
return split, segs, labs
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def main():
|
| 123 |
+
import argparse
|
| 124 |
+
|
| 125 |
+
args = argparse.ArgumentParser(description="Process EMG data from DB8.")
|
| 126 |
+
args.add_argument("--download_data", action="store_true")
|
| 127 |
+
args.add_argument("--data_dir", type=str, required=True)
|
| 128 |
+
args.add_argument("--save_dir", type=str, required=True)
|
| 129 |
+
args.add_argument(
|
| 130 |
+
"--window_size", type=int, help="Size of the sliding window for segmentation."
|
| 131 |
+
)
|
| 132 |
+
args.add_argument(
|
| 133 |
+
"--stride", type=int, help="Stride for the sliding window segmentation."
|
| 134 |
+
)
|
| 135 |
+
args.add_argument(
|
| 136 |
+
"--n_jobs", type=int, default=-1, help="Number of parallel jobs to run."
|
| 137 |
+
)
|
| 138 |
+
args = args.parse_args()
|
| 139 |
+
data_dir = args.data_dir # input folder with .mat files
|
| 140 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 141 |
+
|
| 142 |
+
# download data if requested
|
| 143 |
+
if args.download_data:
|
| 144 |
+
# https://ninapro.hevs.ch/instructions/DB8.html
|
| 145 |
+
len_data = range(1, 13) # 1β12
|
| 146 |
+
base_url = "https://ninapro.hevs.ch/files/DB8/"
|
| 147 |
+
# download and unzip
|
| 148 |
+
for i in len_data:
|
| 149 |
+
url_a = f"{base_url}S{i}_E1_A1.mat"
|
| 150 |
+
url_b = f"{base_url}S{i}_E1_A2.mat"
|
| 151 |
+
url_c = f"{base_url}S{i}_E1_A3.mat"
|
| 152 |
+
os.system(f"wget -P {data_dir} {url_a}")
|
| 153 |
+
os.system(f"wget -P {data_dir} {url_b}")
|
| 154 |
+
os.system(f"wget -P {data_dir} {url_c}")
|
| 155 |
+
print(
|
| 156 |
+
f"Downloaded subject {i}\n{data_dir}/S{i}_E1_A1.mat and {data_dir}/S{i}_E1_A2.mat and {data_dir}/S{i}_E1_A3.mat"
|
| 157 |
+
)
|
| 158 |
+
sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
|
| 159 |
+
|
| 160 |
+
fs = 2000.0 # Hz
|
| 161 |
+
|
| 162 |
+
# collect all .mat paths
|
| 163 |
+
mat_paths = [
|
| 164 |
+
os.path.join(args.data_dir, f)
|
| 165 |
+
for f in sorted(os.listdir(args.data_dir))
|
| 166 |
+
if f.endswith(".mat")
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
# run in parallel
|
| 170 |
+
results = Parallel(n_jobs=min(os.cpu_count(), args.n_jobs), verbose=5)(
|
| 171 |
+
delayed(process_mat_file)(mp, args.window_size, args.stride, fs)
|
| 172 |
+
for mp in mat_paths
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# aggregate
|
| 176 |
+
splits = {k: {"data": [], "label": []} for k in ("train", "val", "test")}
|
| 177 |
+
for out in tqdm(results, desc="Processing files", unit="file"):
|
| 178 |
+
if out is None:
|
| 179 |
+
continue
|
| 180 |
+
split, segs, labs = out
|
| 181 |
+
splits[split]["data"].append(segs)
|
| 182 |
+
splits[split]["label"].append(labs)
|
| 183 |
+
|
| 184 |
+
# concatenate + save + stats
|
| 185 |
+
for split, d in tqdm(splits.items(), desc="Saving splits", unit="split"):
|
| 186 |
+
if not d["data"]:
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
X = np.concatenate(d["data"], axis=0)
|
| 190 |
+
y = np.concatenate(d["label"], axis=0)
|
| 191 |
+
|
| 192 |
+
# transpose to [N, ch, window_size]
|
| 193 |
+
X = X.transpose(0, 2, 1)
|
| 194 |
+
|
| 195 |
+
print(f"Split: {split}, X shape: {X.shape}, y shape: {y.shape}")
|
| 196 |
+
# save
|
| 197 |
+
with h5py.File(os.path.join(args.save_dir, f"{split}.h5"), "w") as hf:
|
| 198 |
+
hf.create_dataset("data", data=X)
|
| 199 |
+
hf.create_dataset("label", data=y)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
main()
|
scripts/emg2pose.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import h5py
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import scipy.io
|
| 8 |
+
import scipy.signal as signal
|
| 9 |
+
from joblib import Parallel, delayed
|
| 10 |
+
from scipy.signal import iirnotch
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ==== Filter functions (operate at original fs=2000) ====
|
| 15 |
+
def notch_filter(data, notch_freq=50.0, Q=30.0, fs=2000.0):
|
| 16 |
+
b, a = iirnotch(notch_freq, Q, fs)
|
| 17 |
+
out = np.zeros_like(data)
|
| 18 |
+
for ch in range(data.shape[1]):
|
| 19 |
+
out[:, ch] = signal.filtfilt(b, a, data[:, ch])
|
| 20 |
+
return out
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4):
|
| 24 |
+
nyq = 0.5 * fs
|
| 25 |
+
low = lowcut / nyq
|
| 26 |
+
high = highcut / nyq
|
| 27 |
+
b, a = signal.butter(order, [low, high], btype="bandpass")
|
| 28 |
+
out = np.zeros_like(emg)
|
| 29 |
+
for c in range(emg.shape[1]):
|
| 30 |
+
out[:, c] = signal.filtfilt(b, a, emg[:, c])
|
| 31 |
+
return out
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ==== Window segmentation ====
|
| 35 |
+
def process_emg_features(emg, window_size=1000, stride=500):
|
| 36 |
+
segs, lbls = [], []
|
| 37 |
+
N = len(emg)
|
| 38 |
+
for start in range(0, N, stride):
|
| 39 |
+
end = start + window_size
|
| 40 |
+
if end > N: # skip the last segment if it is not complete
|
| 41 |
+
continue
|
| 42 |
+
win = emg[start:end]
|
| 43 |
+
segs.append(win)
|
| 44 |
+
return np.array(segs)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def process_one_recording(file_path, fs=2000.0, window_size=1000, stride=500):
|
| 48 |
+
"""
|
| 49 |
+
Process a single recording file to extract EMG features and labels
|
| 50 |
+
as to be used in the main pipeline with parallel processing.
|
| 51 |
+
"""
|
| 52 |
+
with h5py.File(file_path, "r") as f:
|
| 53 |
+
grp = f["emg2pose"]
|
| 54 |
+
data = grp["timeseries"]
|
| 55 |
+
emg = data["emg"][:].astype(np.float32)
|
| 56 |
+
|
| 57 |
+
# ==== Preprocessing EMG data ====
|
| 58 |
+
emg_filt = bandpass_filter_emg(emg, 20, 450, fs=fs)
|
| 59 |
+
emg_filt = notch_filter(emg_filt, 50, 30, fs=fs)
|
| 60 |
+
|
| 61 |
+
# z-score
|
| 62 |
+
mu = emg_filt.mean(axis=0)
|
| 63 |
+
sd = emg_filt.std(axis=0, ddof=1)
|
| 64 |
+
sd[sd == 0] = 1.0
|
| 65 |
+
emg_z = (emg_filt - mu) / sd
|
| 66 |
+
|
| 67 |
+
# segment
|
| 68 |
+
segs = process_emg_features(emg_z, window_size, stride)
|
| 69 |
+
|
| 70 |
+
return segs
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ==== Main pipeline ====
|
| 74 |
+
def main():
|
| 75 |
+
import argparse
|
| 76 |
+
|
| 77 |
+
args = argparse.ArgumentParser(description="Process EMG data from DB5.")
|
| 78 |
+
args.add_argument("--data_dir", type=str)
|
| 79 |
+
args.add_argument("--save_dir", type=str)
|
| 80 |
+
args.add_argument(
|
| 81 |
+
"--window_size", type=int, help="Size of the sliding window for segmentation."
|
| 82 |
+
)
|
| 83 |
+
args.add_argument(
|
| 84 |
+
"--stride", type=int, help="Stride for the sliding window segmentation."
|
| 85 |
+
)
|
| 86 |
+
args.add_argument(
|
| 87 |
+
"--subsample", type=float, default=1.0, help="Whether to subsample the data"
|
| 88 |
+
)
|
| 89 |
+
args.add_argument(
|
| 90 |
+
"--n_jobs",
|
| 91 |
+
type=int,
|
| 92 |
+
default=-1,
|
| 93 |
+
help="Number of parallel jobs to run. -1 means using all available cores.",
|
| 94 |
+
)
|
| 95 |
+
args.add_argument(
|
| 96 |
+
"--seed", type=int, default=42, help="Random seed for reproducibility."
|
| 97 |
+
)
|
| 98 |
+
args = args.parse_args()
|
| 99 |
+
|
| 100 |
+
data_dir = args.data_dir
|
| 101 |
+
save_dir = args.save_dir
|
| 102 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
fs = 2000.0 # original sampling rate
|
| 105 |
+
window_size, stride = args.window_size, args.stride
|
| 106 |
+
|
| 107 |
+
df = pd.read_csv(os.path.join(data_dir, "metadata.csv"))
|
| 108 |
+
df = df.groupby("split").apply(
|
| 109 |
+
lambda x: (
|
| 110 |
+
x.sample(frac=args.subsample, random_state=args.seed)
|
| 111 |
+
if args.subsample < 1.0
|
| 112 |
+
else x
|
| 113 |
+
)
|
| 114 |
+
)
|
| 115 |
+
df.reset_index(drop=True, inplace=True)
|
| 116 |
+
|
| 117 |
+
splits = {}
|
| 118 |
+
for split, df_ in df.groupby("split"):
|
| 119 |
+
sessions = list(df_.filename)
|
| 120 |
+
splits[split] = [
|
| 121 |
+
Path(data_dir).expanduser().joinpath(f"{session}.hdf5")
|
| 122 |
+
for session in sessions
|
| 123 |
+
]
|
| 124 |
+
|
| 125 |
+
all_data = {"train": [], "val": [], "test": []}
|
| 126 |
+
|
| 127 |
+
for split, files in splits.items():
|
| 128 |
+
# Here we use joblib to parallelize the file processing, each file is processed independently as the task is embarrassingly parallel. We scale the processing across all available CPU cores since the number of files is around 25k (with training being 17k).
|
| 129 |
+
results = Parallel(n_jobs=args.n_jobs)(
|
| 130 |
+
delayed(process_one_recording)(file_path, fs, window_size, stride)
|
| 131 |
+
for file_path in tqdm(files, desc=f"Processing {split} files")
|
| 132 |
+
)
|
| 133 |
+
# Collect results
|
| 134 |
+
for segs in tqdm(results, desc=f"Collecting {split} data"):
|
| 135 |
+
all_data[split].append(segs)
|
| 136 |
+
|
| 137 |
+
# stack, augment train, transpose, save, and print stats
|
| 138 |
+
X = np.concatenate(all_data[split], axis=0) # [N, window_size, ch]
|
| 139 |
+
|
| 140 |
+
# transpose to [N, ch, window_size]
|
| 141 |
+
X = X.transpose(0, 2, 1)
|
| 142 |
+
|
| 143 |
+
# save
|
| 144 |
+
with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as hf:
|
| 145 |
+
hf.create_dataset("data", data=X)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
main()
|
scripts/epn.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
import h5py
|
| 7 |
+
import numpy as np
|
| 8 |
+
import scipy.signal as signal
|
| 9 |
+
from joblib import Parallel, delayed
|
| 10 |
+
from scipy.signal import iirnotch
|
| 11 |
+
from tqdm.auto import tqdm
|
| 12 |
+
|
| 13 |
+
# Sampling frequency and EMG channels
|
| 14 |
+
tfs, n_ch = 200.0, 8
|
| 15 |
+
|
| 16 |
+
# Gesture label mapping
|
| 17 |
+
gesture_map = {
|
| 18 |
+
"noGesture": 0,
|
| 19 |
+
"waveIn": 1,
|
| 20 |
+
"waveOut": 2,
|
| 21 |
+
"pinch": 3,
|
| 22 |
+
"open": 4,
|
| 23 |
+
"fist": 5,
|
| 24 |
+
"notProvided": 6,
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Filtering utilities
|
| 29 |
+
def bandpass_filter_emg(emg, low=20.0, high=90.0, fs=tfs, order=4):
|
| 30 |
+
nyq = 0.5 * fs
|
| 31 |
+
b, a = signal.butter(order, [low / nyq, high / nyq], btype="bandpass")
|
| 32 |
+
return signal.filtfilt(b, a, emg, axis=1)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def notch_filter_emg(emg, notch=50.0, Q=30.0, fs=tfs):
|
| 36 |
+
w0 = notch / (0.5 * fs)
|
| 37 |
+
b, a = iirnotch(w0, Q)
|
| 38 |
+
return signal.filtfilt(b, a, emg, axis=1)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Normalization helpers
|
| 42 |
+
def zscore_per_channel(emg):
|
| 43 |
+
mean = emg.mean(axis=1, keepdims=True)
|
| 44 |
+
std = emg.std(axis=1, ddof=1, keepdims=True)
|
| 45 |
+
std[std == 0] = 1.0
|
| 46 |
+
return (emg - mean) / std
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def adjust_length(x, max_len):
|
| 50 |
+
n_ch, seq_len = x.shape
|
| 51 |
+
if seq_len >= max_len:
|
| 52 |
+
return x[:, :max_len]
|
| 53 |
+
pad = np.zeros((n_ch, max_len - seq_len), dtype=x.dtype)
|
| 54 |
+
return np.concatenate([x, pad], axis=1)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Single-sample processing
|
| 58 |
+
def extract_emg_signal(sample, seq_len):
|
| 59 |
+
emg = np.stack([v for v in sample["emg"].values()], dtype=np.float32) / 128.0
|
| 60 |
+
emg = bandpass_filter_emg(emg, 20.0, 90.0)
|
| 61 |
+
emg = notch_filter_emg(emg, 50.0, 30.0)
|
| 62 |
+
emg = zscore_per_channel(emg)
|
| 63 |
+
emg = adjust_length(emg, seq_len)
|
| 64 |
+
label = gesture_map.get(sample.get("gestureName", "notProvided"), 6)
|
| 65 |
+
return emg, label
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# Process one user JSON for train/validation
|
| 69 |
+
def process_user_training(path, seq_len):
|
| 70 |
+
train_X, train_y, val_X, val_y = [], [], [], []
|
| 71 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 72 |
+
data = json.load(f)
|
| 73 |
+
for sample in data.get("trainingSamples", {}).values():
|
| 74 |
+
emg, lbl = extract_emg_signal(sample, seq_len)
|
| 75 |
+
if lbl != 6:
|
| 76 |
+
train_X.append(emg)
|
| 77 |
+
train_y.append(lbl)
|
| 78 |
+
for sample in data.get("testingSamples", {}).values():
|
| 79 |
+
emg, lbl = extract_emg_signal(sample, seq_len)
|
| 80 |
+
if lbl != 6:
|
| 81 |
+
val_X.append(emg)
|
| 82 |
+
val_y.append(lbl)
|
| 83 |
+
return train_X, train_y, val_X, val_y
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Process one user JSON for testing split
|
| 87 |
+
def process_user_testing(path, seq_len):
|
| 88 |
+
train_X, train_y, test_X, test_y = [], [], [], []
|
| 89 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 90 |
+
data = json.load(f)
|
| 91 |
+
buckets = {g: [] for g in gesture_map}
|
| 92 |
+
for sample in data.get("trainingSamples", {}).values():
|
| 93 |
+
buckets.setdefault(sample.get("gestureName", "notProvided"), []).append(sample)
|
| 94 |
+
for samples in buckets.values():
|
| 95 |
+
for i, sample in enumerate(samples):
|
| 96 |
+
emg, lbl = extract_emg_signal(sample, seq_len)
|
| 97 |
+
if lbl == 6:
|
| 98 |
+
continue
|
| 99 |
+
if i < 10:
|
| 100 |
+
train_X.append(emg)
|
| 101 |
+
train_y.append(lbl)
|
| 102 |
+
else:
|
| 103 |
+
test_X.append(emg)
|
| 104 |
+
test_y.append(lbl)
|
| 105 |
+
return train_X, train_y, test_X, test_y
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# Save to HDF5
|
| 109 |
+
def save_h5(path, data, labels):
|
| 110 |
+
with h5py.File(path, "w") as f:
|
| 111 |
+
f.create_dataset("data", data=np.asarray(data, np.float32))
|
| 112 |
+
f.create_dataset("label", data=np.asarray(labels, np.int64))
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# Main parallelized pipeline
|
| 116 |
+
def main():
|
| 117 |
+
import argparse
|
| 118 |
+
|
| 119 |
+
parser = argparse.ArgumentParser()
|
| 120 |
+
parser.add_argument("--download_data", action="store_true")
|
| 121 |
+
parser.add_argument("--data_dir", type=str, required=True)
|
| 122 |
+
parser.add_argument("--source_training", required=True)
|
| 123 |
+
parser.add_argument("--source_testing", required=True)
|
| 124 |
+
parser.add_argument("--dest_dir", required=True)
|
| 125 |
+
parser.add_argument("--window_size", type=int, required=True)
|
| 126 |
+
parser.add_argument("--n_jobs", type=int, default=-1)
|
| 127 |
+
args = parser.parse_args()
|
| 128 |
+
data_dir = args.data_dir
|
| 129 |
+
os.makedirs(args.dest_dir, exist_ok=True)
|
| 130 |
+
|
| 131 |
+
# download data if requested
|
| 132 |
+
if args.download_data:
|
| 133 |
+
# https://zenodo.org/records/4421500
|
| 134 |
+
url = "https://zenodo.org/records/4421500/files/EMG-EPN612%20Dataset.zip?download=1"
|
| 135 |
+
os.system(f"wget -O {data_dir}/EMG-EPN612_Dataset.zip {url}")
|
| 136 |
+
os.system(f"unzip -o {data_dir}/EMG-EPN612_Dataset.zip -d {data_dir}")
|
| 137 |
+
# move the contents one level up
|
| 138 |
+
os.system(rf"mv {data_dir}/EMG-EPN612\ Dataset/* {data_dir}/")
|
| 139 |
+
os.system(f"rmdir {data_dir}/EMG-EPN612_Dataset")
|
| 140 |
+
# clean up zip file
|
| 141 |
+
os.system(f"rm {data_dir}/EMG-EPN612_Dataset.zip")
|
| 142 |
+
print(f"Downloaded and unzipped dataset\n{data_dir}/EMG-EPN612_Dataset.zip")
|
| 143 |
+
sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
|
| 144 |
+
|
| 145 |
+
seq_len = args.window_size
|
| 146 |
+
train_X, train_y, val_X, val_y, test_X, test_y = [], [], [], [], [], []
|
| 147 |
+
|
| 148 |
+
paths = glob.glob(os.path.join(args.source_training, "user*", "user*.json"))
|
| 149 |
+
|
| 150 |
+
# Parallel process training JSONs
|
| 151 |
+
results = Parallel(n_jobs=args.n_jobs)(
|
| 152 |
+
delayed(process_user_training)(p, seq_len)
|
| 153 |
+
for p in tqdm(paths, desc="Training files")
|
| 154 |
+
)
|
| 155 |
+
for tX, ty, vX, vy in results:
|
| 156 |
+
train_X.extend(tX)
|
| 157 |
+
train_y.extend(ty)
|
| 158 |
+
val_X.extend(vX)
|
| 159 |
+
val_y.extend(vy)
|
| 160 |
+
|
| 161 |
+
# Parallel process testing JSONs
|
| 162 |
+
test_results = Parallel(n_jobs=args.n_jobs)(
|
| 163 |
+
delayed(process_user_testing)(p, seq_len)
|
| 164 |
+
for p in tqdm(
|
| 165 |
+
glob.glob(os.path.join(args.source_testing, "user*", "user*.json")),
|
| 166 |
+
desc="Testing files",
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
for tX, ty, teX, tey in test_results:
|
| 170 |
+
train_X.extend(tX)
|
| 171 |
+
train_y.extend(ty)
|
| 172 |
+
test_X.extend(teX)
|
| 173 |
+
test_y.extend(tey)
|
| 174 |
+
|
| 175 |
+
# Save datasets
|
| 176 |
+
save_h5(os.path.join(args.dest_dir, "train.h5"), train_X, train_y)
|
| 177 |
+
save_h5(os.path.join(args.dest_dir, "val.h5"), val_X, val_y)
|
| 178 |
+
save_h5(os.path.join(args.dest_dir, "test.h5"), test_X, test_y)
|
| 179 |
+
|
| 180 |
+
# Print distributions
|
| 181 |
+
for split, X, y in [
|
| 182 |
+
("Train", train_X, train_y),
|
| 183 |
+
("Val", val_X, val_y),
|
| 184 |
+
("Test", test_X, test_y),
|
| 185 |
+
]:
|
| 186 |
+
arr = np.array(y)
|
| 187 |
+
uniq, cnt = np.unique(arr, return_counts=True)
|
| 188 |
+
uniq = [i.item() for i in uniq]
|
| 189 |
+
cnt = [i.item() for i in cnt]
|
| 190 |
+
print(f"{split} β total={len(y)}, classes={{}}".format(dict(zip(uniq, cnt))))
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
main()
|
scripts/uci.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import h5py
|
| 6 |
+
import numpy as np
|
| 7 |
+
import scipy.signal as signal
|
| 8 |
+
from scipy.signal import iirnotch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 12 |
+
# Filtering utilities
|
| 13 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
+
def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
|
| 15 |
+
nyq = 0.5 * fs
|
| 16 |
+
b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
|
| 17 |
+
return signal.filtfilt(b, a, emg, axis=0)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def notch_filter_emg(emg, notch_freq=50.0, Q=30.0, fs=200.0):
|
| 21 |
+
b, a = iirnotch(notch_freq / (0.5 * fs), Q)
|
| 22 |
+
return signal.filtfilt(b, a, emg, axis=0)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
# Core I/O + preprocessing helpers
|
| 27 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
def read_emg_txt(txt_path):
|
| 29 |
+
"""
|
| 30 |
+
Read a txt file with columns: time ch1 β¦ ch8 class.
|
| 31 |
+
Return float32 array of shape (N, 10).
|
| 32 |
+
"""
|
| 33 |
+
data = []
|
| 34 |
+
with open(txt_path, "r") as f:
|
| 35 |
+
for line in f.readlines()[1:]: # skip header
|
| 36 |
+
cols = line.strip().split()
|
| 37 |
+
if len(cols) == 10:
|
| 38 |
+
data.append(list(map(float, cols)))
|
| 39 |
+
return np.asarray(data, dtype=np.float32)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def preprocess_emg(arr, fs=200.0, remove_class0=True):
|
| 43 |
+
"""
|
| 44 |
+
1) optional removal of class-0 rows
|
| 45 |
+
2) band-pass β notch β Z-score (on 8 channels)
|
| 46 |
+
"""
|
| 47 |
+
if remove_class0:
|
| 48 |
+
arr = arr[arr[:, -1] >= 1]
|
| 49 |
+
if arr.size == 0:
|
| 50 |
+
return arr
|
| 51 |
+
|
| 52 |
+
emg = arr[:, 1:9] # (N, 8)
|
| 53 |
+
emg = bandpass_filter_emg(emg, 20, 90, fs)
|
| 54 |
+
emg = notch_filter_emg(emg, 50, 30, fs)
|
| 55 |
+
|
| 56 |
+
mu = emg.mean(axis=0)
|
| 57 |
+
sd = emg.std(axis=0, ddof=1)
|
| 58 |
+
sd[sd == 0] = 1.0
|
| 59 |
+
emg = (emg - mu) / sd
|
| 60 |
+
|
| 61 |
+
arr[:, 1:9] = emg
|
| 62 |
+
return arr
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def find_label_runs(arr):
|
| 66 |
+
"""Group consecutive rows with identical class labels."""
|
| 67 |
+
runs = []
|
| 68 |
+
if arr.size == 0:
|
| 69 |
+
return runs
|
| 70 |
+
curr_lbl = int(arr[0, -1])
|
| 71 |
+
start = 0
|
| 72 |
+
for i in range(1, len(arr)):
|
| 73 |
+
lbl = int(arr[i, -1])
|
| 74 |
+
if lbl != curr_lbl:
|
| 75 |
+
runs.append((curr_lbl, arr[start:i]))
|
| 76 |
+
curr_lbl, start = lbl, i
|
| 77 |
+
runs.append((curr_lbl, arr[start:]))
|
| 78 |
+
return runs
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def sliding_window_majority(seg_arr, window_size=1000, stride=500):
|
| 82 |
+
segs, labs = [], []
|
| 83 |
+
for start in range(0, len(seg_arr) - window_size + 1, stride):
|
| 84 |
+
win = seg_arr[start : start + window_size]
|
| 85 |
+
maj = np.argmax(np.bincount(win[:, -1].astype(int)))
|
| 86 |
+
segs.append(win[:, 1:9]) # keep 8-channel EMG
|
| 87 |
+
labs.append(maj)
|
| 88 |
+
return np.asarray(segs, dtype=np.float32), np.asarray(labs, dtype=np.int32)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def users_with_gesture(
|
| 92 |
+
data_root, gesture_id, subj_range=range(1, 37), return_counts=False
|
| 93 |
+
):
|
| 94 |
+
found = {}
|
| 95 |
+
for subj in subj_range:
|
| 96 |
+
subj_dir = os.path.join(data_root, f"{subj:02d}")
|
| 97 |
+
if not os.path.isdir(subj_dir):
|
| 98 |
+
continue
|
| 99 |
+
count = 0
|
| 100 |
+
for fname in os.listdir(subj_dir):
|
| 101 |
+
if not fname.endswith(".txt"):
|
| 102 |
+
continue
|
| 103 |
+
txt_path = os.path.join(subj_dir, fname)
|
| 104 |
+
try:
|
| 105 |
+
arr = read_emg_txt(txt_path)
|
| 106 |
+
except Exception:
|
| 107 |
+
# skip files we can't parse
|
| 108 |
+
continue
|
| 109 |
+
if arr.size == 0:
|
| 110 |
+
continue
|
| 111 |
+
# last column is class label (as float). Compare as int.
|
| 112 |
+
if np.any(arr[:, -1].astype(int) == int(gesture_id)):
|
| 113 |
+
# count occurrences (rows) of that gesture in this file
|
| 114 |
+
count += int((arr[:, -1].astype(int) == int(gesture_id)).sum())
|
| 115 |
+
if count > 0:
|
| 116 |
+
found[subj] = count
|
| 117 |
+
|
| 118 |
+
if return_counts:
|
| 119 |
+
return found # dict subj -> count
|
| 120 |
+
else:
|
| 121 |
+
return sorted(found.keys())
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 125 |
+
# Safe concatenation utilities
|
| 126 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 127 |
+
def concat_data(lst): # lst of (N,256,8)
|
| 128 |
+
return np.concatenate(lst, axis=0) if lst else np.empty((0, 1000, 8), np.float32)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def concat_label(lst):
|
| 132 |
+
return np.concatenate(lst, axis=0) if lst else np.empty((0,), np.int32)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 136 |
+
# Main
|
| 137 |
+
# ββββββββββββββββββββββββββοΏ½οΏ½οΏ½ββββββββββββββββββ
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
import argparse
|
| 140 |
+
|
| 141 |
+
arg = argparse.ArgumentParser(description="Convert UCI EMG dataset to h5 format.")
|
| 142 |
+
arg.add_argument("--download_data", action="store_true")
|
| 143 |
+
arg.add_argument(
|
| 144 |
+
"--data_dir",
|
| 145 |
+
type=str,
|
| 146 |
+
required=True,
|
| 147 |
+
help="Root directory of the UCI EMG dataset",
|
| 148 |
+
)
|
| 149 |
+
arg.add_argument(
|
| 150 |
+
"--save_dir",
|
| 151 |
+
type=str,
|
| 152 |
+
required=True,
|
| 153 |
+
help="Directory to save the output h5 files",
|
| 154 |
+
)
|
| 155 |
+
arg.add_argument("--window_size", type=int, help="Window size for sliding window")
|
| 156 |
+
arg.add_argument("--stride", type=int, help="Stride for sliding window")
|
| 157 |
+
args = arg.parse_args()
|
| 158 |
+
|
| 159 |
+
data_root = args.data_dir
|
| 160 |
+
save_root = args.save_dir
|
| 161 |
+
os.makedirs(save_root, exist_ok=True)
|
| 162 |
+
|
| 163 |
+
# download data if requested
|
| 164 |
+
if args.download_data:
|
| 165 |
+
# https://archive.ics.uci.edu/dataset/481/emg+data+for+gestures
|
| 166 |
+
base_url = (
|
| 167 |
+
"https://archive.ics.uci.edu/static/public/481/emg+data+for+gestures.zip"
|
| 168 |
+
)
|
| 169 |
+
os.system(f"wget -O {data_root}/emg_gestures.zip '{base_url}'")
|
| 170 |
+
os.system(f"unzip -o {data_root}/emg_gestures.zip -d {Path(data_root).parent}")
|
| 171 |
+
os.system(f"rm {data_root}/emg_gestures.zip")
|
| 172 |
+
print("Dataset downloaded and cleaned up.")
|
| 173 |
+
sys.exit("Rerun without --download_data.")
|
| 174 |
+
|
| 175 |
+
fs = 200.0 # sampling rate of MYO bracelet
|
| 176 |
+
window_size, stride = args.window_size, args.stride
|
| 177 |
+
|
| 178 |
+
split_map = {
|
| 179 |
+
"train": list(range(1, 25)), # 1β24
|
| 180 |
+
"val": list(range(25, 31)), # 25β30
|
| 181 |
+
"test": list(range(31, 37)), # 31β36
|
| 182 |
+
}
|
| 183 |
+
# remove users that performed gesture 7
|
| 184 |
+
gesture_id = 7
|
| 185 |
+
gesture7_users = users_with_gesture(data_root, gesture_id)
|
| 186 |
+
print(f"Users that performed gesture {gesture_id}:", gesture7_users)
|
| 187 |
+
|
| 188 |
+
keep_subjs = []
|
| 189 |
+
for k in split_map:
|
| 190 |
+
split_map[k] = [u for u in split_map[k] if u not in gesture7_users]
|
| 191 |
+
keep_subjs.extend(split_map[k])
|
| 192 |
+
print("Updated split map after removing gesture-7 users:", keep_subjs)
|
| 193 |
+
|
| 194 |
+
datasets = {k: {"data": [], "label": []} for k in split_map}
|
| 195 |
+
|
| 196 |
+
for subj in keep_subjs:
|
| 197 |
+
subj_dir = os.path.join(data_root, f"{subj:02d}")
|
| 198 |
+
if not os.path.isdir(subj_dir):
|
| 199 |
+
continue
|
| 200 |
+
split_key = next(k for k, v in split_map.items() if subj in v)
|
| 201 |
+
|
| 202 |
+
for fname in sorted(os.listdir(subj_dir)):
|
| 203 |
+
if not fname.endswith(".txt"):
|
| 204 |
+
continue
|
| 205 |
+
arr = read_emg_txt(os.path.join(subj_dir, fname))
|
| 206 |
+
arr = preprocess_emg(arr, fs)
|
| 207 |
+
|
| 208 |
+
for lbl, seg_arr in find_label_runs(arr):
|
| 209 |
+
segs, labs = sliding_window_majority(seg_arr, window_size, stride)
|
| 210 |
+
if segs.size:
|
| 211 |
+
datasets[split_key]["data"].append(segs)
|
| 212 |
+
datasets[split_key]["label"].append(labs - 1)
|
| 213 |
+
|
| 214 |
+
# concatenate, transpose & save
|
| 215 |
+
for split in ["train", "val", "test"]:
|
| 216 |
+
X = concat_data(datasets[split]["data"]) # (N,256,8)
|
| 217 |
+
y = concat_label(datasets[split]["label"])
|
| 218 |
+
X = X.transpose(0, 2, 1) # (N,8,256)
|
| 219 |
+
|
| 220 |
+
with h5py.File(os.path.join(save_root, f"{split}.h5"), "w") as f:
|
| 221 |
+
f.create_dataset("data", data=X.astype(np.float32))
|
| 222 |
+
f.create_dataset("label", data=y.astype(np.int32))
|
| 223 |
+
uniq, cnt = np.unique(y, return_counts=True)
|
| 224 |
+
print(
|
| 225 |
+
f"{split.upper():5} β X={X.shape}, label dist:",
|
| 226 |
+
dict(zip(uniq.tolist(), cnt.tolist())),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
print("\nAll splits saved to:", save_root)
|