TinyMyo / scripts /uci.py
MatteoFasulo's picture
Upload 9 files
ca8e271 verified
import os
import sys
from pathlib import Path
import h5py
import numpy as np
import scipy.signal as signal
from scipy.signal import iirnotch
# ─────────────────────────────────────────────
# Filtering utilities
# ─────────────────────────────────────────────
def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
nyq = 0.5 * fs
b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
return signal.filtfilt(b, a, emg, axis=0)
def notch_filter_emg(emg, notch_freq=50.0, Q=30.0, fs=200.0):
b, a = iirnotch(notch_freq / (0.5 * fs), Q)
return signal.filtfilt(b, a, emg, axis=0)
# ─────────────────────────────────────────────
# Core I/O + preprocessing helpers
# ─────────────────────────────────────────────
def read_emg_txt(txt_path):
"""
Read a txt file with columns: time ch1 … ch8 class.
Return float32 array of shape (N, 10).
"""
data = []
with open(txt_path, "r") as f:
for line in f.readlines()[1:]: # skip header
cols = line.strip().split()
if len(cols) == 10:
data.append(list(map(float, cols)))
return np.asarray(data, dtype=np.float32)
def preprocess_emg(arr, fs=200.0, remove_class0=True):
"""
1) optional removal of class-0 rows
2) band-pass β†’ notch β†’ Z-score (on 8 channels)
"""
if remove_class0:
arr = arr[arr[:, -1] >= 1]
if arr.size == 0:
return arr
emg = arr[:, 1:9] # (N, 8)
emg = bandpass_filter_emg(emg, 20, 90, fs)
emg = notch_filter_emg(emg, 50, 30, fs)
mu = emg.mean(axis=0)
sd = emg.std(axis=0, ddof=1)
sd[sd == 0] = 1.0
emg = (emg - mu) / sd
arr[:, 1:9] = emg
return arr
def find_label_runs(arr):
"""Group consecutive rows with identical class labels."""
runs = []
if arr.size == 0:
return runs
curr_lbl = int(arr[0, -1])
start = 0
for i in range(1, len(arr)):
lbl = int(arr[i, -1])
if lbl != curr_lbl:
runs.append((curr_lbl, arr[start:i]))
curr_lbl, start = lbl, i
runs.append((curr_lbl, arr[start:]))
return runs
def sliding_window_majority(seg_arr, window_size=1000, stride=500):
segs, labs = [], []
for start in range(0, len(seg_arr) - window_size + 1, stride):
win = seg_arr[start : start + window_size]
maj = np.argmax(np.bincount(win[:, -1].astype(int)))
segs.append(win[:, 1:9]) # keep 8-channel EMG
labs.append(maj)
return np.asarray(segs, dtype=np.float32), np.asarray(labs, dtype=np.int32)
def users_with_gesture(
data_root, gesture_id, subj_range=range(1, 37), return_counts=False
):
found = {}
for subj in subj_range:
subj_dir = os.path.join(data_root, f"{subj:02d}")
if not os.path.isdir(subj_dir):
continue
count = 0
for fname in os.listdir(subj_dir):
if not fname.endswith(".txt"):
continue
txt_path = os.path.join(subj_dir, fname)
try:
arr = read_emg_txt(txt_path)
except Exception:
# skip files we can't parse
continue
if arr.size == 0:
continue
# last column is class label (as float). Compare as int.
if np.any(arr[:, -1].astype(int) == int(gesture_id)):
# count occurrences (rows) of that gesture in this file
count += int((arr[:, -1].astype(int) == int(gesture_id)).sum())
if count > 0:
found[subj] = count
if return_counts:
return found # dict subj -> count
else:
return sorted(found.keys())
# ─────────────────────────────────────────────
# Safe concatenation utilities
# ─────────────────────────────────────────────
def concat_data(lst): # lst of (N,256,8)
return np.concatenate(lst, axis=0) if lst else np.empty((0, 1000, 8), np.float32)
def concat_label(lst):
return np.concatenate(lst, axis=0) if lst else np.empty((0,), np.int32)
# ─────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────
if __name__ == "__main__":
import argparse
arg = argparse.ArgumentParser(description="Convert UCI EMG dataset to h5 format.")
arg.add_argument("--download_data", action="store_true")
arg.add_argument(
"--data_dir",
type=str,
required=True,
help="Root directory of the UCI EMG dataset",
)
arg.add_argument(
"--save_dir",
type=str,
required=True,
help="Directory to save the output h5 files",
)
arg.add_argument("--window_size", type=int, help="Window size for sliding window")
arg.add_argument("--stride", type=int, help="Stride for sliding window")
args = arg.parse_args()
data_root = args.data_dir
save_root = args.save_dir
os.makedirs(save_root, exist_ok=True)
# download data if requested
if args.download_data:
# https://archive.ics.uci.edu/dataset/481/emg+data+for+gestures
base_url = (
"https://archive.ics.uci.edu/static/public/481/emg+data+for+gestures.zip"
)
os.system(f"wget -O {data_root}/emg_gestures.zip '{base_url}'")
os.system(f"unzip -o {data_root}/emg_gestures.zip -d {Path(data_root).parent}")
os.system(f"rm {data_root}/emg_gestures.zip")
print("Dataset downloaded and cleaned up.")
sys.exit("Rerun without --download_data.")
fs = 200.0 # sampling rate of MYO bracelet
window_size, stride = args.window_size, args.stride
split_map = {
"train": list(range(1, 25)), # 1–24
"val": list(range(25, 31)), # 25–30
"test": list(range(31, 37)), # 31–36
}
# remove users that performed gesture 7
gesture_id = 7
gesture7_users = users_with_gesture(data_root, gesture_id)
print(f"Users that performed gesture {gesture_id}:", gesture7_users)
keep_subjs = []
for k in split_map:
split_map[k] = [u for u in split_map[k] if u not in gesture7_users]
keep_subjs.extend(split_map[k])
print("Updated split map after removing gesture-7 users:", keep_subjs)
datasets = {k: {"data": [], "label": []} for k in split_map}
for subj in keep_subjs:
subj_dir = os.path.join(data_root, f"{subj:02d}")
if not os.path.isdir(subj_dir):
continue
split_key = next(k for k, v in split_map.items() if subj in v)
for fname in sorted(os.listdir(subj_dir)):
if not fname.endswith(".txt"):
continue
arr = read_emg_txt(os.path.join(subj_dir, fname))
arr = preprocess_emg(arr, fs)
for lbl, seg_arr in find_label_runs(arr):
segs, labs = sliding_window_majority(seg_arr, window_size, stride)
if segs.size:
datasets[split_key]["data"].append(segs)
datasets[split_key]["label"].append(labs - 1)
# concatenate, transpose & save
for split in ["train", "val", "test"]:
X = concat_data(datasets[split]["data"]) # (N,256,8)
y = concat_label(datasets[split]["label"])
X = X.transpose(0, 2, 1) # (N,8,256)
with h5py.File(os.path.join(save_root, f"{split}.h5"), "w") as f:
f.create_dataset("data", data=X.astype(np.float32))
f.create_dataset("label", data=y.astype(np.int32))
uniq, cnt = np.unique(y, return_counts=True)
print(
f"{split.upper():5} β†’ X={X.shape}, label dist:",
dict(zip(uniq.tolist(), cnt.tolist())),
)
print("\nAll splits saved to:", save_root)