| import glob
|
| import json
|
| import os
|
| import sys
|
|
|
| import h5py
|
| import numpy as np
|
| import scipy.signal as signal
|
| from joblib import Parallel, delayed
|
| from scipy.signal import iirnotch
|
| from tqdm.auto import tqdm
|
|
|
| sequence_to_seconds = lambda seq_len, fs: seq_len / fs
|
|
|
|
|
| tfs, n_ch = 200.0, 8
|
|
|
|
|
| gesture_map = {
|
| "noGesture": 0,
|
| "waveIn": 1,
|
| "waveOut": 2,
|
| "pinch": 3,
|
| "open": 4,
|
| "fist": 5,
|
| "notProvided": 6,
|
| }
|
|
|
|
|
|
|
| def bandpass_filter_emg(emg, low=20.0, high=90.0, fs=tfs, order=4):
|
| nyq = 0.5 * fs
|
| b, a = signal.butter(order, [low / nyq, high / nyq], btype="bandpass")
|
| return signal.filtfilt(b, a, emg, axis=1)
|
|
|
|
|
| def notch_filter_emg(emg, notch=50.0, Q=30.0, fs=tfs):
|
| w0 = notch / (0.5 * fs)
|
| b, a = iirnotch(w0, Q)
|
| return signal.filtfilt(b, a, emg, axis=1)
|
|
|
|
|
|
|
| def zscore_per_channel(emg):
|
| mean = emg.mean(axis=1, keepdims=True)
|
| std = emg.std(axis=1, ddof=1, keepdims=True)
|
| std[std == 0] = 1.0
|
| return (emg - mean) / std
|
|
|
|
|
| def adjust_length(x, max_len):
|
| n_ch, seq_len = x.shape
|
| if seq_len >= max_len:
|
| return x[:, :max_len]
|
| pad = np.zeros((n_ch, max_len - seq_len), dtype=x.dtype)
|
| return np.concatenate([x, pad], axis=1)
|
|
|
|
|
|
|
| def extract_emg_signal(sample, seq_len):
|
| emg = np.stack([v for v in sample["emg"].values()], dtype=np.float32) / 128.0
|
| emg = bandpass_filter_emg(emg, 20.0, 90.0)
|
| emg = notch_filter_emg(emg, 50.0, 30.0)
|
| emg = zscore_per_channel(emg)
|
| emg = adjust_length(emg, seq_len)
|
| label = gesture_map.get(sample.get("gestureName", "notProvided"), 6)
|
| return emg, label
|
|
|
|
|
|
|
| def process_user_training(path, seq_len):
|
| train_X, train_y, val_X, val_y = [], [], [], []
|
| with open(path, "r", encoding="utf-8") as f:
|
| data = json.load(f)
|
| for sample in data.get("trainingSamples", {}).values():
|
| emg, lbl = extract_emg_signal(sample, seq_len)
|
| if lbl != 6:
|
| train_X.append(emg)
|
| train_y.append(lbl)
|
| for sample in data.get("testingSamples", {}).values():
|
| emg, lbl = extract_emg_signal(sample, seq_len)
|
| if lbl != 6:
|
| val_X.append(emg)
|
| val_y.append(lbl)
|
| return train_X, train_y, val_X, val_y
|
|
|
|
|
|
|
| def process_user_testing(path, seq_len):
|
| train_X, train_y, test_X, test_y = [], [], [], []
|
| with open(path, "r", encoding="utf-8") as f:
|
| data = json.load(f)
|
| buckets = {g: [] for g in gesture_map}
|
| for sample in data.get("trainingSamples", {}).values():
|
| buckets.setdefault(sample.get("gestureName", "notProvided"), []).append(sample)
|
| for samples in buckets.values():
|
| for i, sample in enumerate(samples):
|
| emg, lbl = extract_emg_signal(sample, seq_len)
|
| if lbl == 6:
|
| continue
|
| if i < 10:
|
| train_X.append(emg)
|
| train_y.append(lbl)
|
| else:
|
| test_X.append(emg)
|
| test_y.append(lbl)
|
| return train_X, train_y, test_X, test_y
|
|
|
|
|
|
|
| def save_h5(path, data, labels):
|
| with h5py.File(path, "w") as f:
|
| f.create_dataset("data", data=np.asarray(data, np.float32))
|
| f.create_dataset("label", data=np.asarray(labels, np.int64))
|
|
|
|
|
|
|
| def main():
|
| import argparse
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--download_data", action="store_true")
|
| parser.add_argument("--data_dir", type=str, required=True)
|
| parser.add_argument("--source_training", required=True)
|
| parser.add_argument("--source_testing", required=True)
|
| parser.add_argument("--dest_dir", required=True)
|
| parser.add_argument(
|
| "--seq_len", type=int, help="Size of the window in samples for segmentation."
|
| )
|
| parser.add_argument("--n_jobs", type=int, default=-1)
|
| args = parser.parse_args()
|
| data_dir = args.data_dir
|
| os.makedirs(args.dest_dir, exist_ok=True)
|
|
|
|
|
| if args.download_data:
|
|
|
| url = "https://zenodo.org/records/4421500/files/EMG-EPN612%20Dataset.zip?download=1"
|
| os.system(f"wget -O {data_dir}/EMG-EPN612_Dataset.zip {url}")
|
| os.system(f"unzip -o {data_dir}/EMG-EPN612_Dataset.zip -d {data_dir}")
|
|
|
| os.system(rf"mv {data_dir}/EMG-EPN612\ Dataset/* {data_dir}/")
|
| os.system(f"rmdir {data_dir}/EMG-EPN612_Dataset")
|
|
|
| os.system(f"rm {data_dir}/EMG-EPN612_Dataset.zip")
|
| print(f"Downloaded and unzipped dataset\n{data_dir}/EMG-EPN612_Dataset.zip")
|
| sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
|
|
|
| seq_len = args.seq_len
|
|
|
| window_seconds = sequence_to_seconds(seq_len, tfs)
|
| print(f"Window size: {seq_len} samples ({window_seconds:.2f} seconds)")
|
|
|
| train_X, train_y, val_X, val_y, test_X, test_y = [], [], [], [], [], []
|
|
|
| paths = glob.glob(os.path.join(args.source_training, "user*", "user*.json"))
|
|
|
|
|
| results = Parallel(n_jobs=args.n_jobs)(
|
| delayed(process_user_training)(p, seq_len)
|
| for p in tqdm(paths, desc="Training files")
|
| )
|
| for tX, ty, vX, vy in results:
|
| train_X.extend(tX)
|
| train_y.extend(ty)
|
| val_X.extend(vX)
|
| val_y.extend(vy)
|
|
|
|
|
| test_results = Parallel(n_jobs=args.n_jobs)(
|
| delayed(process_user_testing)(p, seq_len)
|
| for p in tqdm(
|
| glob.glob(os.path.join(args.source_testing, "user*", "user*.json")),
|
| desc="Testing files",
|
| )
|
| )
|
| for tX, ty, teX, tey in test_results:
|
| train_X.extend(tX)
|
| train_y.extend(ty)
|
| test_X.extend(teX)
|
| test_y.extend(tey)
|
|
|
|
|
| save_h5(os.path.join(args.dest_dir, "train.h5"), train_X, train_y)
|
| save_h5(os.path.join(args.dest_dir, "val.h5"), val_X, val_y)
|
| save_h5(os.path.join(args.dest_dir, "test.h5"), test_X, test_y)
|
|
|
|
|
| for split, X, y in [
|
| ("Train", train_X, train_y),
|
| ("Val", val_X, val_y),
|
| ("Test", test_X, test_y),
|
| ]:
|
| arr = np.array(y)
|
| uniq, cnt = np.unique(arr, return_counts=True)
|
| uniq = [i.item() for i in uniq]
|
| cnt = [i.item() for i in cnt]
|
| print(f"{split} → total={len(y)}, classes={{}}".format(dict(zip(uniq, cnt))))
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|