English
Eval Results (legacy)
File size: 8,656 Bytes
ca8e271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import os
import sys
from pathlib import Path

import h5py
import numpy as np
import scipy.signal as signal
from scipy.signal import iirnotch


# ─────────────────────────────────────────────
# Filtering utilities
# ─────────────────────────────────────────────
def bandpass_filter_emg(emg, lowcut=20.0, highcut=90.0, fs=200.0, order=4):
    nyq = 0.5 * fs
    b, a = signal.butter(order, [lowcut / nyq, highcut / nyq], btype="bandpass")
    return signal.filtfilt(b, a, emg, axis=0)


def notch_filter_emg(emg, notch_freq=50.0, Q=30.0, fs=200.0):
    b, a = iirnotch(notch_freq / (0.5 * fs), Q)
    return signal.filtfilt(b, a, emg, axis=0)


# ─────────────────────────────────────────────
# Core I/O + preprocessing helpers
# ─────────────────────────────────────────────
def read_emg_txt(txt_path):
    """

    Read a txt file with columns: time ch1 … ch8 class.

    Return float32 array of shape (N, 10).

    """
    data = []
    with open(txt_path, "r") as f:
        for line in f.readlines()[1:]:  # skip header
            cols = line.strip().split()
            if len(cols) == 10:
                data.append(list(map(float, cols)))
    return np.asarray(data, dtype=np.float32)


def preprocess_emg(arr, fs=200.0, remove_class0=True):
    """

    1) optional removal of class-0 rows

    2) band-pass β†’ notch β†’ Z-score  (on 8 channels)

    """
    if remove_class0:
        arr = arr[arr[:, -1] >= 1]
    if arr.size == 0:
        return arr

    emg = arr[:, 1:9]  # (N, 8)
    emg = bandpass_filter_emg(emg, 20, 90, fs)
    emg = notch_filter_emg(emg, 50, 30, fs)

    mu = emg.mean(axis=0)
    sd = emg.std(axis=0, ddof=1)
    sd[sd == 0] = 1.0
    emg = (emg - mu) / sd

    arr[:, 1:9] = emg
    return arr


def find_label_runs(arr):
    """Group consecutive rows with identical class labels."""
    runs = []
    if arr.size == 0:
        return runs
    curr_lbl = int(arr[0, -1])
    start = 0
    for i in range(1, len(arr)):
        lbl = int(arr[i, -1])
        if lbl != curr_lbl:
            runs.append((curr_lbl, arr[start:i]))
            curr_lbl, start = lbl, i
    runs.append((curr_lbl, arr[start:]))
    return runs


def sliding_window_majority(seg_arr, window_size=1000, stride=500):
    segs, labs = [], []
    for start in range(0, len(seg_arr) - window_size + 1, stride):
        win = seg_arr[start : start + window_size]
        maj = np.argmax(np.bincount(win[:, -1].astype(int)))
        segs.append(win[:, 1:9])  # keep 8-channel EMG
        labs.append(maj)
    return np.asarray(segs, dtype=np.float32), np.asarray(labs, dtype=np.int32)


def users_with_gesture(

    data_root, gesture_id, subj_range=range(1, 37), return_counts=False

):
    found = {}
    for subj in subj_range:
        subj_dir = os.path.join(data_root, f"{subj:02d}")
        if not os.path.isdir(subj_dir):
            continue
        count = 0
        for fname in os.listdir(subj_dir):
            if not fname.endswith(".txt"):
                continue
            txt_path = os.path.join(subj_dir, fname)
            try:
                arr = read_emg_txt(txt_path)
            except Exception:
                # skip files we can't parse
                continue
            if arr.size == 0:
                continue
            # last column is class label (as float). Compare as int.
            if np.any(arr[:, -1].astype(int) == int(gesture_id)):
                # count occurrences (rows) of that gesture in this file
                count += int((arr[:, -1].astype(int) == int(gesture_id)).sum())
        if count > 0:
            found[subj] = count

    if return_counts:
        return found  # dict subj -> count
    else:
        return sorted(found.keys())


