TinyMyo / scripts /db5.py
MatteoFasulo's picture
Upload 9 files
ca8e271 verified
import os
import sys
import h5py
import numpy as np
import scipy.io
import scipy.signal as signal
from scipy.signal import iirnotch
# ==== Data augmentation functions ====
def random_amplitude_scale(sig, scale_range=(0.9, 1.1)):
scale = np.random.uniform(*scale_range)
return sig * scale
def random_time_jitter(sig, jitter_ratio=0.01):
T, D = sig.shape
std_ch = np.std(sig, axis=0)
noise = np.random.randn(T, D) * (jitter_ratio * std_ch)
return sig + noise
def random_channel_dropout(sig, dropout_prob=0.05):
T, D = sig.shape
mask = np.random.rand(D) < dropout_prob
sig[:, mask] = 0.0
return sig
def augment_one_sample(seg):
out = seg.copy()
out = random_amplitude_scale(out, (0.9, 1.1))
out = random_time_jitter(out, 0.01)
out = random_channel_dropout(out, 0.05)
return out
def augment_train_data(data, labels, factor=3):
if factor <= 0 or data.shape[0] == 0:
return data, labels
aug_segs = [data]
aug_lbls = [labels]
N = data.shape[0]
for i in range(N):
seg = data[i] # [window_size, n_ch]
lab = labels[i]
for _ in range(factor):
aug_segs.append(augment_one_sample(seg)[None, ...])
aug_lbls.append([lab])
new_data = np.concatenate(aug_segs, axis=0)
new_labels = np.concatenate(aug_lbls, axis=0).ravel()
return new_data, new_labels
# ==== Filter functions (operate at original fs=200) ====
def notch_filter(data, notch_freq=50.0, Q=30.0, fs=200.0):
b, a = iirnotch(notch_freq, Q, fs)
out = np.zeros_like(data)
for ch in range(data.shape[1]):
out[:, ch] = signal.filtfilt(b, a, data[:, ch])
return out
def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = signal.butter(order, [low, high], btype="bandpass")
out = np.zeros_like(emg)
for c in range(emg.shape[1]):
out[:, c] = signal.filtfilt(b, a, emg[:, c])
return out
# ==== Window segmentation ====
def process_emg_features(emg, label, rerep, window_size=1024, stride=512):
segs, lbls, reps = [], [], []
N = len(label)
for start in range(0, N, stride):
end = start + window_size
if end > N:
cut = emg[start:N]
pad = np.zeros((end - N, emg.shape[1]))
win = np.vstack([cut, pad])
else:
win = emg[start:end]
segs.append(win)
lbls.append(label[start])
reps.append(rerep[start])
return np.array(segs), np.array(lbls), np.array(reps)
# ==== Main pipeline ====
def main():
import argparse
args = argparse.ArgumentParser(description="Process EMG data from DB5.")
args.add_argument("--download_data", action="store_true")
args.add_argument("--data_dir", type=str)
args.add_argument("--save_dir", type=str)
args.add_argument(
"--window_size", type=int, help="Size of the sliding window for segmentation."
)
args.add_argument(
"--stride", type=int, help="Stride for the sliding window segmentation."
)
args = args.parse_args()
data_dir = args.data_dir
save_dir = args.save_dir
os.makedirs(save_dir, exist_ok=True)
# download data if requested
if args.download_data:
# https://ninapro.hevs.ch/instructions/DB5.html
len_data = range(1, 11) # 1–10
base_url = "https://ninapro.hevs.ch/files/DB5_Preproc/"
# download and unzip
for i in len_data:
url = f"{base_url}s{i}.zip"
os.system(f"wget -P {data_dir} {url}")
os.system(f"unzip -o {data_dir}/s{i}.zip -d {data_dir}")
os.system(f"rm {data_dir}/s{i}.zip")
print(f"Downloaded and unzipped subject {i}\n{data_dir}/s{i}.zip")
sys.exit("Data downloaded and unzipped. Rerun without --download_data.")
fs = 200.0 # original sampling rate
window_size, stride = args.window_size, args.stride
train_reps = [1, 3, 4, 6]
val_reps = [2]
test_reps = [5]
all_data = {"train": [], "val": [], "test": []}
all_lbls = {"train": [], "val": [], "test": []}
for subj in sorted(os.listdir(data_dir)):
subj_path = os.path.join(data_dir, subj)
if not os.path.isdir(subj_path):
continue
print(f"Processing subject {subj}...")
for mat in sorted(os.listdir(subj_path)):
if not mat.endswith(".mat"):
continue
dd = scipy.io.loadmat(os.path.join(subj_path, mat))
emg = dd["emg"] # [N,16]
label = dd["restimulus"].ravel().astype(int)
rerep = dd["rerepetition"].ravel().astype(int)
# label shift by exercise
if "E2" in mat:
label = np.where(label != 0, label + 12, 0)
elif "E3" in mat:
label = np.where(label != 0, label + 29, 0)
# filtering at original 200 Hz
emg_filt = bandpass_filter_emg(emg, 20, 90, fs=fs)
emg_filt = notch_filter(emg_filt, 50, 30, fs=fs)
# z-score
mu = emg_filt.mean(axis=0)
sd = emg_filt.std(axis=0, ddof=1)
sd[sd == 0] = 1.0
emg_z = (emg_filt - mu) / sd
# segment
segs, lbls, reps = process_emg_features(
emg_z, label, rerep, window_size, stride
)
# split by repetition index
for seg, lab, rp in zip(segs, lbls, reps):
if rp in train_reps:
all_data["train"].append(seg)
all_lbls["train"].append(lab)
elif rp in val_reps:
all_data["val"].append(seg)
all_lbls["val"].append(lab)
elif rp in test_reps:
all_data["test"].append(seg)
all_lbls["test"].append(lab)
# stack, augment train, transpose, save, and print stats
stats = {}
for split in ["train", "val", "test"]:
X = np.stack(all_data[split], axis=0) # [N, window_size, ch]
y = np.array(all_lbls[split], dtype=int)
if split == "train":
X, y = augment_train_data(X, y, factor=3)
# transpose to [N, ch, window_size]
X = X.transpose(0, 2, 1)
# save
with h5py.File(os.path.join(save_dir, f"{split}.h5"), "w") as hf:
hf.create_dataset("data", data=X)
hf.create_dataset("label", data=y)
# compute stats
uniq, cnt = np.unique(y, return_counts=True)
stats[split] = (X.shape, dict(zip(uniq.tolist(), cnt.tolist())))
# print stats
for split, (shape, dist) in stats.items():
print(f"\n{split} → X={shape}, label distribution:")
for lab, count in dist.items():
print(f" label {lab}: {count} samples")
if __name__ == "__main__":
main()