|
|
import os
|
|
|
import sys
|
|
|
from pathlib import Path
|
|
|
|
|
|
import h5py
|
|
|
import numpy as np
|
|
|
import scipy.signal as signal
|
|
|
from scipy.signal import iirnotch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
|
|
|
nyq = 0.5 * fs
|
|
|
b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
|
|
|
return signal.filtfilt(b, a, emg, axis=0)
|
|
|
|
|
|
|
|
|
def notch_filter_emg(emg, notch_freq=50.0, Q=30.0, fs=200.0):
|
|
|
b, a = iirnotch(notch_freq / (0.5 * fs), Q)
|
|
|
return signal.filtfilt(b, a, emg, axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_emg_txt(txt_path):
|
|
|
"""
|
|
|
Read a txt file with columns: time ch1 β¦ ch8 class.
|
|
|
Return float32 array of shape (N, 10).
|
|
|
"""
|
|
|
data = []
|
|
|
with open(txt_path, "r") as f:
|
|
|
for line in f.readlines()[1:]:
|
|
|
cols = line.strip().split()
|
|
|
if len(cols) == 10:
|
|
|
data.append(list(map(float, cols)))
|
|
|
return np.asarray(data, dtype=np.float32)
|
|
|
|
|
|
|
|
|
def preprocess_emg(arr, fs=200.0, remove_class0=True):
|
|
|
"""
|
|
|
1) optional removal of class-0 rows
|
|
|
2) band-pass β notch β Z-score (on 8 channels)
|
|
|
"""
|
|
|
if remove_class0:
|
|
|
arr = arr[arr[:, -1] >= 1]
|
|
|
if arr.size == 0:
|
|
|
return arr
|
|
|
|
|
|
emg = arr[:, 1:9]
|
|
|
emg = bandpass_filter_emg(emg, 20, 90, fs)
|
|
|
emg = notch_filter_emg(emg, 50, 30, fs)
|
|
|
|
|
|
mu = emg.mean(axis=0)
|
|
|
sd = emg.std(axis=0, ddof=1)
|
|
|
sd[sd == 0] = 1.0
|
|
|
emg = (emg - mu) / sd
|
|
|
|
|
|
arr[:, 1:9] = emg
|
|
|
return arr
|
|
|
|
|
|
|
|
|
def find_label_runs(arr):
|
|
|
"""Group consecutive rows with identical class labels."""
|
|
|
runs = []
|
|
|
if arr.size == 0:
|
|
|
return runs
|
|
|
curr_lbl = int(arr[0, -1])
|
|
|
start = 0
|
|
|
for i in range(1, len(arr)):
|
|
|
lbl = int(arr[i, -1])
|
|
|
if lbl != curr_lbl:
|
|
|
runs.append((curr_lbl, arr[start:i]))
|
|
|
curr_lbl, start = lbl, i
|
|
|
runs.append((curr_lbl, arr[start:]))
|
|
|
return runs
|
|
|
|
|
|
|
|
|
def sliding_window_majority(seg_arr, window_size=1000, stride=500):
|
|
|
segs, labs = [], []
|
|
|
for start in range(0, len(seg_arr) - window_size + 1, stride):
|
|
|
win = seg_arr[start : start + window_size]
|
|
|
maj = np.argmax(np.bincount(win[:, -1].astype(int)))
|
|
|
segs.append(win[:, 1:9])
|
|
|
labs.append(maj)
|
|
|
return np.asarray(segs, dtype=np.float32), np.asarray(labs, dtype=np.int32)
|
|
|
|
|
|
|
|
|
def users_with_gesture(
|
|
|
data_root, gesture_id, subj_range=range(1, 37), return_counts=False
|
|
|
):
|
|
|
found = {}
|
|
|
for subj in subj_range:
|
|
|
subj_dir = os.path.join(data_root, f"{subj:02d}")
|
|
|
if not os.path.isdir(subj_dir):
|
|
|
continue
|
|
|
count = 0
|
|
|
for fname in os.listdir(subj_dir):
|
|
|
if not fname.endswith(".txt"):
|
|
|
continue
|
|
|
txt_path = os.path.join(subj_dir, fname)
|
|
|
try:
|
|
|
arr = read_emg_txt(txt_path)
|
|
|
except Exception:
|
|
|
|
|
|
continue
|
|
|
if arr.size == 0:
|
|
|
continue
|
|
|
|
|
|
if np.any(arr[:, -1].astype(int) == int(gesture_id)):
|
|
|
|
|
|
count += int((arr[:, -1].astype(int) == int(gesture_id)).sum())
|
|
|
if count > 0:
|
|
|
found[subj] = count
|
|
|
|
|
|
if return_counts:
|
|
|
return found
|
|
|
else:
|
|
|
return sorted(found.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def concat_data(lst):
|
|
|
return np.concatenate(lst, axis=0) if lst else np.empty((0, 1000, 8), np.float32)
|
|
|
|
|
|
|
|
|
def concat_label(lst):
|
|
|
return np.concatenate(lst, axis=0) if lst else np.empty((0,), np.int32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import argparse
|
|
|
|
|
|
arg = argparse.ArgumentParser(description="Convert UCI EMG dataset to h5 format.")
|
|
|
arg.add_argument("--download_data", action="store_true")
|
|
|
arg.add_argument(
|
|
|
"--data_dir",
|
|
|
type=str,
|
|
|
required=True,
|
|
|
help="Root directory of the UCI EMG dataset",
|
|
|
)
|
|
|
arg.add_argument(
|
|
|
"--save_dir",
|
|
|
type=str,
|
|
|
required=True,
|
|
|
help="Directory to save the output h5 files",
|
|
|
)
|
|
|
arg.add_argument("--window_size", type=int, help="Window size for sliding window")
|
|
|
arg.add_argument("--stride", type=int, help="Stride for sliding window")
|
|
|
args = arg.parse_args()
|
|
|
|
|
|
data_root = args.data_dir
|
|
|
save_root = args.save_dir
|
|
|
os.makedirs(save_root, exist_ok=True)
|
|
|
|
|
|
|
|
|
if args.download_data:
|
|
|
|
|
|
base_url = (
|
|
|
"https://archive.ics.uci.edu/static/public/481/emg+data+for+gestures.zip"
|
|
|
)
|
|
|
os.system(f"wget -O {data_root}/emg_gestures.zip '{base_url}'")
|
|
|
os.system(f"unzip -o {data_root}/emg_gestures.zip -d {Path(data_root).parent}")
|
|
|
os.system(f"rm {data_root}/emg_gestures.zip")
|
|
|
print("Dataset downloaded and cleaned up.")
|
|
|
sys.exit("Rerun without --download_data.")
|
|
|
|
|
|
fs = 200.0
|
|
|
window_size, stride = args.window_size, args.stride
|
|
|
|
|
|
split_map = {
|
|
|
"train": list(range(1, 25)),
|
|
|
"val": list(range(25, 31)),
|
|
|
"test": list(range(31, 37)),
|
|
|
}
|
|
|
|
|
|
gesture_id = 7
|
|
|
gesture7_users = users_with_gesture(data_root, gesture_id)
|
|
|
print(f"Users that performed gesture {gesture_id}:", gesture7_users)
|
|
|
|
|
|
keep_subjs = []
|
|
|
for k in split_map:
|
|
|
split_map[k] = [u for u in split_map[k] if u not in gesture7_users]
|
|
|
keep_subjs.extend(split_map[k])
|
|
|
print("Updated split map after removing gesture-7 users:", keep_subjs)
|
|
|
|
|
|
datasets = {k: {"data": [], "label": []} for k in split_map}
|
|
|
|
|
|
for subj in keep_subjs:
|
|
|
subj_dir = os.path.join(data_root, f"{subj:02d}")
|
|
|
if not os.path.isdir(subj_dir):
|
|
|
continue
|
|
|
split_key = next(k for k, v in split_map.items() if subj in v)
|
|
|
|
|
|
for fname in sorted(os.listdir(subj_dir)):
|
|
|
if not fname.endswith(".txt"):
|
|
|
continue
|
|
|
arr = read_emg_txt(os.path.join(subj_dir, fname))
|
|
|
arr = preprocess_emg(arr, fs)
|
|
|
|
|
|
for lbl, seg_arr in find_label_runs(arr):
|
|
|
segs, labs = sliding_window_majority(seg_arr, window_size, stride)
|
|
|
if segs.size:
|
|
|
datasets[split_key]["data"].append(segs)
|
|
|
datasets[split_key]["label"].append(labs - 1)
|
|
|
|
|
|
|
|
|
for split in ["train", "val", "test"]:
|
|
|
X = concat_data(datasets[split]["data"])
|
|
|
y = concat_label(datasets[split]["label"])
|
|
|
X = X.transpose(0, 2, 1)
|
|
|
|
|
|
with h5py.File(os.path.join(save_root, f"{split}.h5"), "w") as f:
|
|
|
f.create_dataset("data", data=X.astype(np.float32))
|
|
|
f.create_dataset("label", data=y.astype(np.int32))
|
|
|
uniq, cnt = np.unique(y, return_counts=True)
|
|
|
print(
|
|
|
f"{split.upper():5} β X={X.shape}, label dist:",
|
|
|
dict(zip(uniq.tolist(), cnt.tolist())),
|
|
|
)
|
|
|
|
|
|
print("\nAll splits saved to:", save_root)
|
|
|
|