|
|
from torch.utils.data import Dataset |
|
|
from torchvision.transforms import functional as F |
|
|
import torch |
|
|
import numpy as np |
|
|
from scipy import signal |
|
|
from functools import reduce |
|
|
from scipy.signal import butter, lfilter, detrend |
|
|
|
|
|
class Augmentations: |
|
|
def __init__(self, padding=120, crop_length=6000, fs=100, lowcut=0.2, highcut=40, order=5): |
|
|
self.padding = padding |
|
|
self.crop_length = crop_length |
|
|
self.fs = fs |
|
|
self.lowcut = lowcut |
|
|
self.highcut = highcut |
|
|
self.order = order |
|
|
|
|
|
b, a = self.butter_bandpass(self.lowcut, self.highcut, self.fs, self.order) |
|
|
self.filter_b = b |
|
|
self.filter_a = a |
|
|
|
|
|
def butter_bandpass(self, lowcut, highcut, fs, order=5): |
|
|
return butter(order, [lowcut, highcut], fs=fs, btype='band') |
|
|
|
|
|
def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5): |
|
|
b, a = self.butter_bandpass(lowcut, highcut, fs, order=order) |
|
|
y = lfilter(self.filter_b, self.filter_a, data) |
|
|
return y |
|
|
|
|
|
def rotate_waveform(self, waveform, angle): |
|
|
fft_waveform = np.fft.fft(waveform) |
|
|
rotate_factor = np.exp(1j * angle) |
|
|
rotated_fft_waveform = fft_waveform * rotate_factor |
|
|
rotated_waveform = np.fft.ifft(rotated_fft_waveform) |
|
|
return rotated_waveform |
|
|
|
|
|
def shuffle(self, sample, target_P, target_S, test): |
|
|
if target_P - (self.crop_length-self.padding) > self.padding: |
|
|
start_indx = int(target_P - torch.randint(low=self.padding, |
|
|
high=(self.crop_length-self.padding), |
|
|
size=(1,))) |
|
|
if test == True: |
|
|
start_indx = int(first_phase - 2*self.padding) |
|
|
|
|
|
elif int(target_P-self.padding) > 0: |
|
|
start_indx = int(target_P - torch.randint(low=0, |
|
|
high=(int(target_P-self.padding)), |
|
|
size=(1,))) |
|
|
if test == True: |
|
|
start_indx = int(target_P - self.padding) |
|
|
else: |
|
|
start_indx = self.padding |
|
|
|
|
|
end_indx = start_indx + self.crop_length |
|
|
|
|
|
if (sample.shape[-1] - end_indx) < 0: |
|
|
start_indx += (sample.shape[-1] - end_indx) |
|
|
end_indx = start_indx + self.crop_length |
|
|
|
|
|
new_target_P = target_P - start_indx |
|
|
new_target_S = target_S - start_indx |
|
|
|
|
|
return start_indx, end_indx, new_target_P, new_target_S |
|
|
|
|
|
def cut(self, sample, start_indx, end_indx): |
|
|
sample_cropped = sample[:,start_indx:end_indx] |
|
|
return sample_cropped |
|
|
|
|
|
def bandpass_filter(self, sample_cropped, test): |
|
|
|
|
|
if test == False: |
|
|
probability = torch.randint(0,2, size=(1,)).item() |
|
|
if probability==1: |
|
|
lowcut = torch.FloatTensor(size=(1,)).uniform_(0.001, 1).item() |
|
|
highcut = torch.FloatTensor(size=(1,)).uniform_(10, 49).item() |
|
|
sample_cropped = self.butter_bandpass_filter(sample_cropped, lowcut=lowcut, highcut=highcut, fs=self.fs, order=self.order) |
|
|
window = signal.windows.tukey(sample_cropped[-1].shape[0], alpha=0.1) |
|
|
sample_cropped = sample_cropped*window |
|
|
return sample_cropped |
|
|
|
|
|
def add_z_component(self, sample_cropped): |
|
|
if len(sample_cropped) < 3: |
|
|
zeros = np.zeros((3, sample_cropped.shape[-1])) |
|
|
zeros[0] = sample_cropped |
|
|
sample_cropped = zeros |
|
|
return sample_cropped |
|
|
|
|
|
def rotate(self, sample_cropped, test): |
|
|
if test == False: |
|
|
probability = torch.randint(0,2, size=(1,)).item() |
|
|
if probability==1: |
|
|
angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item() |
|
|
sample_cropped = self.rotate_waveform(sample_cropped, angle).real |
|
|
return sample_cropped |
|
|
|
|
|
def demean(self, sample_cropped): |
|
|
|
|
|
sample_cropped = sample_cropped - np.mean(sample_cropped, axis=-1, keepdims=True) |
|
|
return sample_cropped |
|
|
|
|
|
def normalize(self, sample_cropped): |
|
|
max_val = np.max(np.abs(sample_cropped)) |
|
|
sample_cropped_norm = sample_cropped/max_val |
|
|
return sample_cropped_norm |
|
|
|
|
|
def channel_dropout(self, sample_cropped_norm, test): |
|
|
if test == False: |
|
|
probability = torch.randint(0,2, size=(1,)).item() |
|
|
channel = torch.randint(1,3, size=(1,)).item() |
|
|
if probability == 1: |
|
|
sample_cropped_norm[channel,:] = 1e-6 |
|
|
return sample_cropped_norm |
|
|
|
|
|
def channel_shuffle(self, sample_cropped_norm, test): |
|
|
if test == False: |
|
|
probability = torch.randint(0, 2, size=(1,)).item() |
|
|
if probability == 1: |
|
|
shuffled_indices = torch.randperm(sample_cropped_norm.shape[0]) |
|
|
sample_cropped_norm = sample_cropped_norm[shuffled_indices, :] |
|
|
return sample_cropped_norm |
|
|
|
|
|
def apply(self, sample, target_P, target_S, test=False): |
|
|
|
|
|
start_indx, end_indx, new_target_P, new_target_S = self.shuffle(sample, target_P, target_S, test) |
|
|
|
|
|
sample_cropped = self.cut(sample, start_indx, end_indx) |
|
|
sample_cropped = self.bandpass_filter(sample_cropped, test) |
|
|
sample_cropped = self.add_z_component(sample_cropped) |
|
|
sample_cropped = self.rotate(sample_cropped, test) |
|
|
sample_cropped = self.demean(sample_cropped) |
|
|
sample_cropped_norm = self.normalize(sample_cropped) |
|
|
|
|
|
sample_cropped_norm = self.channel_dropout(sample_cropped_norm, test) |
|
|
sample_cropped_norm = self.channel_shuffle(sample_cropped_norm, test) |
|
|
|
|
|
new_target_P = new_target_P/self.crop_length |
|
|
new_target_S = new_target_S/self.crop_length |
|
|
|
|
|
return sample_cropped_norm, new_target_P, new_target_S |
|
|
|
|
|
class Waveforms_dataset(Dataset): |
|
|
def __init__(self, meta, data, test=False, transform=None, augmentations=None): |
|
|
|
|
|
self.meta = meta |
|
|
self.data = data |
|
|
self.test = test |
|
|
self.augmentations = augmentations |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.meta) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
meta = self.meta.iloc[idx] |
|
|
sample = self.data[meta.name] |
|
|
|
|
|
target_P = float(meta.trace_P_final) |
|
|
target_S = float(meta.trace_S_final) |
|
|
|
|
|
if self.augmentations: |
|
|
sample, target_P, target_S = self.augmentations.apply(sample, target_P, target_S, test=self.test) |
|
|
|
|
|
|
|
|
if (target_P <= 0) or (target_P >= 1) or (np.isnan(target_P)): |
|
|
target_P = 0 |
|
|
if (target_S <= 0) or (target_S >= 1) or (np.isnan(target_S)): |
|
|
target_S = 0 |
|
|
|
|
|
|
|
|
if np.isnan(sample).any(): |
|
|
sample = np.zeros((3, self.augmentations.crop_length)) |
|
|
target_P = 0 |
|
|
target_S = 0 |
|
|
|
|
|
|
|
|
sample = torch.tensor(sample, dtype=torch.float) |
|
|
target_P = torch.tensor(target_P, dtype=torch.float) |
|
|
target_S = torch.tensor(target_S, dtype=torch.float) |
|
|
|
|
|
return sample, target_P, target_S |
|
|
|