File size: 3,218 Bytes
56603ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import pyxdf
from scipy import signal
import numpy as np
import os
import urllib.request
import tarfile
import h5py

def fix_length(emg_signal, target_length=120_000):
    """
    Ensures a 2D EMG signal (time x channels) is exactly target_length samples long.
    Pads with zeros or trims only along the time axis.
    """
    current_length = emg_signal.shape[0]

    if current_length > target_length:
        # Trim along time dimension
        return emg_signal[:target_length, :]
    elif current_length < target_length:
        # Pad with zeros at the end along the time axis
        pad_amount = target_length - current_length
        padding = ((0, pad_amount), (0, 0))  # pad only time axis
        return np.pad(emg_signal, padding, mode='constant')
    else:
        # Already correct length
        return emg_signal
        
'''
Example of how to preprocess an EMG recording from emg2pose
High Pass: 400Hz
Low Pass: 20Hz
Resample at: 1000Hz
'''
def preprocessing_emg():

    # Download file
    print("Downloading emg2pose (mini)...")
    url = "https://fb-ctrl-oss.s3.amazonaws.com/emg2pose/emg2pose_dataset_mini.tar"
    tar_path = os.path.join('./example_files/emg_sample', "emg2pose_dataset_mini.tar")
    urllib.request.urlretrieve(url, tar_path)
    # Extract tar file
    with tarfile.open(tar_path, "r") as tar:
        tar.extractall(path='./example_files/emg_sample')

    print("Download and extraction complete.")

    target_fs = 1000
    highpass = 20
    lowpass = 400
    
    # Loop over the downloaded hdf5 files
    f_names = [d for d in os.listdir('./example_files/emg_sample/emg2pose_dataset_mini') if d.endswith('hdf5')]
    
    # Get sample rate from documentation
    fs = 2000
    
    # Hardcode the channel names
    ch_names = np.array(['c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9', 'c10', 'c11', 'c12', 'c13', 'c14', 'c15', 'c16'])
    
    for f_i in f_names:
        
        with h5py.File(os.path.join('./example_files/emg_sample/emg2pose_dataset_mini', f_i), 'r') as f_emg:
            dset = f_emg['emg2pose']['timeseries']
            # Load everything into memory at once (structured NumPy array)
            data = dset[:]
            
        x = fix_length(data['emg']).T
    
        lowpass_applied = min(lowpass, fs / 2) - 0.5
        [b, a] = signal.butter(N=3, Wn=[highpass, lowpass_applied], btype='bandpass', fs=fs)
        x = signal.filtfilt(b, a, x, axis=-1)
        
        # Resampling
        if target_fs != fs:
            x = signal.resample(x, num=int(x.shape[-1] / fs * target_fs), axis=-1)
        
        # Convert to float16 only after filtering
        x = x.astype('float16')
        x = x.reshape(1, x.shape[0], x.shape[1])
        ch_names = np.array([c.lower().encode() for c in ch_names])
        break
        
    return x, ch_names

'''
Function to create patches for NeuroRVQ
'''
def create_patches(emg_signal, maximum_patches, patch_size, channels_use):
    n, c, t = emg_signal.shape  # Batch / trials, channels, time
    n_time = (maximum_patches // len(channels_use))
    emg_signal = emg_signal[:, :, :n_time * patch_size]
    emg_signal_patches = emg_signal[:, channels_use, :]
    return emg_signal_patches, n_time