|
|
import glob
|
|
|
import json
|
|
|
import os
|
|
|
import sys
|
|
|
|
|
|
import h5py
|
|
|
import numpy as np
|
|
|
import scipy.signal as signal
|
|
|
from joblib import Parallel, delayed
|
|
|
from scipy.signal import iirnotch
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
|
tfs, n_ch = 200.0, 8
|
|
|
|
|
|
|
|
|
gesture_map = {
|
|
|
"noGesture": 0,
|
|
|
"waveIn": 1,
|
|
|
"waveOut": 2,
|
|
|
"pinch": 3,
|
|
|
"open": 4,
|
|
|
"fist": 5,
|
|
|
"notProvided": 6,
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def bandpass_filter_emg(emg, low=20.0, high=90.0, fs=tfs, order=4):
|
|
|
nyq = 0.5 * fs
|
|
|
b, a = signal.butter(order, [low / nyq, high / nyq], btype="bandpass")
|
|
|
return signal.filtfilt(b, a, emg, axis=1)
|
|
|
|
|
|
|
|
|
def notch_filter_emg(emg, notch=50.0, Q=30.0, fs=tfs):
|
|
|
w0 = notch / (0.5 * fs)
|
|
|
b, a = iirnotch(w0, Q)
|
|
|
return signal.filtfilt(b, a, emg, axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
def zscore_per_channel(emg):
|
|
|
mean = emg.mean(axis=1, keepdims=True)
|
|
|
std = emg.std(axis=1, ddof=1, keepdims=True)
|
|
|
std[std == 0] = 1.0
|
|
|
return (emg - mean) / std
|
|
|
|
|
|
|
|
|
def adjust_length(x, max_len):
|
|
|
n_ch, seq_len = x.shape
|
|
|
if seq_len >= max_len:
|
|
|
return x[:, :max_len]
|
|
|
pad = np.zeros((n_ch, max_len - seq_len), dtype=x.dtype)
|
|
|
return np.concatenate([x, pad], axis=1)
|
|
|
|
|
|
|
|
|
|
|
|
def extract_emg_signal(sample, seq_len):
|
|
|
emg = np.stack([v for v in sample["emg"].values()], dtype=np.float32) / 128.0
|
|
|
emg = bandpass_filter_emg(emg, 20.0, 90.0)
|
|
|
emg = notch_filter_emg(emg, 50.0, 30.0)
|
|
|
emg = zscore_per_channel(emg)
|
|
|
emg = adjust_length(emg, seq_len)
|
|
|
label = gesture_map.get(sample.get("gestureName", "notProvided"), 6)
|
|
|
return emg, label
|
|
|
|
|
|
|
|
|
|
|
|
def process_user_training(path, seq_len):
|
|
|
train_X, train_y, val_X, val_y = [], [], [], []
|
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
|
data = json.load(f)
|
|
|
for sample in data.get("trainingSamples", {}).values():
|
|
|
emg, lbl = extract_emg_signal(sample, seq_len)
|
|
|
if lbl != 6:
|
|
|
train_X.append(emg)
|
|
|
train_y.append(lbl)
|
|
|
for sample in data.get("testingSamples", {}).values():
|
|
|
emg, lbl = extract_emg_signal(sample, seq_len)
|
|
|
if lbl != 6:
|
|
|
val_X.append(emg)
|
|
|
val_y.append(lbl)
|
|
|
return train_X, train_y, val_X, val_y
|
|
|
|
|
|
|
|
|
|
|
|
def process_user_testing(path, seq_len):
|
|
|
train_X, train_y, test_X, test_y = [], [], [], []
|
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
|
data = json.load(f)
|
|
|
buckets = {g: [] for g in gesture_map}
|
|
|
for sample in data.get("trainingSamples", {}).values():
|
|
|
buckets.setdefault(sample.get("gestureName", "notProvided"), []).append(sample)
|
|
|
for samples in buckets.values():
|
|
|
for i, sample in enumerate(samples):
|
|
|
emg, lbl = extract_emg_signal(sample, seq_len)
|
|
|
if lbl == 6:
|
|
|
continue
|
|
|
if i < 10:
|
|
|
train_X.append(emg)
|
|
|
train_y.append(lbl)
|
|
|
else:
|
|
|
test_X.append(emg)
|
|
|
test_y.append(lbl)
|
|
|
return train_X, train_y, test_X, test_y
|
|
|
|
|
|
|
|
|
|
|
|
def save_h5(path, data, labels):
|
|
|
with h5py.File(path, "w") as f:
|
|
|
f.create_dataset("data", data=np.asarray(data, np.float32))
|
|
|
f.create_dataset("label", data=np.asarray(labels, np.int64))
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
import argparse
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--download_data", action="store_true")
|
|
|
parser.add_argument("--data_dir", type=str, required=True)
|
|
|
parser.add_argument("--source_training", required=True)
|
|
|
parser.add_argument("--source_testing", required=True)
|
|
|
parser.add_argument("--dest_dir", required=True)
|
|
|
parser.add_argument("--window_size", type=int, required=True)
|
|
|
parser.add_argument("--n_jobs", type=int, default=-1)
|
|
|
args = parser.parse_args()
|
|
|
data_dir = args.data_dir
|
|
|
os.makedirs(args.dest_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
if args.download_data:
|
|
|
|
|
|
url = "https://zenodo.org/records/4421500/files/EMG-EPN612%20Dataset.zip?download=1"
|
|
|
os.system(f"wget -O {data_dir}/EMG-EPN612_Dataset.zip {url}")
|
|
|
os.system(f"unzip -o {data_dir}/EMG-EPN612_Dataset.zip -d {data_dir}")
|
|
|
|
|
|
os.system(rf"mv {data_dir}/EMG-EPN612\ Dataset/* {data_dir}/")
|
|
|
os.system(f"rmdir {data_dir}/EMG-EPN612_Dataset")
|
|
|
|
|
|
os.system(f"rm {data_dir}/EMG-EPN612_Dataset.zip")
|
|
|
print(f"Downloaded and unzipped dataset\n{data_dir}/EMG-EPN612_Dataset.zip")
|
|
|
sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
|
|
|
|
|
|
seq_len = args.window_size
|
|
|
train_X, train_y, val_X, val_y, test_X, test_y = [], [], [], [], [], []
|
|
|
|
|
|
paths = glob.glob(os.path.join(args.source_training, "user*", "user*.json"))
|
|
|
|
|
|
|
|
|
results = Parallel(n_jobs=args.n_jobs)(
|
|
|
delayed(process_user_training)(p, seq_len)
|
|
|
for p in tqdm(paths, desc="Training files")
|
|
|
)
|
|
|
for tX, ty, vX, vy in results:
|
|
|
train_X.extend(tX)
|
|
|
train_y.extend(ty)
|
|
|
val_X.extend(vX)
|
|
|
val_y.extend(vy)
|
|
|
|
|
|
|
|
|
test_results = Parallel(n_jobs=args.n_jobs)(
|
|
|
delayed(process_user_testing)(p, seq_len)
|
|
|
for p in tqdm(
|
|
|
glob.glob(os.path.join(args.source_testing, "user*", "user*.json")),
|
|
|
desc="Testing files",
|
|
|
)
|
|
|
)
|
|
|
for tX, ty, teX, tey in test_results:
|
|
|
train_X.extend(tX)
|
|
|
train_y.extend(ty)
|
|
|
test_X.extend(teX)
|
|
|
test_y.extend(tey)
|
|
|
|
|
|
|
|
|
save_h5(os.path.join(args.dest_dir, "train.h5"), train_X, train_y)
|
|
|
save_h5(os.path.join(args.dest_dir, "val.h5"), val_X, val_y)
|
|
|
save_h5(os.path.join(args.dest_dir, "test.h5"), test_X, test_y)
|
|
|
|
|
|
|
|
|
for split, X, y in [
|
|
|
("Train", train_X, train_y),
|
|
|
("Val", val_X, val_y),
|
|
|
("Test", test_X, test_y),
|
|
|
]:
|
|
|
arr = np.array(y)
|
|
|
uniq, cnt = np.unique(arr, return_counts=True)
|
|
|
uniq = [i.item() for i in uniq]
|
|
|
cnt = [i.item() for i in cnt]
|
|
|
print(f"{split} → total={len(y)}, classes={{}}".format(dict(zip(uniq, cnt))))
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|