|
|
import os |
|
|
import sys |
|
|
|
|
|
import h5py |
|
|
import numpy as np |
|
|
import scipy.io |
|
|
import scipy.signal as signal |
|
|
from joblib import Parallel, delayed |
|
|
from scipy.signal import iirnotch |
|
|
from tqdm import tqdm |
|
|
|
|
|
_MATRIX_DOF2DOA_TRANSPOSED = np.array( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
[ |
|
|
[+0.6390, +0.0000, +0.0000, +0.0000, +0.0000], |
|
|
[+0.3830, +0.0000, +0.0000, +0.0000, +0.0000], |
|
|
[+0.0000, +1.0000, +0.0000, +0.0000, +0.0000], |
|
|
[-0.6390, +0.0000, +0.0000, +0.0000, +0.0000], |
|
|
[+0.0000, +0.0000, +0.4000, +0.0000, +0.0000], |
|
|
[+0.0000, +0.0000, +0.6000, +0.0000, +0.0000], |
|
|
[+0.0000, +0.0000, +0.0000, +0.4000, +0.0000], |
|
|
[+0.0000, +0.0000, +0.0000, +0.6000, +0.0000], |
|
|
[+0.0000, +0.0000, +0.0000, +0.0000, +0.0000], |
|
|
[+0.0000, +0.0000, +0.0000, +0.0000, +0.1667], |
|
|
[+0.0000, +0.0000, +0.0000, +0.0000, +0.3333], |
|
|
[+0.0000, +0.0000, +0.0000, +0.0000, +0.0000], |
|
|
[+0.0000, +0.0000, +0.0000, +0.0000, +0.1667], |
|
|
[+0.0000, +0.0000, +0.0000, +0.0000, +0.3333], |
|
|
[+0.0000, +0.0000, +0.0000, +0.0000, +0.0000], |
|
|
[+0.0000, +0.0000, +0.0000, +0.0000, +0.0000], |
|
|
[-0.1900, +0.0000, +0.0000, +0.0000, +0.0000], |
|
|
[+0.0000, +0.0000, +0.0000, +0.0000, +0.0000], |
|
|
], |
|
|
dtype=np.float32, |
|
|
) |
|
|
|
|
|
MATRIX_DOF2DOA = _MATRIX_DOF2DOA_TRANSPOSED.T |
|
|
|
|
|
|
|
|
|
|
|
def notch_filter(data, notch_freq=50.0, Q=30.0, fs=1111.0): |
|
|
"""Notch-filter every channel independently.""" |
|
|
b, a = iirnotch(notch_freq, Q, fs) |
|
|
out = np.zeros_like(data) |
|
|
for ch in range(data.shape[1]): |
|
|
out[:, ch] = signal.filtfilt(b, a, data[:, ch]) |
|
|
return out |
|
|
|
|
|
|
|
|
def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=2000.0, order=4): |
|
|
nyq = 0.5 * fs |
|
|
b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass") |
|
|
out = np.zeros_like(emg) |
|
|
for ch in range(emg.shape[1]): |
|
|
out[:, ch] = signal.filtfilt(b, a, emg[:, ch]) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def sliding_window_segment(emg, label, window_size, stride): |
|
|
""" |
|
|
Segment EMG with a sliding window. |
|
|
Use the frame at the window centre as the segment label / repetition index. |
|
|
""" |
|
|
segments, labels = [], [] |
|
|
n_samples = len(label) |
|
|
|
|
|
for start in range(0, n_samples - window_size + 1, stride): |
|
|
end = start + window_size |
|
|
emg_segment = emg[start:end] |
|
|
label_segment = label[start:end] |
|
|
segments.append(emg_segment) |
|
|
labels.append(label_segment) |
|
|
|
|
|
return np.array(segments), np.array(labels) |
|
|
|
|
|
|
|
|
|
|
|
def process_mat_file(mat_path, window_size, stride, fs): |
|
|
""" |
|
|
Load one .mat file, filter out NaNs, filter & normalize EMG, map DoFβDoA, |
|
|
segment, and return (split, segs, labels). |
|
|
""" |
|
|
mat = scipy.io.loadmat(mat_path) |
|
|
emg = mat["emg"] |
|
|
label = mat["glove"] |
|
|
|
|
|
|
|
|
valid = ~np.isnan(label).any(axis=1) |
|
|
emg = emg[valid] |
|
|
label = label[valid] |
|
|
|
|
|
|
|
|
mu = emg.mean(axis=0) |
|
|
sd = emg.std(axis=0, ddof=1) |
|
|
sd[sd == 0] = 1.0 |
|
|
emg = (emg - mu) / sd |
|
|
|
|
|
|
|
|
y_doa = (MATRIX_DOF2DOA @ label.T).T |
|
|
|
|
|
|
|
|
segs, labs = sliding_window_segment(emg, y_doa, window_size, stride) |
|
|
|
|
|
|
|
|
fname = os.path.basename(mat_path) |
|
|
if "_A1" in fname: |
|
|
split = "train" |
|
|
elif "_A2" in fname: |
|
|
split = "val" |
|
|
elif "_A3" in fname: |
|
|
split = "test" |
|
|
else: |
|
|
return None |
|
|
|
|
|
return split, segs, labs |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
|
|
|
args = argparse.ArgumentParser(description="Process EMG data from DB8.") |
|
|
args.add_argument("--download_data", action="store_true") |
|
|
args.add_argument("--data_dir", type=str, required=True) |
|
|
args.add_argument("--save_dir", type=str, required=True) |
|
|
args.add_argument( |
|
|
"--window_size", type=int, help="Size of the sliding window for segmentation." |
|
|
) |
|
|
args.add_argument( |
|
|
"--stride", type=int, help="Stride for the sliding window segmentation." |
|
|
) |
|
|
args.add_argument( |
|
|
"--n_jobs", type=int, default=-1, help="Number of parallel jobs to run." |
|
|
) |
|
|
args = args.parse_args() |
|
|
data_dir = args.data_dir |
|
|
os.makedirs(args.save_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if args.download_data: |
|
|
|
|
|
len_data = range(1, 13) |
|
|
base_url = "https://ninapro.hevs.ch/files/DB8/" |
|
|
|
|
|
for i in len_data: |
|
|
url_a = f"{base_url}S{i}_E1_A1.mat" |
|
|
url_b = f"{base_url}S{i}_E1_A2.mat" |
|
|
url_c = f"{base_url}S{i}_E1_A3.mat" |
|
|
os.system(f"wget -P {data_dir} {url_a}") |
|
|
os.system(f"wget -P {data_dir} {url_b}") |
|
|
os.system(f"wget -P {data_dir} {url_c}") |
|
|
print( |
|
|
f"Downloaded subject {i}\n{data_dir}/S{i}_E1_A1.mat and {data_dir}/S{i}_E1_A2.mat and {data_dir}/S{i}_E1_A3.mat" |
|
|
) |
|
|
sys.exit("Data downloaded and unzipped. Rerun without --download_data.") |
|
|
|
|
|
fs = 2000.0 |
|
|
|
|
|
|
|
|
mat_paths = [ |
|
|
os.path.join(args.data_dir, f) |
|
|
for f in sorted(os.listdir(args.data_dir)) |
|
|
if f.endswith(".mat") |
|
|
] |
|
|
|
|
|
|
|
|
results = Parallel(n_jobs=min(os.cpu_count(), args.n_jobs), verbose=5)( |
|
|
delayed(process_mat_file)(mp, args.window_size, args.stride, fs) |
|
|
for mp in mat_paths |
|
|
) |
|
|
|
|
|
|
|
|
splits = {k: {"data": [], "label": []} for k in ("train", "val", "test")} |
|
|
for out in tqdm(results, desc="Processing files", unit="file"): |
|
|
if out is None: |
|
|
continue |
|
|
split, segs, labs = out |
|
|
splits[split]["data"].append(segs) |
|
|
splits[split]["label"].append(labs) |
|
|
|
|
|
|
|
|
for split, d in tqdm(splits.items(), desc="Saving splits", unit="split"): |
|
|
if not d["data"]: |
|
|
continue |
|
|
|
|
|
X = np.concatenate(d["data"], axis=0) |
|
|
y = np.concatenate(d["label"], axis=0) |
|
|
|
|
|
|
|
|
X = X.transpose(0, 2, 1) |
|
|
|
|
|
print(f"Split: {split}, X shape: {X.shape}, y shape: {y.shape}") |
|
|
|
|
|
with h5py.File(os.path.join(args.save_dir, f"{split}.h5"), "w") as hf: |
|
|
hf.create_dataset("data", data=X) |
|
|
hf.create_dataset("label", data=y) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|