# ─────────────────────────────────────────────
# Safe concatenation utilities
# ─────────────────────────────────────────────
def concat_data(lst):  # lst of (N,256,8)
    return np.concatenate(lst, axis=0) if lst else np.empty((0, 1000, 8), np.float32)


def concat_label(lst):
    return np.concatenate(lst, axis=0) if lst else np.empty((0,), np.int32)


# ─────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────
if __name__ == "__main__":
    import argparse

    arg = argparse.ArgumentParser(description="Convert UCI EMG dataset to h5 format.")
    arg.add_argument("--download_data", action="store_true")
    arg.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Root directory of the UCI EMG dataset",
    )
    arg.add_argument(
        "--save_dir",
        type=str,
        required=True,
        help="Directory to save the output h5 files",
    )
    arg.add_argument("--window_size", type=int, help="Window size for sliding window")
    arg.add_argument("--stride", type=int, help="Stride for sliding window")
    args = arg.parse_args()

    data_root = args.data_dir
    save_root = args.save_dir
    os.makedirs(save_root, exist_ok=True)

    # download data if requested
    if args.download_data:
        # https://archive.ics.uci.edu/dataset/481/emg+data+for+gestures
        base_url = (
            "https://archive.ics.uci.edu/static/public/481/emg+data+for+gestures.zip"
        )
        os.system(f"wget -O {data_root}/emg_gestures.zip '{base_url}'")
        os.system(f"unzip -o {data_root}/emg_gestures.zip -d {Path(data_root).parent}")
        os.system(f"rm {data_root}/emg_gestures.zip")
        print("Dataset downloaded and cleaned up.")
        sys.exit("Rerun without --download_data.")

    fs = 200.0  # sampling rate of MYO bracelet
    window_size, stride = args.window_size, args.stride

    split_map = {
        "train": list(range(1, 25)),  # 1–24
        "val": list(range(25, 31)),  # 25–30
        "test": list(range(31, 37)),  # 31–36
    }
    # remove users that performed gesture 7
    gesture_id = 7
    gesture7_users = users_with_gesture(data_root, gesture_id)
    print(f"Users that performed gesture {gesture_id}:", gesture7_users)

    keep_subjs = []
    for k in split_map:
        split_map[k] = [u for u in split_map[k] if u not in gesture7_users]
        keep_subjs.extend(split_map[k])
    print("Updated split map after removing gesture-7 users:", keep_subjs)

    datasets = {k: {"data": [], "label": []} for k in split_map}

    for subj in keep_subjs:
        subj_dir = os.path.join(data_root, f"{subj:02d}")
        if not os.path.isdir(subj_dir):
            continue
        split_key = next(k for k, v in split_map.items() if subj in v)

        for fname in sorted(os.listdir(subj_dir)):
            if not fname.endswith(".txt"):
                continue
            arr = read_emg_txt(os.path.join(subj_dir, fname))
            arr = preprocess_emg(arr, fs)

            for lbl, seg_arr in find_label_runs(arr):
                segs, labs = sliding_window_majority(seg_arr, window_size, stride)
                if segs.size:
                    datasets[split_key]["data"].append(segs)
                    datasets[split_key]["label"].append(labs - 1)

    # concatenate, transpose & save
    for split in ["train", "val", "test"]:
        X = concat_data(datasets[split]["data"])  # (N,256,8)
        y = concat_label(datasets[split]["label"])
        X = X.transpose(0, 2, 1)  # (N,8,256)

        with h5py.File(os.path.join(save_root, f"{split}.h5"), "w") as f:
            f.create_dataset("data", data=X.astype(np.float32))
            f.create_dataset("label", data=y.astype(np.int32))
        uniq, cnt = np.unique(y, return_counts=True)
        print(
            f"{split.upper():5} β†’ X={X.shape}, label dist:",
            dict(zip(uniq.tolist(), cnt.tolist())),
        )

    print("\nAll splits saved to:", save_root)