Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from scipy import signal | |
| from scipy.signal import butter, lfilter, detrend | |
| # Make bandpass filter | |
| def butter_bandpass(lowcut, highcut, fs, order=5): | |
| nyq = 0.5 * fs # Nyquist frequency | |
| low = lowcut / nyq # Normalized frequency | |
| high = highcut / nyq | |
| b, a = butter(order, [low, high], btype="band") # Bandpass filter | |
| return b, a | |
| def butter_bandpass_filter(data, lowcut, highcut, fs, order=5): | |
| b, a = butter_bandpass(lowcut, highcut, fs, order=order) | |
| y = lfilter(b, a, data) | |
| return y | |
| def rotate_waveform(waveform, angle): | |
| fft_waveform = np.fft.fft(waveform) # Compute the Fourier transform of the waveform | |
| rotate_factor = np.exp( | |
| 1j * angle | |
| ) # Create a complex exponential with the specified rotation angle | |
| rotated_fft_waveform = ( | |
| fft_waveform * rotate_factor | |
| ) # Multiply the Fourier transform by the rotation factor | |
| rotated_waveform = np.fft.ifft( | |
| rotated_fft_waveform | |
| ) # Compute the inverse Fourier transform to get the rotated waveform in the time domain | |
| return rotated_waveform | |
| def augment(sample): | |
| # SET PARAMETERS: | |
| crop_length = 6000 | |
| padding = 120 | |
| test = False | |
| waveform = sample["waveform.npy"] | |
| meta = sample["meta.json"] | |
| if meta["split"] != "train": | |
| test = True | |
| target_sample_P = meta["trace_p_arrival_sample"] | |
| target_sample_S = meta["trace_s_arrival_sample"] | |
| if target_sample_P is None: | |
| target_sample_P = 0 | |
| if target_sample_S is None: | |
| target_sample_S = 0 | |
| # Randomly select a phase to start the crop | |
| current_phases = [x for x in (target_sample_P, target_sample_S) if x > 0] | |
| phase_selector = np.random.randint(0, len(current_phases)) | |
| first_phase = current_phases[phase_selector] | |
| # Shuffle | |
| if first_phase - (crop_length - padding) > padding: | |
| start_indx = int( | |
| first_phase | |
| - torch.randint(low=padding, high=(crop_length - padding), size=(1,)) | |
| ) | |
| if test == True: | |
| start_indx = int(first_phase - 2 * padding) | |
| elif int(first_phase - padding) > 0: | |
| start_indx = int( | |
| first_phase | |
| - torch.randint(low=0, high=(int(first_phase - padding)), size=(1,)) | |
| ) | |
| if test == True: | |
| start_indx = int(first_phase - padding) | |
| else: | |
| start_indx = padding | |
| end_indx = start_indx + crop_length | |
| if (waveform.shape[-1] - end_indx) < 0: | |
| start_indx += waveform.shape[-1] - end_indx | |
| end_indx = start_indx + crop_length | |
| # Update target | |
| new_target_P = target_sample_P - start_indx | |
| new_target_S = target_sample_S - start_indx | |
| # Cut | |
| waveform_cropped = waveform[:, start_indx:end_indx] | |
| # Preprocess | |
| waveform_cropped = detrend(waveform_cropped) | |
| waveform_cropped = butter_bandpass_filter( | |
| waveform_cropped, lowcut=0.2, highcut=40, fs=100, order=5 | |
| ) | |
| window = signal.windows.tukey(waveform_cropped[-1].shape[0], alpha=0.1) | |
| waveform_cropped = waveform_cropped * window | |
| waveform_cropped = detrend(waveform_cropped) | |
| if np.isnan(waveform_cropped).any() == True: | |
| waveform_cropped = np.zeros(shape=waveform_cropped.shape) | |
| new_target_P = 0 | |
| new_target_S = 0 | |
| if np.sum(waveform_cropped) == 0: | |
| new_target_P = 0 | |
| new_target_S = 0 | |
| # Normalize data | |
| max_val = np.max(np.abs(waveform_cropped)) | |
| waveform_cropped_norm = waveform_cropped / max_val | |
| # Added Z component only | |
| if len(waveform_cropped_norm) < 3: | |
| zeros = np.zeros((3, waveform_cropped_norm.shape[-1])) | |
| zeros[0] = waveform_cropped_norm | |
| waveform_cropped_norm = zeros | |
| if test == False: | |
| ##### Rotate waveform ##### | |
| probability = torch.randint(0, 2, size=(1,)).item() | |
| angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item() | |
| if probability == 1: | |
| waveform_cropped_norm = rotate_waveform(waveform_cropped_norm, angle).real | |
| #### Channel DropOUT ##### | |
| probability = torch.randint(0, 2, size=(1,)).item() | |
| channel = torch.randint(1, 3, size=(1,)).item() | |
| if probability == 1: | |
| waveform_cropped_norm[channel, :] = 1e-6 | |
| # Normalize target | |
| new_target_P = new_target_P / crop_length | |
| new_target_S = new_target_S / crop_length | |
| if (new_target_P <= 0) or (new_target_P >= 1) or (np.isnan(new_target_P)): | |
| new_target_P = 0 | |
| if (new_target_S <= 0) or (new_target_S >= 1) or (np.isnan(new_target_S)): | |
| new_target_S = 0 | |
| return waveform_cropped_norm, new_target_P, new_target_S | |
| def collation_fn(sample): | |
| waveforms = np.stack([x[0] for x in sample]) | |
| targets_P = np.stack([x[1] for x in sample]) | |
| targets_S = np.stack([x[2] for x in sample]) | |
| return ( | |
| torch.tensor(waveforms, dtype=torch.float), | |
| torch.tensor(targets_P, dtype=torch.float), | |
| torch.tensor(targets_S, dtype=torch.float), | |
| ) | |
| def my_split_by_node(urls): | |
| node_id, node_count = ( | |
| torch.distributed.get_rank(), | |
| torch.distributed.get_world_size(), | |
| ) | |
| return list(urls)[node_id::node_count] | |
| def prepare_waveform(waveform): | |
| # SET PARAMETERS: | |
| crop_length = 6000 | |
| padding = 120 | |
| assert waveform.shape[0] <= 3, "Waveform has more than 3 channels" | |
| if waveform.shape[-1] < crop_length: | |
| waveform = np.pad( | |
| waveform, | |
| ((0, 0), (0, crop_length - waveform.shape[-1])), | |
| mode="constant", | |
| constant_values=0, | |
| ) | |
| if waveform.shape[-1] > crop_length: | |
| waveform = waveform[:, :crop_length] | |
| # Preprocess | |
| waveform = detrend(waveform) | |
| waveform = butter_bandpass_filter( | |
| waveform, lowcut=0.2, highcut=40, fs=100, order=5 | |
| ) | |
| window = signal.windows.tukey(waveform[-1].shape[0], alpha=0.1) | |
| waveform = waveform * window | |
| waveform = detrend(waveform) | |
| assert np.isnan(waveform).any() != True, "Nan in waveform" | |
| assert np.sum(waveform) != 0, "Sum of waveform sample is zero" | |
| # Normalize data | |
| max_val = np.max(np.abs(waveform)) | |
| waveform = waveform / max_val | |
| # Added Z component only | |
| if len(waveform) < 3: | |
| zeros = np.zeros((3, waveform.shape[-1])) | |
| zeros[0] = waveform | |
| waveform = zeros | |
| return torch.tensor([waveform]*128, dtype=torch.float) |