Init
Browse files- phasehunter/.ipynb_checkpoints/dataloader-checkpoint.py +158 -0
- phasehunter/.ipynb_checkpoints/model-checkpoint.py +606 -0
- phasehunter/__init__.py +0 -0
- phasehunter/__pycache__/__init__.cpython-310.pyc +0 -0
- phasehunter/__pycache__/dataloader.cpython-310.pyc +0 -0
- phasehunter/__pycache__/model.cpython-310.pyc +0 -0
- phasehunter/dataloader.py +179 -0
- phasehunter/main.py +7 -0
- phasehunter/model.py +639 -0
phasehunter/.ipynb_checkpoints/dataloader-checkpoint.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
from torchvision.transforms import functional as F
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from scipy import signal
|
| 6 |
+
from functools import reduce
|
| 7 |
+
from scipy.signal import butter, lfilter, detrend
|
| 8 |
+
|
| 9 |
+
class Augmentations:
|
| 10 |
+
def __init__(self, padding=120, crop_length=6000, fs=100, lowcut=0.2, highcut=40, order=5):
|
| 11 |
+
self.padding = padding
|
| 12 |
+
self.crop_length = crop_length
|
| 13 |
+
self.fs = fs
|
| 14 |
+
self.lowcut = lowcut
|
| 15 |
+
self.highcut = highcut
|
| 16 |
+
self.order = order
|
| 17 |
+
|
| 18 |
+
b, a = self.butter_bandpass(self.lowcut, self.highcut, self.fs, self.order)
|
| 19 |
+
self.filter_b = b
|
| 20 |
+
self.filter_a = a
|
| 21 |
+
|
| 22 |
+
def butter_bandpass(self, lowcut, highcut, fs, order=5):
|
| 23 |
+
return butter(order, [lowcut, highcut], fs=fs, btype='band')
|
| 24 |
+
|
| 25 |
+
def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5):
|
| 26 |
+
b, a = self.butter_bandpass(lowcut, highcut, fs, order=order)
|
| 27 |
+
y = lfilter(self.filter_b, self.filter_a, data)
|
| 28 |
+
return y
|
| 29 |
+
|
| 30 |
+
def rotate_waveform(self, waveform, angle):
|
| 31 |
+
fft_waveform = np.fft.fft(waveform)
|
| 32 |
+
rotate_factor = np.exp(1j * angle)
|
| 33 |
+
rotated_fft_waveform = fft_waveform * rotate_factor
|
| 34 |
+
rotated_waveform = np.fft.ifft(rotated_fft_waveform)
|
| 35 |
+
return rotated_waveform
|
| 36 |
+
|
| 37 |
+
def shuffle(self, sample, target_P, target_S, test):
|
| 38 |
+
if target_P - (self.crop_length-self.padding) > self.padding:
|
| 39 |
+
start_indx = int(target_P - torch.randint(low=self.padding,
|
| 40 |
+
high=(self.crop_length-self.padding),
|
| 41 |
+
size=(1,)))
|
| 42 |
+
if test == True:
|
| 43 |
+
start_indx = int(first_phase - 2*self.padding)
|
| 44 |
+
|
| 45 |
+
elif int(target_P-self.padding) > 0:
|
| 46 |
+
start_indx = int(target_P - torch.randint(low=0,
|
| 47 |
+
high=(int(target_P-self.padding)),
|
| 48 |
+
size=(1,)))
|
| 49 |
+
if test == True:
|
| 50 |
+
start_indx = int(target_P - self.padding)
|
| 51 |
+
else:
|
| 52 |
+
start_indx = self.padding
|
| 53 |
+
|
| 54 |
+
end_indx = start_indx + self.crop_length
|
| 55 |
+
|
| 56 |
+
if (sample.shape[-1] - end_indx) < 0:
|
| 57 |
+
start_indx += (sample.shape[-1] - end_indx)
|
| 58 |
+
end_indx = start_indx + self.crop_length
|
| 59 |
+
|
| 60 |
+
new_target_P = target_P - start_indx
|
| 61 |
+
new_target_S = target_S - start_indx
|
| 62 |
+
|
| 63 |
+
return start_indx, end_indx, new_target_P, new_target_S
|
| 64 |
+
|
| 65 |
+
def cut(self, sample, start_indx, end_indx):
|
| 66 |
+
sample_cropped = sample[:,start_indx:end_indx]
|
| 67 |
+
return sample_cropped
|
| 68 |
+
|
| 69 |
+
def preprocess(self, sample_cropped):
|
| 70 |
+
# sample_cropped = detrend(sample_cropped)
|
| 71 |
+
sample_cropped = self.butter_bandpass_filter(sample_cropped, lowcut=self.lowcut, highcut=self.highcut, fs=self.fs, order=self.order)
|
| 72 |
+
window = signal.windows.tukey(sample_cropped[-1].shape[0], alpha=0.1)
|
| 73 |
+
sample_cropped = sample_cropped*window
|
| 74 |
+
return sample_cropped
|
| 75 |
+
|
| 76 |
+
def add_z_component(self, sample_cropped):
|
| 77 |
+
if len(sample_cropped) < 3:
|
| 78 |
+
zeros = np.zeros((3, sample_cropped.shape[-1]))
|
| 79 |
+
zeros[0] = sample_cropped
|
| 80 |
+
sample_cropped = zeros
|
| 81 |
+
return sample_cropped
|
| 82 |
+
|
| 83 |
+
def rotate(self, sample_cropped, test):
|
| 84 |
+
if test == False:
|
| 85 |
+
probability = torch.randint(0,2, size=(1,)).item()
|
| 86 |
+
if probability==1:
|
| 87 |
+
angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item()
|
| 88 |
+
sample_cropped = self.rotate_waveform(sample_cropped, angle).real
|
| 89 |
+
return sample_cropped
|
| 90 |
+
|
| 91 |
+
def normalize(self, sample_cropped):
|
| 92 |
+
max_val = np.max(np.abs(sample_cropped))
|
| 93 |
+
sample_cropped_norm = sample_cropped/max_val
|
| 94 |
+
return sample_cropped_norm
|
| 95 |
+
|
| 96 |
+
def channel_dropout(self, sample_cropped_norm, test):
|
| 97 |
+
if test == False:
|
| 98 |
+
probability = torch.randint(0,2, size=(1,)).item()
|
| 99 |
+
channel = torch.randint(1,3, size=(1,)).item()
|
| 100 |
+
if probability==1:
|
| 101 |
+
sample_cropped_norm[channel,:] = 1e-6
|
| 102 |
+
return sample_cropped_norm
|
| 103 |
+
|
| 104 |
+
def apply(self, sample, target_P, target_S, test=False):
|
| 105 |
+
|
| 106 |
+
start_indx, end_indx, new_target_P, new_target_S = self.shuffle(sample, target_P, target_S, test)
|
| 107 |
+
|
| 108 |
+
sample_cropped = self.cut(sample, start_indx, end_indx)
|
| 109 |
+
# sample_cropped = self.preprocess(sample_cropped)
|
| 110 |
+
sample_cropped = self.add_z_component(sample_cropped)
|
| 111 |
+
sample_cropped = self.rotate(sample_cropped, test)
|
| 112 |
+
sample_cropped_norm = self.normalize(sample_cropped)
|
| 113 |
+
sample_cropped_norm = self.channel_dropout(sample_cropped_norm, test)
|
| 114 |
+
|
| 115 |
+
new_target_P = new_target_P/self.crop_length
|
| 116 |
+
new_target_S = new_target_S/self.crop_length
|
| 117 |
+
|
| 118 |
+
return sample_cropped_norm, new_target_P, new_target_S
|
| 119 |
+
|
| 120 |
+
class Waveforms_dataset(Dataset):
|
| 121 |
+
def __init__(self, meta, data, test=False, transform=None, augmentations=None):
|
| 122 |
+
# self.data_list = glob(data_path)
|
| 123 |
+
self.meta = meta
|
| 124 |
+
self.data = data
|
| 125 |
+
self.test = test
|
| 126 |
+
self.augmentations = augmentations
|
| 127 |
+
|
| 128 |
+
def __len__(self):
|
| 129 |
+
return len(self.meta)
|
| 130 |
+
|
| 131 |
+
def __getitem__(self, idx):
|
| 132 |
+
meta = self.meta.iloc[idx]
|
| 133 |
+
sample = self.data[meta.name]
|
| 134 |
+
|
| 135 |
+
target_P = float(meta.trace_P_final)
|
| 136 |
+
target_S = float(meta.trace_S_final)
|
| 137 |
+
|
| 138 |
+
if self.augmentations:
|
| 139 |
+
sample, target_P, target_S = self.augmentations.apply(sample, target_P, target_S, test=self.test)
|
| 140 |
+
|
| 141 |
+
# Setting labels to zero if they're not in the valid range or are NaNs
|
| 142 |
+
if (target_P <= 0) or (target_P >= 1) or (np.isnan(target_P)):
|
| 143 |
+
target_P = 0
|
| 144 |
+
if (target_S <= 0) or (target_S >= 1) or (np.isnan(target_S)):
|
| 145 |
+
target_S = 0
|
| 146 |
+
|
| 147 |
+
# If something went wrong
|
| 148 |
+
if np.isnan(sample).any():
|
| 149 |
+
sample = np.zeros((3, self.augmentations.crop_length))
|
| 150 |
+
target_P = 0
|
| 151 |
+
target_S = 0
|
| 152 |
+
|
| 153 |
+
# Convert to tensor
|
| 154 |
+
sample = torch.tensor(sample, dtype=torch.float)
|
| 155 |
+
target_P = torch.tensor(target_P, dtype=torch.float)
|
| 156 |
+
target_S = torch.tensor(target_S, dtype=torch.float)
|
| 157 |
+
|
| 158 |
+
return sample, target_P, target_S
|
phasehunter/.ipynb_checkpoints/model-checkpoint.py
ADDED
|
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union, Tuple, Any
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
from lightning import seed_everything
|
| 5 |
+
import lightning as pl
|
| 6 |
+
|
| 7 |
+
from masksembles import common
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torchmetrics import MeanAbsoluteError
|
| 13 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 14 |
+
|
| 15 |
+
from scipy.stats import gaussian_kde
|
| 16 |
+
from scipy.special import comb
|
| 17 |
+
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
import pandas as pd
|
| 20 |
+
|
| 21 |
+
from obspy import Stream
|
| 22 |
+
|
| 23 |
+
seed_everything(42, workers=False)
|
| 24 |
+
torch.set_float32_matmul_precision('medium')
|
| 25 |
+
|
| 26 |
+
class BlurPool1D(nn.Module):
|
| 27 |
+
"""Implements 1D version of blur pooling.
|
| 28 |
+
|
| 29 |
+
Attributes:
|
| 30 |
+
channels (int): Number of input channels.
|
| 31 |
+
pad_type (str): Type of padding (reflect, replicate, zero).
|
| 32 |
+
filt_size (int): Filter size for blur pooling.
|
| 33 |
+
stride (int): Stride size for downsampling.
|
| 34 |
+
pad_off (int): Padding offset.
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, channels: int, pad_type: str='reflect', filt_size: int=3, stride: int=2, pad_off: int=0):
|
| 37 |
+
super(BlurPool1D, self).__init__()
|
| 38 |
+
self.filt_size = filt_size
|
| 39 |
+
self.pad_off = pad_off
|
| 40 |
+
# Calculate padding sizes for the beginning and end of signal
|
| 41 |
+
self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
|
| 42 |
+
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
| 43 |
+
self.stride = stride
|
| 44 |
+
self.off = int((self.stride - 1) / 2.)
|
| 45 |
+
self.channels = channels
|
| 46 |
+
|
| 47 |
+
# Generate coefficients for the specified filter size using binomial coefficients
|
| 48 |
+
a = np.array([comb(filt_size-1, i, exact=False) for i in range(filt_size)])
|
| 49 |
+
|
| 50 |
+
filt = torch.Tensor(a)
|
| 51 |
+
filt = filt / torch.sum(filt) # normalize the filter
|
| 52 |
+
# Make the filter to have same size with number of channels
|
| 53 |
+
self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))
|
| 54 |
+
|
| 55 |
+
# Get the appropriate padding layer
|
| 56 |
+
self.pad = self.get_pad_layer_1d(pad_type)(self.pad_sizes)
|
| 57 |
+
|
| 58 |
+
def forward(self, inp):
|
| 59 |
+
"""Computes forward pass for blur pooling."""
|
| 60 |
+
if self.filt_size == 1:
|
| 61 |
+
if self.pad_off == 0:
|
| 62 |
+
return inp[:, :, ::self.stride]
|
| 63 |
+
else:
|
| 64 |
+
# Apply padding if pad_off is not zero
|
| 65 |
+
return self.pad(inp)[:, :, ::self.stride]
|
| 66 |
+
else:
|
| 67 |
+
# Convolve input with filter and then apply downsampling
|
| 68 |
+
return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
|
| 69 |
+
|
| 70 |
+
def get_pad_layer_1d(self, pad_type: str):
|
| 71 |
+
"""Returns appropriate padding layer based on the pad_type string.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
pad_type: Type of padding. It can be 'refl', 'reflect', 'repl', 'replicate', or 'zero'.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Appropriate padding layer based on pad_type.
|
| 78 |
+
|
| 79 |
+
Raises:
|
| 80 |
+
ValueError: If pad_type is not recognized.
|
| 81 |
+
"""
|
| 82 |
+
# Define the padding layer depending on the input pad_type
|
| 83 |
+
if pad_type in ['refl', 'reflect']:
|
| 84 |
+
pad_layer = nn.ReflectionPad1d
|
| 85 |
+
elif pad_type in ['repl', 'replicate']:
|
| 86 |
+
pad_layer = nn.ReplicationPad1d
|
| 87 |
+
elif pad_type == 'zero':
|
| 88 |
+
pad_layer = nn.ZeroPad1d
|
| 89 |
+
else:
|
| 90 |
+
# Raise an error if pad_type is not recognized
|
| 91 |
+
raise ValueError(f"Pad type [{pad_type}] not recognized")
|
| 92 |
+
return pad_layer
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Masksembles1D(nn.Module):
|
| 96 |
+
"""Implements 1D version of Masksembles operation.
|
| 97 |
+
|
| 98 |
+
Masksembles operation applies different masks to the input in a way that allows the model to estimate uncertainty and confidence at inference time.
|
| 99 |
+
|
| 100 |
+
Attributes:
|
| 101 |
+
channels (int): Number of input channels.
|
| 102 |
+
n (int): Number of masks to generate.
|
| 103 |
+
scale (float): Scaling factor for masks.
|
| 104 |
+
"""
|
| 105 |
+
def __init__(self, channels: int, n: int, scale: float):
|
| 106 |
+
super().__init__()
|
| 107 |
+
|
| 108 |
+
self.channels = channels
|
| 109 |
+
self.n = n
|
| 110 |
+
self.scale = scale
|
| 111 |
+
|
| 112 |
+
# Generate masks using a provided function
|
| 113 |
+
masks = common.generation_wrapper(channels, n, scale)
|
| 114 |
+
masks = torch.from_numpy(masks)
|
| 115 |
+
|
| 116 |
+
# Convert masks into PyTorch Parameter and set it to not require gradient
|
| 117 |
+
self.masks = torch.nn.Parameter(masks, requires_grad=False)
|
| 118 |
+
|
| 119 |
+
def forward(self, inputs):
|
| 120 |
+
"""Computes forward pass for Masksembles operation.
|
| 121 |
+
|
| 122 |
+
The input is divided into multiple groups, each group is multiplied with a different mask, and then the results
|
| 123 |
+
are concatenated together.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
inputs (torch.Tensor): Input tensor.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
torch.Tensor: Output tensor after applying Masksembles operation.
|
| 130 |
+
"""
|
| 131 |
+
# Number of samples in the batch
|
| 132 |
+
batch = inputs.shape[0]
|
| 133 |
+
|
| 134 |
+
# Divide the input into n groups along the batch dimension
|
| 135 |
+
x = torch.split(inputs.unsqueeze(1), batch // self.n, dim=0)
|
| 136 |
+
|
| 137 |
+
# Concatenate the groups along the new dimension and permute the dimensions
|
| 138 |
+
x = torch.cat(x, dim=1).permute([1, 0, 2, 3])
|
| 139 |
+
|
| 140 |
+
# Multiply each group with a different mask
|
| 141 |
+
x = x * self.masks.unsqueeze(1).unsqueeze(-1)
|
| 142 |
+
|
| 143 |
+
# Concatenate the results along the channel dimension
|
| 144 |
+
x = torch.cat(torch.split(x, 1, dim=0), dim=1)
|
| 145 |
+
|
| 146 |
+
# Remove the extra dimension and convert the tensor to the original data type
|
| 147 |
+
return x.squeeze(0).type(inputs.dtype)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class BasicBlock(nn.Module):
|
| 151 |
+
"""Implements a basic block of convolutions, a fundamental part of PhaseHunter.
|
| 152 |
+
|
| 153 |
+
A basic block consists of two convolutional layers, each followed by batch normalization. The output from the second
|
| 154 |
+
convolutional layer is added to the shortcut connection before applying an optional activation function.
|
| 155 |
+
|
| 156 |
+
Attributes:
|
| 157 |
+
in_planes (int): Number of input channels (also known as input planes).
|
| 158 |
+
planes (int): Number of output channels (also known as output planes or filters).
|
| 159 |
+
stride (int, optional): Stride size for convolution. Default is 1.
|
| 160 |
+
kernel_size (int, optional): Kernel size for convolution. Default is 7.
|
| 161 |
+
groups (int, optional): Number of groups for convolution. Default is 1.
|
| 162 |
+
do_activation (bool, optional): Whether to apply an activation function (ReLU) at the end. Introduced for embedding capture. Default is True.
|
| 163 |
+
"""
|
| 164 |
+
def __init__(self, in_planes: int, planes: int, stride: int = 1, kernel_size: int = 7, groups: int = 1, do_activation: bool = True):
|
| 165 |
+
super(BasicBlock, self).__init__()
|
| 166 |
+
|
| 167 |
+
self.do_activation = do_activation
|
| 168 |
+
|
| 169 |
+
# First convolutional layer
|
| 170 |
+
self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=kernel_size, stride=stride, padding='same', bias=False)
|
| 171 |
+
self.bn1 = nn.BatchNorm1d(planes)
|
| 172 |
+
|
| 173 |
+
# Second convolutional layer
|
| 174 |
+
self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=1, padding='same', bias=False)
|
| 175 |
+
self.bn2 = nn.BatchNorm1d(planes)
|
| 176 |
+
|
| 177 |
+
# Shortcut connection, used to match the dimensionality between input and output
|
| 178 |
+
self.shortcut = nn.Sequential(
|
| 179 |
+
nn.Conv1d(in_planes, planes, kernel_size=1, stride=stride, padding='same', bias=False),
|
| 180 |
+
nn.BatchNorm1d(planes)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
"""Computes forward pass for the block.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
x (torch.Tensor): Input tensor.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
torch.Tensor: Output tensor after passing through the basic block.
|
| 191 |
+
"""
|
| 192 |
+
# Apply first convolution followed by ReLU activation
|
| 193 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 194 |
+
|
| 195 |
+
# Apply second convolution
|
| 196 |
+
out = self.bn2(self.conv2(out))
|
| 197 |
+
|
| 198 |
+
# Add the output of the shortcut connection
|
| 199 |
+
out += self.shortcut(x)
|
| 200 |
+
|
| 201 |
+
# Apply activation (it's here for the embedding)
|
| 202 |
+
if self.do_activation:
|
| 203 |
+
out = F.relu(out)
|
| 204 |
+
|
| 205 |
+
return out
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class PhaseHunter(pl.LightningModule):
|
| 210 |
+
"""Implements PhaseHunter model for seismic phase picking.
|
| 211 |
+
|
| 212 |
+
Attributes:
|
| 213 |
+
n_masks (int): Number of masks for Masksembles operation.
|
| 214 |
+
n_outs (int): Number of output units.
|
| 215 |
+
"""
|
| 216 |
+
def __init__(self, n_masks=128, n_outs=2):
|
| 217 |
+
super().__init__()
|
| 218 |
+
|
| 219 |
+
self.n_masks = 128
|
| 220 |
+
self.n_outs = n_outs
|
| 221 |
+
|
| 222 |
+
# Define sequential layers for block 1 to 9
|
| 223 |
+
# Each block consist of BasicBlock, GELU activation, BlurPool1D, and GroupNorm layers
|
| 224 |
+
# Blocks vary in the number of in and out features
|
| 225 |
+
|
| 226 |
+
self.block1 = nn.Sequential(
|
| 227 |
+
BasicBlock(3,8, kernel_size=7, groups=1),
|
| 228 |
+
nn.GELU(),
|
| 229 |
+
BlurPool1D(8, filt_size=3, stride=2),
|
| 230 |
+
nn.GroupNorm(2,8),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self.block2 = nn.Sequential(
|
| 234 |
+
BasicBlock(8, 16, kernel_size=7, groups=8),
|
| 235 |
+
nn.GELU(),
|
| 236 |
+
BlurPool1D(16, filt_size=3, stride=2),
|
| 237 |
+
nn.GroupNorm(2,16),
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
self.block3 = nn.Sequential(
|
| 241 |
+
BasicBlock(16,32, kernel_size=7, groups=16),
|
| 242 |
+
nn.GELU(),
|
| 243 |
+
BlurPool1D(32, filt_size=3, stride=2),
|
| 244 |
+
nn.GroupNorm(2,32),
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
self.block4 = nn.Sequential(
|
| 248 |
+
BasicBlock(32,64, kernel_size=7, groups=32),
|
| 249 |
+
nn.GELU(),
|
| 250 |
+
BlurPool1D(64, filt_size=3, stride=2),
|
| 251 |
+
nn.GroupNorm(2,64),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
self.block5 = nn.Sequential(
|
| 255 |
+
BasicBlock(64,128, kernel_size=7, groups=64),
|
| 256 |
+
nn.GELU(),
|
| 257 |
+
BlurPool1D(128, filt_size=3, stride=2),
|
| 258 |
+
nn.GroupNorm(2,128),
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
self.block6 = nn.Sequential(
|
| 262 |
+
Masksembles1D(128, self.n_masks, 2.0),
|
| 263 |
+
BasicBlock(128,256, kernel_size=7, groups=128),
|
| 264 |
+
nn.GELU(),
|
| 265 |
+
BlurPool1D(256, filt_size=3, stride=2),
|
| 266 |
+
nn.GroupNorm(2,256),
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
self.block7 = nn.Sequential(
|
| 270 |
+
Masksembles1D(256, self.n_masks, 2.0),
|
| 271 |
+
BasicBlock(256,512, kernel_size=7, groups=256),
|
| 272 |
+
BlurPool1D(512, filt_size=3, stride=2),
|
| 273 |
+
nn.GELU(),
|
| 274 |
+
nn.GroupNorm(2,512),
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self.block8 = nn.Sequential(
|
| 278 |
+
Masksembles1D(512, self.n_masks, 2.0),
|
| 279 |
+
BasicBlock(512,1024, kernel_size=7, groups=512),
|
| 280 |
+
BlurPool1D(1024, filt_size=3, stride=2),
|
| 281 |
+
nn.GELU(),
|
| 282 |
+
nn.GroupNorm(2,1024),
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self.block9 = nn.Sequential(
|
| 286 |
+
Masksembles1D(1024, self.n_masks, 2.0),
|
| 287 |
+
BasicBlock(1024,128, kernel_size=7, groups=128, do_activation=False),
|
| 288 |
+
|
| 289 |
+
# Works better with those off on the last layer before regressor
|
| 290 |
+
# BlurPool1D(512, filt_size=3, stride=2),
|
| 291 |
+
# nn.GELU(),
|
| 292 |
+
# nn.GroupNorm(2,512),
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Final output layer with Sigmoid activation
|
| 296 |
+
self.out = nn.Sequential(
|
| 297 |
+
nn.LazyLinear(n_outs),
|
| 298 |
+
nn.Sigmoid()
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Save hyperparameters and initialize Mean Absolute Error loss
|
| 302 |
+
self.save_hyperparameters(ignore=['picker'])
|
| 303 |
+
self.mae = MeanAbsoluteError()
|
| 304 |
+
|
| 305 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 306 |
+
"""Computes forward pass for the model."""
|
| 307 |
+
# Feature extraction
|
| 308 |
+
x = self.block1(x)
|
| 309 |
+
x = self.block2(x)
|
| 310 |
+
|
| 311 |
+
x = self.block3(x)
|
| 312 |
+
x = self.block4(x)
|
| 313 |
+
|
| 314 |
+
x = self.block5(x)
|
| 315 |
+
x = self.block6(x)
|
| 316 |
+
|
| 317 |
+
x = self.block7(x)
|
| 318 |
+
x = self.block8(x)
|
| 319 |
+
|
| 320 |
+
x = self.block9(x)
|
| 321 |
+
|
| 322 |
+
# Regressor
|
| 323 |
+
embedding = x.flatten(start_dim=1)
|
| 324 |
+
x = self.out(F.relu(embedding))
|
| 325 |
+
|
| 326 |
+
return x, embedding
|
| 327 |
+
|
| 328 |
+
def compute_loss(self, y: torch.Tensor, pick: torch.Tensor, mae_name: Optional[Union[str, bool]] = False) -> torch.Tensor:
|
| 329 |
+
"""Computes loss for the predictions.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
y (torch.Tensor): The ground truth tensor.
|
| 333 |
+
pick (torch.Tensor): The predicted tensor.
|
| 334 |
+
mae_name (Union[str, bool], optional): The name for the Mean Absolute Error (MAE) metric.
|
| 335 |
+
If provided, it logs the MAE metric with the name 'MAE/{mae_name}_val'. Default is False.
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
torch.Tensor: The computed loss.
|
| 339 |
+
"""
|
| 340 |
+
# Filter non-zero values
|
| 341 |
+
y_filt = y[y != 0]
|
| 342 |
+
pick_filt = pick[y != 0]
|
| 343 |
+
|
| 344 |
+
# Compute L1 loss if there are non-zero values
|
| 345 |
+
if len(y_filt) > 0:
|
| 346 |
+
loss = F.l1_loss(y_filt, pick_filt.flatten())
|
| 347 |
+
|
| 348 |
+
# If mae_name is provided, log the MAE metric
|
| 349 |
+
if mae_name != False:
|
| 350 |
+
mae_phase = self.mae(y_filt, pick_filt.flatten())*30
|
| 351 |
+
self.log(f'MAE/{mae_name}_val', mae_phase, on_step=False, on_epoch=True, prog_bar=False)
|
| 352 |
+
else:
|
| 353 |
+
loss = 0
|
| 354 |
+
return loss
|
| 355 |
+
|
| 356 |
+
def get_likely_val(self, array: np.ndarray) -> Tuple[np.ndarray, gaussian_kde, torch.Tensor, float]:
|
| 357 |
+
"""Computes most likely value using Kernel Density Estimation.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
array (np.ndarray): The input array for which to compute the most likely value.
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
Tuple[np.ndarray, gaussian_kde, torch.Tensor, float]: A tuple containing
|
| 364 |
+
- the distribution space (dist_space),
|
| 365 |
+
- the Kernel Density Estimation (kde),
|
| 366 |
+
- the most likely value (val), and
|
| 367 |
+
- the uncertainty of the estimation.
|
| 368 |
+
"""
|
| 369 |
+
# Compute KDE for the input array
|
| 370 |
+
kde = gaussian_kde(array)
|
| 371 |
+
|
| 372 |
+
# Define the distribution space
|
| 373 |
+
dist_space = np.linspace(min(array)-0.001, max(array)+0.001, 512)
|
| 374 |
+
|
| 375 |
+
# Compute the most likely value and the uncertainty
|
| 376 |
+
val = torch.tensor(dist_space[np.argmax(kde(dist_space))], dtype=torch.float32)
|
| 377 |
+
uncertainty = dist_space.ptp()/2
|
| 378 |
+
|
| 379 |
+
return dist_space, kde, val, uncertainty
|
| 380 |
+
|
| 381 |
+
def process_continuous_waveform(self, st: Stream) -> pd.DataFrame:
|
| 382 |
+
"""
|
| 383 |
+
Processes a continuous seismic waveform and predicts P and S wave arrival times using PhaseHunter.
|
| 384 |
+
|
| 385 |
+
Parameters:
|
| 386 |
+
-----------
|
| 387 |
+
st : Stream
|
| 388 |
+
The input seismic data as an ObsPy Stream object with three components.
|
| 389 |
+
|
| 390 |
+
Returns:
|
| 391 |
+
--------
|
| 392 |
+
pd.DataFrame
|
| 393 |
+
A DataFrame containing the following columns:
|
| 394 |
+
- p_time: Predicted P-wave arrival time.
|
| 395 |
+
- s_time: Predicted S-wave arrival time.
|
| 396 |
+
- p_uncert: Uncertainty associated with the P-wave prediction.
|
| 397 |
+
- s_uncert: Uncertainty associated with the S-wave prediction.
|
| 398 |
+
- embedding: Embedding representation of the chunk.
|
| 399 |
+
- p_conf: Confidence level of the P-wave prediction.
|
| 400 |
+
- s_conf: Confidence level of the S-wave prediction.
|
| 401 |
+
- p_time_rel: Relative P-wave arrival time in seconds from the start of the input stream.
|
| 402 |
+
- s_time_rel: Relative S-wave arrival time in seconds from the start of the input stream.
|
| 403 |
+
|
| 404 |
+
Notes:
|
| 405 |
+
------
|
| 406 |
+
The function assumes that the input Stream object has three components.
|
| 407 |
+
The neural network inference is performed on chunks of data of 30 seconds.
|
| 408 |
+
The output DataFrame is a result of aggregating predictions for each chunk and filtering duplicate rows.
|
| 409 |
+
|
| 410 |
+
Raises:
|
| 411 |
+
-------
|
| 412 |
+
AssertionError
|
| 413 |
+
If the input Stream object doesn't contain three components.
|
| 414 |
+
|
| 415 |
+
Examples:
|
| 416 |
+
---------
|
| 417 |
+
>>> from obspy import read
|
| 418 |
+
>>> st = read('path_to_your_waveform_data')
|
| 419 |
+
>>> predictions = process_continuous_waveform(st)
|
| 420 |
+
>>> print(predictions)
|
| 421 |
+
"""
|
| 422 |
+
assert len(st) == 3, 'For the moment, PhaseHunter works only with 3C input data'
|
| 423 |
+
|
| 424 |
+
start_time = st[0].stats.starttime
|
| 425 |
+
end_time = st[0].stats.endtime
|
| 426 |
+
|
| 427 |
+
chunk_size = 30
|
| 428 |
+
|
| 429 |
+
chunks = []
|
| 430 |
+
predictions = pd.DataFrame()
|
| 431 |
+
|
| 432 |
+
for chunk_start in tqdm(np.arange(start_time, end_time, chunk_size)):
|
| 433 |
+
chunk_end = chunk_start + chunk_size
|
| 434 |
+
|
| 435 |
+
chunk = st.slice(chunk_start, chunk_end)
|
| 436 |
+
|
| 437 |
+
# chunk_orig = np.vstack([x.data for x in chunk], dtype='float')[:,:-1]
|
| 438 |
+
chunk_orig = np.vstack([x.data for x in chunk])
|
| 439 |
+
chunk_orig = chunk_orig.astype('float')[:,:-1]
|
| 440 |
+
|
| 441 |
+
if chunk_orig.shape[-1] != chunk_size * 100:
|
| 442 |
+
continue
|
| 443 |
+
|
| 444 |
+
chunk = chunk_orig - chunk_orig.mean(axis=0)
|
| 445 |
+
max_val = np.max(np.abs(chunk))
|
| 446 |
+
chunk = chunk/max_val
|
| 447 |
+
|
| 448 |
+
chunk = torch.tensor(chunk, dtype=torch.float)
|
| 449 |
+
|
| 450 |
+
inference_sample = torch.stack([chunk]*128).to(self.device)
|
| 451 |
+
|
| 452 |
+
with torch.no_grad():
|
| 453 |
+
preds, embeddings = self(inference_sample)
|
| 454 |
+
|
| 455 |
+
p_pred = preds[:,0].detach().cpu()
|
| 456 |
+
s_pred = preds[:,1].detach().cpu()
|
| 457 |
+
embeddings = torch.mean(embeddings, axis=0).detach().cpu().numpy()
|
| 458 |
+
|
| 459 |
+
p_dist, p_kde, p_val, p_uncert = self.get_likely_val(p_pred)
|
| 460 |
+
s_dist, s_kde, s_val, s_uncert = self.get_likely_val(s_pred)
|
| 461 |
+
|
| 462 |
+
p_time = chunk_start+p_val.item()*chunk_size
|
| 463 |
+
s_time = chunk_start+s_val.item()*chunk_size
|
| 464 |
+
|
| 465 |
+
current_predictions = pd.DataFrame({'p_time': p_time, 's_time':s_time,
|
| 466 |
+
'p_uncert' : p_uncert, 's_uncert' : s_uncert,
|
| 467 |
+
'embedding' : [embeddings]})
|
| 468 |
+
|
| 469 |
+
predictions = pd.concat([predictions, current_predictions], ignore_index=True)
|
| 470 |
+
|
| 471 |
+
predictions = predictions.drop_duplicates(subset=['p_uncert', 's_uncert']).reset_index()
|
| 472 |
+
|
| 473 |
+
predictions['p_conf'] = 1/predictions['p_uncert']
|
| 474 |
+
predictions['s_conf'] = 1/predictions['s_uncert']
|
| 475 |
+
|
| 476 |
+
predictions['p_conf'] /= predictions['p_conf'].max()
|
| 477 |
+
predictions['s_conf'] /= predictions['s_conf'].max()
|
| 478 |
+
|
| 479 |
+
predictions['p_time_rel'] = (predictions.p_time.apply(lambda x: pd.Timestamp(x.timestamp, unit='s')) - pd.Timestamp(predictions.p_time.iloc[0].date)).dt.total_seconds()
|
| 480 |
+
predictions['s_time_rel'] = (predictions.s_time.apply(lambda x: pd.Timestamp(x.timestamp, unit='s')) - pd.Timestamp(predictions.s_time.iloc[0].date)).dt.total_seconds()
|
| 481 |
+
|
| 482 |
+
return predictions
|
| 483 |
+
|
| 484 |
+
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
| 485 |
+
"""
|
| 486 |
+
Defines a single step in the training loop for PhaseHunter.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
batch (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing an input batch (x),
|
| 490 |
+
and the corresponding P-wave (y_p) and S-wave (y_s) target tensors.
|
| 491 |
+
batch_idx (int): The index of the current batch.
|
| 492 |
+
|
| 493 |
+
Returns:
|
| 494 |
+
torch.Tensor: The computed loss for this training step.
|
| 495 |
+
"""
|
| 496 |
+
# Unpack the batch
|
| 497 |
+
x, y_p, y_s = batch
|
| 498 |
+
|
| 499 |
+
# Perform forward pass and get predictions
|
| 500 |
+
picks, embedding = self(x)
|
| 501 |
+
|
| 502 |
+
# Extract P and S phase picks
|
| 503 |
+
p_pick = picks[:,0]
|
| 504 |
+
s_pick = picks[:,1]
|
| 505 |
+
|
| 506 |
+
# Compute losses for P and S phase picks
|
| 507 |
+
p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
|
| 508 |
+
s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
|
| 509 |
+
|
| 510 |
+
# Combine losses
|
| 511 |
+
loss = (p_loss+s_loss)/self.n_outs
|
| 512 |
+
|
| 513 |
+
# Log the loss
|
| 514 |
+
self.log('Loss/train', loss, on_step=True, on_epoch=False, prog_bar=True)
|
| 515 |
+
|
| 516 |
+
return loss
|
| 517 |
+
|
| 518 |
+
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
| 519 |
+
"""
|
| 520 |
+
Defines a single step in the validation loop for PhaseHunter.
|
| 521 |
+
|
| 522 |
+
Args:
|
| 523 |
+
batch (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing an input batch (x),
|
| 524 |
+
and the corresponding P-wave (y_p) and S-wave (y_s) target tensors.
|
| 525 |
+
batch_idx (int): The index of the current batch.
|
| 526 |
+
|
| 527 |
+
Returns:
|
| 528 |
+
torch.Tensor: The computed loss for this validation step.
|
| 529 |
+
"""
|
| 530 |
+
# Unpack the batch
|
| 531 |
+
x, y_p, y_s = batch
|
| 532 |
+
|
| 533 |
+
# Perform forward pass and get predictions
|
| 534 |
+
picks, embedding = self(x)
|
| 535 |
+
|
| 536 |
+
# Extract P and S phase picks
|
| 537 |
+
p_pick = picks[:,0]
|
| 538 |
+
s_pick = picks[:,1]
|
| 539 |
+
|
| 540 |
+
# Compute losses for P and S phase picks
|
| 541 |
+
p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
|
| 542 |
+
s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
|
| 543 |
+
|
| 544 |
+
# Combine losses
|
| 545 |
+
loss = (p_loss+s_loss)/self.n_outs
|
| 546 |
+
|
| 547 |
+
# Log the loss
|
| 548 |
+
self.log('Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False)
|
| 549 |
+
|
| 550 |
+
return loss
|
| 551 |
+
|
| 552 |
+
# def configure_optimizers(self) -> dict:
|
| 553 |
+
# """
|
| 554 |
+
# Defines the optimizer and scheduler for PhaseHunter.
|
| 555 |
+
|
| 556 |
+
# Returns:
|
| 557 |
+
# dict: A dictionary containing the optimizer, the learning rate scheduler, and the metric to monitor.
|
| 558 |
+
# """
|
| 559 |
+
# # Define the optimizer
|
| 560 |
+
# optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
| 561 |
+
|
| 562 |
+
# # Define the learning rate scheduler
|
| 563 |
+
# # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, cooldown=10, threshold=1e-6)
|
| 564 |
+
|
| 565 |
+
# # Define the metric to monitor
|
| 566 |
+
# # monitor = 'Loss/train'
|
| 567 |
+
|
| 568 |
+
# return {"optimizer": optimizer}#, "lr_scheduler": scheduler, 'monitor': monitor}
|
| 569 |
+
|
| 570 |
+
def configure_optimizers(self) -> dict:
|
| 571 |
+
"""
|
| 572 |
+
Defines the optimizer and scheduler for PhaseHunter.
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
dict: A dictionary containing the optimizer, the learning rate scheduler, and the metric to monitor.
|
| 576 |
+
"""
|
| 577 |
+
# Define the optimizer
|
| 578 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
| 579 |
+
|
| 580 |
+
# Total number of epochs for decay
|
| 581 |
+
decay_epochs = 100
|
| 582 |
+
|
| 583 |
+
# Total number of epochs including constant learning rate period
|
| 584 |
+
total_epochs = 200
|
| 585 |
+
|
| 586 |
+
# Final learning rate
|
| 587 |
+
final_lr = 1e-7
|
| 588 |
+
|
| 589 |
+
# Lambda function for learning rate schedule
|
| 590 |
+
def lambda_func(epoch):
|
| 591 |
+
if epoch < decay_epochs:
|
| 592 |
+
return 1.0 # constant learning rate
|
| 593 |
+
else:
|
| 594 |
+
epoch_adjusted = epoch - decay_epochs
|
| 595 |
+
return 1 - epoch_adjusted/decay_epochs + (final_lr/1e-3)*epoch_adjusted/decay_epochs
|
| 596 |
+
|
| 597 |
+
# Define the learning rate scheduler
|
| 598 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_func)
|
| 599 |
+
|
| 600 |
+
# Define the metric to monitor
|
| 601 |
+
# monitor = 'Loss/train'
|
| 602 |
+
|
| 603 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
|
phasehunter/__init__.py
ADDED
|
File without changes
|
phasehunter/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
phasehunter/__pycache__/dataloader.cpython-310.pyc
ADDED
|
Binary file (6.03 kB). View file
|
|
|
phasehunter/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (17.3 kB). View file
|
|
|
phasehunter/dataloader.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import Dataset
|
| 2 |
+
from torchvision.transforms import functional as F
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from scipy import signal
|
| 6 |
+
from functools import reduce
|
| 7 |
+
from scipy.signal import butter, lfilter, detrend
|
| 8 |
+
|
| 9 |
+
class Augmentations:
|
| 10 |
+
def __init__(self, padding=120, crop_length=6000, fs=100, lowcut=0.2, highcut=40, order=5):
|
| 11 |
+
self.padding = padding
|
| 12 |
+
self.crop_length = crop_length
|
| 13 |
+
self.fs = fs
|
| 14 |
+
self.lowcut = lowcut
|
| 15 |
+
self.highcut = highcut
|
| 16 |
+
self.order = order
|
| 17 |
+
|
| 18 |
+
b, a = self.butter_bandpass(self.lowcut, self.highcut, self.fs, self.order)
|
| 19 |
+
self.filter_b = b
|
| 20 |
+
self.filter_a = a
|
| 21 |
+
|
| 22 |
+
def butter_bandpass(self, lowcut, highcut, fs, order=5):
|
| 23 |
+
return butter(order, [lowcut, highcut], fs=fs, btype='band')
|
| 24 |
+
|
| 25 |
+
def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5):
|
| 26 |
+
b, a = self.butter_bandpass(lowcut, highcut, fs, order=order)
|
| 27 |
+
y = lfilter(self.filter_b, self.filter_a, data)
|
| 28 |
+
return y
|
| 29 |
+
|
| 30 |
+
def rotate_waveform(self, waveform, angle):
|
| 31 |
+
fft_waveform = np.fft.fft(waveform)
|
| 32 |
+
rotate_factor = np.exp(1j * angle)
|
| 33 |
+
rotated_fft_waveform = fft_waveform * rotate_factor
|
| 34 |
+
rotated_waveform = np.fft.ifft(rotated_fft_waveform)
|
| 35 |
+
return rotated_waveform
|
| 36 |
+
|
| 37 |
+
def shuffle(self, sample, target_P, target_S, test):
|
| 38 |
+
if target_P - (self.crop_length-self.padding) > self.padding:
|
| 39 |
+
start_indx = int(target_P - torch.randint(low=self.padding,
|
| 40 |
+
high=(self.crop_length-self.padding),
|
| 41 |
+
size=(1,)))
|
| 42 |
+
if test == True:
|
| 43 |
+
start_indx = int(first_phase - 2*self.padding)
|
| 44 |
+
|
| 45 |
+
elif int(target_P-self.padding) > 0:
|
| 46 |
+
start_indx = int(target_P - torch.randint(low=0,
|
| 47 |
+
high=(int(target_P-self.padding)),
|
| 48 |
+
size=(1,)))
|
| 49 |
+
if test == True:
|
| 50 |
+
start_indx = int(target_P - self.padding)
|
| 51 |
+
else:
|
| 52 |
+
start_indx = self.padding
|
| 53 |
+
|
| 54 |
+
end_indx = start_indx + self.crop_length
|
| 55 |
+
|
| 56 |
+
if (sample.shape[-1] - end_indx) < 0:
|
| 57 |
+
start_indx += (sample.shape[-1] - end_indx)
|
| 58 |
+
end_indx = start_indx + self.crop_length
|
| 59 |
+
|
| 60 |
+
new_target_P = target_P - start_indx
|
| 61 |
+
new_target_S = target_S - start_indx
|
| 62 |
+
|
| 63 |
+
return start_indx, end_indx, new_target_P, new_target_S
|
| 64 |
+
|
| 65 |
+
def cut(self, sample, start_indx, end_indx):
|
| 66 |
+
sample_cropped = sample[:,start_indx:end_indx]
|
| 67 |
+
return sample_cropped
|
| 68 |
+
|
| 69 |
+
def bandpass_filter(self, sample_cropped, test):
|
| 70 |
+
# sample_cropped = detrend(sample_cropped)
|
| 71 |
+
if test == False:
|
| 72 |
+
probability = torch.randint(0,2, size=(1,)).item()
|
| 73 |
+
if probability==1:
|
| 74 |
+
lowcut = torch.FloatTensor(size=(1,)).uniform_(0.001, 1).item()
|
| 75 |
+
highcut = torch.FloatTensor(size=(1,)).uniform_(10, 49).item()
|
| 76 |
+
sample_cropped = self.butter_bandpass_filter(sample_cropped, lowcut=lowcut, highcut=highcut, fs=self.fs, order=self.order)
|
| 77 |
+
window = signal.windows.tukey(sample_cropped[-1].shape[0], alpha=0.1)
|
| 78 |
+
sample_cropped = sample_cropped*window
|
| 79 |
+
return sample_cropped
|
| 80 |
+
|
| 81 |
+
def add_z_component(self, sample_cropped):
|
| 82 |
+
if len(sample_cropped) < 3:
|
| 83 |
+
zeros = np.zeros((3, sample_cropped.shape[-1]))
|
| 84 |
+
zeros[0] = sample_cropped
|
| 85 |
+
sample_cropped = zeros
|
| 86 |
+
return sample_cropped
|
| 87 |
+
|
| 88 |
+
def rotate(self, sample_cropped, test):
|
| 89 |
+
if test == False:
|
| 90 |
+
probability = torch.randint(0,2, size=(1,)).item()
|
| 91 |
+
if probability==1:
|
| 92 |
+
angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item()
|
| 93 |
+
sample_cropped = self.rotate_waveform(sample_cropped, angle).real
|
| 94 |
+
return sample_cropped
|
| 95 |
+
|
| 96 |
+
def demean(self, sample_cropped):
|
| 97 |
+
# Subtracting mean from the data
|
| 98 |
+
sample_cropped = sample_cropped - np.mean(sample_cropped, axis=-1, keepdims=True)
|
| 99 |
+
return sample_cropped
|
| 100 |
+
|
| 101 |
+
def normalize(self, sample_cropped):
|
| 102 |
+
max_val = np.max(np.abs(sample_cropped))
|
| 103 |
+
sample_cropped_norm = sample_cropped/max_val
|
| 104 |
+
return sample_cropped_norm
|
| 105 |
+
|
| 106 |
+
def channel_dropout(self, sample_cropped_norm, test):
|
| 107 |
+
if test == False:
|
| 108 |
+
probability = torch.randint(0,2, size=(1,)).item()
|
| 109 |
+
channel = torch.randint(1,3, size=(1,)).item()
|
| 110 |
+
if probability == 1:
|
| 111 |
+
sample_cropped_norm[channel,:] = 1e-6
|
| 112 |
+
return sample_cropped_norm
|
| 113 |
+
|
| 114 |
+
def channel_shuffle(self, sample_cropped_norm, test):
|
| 115 |
+
if test == False:
|
| 116 |
+
probability = torch.randint(0, 2, size=(1,)).item()
|
| 117 |
+
if probability == 1:
|
| 118 |
+
shuffled_indices = torch.randperm(sample_cropped_norm.shape[0])
|
| 119 |
+
sample_cropped_norm = sample_cropped_norm[shuffled_indices, :]
|
| 120 |
+
return sample_cropped_norm
|
| 121 |
+
|
| 122 |
+
def apply(self, sample, target_P, target_S, test=False):
|
| 123 |
+
|
| 124 |
+
start_indx, end_indx, new_target_P, new_target_S = self.shuffle(sample, target_P, target_S, test)
|
| 125 |
+
|
| 126 |
+
sample_cropped = self.cut(sample, start_indx, end_indx)
|
| 127 |
+
sample_cropped = self.bandpass_filter(sample_cropped, test)
|
| 128 |
+
sample_cropped = self.add_z_component(sample_cropped)
|
| 129 |
+
sample_cropped = self.rotate(sample_cropped, test)
|
| 130 |
+
sample_cropped = self.demean(sample_cropped)
|
| 131 |
+
sample_cropped_norm = self.normalize(sample_cropped)
|
| 132 |
+
|
| 133 |
+
sample_cropped_norm = self.channel_dropout(sample_cropped_norm, test)
|
| 134 |
+
sample_cropped_norm = self.channel_shuffle(sample_cropped_norm, test)
|
| 135 |
+
|
| 136 |
+
new_target_P = new_target_P/self.crop_length
|
| 137 |
+
new_target_S = new_target_S/self.crop_length
|
| 138 |
+
|
| 139 |
+
return sample_cropped_norm, new_target_P, new_target_S
|
| 140 |
+
|
| 141 |
+
class Waveforms_dataset(Dataset):
|
| 142 |
+
def __init__(self, meta, data, test=False, transform=None, augmentations=None):
|
| 143 |
+
# self.data_list = glob(data_path)
|
| 144 |
+
self.meta = meta
|
| 145 |
+
self.data = data
|
| 146 |
+
self.test = test
|
| 147 |
+
self.augmentations = augmentations
|
| 148 |
+
|
| 149 |
+
def __len__(self):
|
| 150 |
+
return len(self.meta)
|
| 151 |
+
|
| 152 |
+
def __getitem__(self, idx):
|
| 153 |
+
meta = self.meta.iloc[idx]
|
| 154 |
+
sample = self.data[meta.name]
|
| 155 |
+
|
| 156 |
+
target_P = float(meta.trace_P_final)
|
| 157 |
+
target_S = float(meta.trace_S_final)
|
| 158 |
+
|
| 159 |
+
if self.augmentations:
|
| 160 |
+
sample, target_P, target_S = self.augmentations.apply(sample, target_P, target_S, test=self.test)
|
| 161 |
+
|
| 162 |
+
# Setting labels to zero if they're not in the valid range or are NaNs
|
| 163 |
+
if (target_P <= 0) or (target_P >= 1) or (np.isnan(target_P)):
|
| 164 |
+
target_P = 0
|
| 165 |
+
if (target_S <= 0) or (target_S >= 1) or (np.isnan(target_S)):
|
| 166 |
+
target_S = 0
|
| 167 |
+
|
| 168 |
+
# If something went wrong
|
| 169 |
+
if np.isnan(sample).any():
|
| 170 |
+
sample = np.zeros((3, self.augmentations.crop_length))
|
| 171 |
+
target_P = 0
|
| 172 |
+
target_S = 0
|
| 173 |
+
|
| 174 |
+
# Convert to tensor
|
| 175 |
+
sample = torch.tensor(sample, dtype=torch.float)
|
| 176 |
+
target_P = torch.tensor(target_P, dtype=torch.float)
|
| 177 |
+
target_S = torch.tensor(target_S, dtype=torch.float)
|
| 178 |
+
|
| 179 |
+
return sample, target_P, target_S
|
phasehunter/main.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
|
| 3 |
+
app = FastAPI()
|
| 4 |
+
|
| 5 |
+
@app.get("/")
|
| 6 |
+
def read_root():
|
| 7 |
+
return {"Hello": "World"}
|
phasehunter/model.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union, Tuple, Any
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
from lightning import seed_everything
|
| 5 |
+
import lightning as pl
|
| 6 |
+
|
| 7 |
+
from masksembles import common
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torchmetrics import MeanAbsoluteError
|
| 13 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
| 14 |
+
|
| 15 |
+
from scipy.stats import gaussian_kde
|
| 16 |
+
from scipy.special import comb
|
| 17 |
+
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
import pandas as pd
|
| 20 |
+
|
| 21 |
+
from obspy import Stream
|
| 22 |
+
|
| 23 |
+
seed_everything(42, workers=False)
|
| 24 |
+
torch.set_float32_matmul_precision('medium')
|
| 25 |
+
|
| 26 |
+
class BlurPool1D(nn.Module):
|
| 27 |
+
"""Implements 1D version of blur pooling.
|
| 28 |
+
|
| 29 |
+
Attributes:
|
| 30 |
+
channels (int): Number of input channels.
|
| 31 |
+
pad_type (str): Type of padding (reflect, replicate, zero).
|
| 32 |
+
filt_size (int): Filter size for blur pooling.
|
| 33 |
+
stride (int): Stride size for downsampling.
|
| 34 |
+
pad_off (int): Padding offset.
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, channels: int, pad_type: str='reflect', filt_size: int=3, stride: int=2, pad_off: int=0):
|
| 37 |
+
super(BlurPool1D, self).__init__()
|
| 38 |
+
self.filt_size = filt_size
|
| 39 |
+
self.pad_off = pad_off
|
| 40 |
+
# Calculate padding sizes for the beginning and end of signal
|
| 41 |
+
self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
|
| 42 |
+
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
| 43 |
+
self.stride = stride
|
| 44 |
+
self.off = int((self.stride - 1) / 2.)
|
| 45 |
+
self.channels = channels
|
| 46 |
+
|
| 47 |
+
# Generate coefficients for the specified filter size using binomial coefficients
|
| 48 |
+
a = np.array([comb(filt_size-1, i, exact=False) for i in range(filt_size)])
|
| 49 |
+
|
| 50 |
+
filt = torch.Tensor(a)
|
| 51 |
+
filt = filt / torch.sum(filt) # normalize the filter
|
| 52 |
+
# Make the filter to have same size with number of channels
|
| 53 |
+
self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))
|
| 54 |
+
|
| 55 |
+
# Get the appropriate padding layer
|
| 56 |
+
self.pad = self.get_pad_layer_1d(pad_type)(self.pad_sizes)
|
| 57 |
+
|
| 58 |
+
def forward(self, inp):
|
| 59 |
+
"""Computes forward pass for blur pooling."""
|
| 60 |
+
if self.filt_size == 1:
|
| 61 |
+
if self.pad_off == 0:
|
| 62 |
+
return inp[:, :, ::self.stride]
|
| 63 |
+
else:
|
| 64 |
+
# Apply padding if pad_off is not zero
|
| 65 |
+
return self.pad(inp)[:, :, ::self.stride]
|
| 66 |
+
else:
|
| 67 |
+
# Convolve input with filter and then apply downsampling
|
| 68 |
+
return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
|
| 69 |
+
|
| 70 |
+
def get_pad_layer_1d(self, pad_type: str):
|
| 71 |
+
"""Returns appropriate padding layer based on the pad_type string.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
pad_type: Type of padding. It can be 'refl', 'reflect', 'repl', 'replicate', or 'zero'.
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Appropriate padding layer based on pad_type.
|
| 78 |
+
|
| 79 |
+
Raises:
|
| 80 |
+
ValueError: If pad_type is not recognized.
|
| 81 |
+
"""
|
| 82 |
+
# Define the padding layer depending on the input pad_type
|
| 83 |
+
if pad_type in ['refl', 'reflect']:
|
| 84 |
+
pad_layer = nn.ReflectionPad1d
|
| 85 |
+
elif pad_type in ['repl', 'replicate']:
|
| 86 |
+
pad_layer = nn.ReplicationPad1d
|
| 87 |
+
elif pad_type == 'zero':
|
| 88 |
+
pad_layer = nn.ZeroPad1d
|
| 89 |
+
else:
|
| 90 |
+
# Raise an error if pad_type is not recognized
|
| 91 |
+
raise ValueError(f"Pad type [{pad_type}] not recognized")
|
| 92 |
+
return pad_layer
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Masksembles1D(nn.Module):
|
| 96 |
+
"""Implements 1D version of Masksembles operation.
|
| 97 |
+
|
| 98 |
+
Masksembles operation applies different masks to the input in a way that allows the model to estimate uncertainty and confidence at inference time.
|
| 99 |
+
|
| 100 |
+
Attributes:
|
| 101 |
+
channels (int): Number of input channels.
|
| 102 |
+
n (int): Number of masks to generate.
|
| 103 |
+
scale (float): Scaling factor for masks.
|
| 104 |
+
"""
|
| 105 |
+
def __init__(self, channels: int, n: int, scale: float):
|
| 106 |
+
super().__init__()
|
| 107 |
+
|
| 108 |
+
self.channels = channels
|
| 109 |
+
self.n = n
|
| 110 |
+
self.scale = scale
|
| 111 |
+
|
| 112 |
+
# Generate masks using a provided function
|
| 113 |
+
masks = common.generation_wrapper(channels, n, scale)
|
| 114 |
+
masks = torch.from_numpy(masks)
|
| 115 |
+
|
| 116 |
+
# Convert masks into PyTorch Parameter and set it to not require gradient
|
| 117 |
+
self.masks = torch.nn.Parameter(masks, requires_grad=False)
|
| 118 |
+
|
| 119 |
+
def forward(self, inputs):
|
| 120 |
+
"""Computes forward pass for Masksembles operation.
|
| 121 |
+
|
| 122 |
+
The input is divided into multiple groups, each group is multiplied with a different mask, and then the results
|
| 123 |
+
are concatenated together.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
inputs (torch.Tensor): Input tensor.
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
torch.Tensor: Output tensor after applying Masksembles operation.
|
| 130 |
+
"""
|
| 131 |
+
# Number of samples in the batch
|
| 132 |
+
batch = inputs.shape[0]
|
| 133 |
+
|
| 134 |
+
# Divide the input into n groups along the batch dimension
|
| 135 |
+
x = torch.split(inputs.unsqueeze(1), batch // self.n, dim=0)
|
| 136 |
+
|
| 137 |
+
# Concatenate the groups along the new dimension and permute the dimensions
|
| 138 |
+
x = torch.cat(x, dim=1).permute([1, 0, 2, 3])
|
| 139 |
+
|
| 140 |
+
# Multiply each group with a different mask
|
| 141 |
+
x = x * self.masks.unsqueeze(1).unsqueeze(-1)
|
| 142 |
+
|
| 143 |
+
# Concatenate the results along the channel dimension
|
| 144 |
+
x = torch.cat(torch.split(x, 1, dim=0), dim=1)
|
| 145 |
+
|
| 146 |
+
# Remove the extra dimension and convert the tensor to the original data type
|
| 147 |
+
return x.squeeze(0).type(inputs.dtype)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class BasicBlock(nn.Module):
|
| 151 |
+
"""Implements a basic block of convolutions, a fundamental part of PhaseHunter.
|
| 152 |
+
|
| 153 |
+
A basic block consists of two convolutional layers, each followed by batch normalization. The output from the second
|
| 154 |
+
convolutional layer is added to the shortcut connection before applying an optional activation function.
|
| 155 |
+
|
| 156 |
+
Attributes:
|
| 157 |
+
in_planes (int): Number of input channels (also known as input planes).
|
| 158 |
+
planes (int): Number of output channels (also known as output planes or filters).
|
| 159 |
+
stride (int, optional): Stride size for convolution. Default is 1.
|
| 160 |
+
kernel_size (int, optional): Kernel size for convolution. Default is 7.
|
| 161 |
+
groups (int, optional): Number of groups for convolution. Default is 1.
|
| 162 |
+
do_activation (bool, optional): Whether to apply an activation function (ReLU) at the end. Introduced for embedding capture. Default is True.
|
| 163 |
+
"""
|
| 164 |
+
def __init__(self, in_planes: int, planes: int, stride: int = 1, kernel_size: int = 7, groups: int = 1, do_activation: bool = True):
|
| 165 |
+
super(BasicBlock, self).__init__()
|
| 166 |
+
|
| 167 |
+
self.do_activation = do_activation
|
| 168 |
+
|
| 169 |
+
# First convolutional layer
|
| 170 |
+
self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=kernel_size, stride=stride, padding='same', bias=False)
|
| 171 |
+
self.bn1 = nn.BatchNorm1d(planes)
|
| 172 |
+
|
| 173 |
+
# Second convolutional layer
|
| 174 |
+
self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=1, padding='same', bias=False)
|
| 175 |
+
self.bn2 = nn.BatchNorm1d(planes)
|
| 176 |
+
|
| 177 |
+
# Shortcut connection, used to match the dimensionality between input and output
|
| 178 |
+
self.shortcut = nn.Sequential(
|
| 179 |
+
nn.Conv1d(in_planes, planes, kernel_size=1, stride=stride, padding='same', bias=False),
|
| 180 |
+
nn.BatchNorm1d(planes)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
def forward(self, x):
|
| 184 |
+
"""Computes forward pass for the block.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
x (torch.Tensor): Input tensor.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
torch.Tensor: Output tensor after passing through the basic block.
|
| 191 |
+
"""
|
| 192 |
+
# Apply first convolution followed by ReLU activation
|
| 193 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 194 |
+
|
| 195 |
+
# Apply second convolution
|
| 196 |
+
out = self.bn2(self.conv2(out))
|
| 197 |
+
|
| 198 |
+
# Add the output of the shortcut connection
|
| 199 |
+
out += self.shortcut(x)
|
| 200 |
+
|
| 201 |
+
# Apply activation (it's here for the embedding)
|
| 202 |
+
if self.do_activation:
|
| 203 |
+
out = F.relu(out)
|
| 204 |
+
|
| 205 |
+
return out
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class PhaseHunter(pl.LightningModule):
|
| 210 |
+
"""Implements PhaseHunter model for seismic phase picking.
|
| 211 |
+
|
| 212 |
+
Attributes:
|
| 213 |
+
n_masks (int): Number of masks for Masksembles operation.
|
| 214 |
+
n_outs (int): Number of output units.
|
| 215 |
+
"""
|
| 216 |
+
def __init__(self, n_masks=128, n_outs=2):
|
| 217 |
+
super().__init__()
|
| 218 |
+
|
| 219 |
+
self.n_masks = 128
|
| 220 |
+
self.n_outs = n_outs
|
| 221 |
+
|
| 222 |
+
# Define sequential layers for block 1 to 9
|
| 223 |
+
# Each block consist of BasicBlock, GELU activation, BlurPool1D, and GroupNorm layers
|
| 224 |
+
# Blocks vary in the number of in and out features
|
| 225 |
+
|
| 226 |
+
self.block1 = nn.Sequential(
|
| 227 |
+
BasicBlock(3,8, kernel_size=7, groups=1),
|
| 228 |
+
nn.GELU(),
|
| 229 |
+
BlurPool1D(8, filt_size=3, stride=2),
|
| 230 |
+
nn.GroupNorm(2,8),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self.block2 = nn.Sequential(
|
| 234 |
+
BasicBlock(8, 16, kernel_size=7, groups=8),
|
| 235 |
+
nn.GELU(),
|
| 236 |
+
BlurPool1D(16, filt_size=3, stride=2),
|
| 237 |
+
nn.GroupNorm(2,16),
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
self.block3 = nn.Sequential(
|
| 241 |
+
BasicBlock(16,32, kernel_size=7, groups=16),
|
| 242 |
+
nn.GELU(),
|
| 243 |
+
BlurPool1D(32, filt_size=3, stride=2),
|
| 244 |
+
nn.GroupNorm(2,32),
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
self.block4 = nn.Sequential(
|
| 248 |
+
BasicBlock(32,64, kernel_size=7, groups=32),
|
| 249 |
+
nn.GELU(),
|
| 250 |
+
BlurPool1D(64, filt_size=3, stride=2),
|
| 251 |
+
nn.GroupNorm(2,64),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
self.block5 = nn.Sequential(
|
| 255 |
+
BasicBlock(64,128, kernel_size=7, groups=64),
|
| 256 |
+
nn.GELU(),
|
| 257 |
+
BlurPool1D(128, filt_size=3, stride=2),
|
| 258 |
+
nn.GroupNorm(2,128),
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
self.block6 = nn.Sequential(
|
| 262 |
+
Masksembles1D(128, self.n_masks, 2.0),
|
| 263 |
+
BasicBlock(128,256, kernel_size=7, groups=128),
|
| 264 |
+
nn.GELU(),
|
| 265 |
+
BlurPool1D(256, filt_size=3, stride=2),
|
| 266 |
+
nn.GroupNorm(2,256),
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
self.block7 = nn.Sequential(
|
| 270 |
+
Masksembles1D(256, self.n_masks, 2.0),
|
| 271 |
+
BasicBlock(256,512, kernel_size=7, groups=256),
|
| 272 |
+
BlurPool1D(512, filt_size=3, stride=2),
|
| 273 |
+
nn.GELU(),
|
| 274 |
+
nn.GroupNorm(2,512),
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self.block8 = nn.Sequential(
|
| 278 |
+
Masksembles1D(512, self.n_masks, 2.0),
|
| 279 |
+
BasicBlock(512,1024, kernel_size=7, groups=512),
|
| 280 |
+
BlurPool1D(1024, filt_size=3, stride=2),
|
| 281 |
+
nn.GELU(),
|
| 282 |
+
nn.GroupNorm(2,1024),
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
self.block9 = nn.Sequential(
|
| 286 |
+
Masksembles1D(1024, self.n_masks, 2.0),
|
| 287 |
+
BasicBlock(1024,128, kernel_size=7, groups=128, do_activation=False),
|
| 288 |
+
|
| 289 |
+
# Works better with those off on the last layer before regressor
|
| 290 |
+
# BlurPool1D(512, filt_size=3, stride=2),
|
| 291 |
+
# nn.GELU(),
|
| 292 |
+
# nn.GroupNorm(2,512),
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Final output layer with Sigmoid activation
|
| 296 |
+
self.out = nn.Sequential(
|
| 297 |
+
nn.LazyLinear(n_outs),
|
| 298 |
+
nn.Sigmoid()
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Save hyperparameters and initialize Mean Absolute Error loss
|
| 302 |
+
self.save_hyperparameters(ignore=['picker'])
|
| 303 |
+
self.mae = MeanAbsoluteError()
|
| 304 |
+
|
| 305 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 306 |
+
"""Computes forward pass for the model."""
|
| 307 |
+
# Feature extraction
|
| 308 |
+
x = self.block1(x)
|
| 309 |
+
x = self.block2(x)
|
| 310 |
+
|
| 311 |
+
x = self.block3(x)
|
| 312 |
+
x = self.block4(x)
|
| 313 |
+
|
| 314 |
+
x = self.block5(x)
|
| 315 |
+
x = self.block6(x)
|
| 316 |
+
|
| 317 |
+
x = self.block7(x)
|
| 318 |
+
x = self.block8(x)
|
| 319 |
+
|
| 320 |
+
x = self.block9(x)
|
| 321 |
+
|
| 322 |
+
# Regressor
|
| 323 |
+
embedding = x.flatten(start_dim=1)
|
| 324 |
+
x = self.out(F.relu(embedding))
|
| 325 |
+
|
| 326 |
+
return x, embedding
|
| 327 |
+
|
| 328 |
+
def compute_loss(self, y: torch.Tensor, pick: torch.Tensor, mae_name: Optional[Union[str, bool]] = False) -> torch.Tensor:
|
| 329 |
+
"""Computes loss for the predictions.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
y (torch.Tensor): The ground truth tensor.
|
| 333 |
+
pick (torch.Tensor): The predicted tensor.
|
| 334 |
+
mae_name (Union[str, bool], optional): The name for the Mean Absolute Error (MAE) metric.
|
| 335 |
+
If provided, it logs the MAE metric with the name 'MAE/{mae_name}_val'. Default is False.
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
torch.Tensor: The computed loss.
|
| 339 |
+
"""
|
| 340 |
+
# Filter non-zero values
|
| 341 |
+
y_filt = y[y != 0]
|
| 342 |
+
pick_filt = pick[y != 0]
|
| 343 |
+
|
| 344 |
+
# Compute L1 loss if there are non-zero values
|
| 345 |
+
if len(y_filt) > 0:
|
| 346 |
+
loss = F.l1_loss(y_filt, pick_filt.flatten())
|
| 347 |
+
|
| 348 |
+
# If mae_name is provided, log the MAE metric
|
| 349 |
+
if mae_name != False:
|
| 350 |
+
mae_phase = self.mae(y_filt, pick_filt.flatten())*30
|
| 351 |
+
self.log(f'MAE/{mae_name}_val', mae_phase, on_step=False, on_epoch=True, prog_bar=False)
|
| 352 |
+
else:
|
| 353 |
+
loss = 0
|
| 354 |
+
return loss
|
| 355 |
+
|
| 356 |
+
def get_likely_val(self, array: np.ndarray) -> Tuple[np.ndarray, gaussian_kde, torch.Tensor, float]:
|
| 357 |
+
"""Computes most likely value using Kernel Density Estimation.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
array (np.ndarray): The input array for which to compute the most likely value.
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
Tuple[np.ndarray, gaussian_kde, torch.Tensor, float]: A tuple containing
|
| 364 |
+
- the distribution space (dist_space),
|
| 365 |
+
- the Kernel Density Estimation (kde),
|
| 366 |
+
- the most likely value (val), and
|
| 367 |
+
- the uncertainty of the estimation.
|
| 368 |
+
"""
|
| 369 |
+
# Compute KDE for the input array
|
| 370 |
+
kde = gaussian_kde(array)
|
| 371 |
+
|
| 372 |
+
# Define the distribution space
|
| 373 |
+
dist_space = np.linspace(min(array)-0.001, max(array)+0.001, 512)
|
| 374 |
+
|
| 375 |
+
# Compute the most likely value and the uncertainty
|
| 376 |
+
val = torch.tensor(dist_space[np.argmax(kde(dist_space))], dtype=torch.float32)
|
| 377 |
+
uncertainty = dist_space.ptp()/2
|
| 378 |
+
|
| 379 |
+
return dist_space, kde, val, uncertainty
|
| 380 |
+
|
| 381 |
+
def align_and_pad_chunk(self, chunk, expected_samples):
|
| 382 |
+
"""
|
| 383 |
+
Align and pad seismic data in a chunk.
|
| 384 |
+
|
| 385 |
+
This function ensures that all traces in the chunk have the same start and end times
|
| 386 |
+
and are of the same length (as specified by expected_samples). If any trace is shorter than
|
| 387 |
+
expected_samples, it is padded with zeros.
|
| 388 |
+
|
| 389 |
+
Parameters:
|
| 390 |
+
- chunk (Stream): The seismic data chunk to be processed.
|
| 391 |
+
- expected_samples (int): The expected number of samples for each trace in the chunk.
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
- Stream: The aligned and padded seismic data chunk.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
# Get the latest start time and earliest end time among the traces
|
| 398 |
+
latest_start_time = max([trace.stats.starttime for trace in chunk])
|
| 399 |
+
earliest_end_time = min([trace.stats.endtime for trace in chunk])
|
| 400 |
+
|
| 401 |
+
for trace in chunk:
|
| 402 |
+
# Trim the trace to the new start and end times
|
| 403 |
+
trace.trim(starttime=latest_start_time, endtime=earliest_end_time, nearest_sample=True, pad=True, fill_value=0.0)
|
| 404 |
+
|
| 405 |
+
# Check the length of the trace data and pad with zeros if necessary
|
| 406 |
+
if len(trace.data) < expected_samples:
|
| 407 |
+
padding = expected_samples - len(trace.data)
|
| 408 |
+
trace.data = np.pad(trace.data, (0, padding), 'constant')
|
| 409 |
+
|
| 410 |
+
return chunk
|
| 411 |
+
|
| 412 |
+
def process_continuous_waveform(self, st: Stream) -> pd.DataFrame:
|
| 413 |
+
"""
|
| 414 |
+
Processes a continuous seismic waveform and predicts P and S wave arrival times using PhaseHunter.
|
| 415 |
+
|
| 416 |
+
Parameters:
|
| 417 |
+
-----------
|
| 418 |
+
st : Stream
|
| 419 |
+
The input seismic data as an ObsPy Stream object with three components.
|
| 420 |
+
|
| 421 |
+
Returns:
|
| 422 |
+
--------
|
| 423 |
+
pd.DataFrame
|
| 424 |
+
A DataFrame containing the following columns:
|
| 425 |
+
- p_time: Predicted P-wave arrival time.
|
| 426 |
+
- s_time: Predicted S-wave arrival time.
|
| 427 |
+
- p_uncert: Uncertainty associated with the P-wave prediction.
|
| 428 |
+
- s_uncert: Uncertainty associated with the S-wave prediction.
|
| 429 |
+
- embedding: Embedding representation of the chunk.
|
| 430 |
+
- p_conf: Confidence level of the P-wave prediction.
|
| 431 |
+
- s_conf: Confidence level of the S-wave prediction.
|
| 432 |
+
- p_time_rel: Relative P-wave arrival time in seconds from the start of the input stream.
|
| 433 |
+
- s_time_rel: Relative S-wave arrival time in seconds from the start of the input stream.
|
| 434 |
+
|
| 435 |
+
Notes:
|
| 436 |
+
------
|
| 437 |
+
The function assumes that the input Stream object has three components.
|
| 438 |
+
The neural network inference is performed on chunks of data of 30 seconds.
|
| 439 |
+
The output DataFrame is a result of aggregating predictions for each chunk and filtering duplicate rows.
|
| 440 |
+
|
| 441 |
+
Raises:
|
| 442 |
+
-------
|
| 443 |
+
AssertionError
|
| 444 |
+
If the input Stream object doesn't contain three components.
|
| 445 |
+
|
| 446 |
+
Examples:
|
| 447 |
+
---------
|
| 448 |
+
>>> from obspy import read
|
| 449 |
+
>>> st = read('path_to_your_waveform_data')
|
| 450 |
+
>>> predictions = process_continuous_waveform(st)
|
| 451 |
+
>>> print(predictions)
|
| 452 |
+
"""
|
| 453 |
+
assert len(st) == 3, 'For the moment, PhaseHunter works only with 3C input data'
|
| 454 |
+
|
| 455 |
+
start_time = st[0].stats.starttime
|
| 456 |
+
end_time = st[0].stats.endtime
|
| 457 |
+
|
| 458 |
+
chunk_size = 30
|
| 459 |
+
chunk_size_samples = int(chunk_size*st[0].stats.sampling_rate) + 1
|
| 460 |
+
|
| 461 |
+
chunks = []
|
| 462 |
+
predictions = pd.DataFrame()
|
| 463 |
+
|
| 464 |
+
for chunk_start in tqdm(np.arange(start_time, end_time, chunk_size)):
|
| 465 |
+
chunk_end = chunk_start + chunk_size
|
| 466 |
+
|
| 467 |
+
chunk = st.slice(chunk_start, chunk_end)
|
| 468 |
+
chunk = self.align_and_pad_chunk(chunk, expected_samples=chunk_size_samples)
|
| 469 |
+
|
| 470 |
+
# chunk_orig = np.vstack([x.data for x in chunk], dtype='float')[:,:-1]
|
| 471 |
+
chunk_orig = np.vstack([x.data for x in chunk])
|
| 472 |
+
chunk_orig = chunk_orig.astype('float')[:,:-1]
|
| 473 |
+
|
| 474 |
+
if chunk_orig.shape[-1] != chunk_size * 100:
|
| 475 |
+
continue
|
| 476 |
+
|
| 477 |
+
chunk = chunk_orig - chunk_orig.mean(axis=0)
|
| 478 |
+
max_val = np.max(np.abs(chunk))
|
| 479 |
+
chunk = chunk/max_val
|
| 480 |
+
|
| 481 |
+
chunk = torch.tensor(chunk, dtype=torch.float)
|
| 482 |
+
|
| 483 |
+
inference_sample = torch.stack([chunk]*128).to(self.device)
|
| 484 |
+
|
| 485 |
+
with torch.no_grad():
|
| 486 |
+
preds, embeddings = self(inference_sample)
|
| 487 |
+
|
| 488 |
+
p_pred = preds[:,0].detach().cpu()
|
| 489 |
+
s_pred = preds[:,1].detach().cpu()
|
| 490 |
+
embeddings = torch.mean(embeddings, axis=0).detach().cpu().numpy()
|
| 491 |
+
|
| 492 |
+
p_dist, p_kde, p_val, p_uncert = self.get_likely_val(p_pred)
|
| 493 |
+
s_dist, s_kde, s_val, s_uncert = self.get_likely_val(s_pred)
|
| 494 |
+
|
| 495 |
+
p_time = chunk_start+p_val.item()*chunk_size
|
| 496 |
+
s_time = chunk_start+s_val.item()*chunk_size
|
| 497 |
+
|
| 498 |
+
current_predictions = pd.DataFrame({'p_time': p_time, 's_time':s_time,
|
| 499 |
+
'p_uncert' : p_uncert, 's_uncert' : s_uncert,
|
| 500 |
+
'embedding' : [embeddings]})
|
| 501 |
+
|
| 502 |
+
predictions = pd.concat([predictions, current_predictions], ignore_index=True)
|
| 503 |
+
|
| 504 |
+
predictions = predictions.drop_duplicates(subset=['p_uncert', 's_uncert']).reset_index()
|
| 505 |
+
|
| 506 |
+
predictions['p_conf'] = 1/predictions['p_uncert']
|
| 507 |
+
predictions['s_conf'] = 1/predictions['s_uncert']
|
| 508 |
+
|
| 509 |
+
predictions['p_conf'] /= predictions['p_conf'].max()
|
| 510 |
+
predictions['s_conf'] /= predictions['s_conf'].max()
|
| 511 |
+
|
| 512 |
+
predictions['p_time_rel'] = predictions.p_time.apply(lambda x: pd.Timestamp(x.timestamp, unit='s') - pd.Timestamp(start_time.timestamp, unit='s')).dt.total_seconds()
|
| 513 |
+
predictions['s_time_rel'] = predictions.s_time.apply(lambda x: pd.Timestamp(x.timestamp, unit='s') - pd.Timestamp(start_time.timestamp, unit='s')).dt.total_seconds()
|
| 514 |
+
|
| 515 |
+
return predictions
|
| 516 |
+
|
| 517 |
+
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
| 518 |
+
"""
|
| 519 |
+
Defines a single step in the training loop for PhaseHunter.
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
batch (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing an input batch (x),
|
| 523 |
+
and the corresponding P-wave (y_p) and S-wave (y_s) target tensors.
|
| 524 |
+
batch_idx (int): The index of the current batch.
|
| 525 |
+
|
| 526 |
+
Returns:
|
| 527 |
+
torch.Tensor: The computed loss for this training step.
|
| 528 |
+
"""
|
| 529 |
+
# Unpack the batch
|
| 530 |
+
x, y_p, y_s = batch
|
| 531 |
+
|
| 532 |
+
# Perform forward pass and get predictions
|
| 533 |
+
picks, embedding = self(x)
|
| 534 |
+
|
| 535 |
+
# Extract P and S phase picks
|
| 536 |
+
p_pick = picks[:,0]
|
| 537 |
+
s_pick = picks[:,1]
|
| 538 |
+
|
| 539 |
+
# Compute losses for P and S phase picks
|
| 540 |
+
p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
|
| 541 |
+
s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
|
| 542 |
+
|
| 543 |
+
# Combine losses
|
| 544 |
+
loss = (p_loss+s_loss)/self.n_outs
|
| 545 |
+
|
| 546 |
+
# Log the loss
|
| 547 |
+
self.log('Loss/train', loss, on_step=True, on_epoch=False, prog_bar=True)
|
| 548 |
+
|
| 549 |
+
return loss
|
| 550 |
+
|
| 551 |
+
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
|
| 552 |
+
"""
|
| 553 |
+
Defines a single step in the validation loop for PhaseHunter.
|
| 554 |
+
|
| 555 |
+
Args:
|
| 556 |
+
batch (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing an input batch (x),
|
| 557 |
+
and the corresponding P-wave (y_p) and S-wave (y_s) target tensors.
|
| 558 |
+
batch_idx (int): The index of the current batch.
|
| 559 |
+
|
| 560 |
+
Returns:
|
| 561 |
+
torch.Tensor: The computed loss for this validation step.
|
| 562 |
+
"""
|
| 563 |
+
# Unpack the batch
|
| 564 |
+
x, y_p, y_s = batch
|
| 565 |
+
|
| 566 |
+
# Perform forward pass and get predictions
|
| 567 |
+
picks, embedding = self(x)
|
| 568 |
+
|
| 569 |
+
# Extract P and S phase picks
|
| 570 |
+
p_pick = picks[:,0]
|
| 571 |
+
s_pick = picks[:,1]
|
| 572 |
+
|
| 573 |
+
# Compute losses for P and S phase picks
|
| 574 |
+
p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
|
| 575 |
+
s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
|
| 576 |
+
|
| 577 |
+
# Combine losses
|
| 578 |
+
loss = (p_loss+s_loss)/self.n_outs
|
| 579 |
+
|
| 580 |
+
# Log the loss
|
| 581 |
+
self.log('Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False)
|
| 582 |
+
|
| 583 |
+
return loss
|
| 584 |
+
|
| 585 |
+
# def configure_optimizers(self) -> dict:
|
| 586 |
+
# """
|
| 587 |
+
# Defines the optimizer and scheduler for PhaseHunter.
|
| 588 |
+
|
| 589 |
+
# Returns:
|
| 590 |
+
# dict: A dictionary containing the optimizer, the learning rate scheduler, and the metric to monitor.
|
| 591 |
+
# """
|
| 592 |
+
# # Define the optimizer
|
| 593 |
+
# optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
| 594 |
+
|
| 595 |
+
# # Define the learning rate scheduler
|
| 596 |
+
# # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, cooldown=10, threshold=1e-6)
|
| 597 |
+
|
| 598 |
+
# # Define the metric to monitor
|
| 599 |
+
# # monitor = 'Loss/train'
|
| 600 |
+
|
| 601 |
+
# return {"optimizer": optimizer}#, "lr_scheduler": scheduler, 'monitor': monitor}
|
| 602 |
+
|
| 603 |
+
def configure_optimizers(self) -> dict:
|
| 604 |
+
"""
|
| 605 |
+
Defines the optimizer and scheduler for PhaseHunter.
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
dict: A dictionary containing the optimizer, the learning rate scheduler, and the metric to monitor.
|
| 609 |
+
"""
|
| 610 |
+
# Define the optimizer
|
| 611 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
| 612 |
+
|
| 613 |
+
# Total number of epochs for decay
|
| 614 |
+
decay_epochs = 100
|
| 615 |
+
|
| 616 |
+
# Total number of epochs including constant learning rate period
|
| 617 |
+
total_epochs = 200
|
| 618 |
+
|
| 619 |
+
# Final learning rate
|
| 620 |
+
final_lr = 1e-7
|
| 621 |
+
|
| 622 |
+
# Lambda function for learning rate schedule
|
| 623 |
+
def lambda_func(epoch):
|
| 624 |
+
if epoch < decay_epochs:
|
| 625 |
+
return 1.0 # constant learning rate
|
| 626 |
+
else:
|
| 627 |
+
epoch_adjusted = epoch - decay_epochs
|
| 628 |
+
return 1 - epoch_adjusted/decay_epochs + (final_lr/1e-3)*epoch_adjusted/decay_epochs
|
| 629 |
+
|
| 630 |
+
# Define the learning rate scheduler
|
| 631 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_func)
|
| 632 |
+
|
| 633 |
+
# Define the metric to monitor
|
| 634 |
+
# monitor = 'Loss/train'
|
| 635 |
+
|
| 636 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
|