crimeacs commited on
Commit
d265965
·
1 Parent(s): 9a9b499
phasehunter/.ipynb_checkpoints/dataloader-checkpoint.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from torchvision.transforms import functional as F
3
+ import torch
4
+ import numpy as np
5
+ from scipy import signal
6
+ from functools import reduce
7
+ from scipy.signal import butter, lfilter, detrend
8
+
9
+ class Augmentations:
10
+ def __init__(self, padding=120, crop_length=6000, fs=100, lowcut=0.2, highcut=40, order=5):
11
+ self.padding = padding
12
+ self.crop_length = crop_length
13
+ self.fs = fs
14
+ self.lowcut = lowcut
15
+ self.highcut = highcut
16
+ self.order = order
17
+
18
+ b, a = self.butter_bandpass(self.lowcut, self.highcut, self.fs, self.order)
19
+ self.filter_b = b
20
+ self.filter_a = a
21
+
22
+ def butter_bandpass(self, lowcut, highcut, fs, order=5):
23
+ return butter(order, [lowcut, highcut], fs=fs, btype='band')
24
+
25
+ def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5):
26
+ b, a = self.butter_bandpass(lowcut, highcut, fs, order=order)
27
+ y = lfilter(self.filter_b, self.filter_a, data)
28
+ return y
29
+
30
+ def rotate_waveform(self, waveform, angle):
31
+ fft_waveform = np.fft.fft(waveform)
32
+ rotate_factor = np.exp(1j * angle)
33
+ rotated_fft_waveform = fft_waveform * rotate_factor
34
+ rotated_waveform = np.fft.ifft(rotated_fft_waveform)
35
+ return rotated_waveform
36
+
37
+ def shuffle(self, sample, target_P, target_S, test):
38
+ if target_P - (self.crop_length-self.padding) > self.padding:
39
+ start_indx = int(target_P - torch.randint(low=self.padding,
40
+ high=(self.crop_length-self.padding),
41
+ size=(1,)))
42
+ if test == True:
43
+ start_indx = int(first_phase - 2*self.padding)
44
+
45
+ elif int(target_P-self.padding) > 0:
46
+ start_indx = int(target_P - torch.randint(low=0,
47
+ high=(int(target_P-self.padding)),
48
+ size=(1,)))
49
+ if test == True:
50
+ start_indx = int(target_P - self.padding)
51
+ else:
52
+ start_indx = self.padding
53
+
54
+ end_indx = start_indx + self.crop_length
55
+
56
+ if (sample.shape[-1] - end_indx) < 0:
57
+ start_indx += (sample.shape[-1] - end_indx)
58
+ end_indx = start_indx + self.crop_length
59
+
60
+ new_target_P = target_P - start_indx
61
+ new_target_S = target_S - start_indx
62
+
63
+ return start_indx, end_indx, new_target_P, new_target_S
64
+
65
+ def cut(self, sample, start_indx, end_indx):
66
+ sample_cropped = sample[:,start_indx:end_indx]
67
+ return sample_cropped
68
+
69
+ def preprocess(self, sample_cropped):
70
+ # sample_cropped = detrend(sample_cropped)
71
+ sample_cropped = self.butter_bandpass_filter(sample_cropped, lowcut=self.lowcut, highcut=self.highcut, fs=self.fs, order=self.order)
72
+ window = signal.windows.tukey(sample_cropped[-1].shape[0], alpha=0.1)
73
+ sample_cropped = sample_cropped*window
74
+ return sample_cropped
75
+
76
+ def add_z_component(self, sample_cropped):
77
+ if len(sample_cropped) < 3:
78
+ zeros = np.zeros((3, sample_cropped.shape[-1]))
79
+ zeros[0] = sample_cropped
80
+ sample_cropped = zeros
81
+ return sample_cropped
82
+
83
+ def rotate(self, sample_cropped, test):
84
+ if test == False:
85
+ probability = torch.randint(0,2, size=(1,)).item()
86
+ if probability==1:
87
+ angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item()
88
+ sample_cropped = self.rotate_waveform(sample_cropped, angle).real
89
+ return sample_cropped
90
+
91
+ def normalize(self, sample_cropped):
92
+ max_val = np.max(np.abs(sample_cropped))
93
+ sample_cropped_norm = sample_cropped/max_val
94
+ return sample_cropped_norm
95
+
96
+ def channel_dropout(self, sample_cropped_norm, test):
97
+ if test == False:
98
+ probability = torch.randint(0,2, size=(1,)).item()
99
+ channel = torch.randint(1,3, size=(1,)).item()
100
+ if probability==1:
101
+ sample_cropped_norm[channel,:] = 1e-6
102
+ return sample_cropped_norm
103
+
104
+ def apply(self, sample, target_P, target_S, test=False):
105
+
106
+ start_indx, end_indx, new_target_P, new_target_S = self.shuffle(sample, target_P, target_S, test)
107
+
108
+ sample_cropped = self.cut(sample, start_indx, end_indx)
109
+ # sample_cropped = self.preprocess(sample_cropped)
110
+ sample_cropped = self.add_z_component(sample_cropped)
111
+ sample_cropped = self.rotate(sample_cropped, test)
112
+ sample_cropped_norm = self.normalize(sample_cropped)
113
+ sample_cropped_norm = self.channel_dropout(sample_cropped_norm, test)
114
+
115
+ new_target_P = new_target_P/self.crop_length
116
+ new_target_S = new_target_S/self.crop_length
117
+
118
+ return sample_cropped_norm, new_target_P, new_target_S
119
+
120
+ class Waveforms_dataset(Dataset):
121
+ def __init__(self, meta, data, test=False, transform=None, augmentations=None):
122
+ # self.data_list = glob(data_path)
123
+ self.meta = meta
124
+ self.data = data
125
+ self.test = test
126
+ self.augmentations = augmentations
127
+
128
+ def __len__(self):
129
+ return len(self.meta)
130
+
131
+ def __getitem__(self, idx):
132
+ meta = self.meta.iloc[idx]
133
+ sample = self.data[meta.name]
134
+
135
+ target_P = float(meta.trace_P_final)
136
+ target_S = float(meta.trace_S_final)
137
+
138
+ if self.augmentations:
139
+ sample, target_P, target_S = self.augmentations.apply(sample, target_P, target_S, test=self.test)
140
+
141
+ # Setting labels to zero if they're not in the valid range or are NaNs
142
+ if (target_P <= 0) or (target_P >= 1) or (np.isnan(target_P)):
143
+ target_P = 0
144
+ if (target_S <= 0) or (target_S >= 1) or (np.isnan(target_S)):
145
+ target_S = 0
146
+
147
+ # If something went wrong
148
+ if np.isnan(sample).any():
149
+ sample = np.zeros((3, self.augmentations.crop_length))
150
+ target_P = 0
151
+ target_S = 0
152
+
153
+ # Convert to tensor
154
+ sample = torch.tensor(sample, dtype=torch.float)
155
+ target_P = torch.tensor(target_P, dtype=torch.float)
156
+ target_S = torch.tensor(target_S, dtype=torch.float)
157
+
158
+ return sample, target_P, target_S
phasehunter/.ipynb_checkpoints/model-checkpoint.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple, Any
2
+ import math
3
+
4
+ from lightning import seed_everything
5
+ import lightning as pl
6
+
7
+ from masksembles import common
8
+ import torch
9
+ import numpy as np
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torchmetrics import MeanAbsoluteError
13
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
14
+
15
+ from scipy.stats import gaussian_kde
16
+ from scipy.special import comb
17
+
18
+ from tqdm.auto import tqdm
19
+ import pandas as pd
20
+
21
+ from obspy import Stream
22
+
23
+ seed_everything(42, workers=False)
24
+ torch.set_float32_matmul_precision('medium')
25
+
26
+ class BlurPool1D(nn.Module):
27
+ """Implements 1D version of blur pooling.
28
+
29
+ Attributes:
30
+ channels (int): Number of input channels.
31
+ pad_type (str): Type of padding (reflect, replicate, zero).
32
+ filt_size (int): Filter size for blur pooling.
33
+ stride (int): Stride size for downsampling.
34
+ pad_off (int): Padding offset.
35
+ """
36
+ def __init__(self, channels: int, pad_type: str='reflect', filt_size: int=3, stride: int=2, pad_off: int=0):
37
+ super(BlurPool1D, self).__init__()
38
+ self.filt_size = filt_size
39
+ self.pad_off = pad_off
40
+ # Calculate padding sizes for the beginning and end of signal
41
+ self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
42
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
43
+ self.stride = stride
44
+ self.off = int((self.stride - 1) / 2.)
45
+ self.channels = channels
46
+
47
+ # Generate coefficients for the specified filter size using binomial coefficients
48
+ a = np.array([comb(filt_size-1, i, exact=False) for i in range(filt_size)])
49
+
50
+ filt = torch.Tensor(a)
51
+ filt = filt / torch.sum(filt) # normalize the filter
52
+ # Make the filter to have same size with number of channels
53
+ self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))
54
+
55
+ # Get the appropriate padding layer
56
+ self.pad = self.get_pad_layer_1d(pad_type)(self.pad_sizes)
57
+
58
+ def forward(self, inp):
59
+ """Computes forward pass for blur pooling."""
60
+ if self.filt_size == 1:
61
+ if self.pad_off == 0:
62
+ return inp[:, :, ::self.stride]
63
+ else:
64
+ # Apply padding if pad_off is not zero
65
+ return self.pad(inp)[:, :, ::self.stride]
66
+ else:
67
+ # Convolve input with filter and then apply downsampling
68
+ return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
69
+
70
+ def get_pad_layer_1d(self, pad_type: str):
71
+ """Returns appropriate padding layer based on the pad_type string.
72
+
73
+ Args:
74
+ pad_type: Type of padding. It can be 'refl', 'reflect', 'repl', 'replicate', or 'zero'.
75
+
76
+ Returns:
77
+ Appropriate padding layer based on pad_type.
78
+
79
+ Raises:
80
+ ValueError: If pad_type is not recognized.
81
+ """
82
+ # Define the padding layer depending on the input pad_type
83
+ if pad_type in ['refl', 'reflect']:
84
+ pad_layer = nn.ReflectionPad1d
85
+ elif pad_type in ['repl', 'replicate']:
86
+ pad_layer = nn.ReplicationPad1d
87
+ elif pad_type == 'zero':
88
+ pad_layer = nn.ZeroPad1d
89
+ else:
90
+ # Raise an error if pad_type is not recognized
91
+ raise ValueError(f"Pad type [{pad_type}] not recognized")
92
+ return pad_layer
93
+
94
+
95
+ class Masksembles1D(nn.Module):
96
+ """Implements 1D version of Masksembles operation.
97
+
98
+ Masksembles operation applies different masks to the input in a way that allows the model to estimate uncertainty and confidence at inference time.
99
+
100
+ Attributes:
101
+ channels (int): Number of input channels.
102
+ n (int): Number of masks to generate.
103
+ scale (float): Scaling factor for masks.
104
+ """
105
+ def __init__(self, channels: int, n: int, scale: float):
106
+ super().__init__()
107
+
108
+ self.channels = channels
109
+ self.n = n
110
+ self.scale = scale
111
+
112
+ # Generate masks using a provided function
113
+ masks = common.generation_wrapper(channels, n, scale)
114
+ masks = torch.from_numpy(masks)
115
+
116
+ # Convert masks into PyTorch Parameter and set it to not require gradient
117
+ self.masks = torch.nn.Parameter(masks, requires_grad=False)
118
+
119
+ def forward(self, inputs):
120
+ """Computes forward pass for Masksembles operation.
121
+
122
+ The input is divided into multiple groups, each group is multiplied with a different mask, and then the results
123
+ are concatenated together.
124
+
125
+ Args:
126
+ inputs (torch.Tensor): Input tensor.
127
+
128
+ Returns:
129
+ torch.Tensor: Output tensor after applying Masksembles operation.
130
+ """
131
+ # Number of samples in the batch
132
+ batch = inputs.shape[0]
133
+
134
+ # Divide the input into n groups along the batch dimension
135
+ x = torch.split(inputs.unsqueeze(1), batch // self.n, dim=0)
136
+
137
+ # Concatenate the groups along the new dimension and permute the dimensions
138
+ x = torch.cat(x, dim=1).permute([1, 0, 2, 3])
139
+
140
+ # Multiply each group with a different mask
141
+ x = x * self.masks.unsqueeze(1).unsqueeze(-1)
142
+
143
+ # Concatenate the results along the channel dimension
144
+ x = torch.cat(torch.split(x, 1, dim=0), dim=1)
145
+
146
+ # Remove the extra dimension and convert the tensor to the original data type
147
+ return x.squeeze(0).type(inputs.dtype)
148
+
149
+
150
+ class BasicBlock(nn.Module):
151
+ """Implements a basic block of convolutions, a fundamental part of PhaseHunter.
152
+
153
+ A basic block consists of two convolutional layers, each followed by batch normalization. The output from the second
154
+ convolutional layer is added to the shortcut connection before applying an optional activation function.
155
+
156
+ Attributes:
157
+ in_planes (int): Number of input channels (also known as input planes).
158
+ planes (int): Number of output channels (also known as output planes or filters).
159
+ stride (int, optional): Stride size for convolution. Default is 1.
160
+ kernel_size (int, optional): Kernel size for convolution. Default is 7.
161
+ groups (int, optional): Number of groups for convolution. Default is 1.
162
+ do_activation (bool, optional): Whether to apply an activation function (ReLU) at the end. Introduced for embedding capture. Default is True.
163
+ """
164
+ def __init__(self, in_planes: int, planes: int, stride: int = 1, kernel_size: int = 7, groups: int = 1, do_activation: bool = True):
165
+ super(BasicBlock, self).__init__()
166
+
167
+ self.do_activation = do_activation
168
+
169
+ # First convolutional layer
170
+ self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=kernel_size, stride=stride, padding='same', bias=False)
171
+ self.bn1 = nn.BatchNorm1d(planes)
172
+
173
+ # Second convolutional layer
174
+ self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=1, padding='same', bias=False)
175
+ self.bn2 = nn.BatchNorm1d(planes)
176
+
177
+ # Shortcut connection, used to match the dimensionality between input and output
178
+ self.shortcut = nn.Sequential(
179
+ nn.Conv1d(in_planes, planes, kernel_size=1, stride=stride, padding='same', bias=False),
180
+ nn.BatchNorm1d(planes)
181
+ )
182
+
183
+ def forward(self, x):
184
+ """Computes forward pass for the block.
185
+
186
+ Args:
187
+ x (torch.Tensor): Input tensor.
188
+
189
+ Returns:
190
+ torch.Tensor: Output tensor after passing through the basic block.
191
+ """
192
+ # Apply first convolution followed by ReLU activation
193
+ out = F.relu(self.bn1(self.conv1(x)))
194
+
195
+ # Apply second convolution
196
+ out = self.bn2(self.conv2(out))
197
+
198
+ # Add the output of the shortcut connection
199
+ out += self.shortcut(x)
200
+
201
+ # Apply activation (it's here for the embedding)
202
+ if self.do_activation:
203
+ out = F.relu(out)
204
+
205
+ return out
206
+
207
+
208
+
209
+ class PhaseHunter(pl.LightningModule):
210
+ """Implements PhaseHunter model for seismic phase picking.
211
+
212
+ Attributes:
213
+ n_masks (int): Number of masks for Masksembles operation.
214
+ n_outs (int): Number of output units.
215
+ """
216
+ def __init__(self, n_masks=128, n_outs=2):
217
+ super().__init__()
218
+
219
+ self.n_masks = 128
220
+ self.n_outs = n_outs
221
+
222
+ # Define sequential layers for block 1 to 9
223
+ # Each block consist of BasicBlock, GELU activation, BlurPool1D, and GroupNorm layers
224
+ # Blocks vary in the number of in and out features
225
+
226
+ self.block1 = nn.Sequential(
227
+ BasicBlock(3,8, kernel_size=7, groups=1),
228
+ nn.GELU(),
229
+ BlurPool1D(8, filt_size=3, stride=2),
230
+ nn.GroupNorm(2,8),
231
+ )
232
+
233
+ self.block2 = nn.Sequential(
234
+ BasicBlock(8, 16, kernel_size=7, groups=8),
235
+ nn.GELU(),
236
+ BlurPool1D(16, filt_size=3, stride=2),
237
+ nn.GroupNorm(2,16),
238
+ )
239
+
240
+ self.block3 = nn.Sequential(
241
+ BasicBlock(16,32, kernel_size=7, groups=16),
242
+ nn.GELU(),
243
+ BlurPool1D(32, filt_size=3, stride=2),
244
+ nn.GroupNorm(2,32),
245
+ )
246
+
247
+ self.block4 = nn.Sequential(
248
+ BasicBlock(32,64, kernel_size=7, groups=32),
249
+ nn.GELU(),
250
+ BlurPool1D(64, filt_size=3, stride=2),
251
+ nn.GroupNorm(2,64),
252
+ )
253
+
254
+ self.block5 = nn.Sequential(
255
+ BasicBlock(64,128, kernel_size=7, groups=64),
256
+ nn.GELU(),
257
+ BlurPool1D(128, filt_size=3, stride=2),
258
+ nn.GroupNorm(2,128),
259
+ )
260
+
261
+ self.block6 = nn.Sequential(
262
+ Masksembles1D(128, self.n_masks, 2.0),
263
+ BasicBlock(128,256, kernel_size=7, groups=128),
264
+ nn.GELU(),
265
+ BlurPool1D(256, filt_size=3, stride=2),
266
+ nn.GroupNorm(2,256),
267
+ )
268
+
269
+ self.block7 = nn.Sequential(
270
+ Masksembles1D(256, self.n_masks, 2.0),
271
+ BasicBlock(256,512, kernel_size=7, groups=256),
272
+ BlurPool1D(512, filt_size=3, stride=2),
273
+ nn.GELU(),
274
+ nn.GroupNorm(2,512),
275
+ )
276
+
277
+ self.block8 = nn.Sequential(
278
+ Masksembles1D(512, self.n_masks, 2.0),
279
+ BasicBlock(512,1024, kernel_size=7, groups=512),
280
+ BlurPool1D(1024, filt_size=3, stride=2),
281
+ nn.GELU(),
282
+ nn.GroupNorm(2,1024),
283
+ )
284
+
285
+ self.block9 = nn.Sequential(
286
+ Masksembles1D(1024, self.n_masks, 2.0),
287
+ BasicBlock(1024,128, kernel_size=7, groups=128, do_activation=False),
288
+
289
+ # Works better with those off on the last layer before regressor
290
+ # BlurPool1D(512, filt_size=3, stride=2),
291
+ # nn.GELU(),
292
+ # nn.GroupNorm(2,512),
293
+ )
294
+
295
+ # Final output layer with Sigmoid activation
296
+ self.out = nn.Sequential(
297
+ nn.LazyLinear(n_outs),
298
+ nn.Sigmoid()
299
+ )
300
+
301
+ # Save hyperparameters and initialize Mean Absolute Error loss
302
+ self.save_hyperparameters(ignore=['picker'])
303
+ self.mae = MeanAbsoluteError()
304
+
305
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
306
+ """Computes forward pass for the model."""
307
+ # Feature extraction
308
+ x = self.block1(x)
309
+ x = self.block2(x)
310
+
311
+ x = self.block3(x)
312
+ x = self.block4(x)
313
+
314
+ x = self.block5(x)
315
+ x = self.block6(x)
316
+
317
+ x = self.block7(x)
318
+ x = self.block8(x)
319
+
320
+ x = self.block9(x)
321
+
322
+ # Regressor
323
+ embedding = x.flatten(start_dim=1)
324
+ x = self.out(F.relu(embedding))
325
+
326
+ return x, embedding
327
+
328
+ def compute_loss(self, y: torch.Tensor, pick: torch.Tensor, mae_name: Optional[Union[str, bool]] = False) -> torch.Tensor:
329
+ """Computes loss for the predictions.
330
+
331
+ Args:
332
+ y (torch.Tensor): The ground truth tensor.
333
+ pick (torch.Tensor): The predicted tensor.
334
+ mae_name (Union[str, bool], optional): The name for the Mean Absolute Error (MAE) metric.
335
+ If provided, it logs the MAE metric with the name 'MAE/{mae_name}_val'. Default is False.
336
+
337
+ Returns:
338
+ torch.Tensor: The computed loss.
339
+ """
340
+ # Filter non-zero values
341
+ y_filt = y[y != 0]
342
+ pick_filt = pick[y != 0]
343
+
344
+ # Compute L1 loss if there are non-zero values
345
+ if len(y_filt) > 0:
346
+ loss = F.l1_loss(y_filt, pick_filt.flatten())
347
+
348
+ # If mae_name is provided, log the MAE metric
349
+ if mae_name != False:
350
+ mae_phase = self.mae(y_filt, pick_filt.flatten())*30
351
+ self.log(f'MAE/{mae_name}_val', mae_phase, on_step=False, on_epoch=True, prog_bar=False)
352
+ else:
353
+ loss = 0
354
+ return loss
355
+
356
+ def get_likely_val(self, array: np.ndarray) -> Tuple[np.ndarray, gaussian_kde, torch.Tensor, float]:
357
+ """Computes most likely value using Kernel Density Estimation.
358
+
359
+ Args:
360
+ array (np.ndarray): The input array for which to compute the most likely value.
361
+
362
+ Returns:
363
+ Tuple[np.ndarray, gaussian_kde, torch.Tensor, float]: A tuple containing
364
+ - the distribution space (dist_space),
365
+ - the Kernel Density Estimation (kde),
366
+ - the most likely value (val), and
367
+ - the uncertainty of the estimation.
368
+ """
369
+ # Compute KDE for the input array
370
+ kde = gaussian_kde(array)
371
+
372
+ # Define the distribution space
373
+ dist_space = np.linspace(min(array)-0.001, max(array)+0.001, 512)
374
+
375
+ # Compute the most likely value and the uncertainty
376
+ val = torch.tensor(dist_space[np.argmax(kde(dist_space))], dtype=torch.float32)
377
+ uncertainty = dist_space.ptp()/2
378
+
379
+ return dist_space, kde, val, uncertainty
380
+
381
+ def process_continuous_waveform(self, st: Stream) -> pd.DataFrame:
382
+ """
383
+ Processes a continuous seismic waveform and predicts P and S wave arrival times using PhaseHunter.
384
+
385
+ Parameters:
386
+ -----------
387
+ st : Stream
388
+ The input seismic data as an ObsPy Stream object with three components.
389
+
390
+ Returns:
391
+ --------
392
+ pd.DataFrame
393
+ A DataFrame containing the following columns:
394
+ - p_time: Predicted P-wave arrival time.
395
+ - s_time: Predicted S-wave arrival time.
396
+ - p_uncert: Uncertainty associated with the P-wave prediction.
397
+ - s_uncert: Uncertainty associated with the S-wave prediction.
398
+ - embedding: Embedding representation of the chunk.
399
+ - p_conf: Confidence level of the P-wave prediction.
400
+ - s_conf: Confidence level of the S-wave prediction.
401
+ - p_time_rel: Relative P-wave arrival time in seconds from the start of the input stream.
402
+ - s_time_rel: Relative S-wave arrival time in seconds from the start of the input stream.
403
+
404
+ Notes:
405
+ ------
406
+ The function assumes that the input Stream object has three components.
407
+ The neural network inference is performed on chunks of data of 30 seconds.
408
+ The output DataFrame is a result of aggregating predictions for each chunk and filtering duplicate rows.
409
+
410
+ Raises:
411
+ -------
412
+ AssertionError
413
+ If the input Stream object doesn't contain three components.
414
+
415
+ Examples:
416
+ ---------
417
+ >>> from obspy import read
418
+ >>> st = read('path_to_your_waveform_data')
419
+ >>> predictions = process_continuous_waveform(st)
420
+ >>> print(predictions)
421
+ """
422
+ assert len(st) == 3, 'For the moment, PhaseHunter works only with 3C input data'
423
+
424
+ start_time = st[0].stats.starttime
425
+ end_time = st[0].stats.endtime
426
+
427
+ chunk_size = 30
428
+
429
+ chunks = []
430
+ predictions = pd.DataFrame()
431
+
432
+ for chunk_start in tqdm(np.arange(start_time, end_time, chunk_size)):
433
+ chunk_end = chunk_start + chunk_size
434
+
435
+ chunk = st.slice(chunk_start, chunk_end)
436
+
437
+ # chunk_orig = np.vstack([x.data for x in chunk], dtype='float')[:,:-1]
438
+ chunk_orig = np.vstack([x.data for x in chunk])
439
+ chunk_orig = chunk_orig.astype('float')[:,:-1]
440
+
441
+ if chunk_orig.shape[-1] != chunk_size * 100:
442
+ continue
443
+
444
+ chunk = chunk_orig - chunk_orig.mean(axis=0)
445
+ max_val = np.max(np.abs(chunk))
446
+ chunk = chunk/max_val
447
+
448
+ chunk = torch.tensor(chunk, dtype=torch.float)
449
+
450
+ inference_sample = torch.stack([chunk]*128).to(self.device)
451
+
452
+ with torch.no_grad():
453
+ preds, embeddings = self(inference_sample)
454
+
455
+ p_pred = preds[:,0].detach().cpu()
456
+ s_pred = preds[:,1].detach().cpu()
457
+ embeddings = torch.mean(embeddings, axis=0).detach().cpu().numpy()
458
+
459
+ p_dist, p_kde, p_val, p_uncert = self.get_likely_val(p_pred)
460
+ s_dist, s_kde, s_val, s_uncert = self.get_likely_val(s_pred)
461
+
462
+ p_time = chunk_start+p_val.item()*chunk_size
463
+ s_time = chunk_start+s_val.item()*chunk_size
464
+
465
+ current_predictions = pd.DataFrame({'p_time': p_time, 's_time':s_time,
466
+ 'p_uncert' : p_uncert, 's_uncert' : s_uncert,
467
+ 'embedding' : [embeddings]})
468
+
469
+ predictions = pd.concat([predictions, current_predictions], ignore_index=True)
470
+
471
+ predictions = predictions.drop_duplicates(subset=['p_uncert', 's_uncert']).reset_index()
472
+
473
+ predictions['p_conf'] = 1/predictions['p_uncert']
474
+ predictions['s_conf'] = 1/predictions['s_uncert']
475
+
476
+ predictions['p_conf'] /= predictions['p_conf'].max()
477
+ predictions['s_conf'] /= predictions['s_conf'].max()
478
+
479
+ predictions['p_time_rel'] = (predictions.p_time.apply(lambda x: pd.Timestamp(x.timestamp, unit='s')) - pd.Timestamp(predictions.p_time.iloc[0].date)).dt.total_seconds()
480
+ predictions['s_time_rel'] = (predictions.s_time.apply(lambda x: pd.Timestamp(x.timestamp, unit='s')) - pd.Timestamp(predictions.s_time.iloc[0].date)).dt.total_seconds()
481
+
482
+ return predictions
483
+
484
+ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
485
+ """
486
+ Defines a single step in the training loop for PhaseHunter.
487
+
488
+ Args:
489
+ batch (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing an input batch (x),
490
+ and the corresponding P-wave (y_p) and S-wave (y_s) target tensors.
491
+ batch_idx (int): The index of the current batch.
492
+
493
+ Returns:
494
+ torch.Tensor: The computed loss for this training step.
495
+ """
496
+ # Unpack the batch
497
+ x, y_p, y_s = batch
498
+
499
+ # Perform forward pass and get predictions
500
+ picks, embedding = self(x)
501
+
502
+ # Extract P and S phase picks
503
+ p_pick = picks[:,0]
504
+ s_pick = picks[:,1]
505
+
506
+ # Compute losses for P and S phase picks
507
+ p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
508
+ s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
509
+
510
+ # Combine losses
511
+ loss = (p_loss+s_loss)/self.n_outs
512
+
513
+ # Log the loss
514
+ self.log('Loss/train', loss, on_step=True, on_epoch=False, prog_bar=True)
515
+
516
+ return loss
517
+
518
+ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
519
+ """
520
+ Defines a single step in the validation loop for PhaseHunter.
521
+
522
+ Args:
523
+ batch (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing an input batch (x),
524
+ and the corresponding P-wave (y_p) and S-wave (y_s) target tensors.
525
+ batch_idx (int): The index of the current batch.
526
+
527
+ Returns:
528
+ torch.Tensor: The computed loss for this validation step.
529
+ """
530
+ # Unpack the batch
531
+ x, y_p, y_s = batch
532
+
533
+ # Perform forward pass and get predictions
534
+ picks, embedding = self(x)
535
+
536
+ # Extract P and S phase picks
537
+ p_pick = picks[:,0]
538
+ s_pick = picks[:,1]
539
+
540
+ # Compute losses for P and S phase picks
541
+ p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
542
+ s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
543
+
544
+ # Combine losses
545
+ loss = (p_loss+s_loss)/self.n_outs
546
+
547
+ # Log the loss
548
+ self.log('Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False)
549
+
550
+ return loss
551
+
552
+ # def configure_optimizers(self) -> dict:
553
+ # """
554
+ # Defines the optimizer and scheduler for PhaseHunter.
555
+
556
+ # Returns:
557
+ # dict: A dictionary containing the optimizer, the learning rate scheduler, and the metric to monitor.
558
+ # """
559
+ # # Define the optimizer
560
+ # optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
561
+
562
+ # # Define the learning rate scheduler
563
+ # # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, cooldown=10, threshold=1e-6)
564
+
565
+ # # Define the metric to monitor
566
+ # # monitor = 'Loss/train'
567
+
568
+ # return {"optimizer": optimizer}#, "lr_scheduler": scheduler, 'monitor': monitor}
569
+
570
+ def configure_optimizers(self) -> dict:
571
+ """
572
+ Defines the optimizer and scheduler for PhaseHunter.
573
+
574
+ Returns:
575
+ dict: A dictionary containing the optimizer, the learning rate scheduler, and the metric to monitor.
576
+ """
577
+ # Define the optimizer
578
+ optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
579
+
580
+ # Total number of epochs for decay
581
+ decay_epochs = 100
582
+
583
+ # Total number of epochs including constant learning rate period
584
+ total_epochs = 200
585
+
586
+ # Final learning rate
587
+ final_lr = 1e-7
588
+
589
+ # Lambda function for learning rate schedule
590
+ def lambda_func(epoch):
591
+ if epoch < decay_epochs:
592
+ return 1.0 # constant learning rate
593
+ else:
594
+ epoch_adjusted = epoch - decay_epochs
595
+ return 1 - epoch_adjusted/decay_epochs + (final_lr/1e-3)*epoch_adjusted/decay_epochs
596
+
597
+ # Define the learning rate scheduler
598
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_func)
599
+
600
+ # Define the metric to monitor
601
+ # monitor = 'Loss/train'
602
+
603
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
604
+
605
+
606
+
phasehunter/__init__.py ADDED
File without changes
phasehunter/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (154 Bytes). View file
 
phasehunter/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (6.03 kB). View file
 
phasehunter/__pycache__/model.cpython-310.pyc ADDED
Binary file (17.3 kB). View file
 
phasehunter/dataloader.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ from torchvision.transforms import functional as F
3
+ import torch
4
+ import numpy as np
5
+ from scipy import signal
6
+ from functools import reduce
7
+ from scipy.signal import butter, lfilter, detrend
8
+
9
+ class Augmentations:
10
+ def __init__(self, padding=120, crop_length=6000, fs=100, lowcut=0.2, highcut=40, order=5):
11
+ self.padding = padding
12
+ self.crop_length = crop_length
13
+ self.fs = fs
14
+ self.lowcut = lowcut
15
+ self.highcut = highcut
16
+ self.order = order
17
+
18
+ b, a = self.butter_bandpass(self.lowcut, self.highcut, self.fs, self.order)
19
+ self.filter_b = b
20
+ self.filter_a = a
21
+
22
+ def butter_bandpass(self, lowcut, highcut, fs, order=5):
23
+ return butter(order, [lowcut, highcut], fs=fs, btype='band')
24
+
25
+ def butter_bandpass_filter(self, data, lowcut, highcut, fs, order=5):
26
+ b, a = self.butter_bandpass(lowcut, highcut, fs, order=order)
27
+ y = lfilter(self.filter_b, self.filter_a, data)
28
+ return y
29
+
30
+ def rotate_waveform(self, waveform, angle):
31
+ fft_waveform = np.fft.fft(waveform)
32
+ rotate_factor = np.exp(1j * angle)
33
+ rotated_fft_waveform = fft_waveform * rotate_factor
34
+ rotated_waveform = np.fft.ifft(rotated_fft_waveform)
35
+ return rotated_waveform
36
+
37
+ def shuffle(self, sample, target_P, target_S, test):
38
+ if target_P - (self.crop_length-self.padding) > self.padding:
39
+ start_indx = int(target_P - torch.randint(low=self.padding,
40
+ high=(self.crop_length-self.padding),
41
+ size=(1,)))
42
+ if test == True:
43
+ start_indx = int(first_phase - 2*self.padding)
44
+
45
+ elif int(target_P-self.padding) > 0:
46
+ start_indx = int(target_P - torch.randint(low=0,
47
+ high=(int(target_P-self.padding)),
48
+ size=(1,)))
49
+ if test == True:
50
+ start_indx = int(target_P - self.padding)
51
+ else:
52
+ start_indx = self.padding
53
+
54
+ end_indx = start_indx + self.crop_length
55
+
56
+ if (sample.shape[-1] - end_indx) < 0:
57
+ start_indx += (sample.shape[-1] - end_indx)
58
+ end_indx = start_indx + self.crop_length
59
+
60
+ new_target_P = target_P - start_indx
61
+ new_target_S = target_S - start_indx
62
+
63
+ return start_indx, end_indx, new_target_P, new_target_S
64
+
65
+ def cut(self, sample, start_indx, end_indx):
66
+ sample_cropped = sample[:,start_indx:end_indx]
67
+ return sample_cropped
68
+
69
+ def bandpass_filter(self, sample_cropped, test):
70
+ # sample_cropped = detrend(sample_cropped)
71
+ if test == False:
72
+ probability = torch.randint(0,2, size=(1,)).item()
73
+ if probability==1:
74
+ lowcut = torch.FloatTensor(size=(1,)).uniform_(0.001, 1).item()
75
+ highcut = torch.FloatTensor(size=(1,)).uniform_(10, 49).item()
76
+ sample_cropped = self.butter_bandpass_filter(sample_cropped, lowcut=lowcut, highcut=highcut, fs=self.fs, order=self.order)
77
+ window = signal.windows.tukey(sample_cropped[-1].shape[0], alpha=0.1)
78
+ sample_cropped = sample_cropped*window
79
+ return sample_cropped
80
+
81
+ def add_z_component(self, sample_cropped):
82
+ if len(sample_cropped) < 3:
83
+ zeros = np.zeros((3, sample_cropped.shape[-1]))
84
+ zeros[0] = sample_cropped
85
+ sample_cropped = zeros
86
+ return sample_cropped
87
+
88
+ def rotate(self, sample_cropped, test):
89
+ if test == False:
90
+ probability = torch.randint(0,2, size=(1,)).item()
91
+ if probability==1:
92
+ angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item()
93
+ sample_cropped = self.rotate_waveform(sample_cropped, angle).real
94
+ return sample_cropped
95
+
96
+ def demean(self, sample_cropped):
97
+ # Subtracting mean from the data
98
+ sample_cropped = sample_cropped - np.mean(sample_cropped, axis=-1, keepdims=True)
99
+ return sample_cropped
100
+
101
+ def normalize(self, sample_cropped):
102
+ max_val = np.max(np.abs(sample_cropped))
103
+ sample_cropped_norm = sample_cropped/max_val
104
+ return sample_cropped_norm
105
+
106
+ def channel_dropout(self, sample_cropped_norm, test):
107
+ if test == False:
108
+ probability = torch.randint(0,2, size=(1,)).item()
109
+ channel = torch.randint(1,3, size=(1,)).item()
110
+ if probability == 1:
111
+ sample_cropped_norm[channel,:] = 1e-6
112
+ return sample_cropped_norm
113
+
114
+ def channel_shuffle(self, sample_cropped_norm, test):
115
+ if test == False:
116
+ probability = torch.randint(0, 2, size=(1,)).item()
117
+ if probability == 1:
118
+ shuffled_indices = torch.randperm(sample_cropped_norm.shape[0])
119
+ sample_cropped_norm = sample_cropped_norm[shuffled_indices, :]
120
+ return sample_cropped_norm
121
+
122
+ def apply(self, sample, target_P, target_S, test=False):
123
+
124
+ start_indx, end_indx, new_target_P, new_target_S = self.shuffle(sample, target_P, target_S, test)
125
+
126
+ sample_cropped = self.cut(sample, start_indx, end_indx)
127
+ sample_cropped = self.bandpass_filter(sample_cropped, test)
128
+ sample_cropped = self.add_z_component(sample_cropped)
129
+ sample_cropped = self.rotate(sample_cropped, test)
130
+ sample_cropped = self.demean(sample_cropped)
131
+ sample_cropped_norm = self.normalize(sample_cropped)
132
+
133
+ sample_cropped_norm = self.channel_dropout(sample_cropped_norm, test)
134
+ sample_cropped_norm = self.channel_shuffle(sample_cropped_norm, test)
135
+
136
+ new_target_P = new_target_P/self.crop_length
137
+ new_target_S = new_target_S/self.crop_length
138
+
139
+ return sample_cropped_norm, new_target_P, new_target_S
140
+
141
+ class Waveforms_dataset(Dataset):
142
+ def __init__(self, meta, data, test=False, transform=None, augmentations=None):
143
+ # self.data_list = glob(data_path)
144
+ self.meta = meta
145
+ self.data = data
146
+ self.test = test
147
+ self.augmentations = augmentations
148
+
149
+ def __len__(self):
150
+ return len(self.meta)
151
+
152
+ def __getitem__(self, idx):
153
+ meta = self.meta.iloc[idx]
154
+ sample = self.data[meta.name]
155
+
156
+ target_P = float(meta.trace_P_final)
157
+ target_S = float(meta.trace_S_final)
158
+
159
+ if self.augmentations:
160
+ sample, target_P, target_S = self.augmentations.apply(sample, target_P, target_S, test=self.test)
161
+
162
+ # Setting labels to zero if they're not in the valid range or are NaNs
163
+ if (target_P <= 0) or (target_P >= 1) or (np.isnan(target_P)):
164
+ target_P = 0
165
+ if (target_S <= 0) or (target_S >= 1) or (np.isnan(target_S)):
166
+ target_S = 0
167
+
168
+ # If something went wrong
169
+ if np.isnan(sample).any():
170
+ sample = np.zeros((3, self.augmentations.crop_length))
171
+ target_P = 0
172
+ target_S = 0
173
+
174
+ # Convert to tensor
175
+ sample = torch.tensor(sample, dtype=torch.float)
176
+ target_P = torch.tensor(target_P, dtype=torch.float)
177
+ target_S = torch.tensor(target_S, dtype=torch.float)
178
+
179
+ return sample, target_P, target_S
phasehunter/main.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+
3
+ app = FastAPI()
4
+
5
+ @app.get("/")
6
+ def read_root():
7
+ return {"Hello": "World"}
phasehunter/model.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Tuple, Any
2
+ import math
3
+
4
+ from lightning import seed_everything
5
+ import lightning as pl
6
+
7
+ from masksembles import common
8
+ import torch
9
+ import numpy as np
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torchmetrics import MeanAbsoluteError
13
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
14
+
15
+ from scipy.stats import gaussian_kde
16
+ from scipy.special import comb
17
+
18
+ from tqdm.auto import tqdm
19
+ import pandas as pd
20
+
21
+ from obspy import Stream
22
+
23
+ seed_everything(42, workers=False)
24
+ torch.set_float32_matmul_precision('medium')
25
+
26
+ class BlurPool1D(nn.Module):
27
+ """Implements 1D version of blur pooling.
28
+
29
+ Attributes:
30
+ channels (int): Number of input channels.
31
+ pad_type (str): Type of padding (reflect, replicate, zero).
32
+ filt_size (int): Filter size for blur pooling.
33
+ stride (int): Stride size for downsampling.
34
+ pad_off (int): Padding offset.
35
+ """
36
+ def __init__(self, channels: int, pad_type: str='reflect', filt_size: int=3, stride: int=2, pad_off: int=0):
37
+ super(BlurPool1D, self).__init__()
38
+ self.filt_size = filt_size
39
+ self.pad_off = pad_off
40
+ # Calculate padding sizes for the beginning and end of signal
41
+ self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
42
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
43
+ self.stride = stride
44
+ self.off = int((self.stride - 1) / 2.)
45
+ self.channels = channels
46
+
47
+ # Generate coefficients for the specified filter size using binomial coefficients
48
+ a = np.array([comb(filt_size-1, i, exact=False) for i in range(filt_size)])
49
+
50
+ filt = torch.Tensor(a)
51
+ filt = filt / torch.sum(filt) # normalize the filter
52
+ # Make the filter to have same size with number of channels
53
+ self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))
54
+
55
+ # Get the appropriate padding layer
56
+ self.pad = self.get_pad_layer_1d(pad_type)(self.pad_sizes)
57
+
58
+ def forward(self, inp):
59
+ """Computes forward pass for blur pooling."""
60
+ if self.filt_size == 1:
61
+ if self.pad_off == 0:
62
+ return inp[:, :, ::self.stride]
63
+ else:
64
+ # Apply padding if pad_off is not zero
65
+ return self.pad(inp)[:, :, ::self.stride]
66
+ else:
67
+ # Convolve input with filter and then apply downsampling
68
+ return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
69
+
70
+ def get_pad_layer_1d(self, pad_type: str):
71
+ """Returns appropriate padding layer based on the pad_type string.
72
+
73
+ Args:
74
+ pad_type: Type of padding. It can be 'refl', 'reflect', 'repl', 'replicate', or 'zero'.
75
+
76
+ Returns:
77
+ Appropriate padding layer based on pad_type.
78
+
79
+ Raises:
80
+ ValueError: If pad_type is not recognized.
81
+ """
82
+ # Define the padding layer depending on the input pad_type
83
+ if pad_type in ['refl', 'reflect']:
84
+ pad_layer = nn.ReflectionPad1d
85
+ elif pad_type in ['repl', 'replicate']:
86
+ pad_layer = nn.ReplicationPad1d
87
+ elif pad_type == 'zero':
88
+ pad_layer = nn.ZeroPad1d
89
+ else:
90
+ # Raise an error if pad_type is not recognized
91
+ raise ValueError(f"Pad type [{pad_type}] not recognized")
92
+ return pad_layer
93
+
94
+
95
+ class Masksembles1D(nn.Module):
96
+ """Implements 1D version of Masksembles operation.
97
+
98
+ Masksembles operation applies different masks to the input in a way that allows the model to estimate uncertainty and confidence at inference time.
99
+
100
+ Attributes:
101
+ channels (int): Number of input channels.
102
+ n (int): Number of masks to generate.
103
+ scale (float): Scaling factor for masks.
104
+ """
105
+ def __init__(self, channels: int, n: int, scale: float):
106
+ super().__init__()
107
+
108
+ self.channels = channels
109
+ self.n = n
110
+ self.scale = scale
111
+
112
+ # Generate masks using a provided function
113
+ masks = common.generation_wrapper(channels, n, scale)
114
+ masks = torch.from_numpy(masks)
115
+
116
+ # Convert masks into PyTorch Parameter and set it to not require gradient
117
+ self.masks = torch.nn.Parameter(masks, requires_grad=False)
118
+
119
+ def forward(self, inputs):
120
+ """Computes forward pass for Masksembles operation.
121
+
122
+ The input is divided into multiple groups, each group is multiplied with a different mask, and then the results
123
+ are concatenated together.
124
+
125
+ Args:
126
+ inputs (torch.Tensor): Input tensor.
127
+
128
+ Returns:
129
+ torch.Tensor: Output tensor after applying Masksembles operation.
130
+ """
131
+ # Number of samples in the batch
132
+ batch = inputs.shape[0]
133
+
134
+ # Divide the input into n groups along the batch dimension
135
+ x = torch.split(inputs.unsqueeze(1), batch // self.n, dim=0)
136
+
137
+ # Concatenate the groups along the new dimension and permute the dimensions
138
+ x = torch.cat(x, dim=1).permute([1, 0, 2, 3])
139
+
140
+ # Multiply each group with a different mask
141
+ x = x * self.masks.unsqueeze(1).unsqueeze(-1)
142
+
143
+ # Concatenate the results along the channel dimension
144
+ x = torch.cat(torch.split(x, 1, dim=0), dim=1)
145
+
146
+ # Remove the extra dimension and convert the tensor to the original data type
147
+ return x.squeeze(0).type(inputs.dtype)
148
+
149
+
150
+ class BasicBlock(nn.Module):
151
+ """Implements a basic block of convolutions, a fundamental part of PhaseHunter.
152
+
153
+ A basic block consists of two convolutional layers, each followed by batch normalization. The output from the second
154
+ convolutional layer is added to the shortcut connection before applying an optional activation function.
155
+
156
+ Attributes:
157
+ in_planes (int): Number of input channels (also known as input planes).
158
+ planes (int): Number of output channels (also known as output planes or filters).
159
+ stride (int, optional): Stride size for convolution. Default is 1.
160
+ kernel_size (int, optional): Kernel size for convolution. Default is 7.
161
+ groups (int, optional): Number of groups for convolution. Default is 1.
162
+ do_activation (bool, optional): Whether to apply an activation function (ReLU) at the end. Introduced for embedding capture. Default is True.
163
+ """
164
+ def __init__(self, in_planes: int, planes: int, stride: int = 1, kernel_size: int = 7, groups: int = 1, do_activation: bool = True):
165
+ super(BasicBlock, self).__init__()
166
+
167
+ self.do_activation = do_activation
168
+
169
+ # First convolutional layer
170
+ self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=kernel_size, stride=stride, padding='same', bias=False)
171
+ self.bn1 = nn.BatchNorm1d(planes)
172
+
173
+ # Second convolutional layer
174
+ self.conv2 = nn.Conv1d(planes, planes, kernel_size=kernel_size, stride=1, padding='same', bias=False)
175
+ self.bn2 = nn.BatchNorm1d(planes)
176
+
177
+ # Shortcut connection, used to match the dimensionality between input and output
178
+ self.shortcut = nn.Sequential(
179
+ nn.Conv1d(in_planes, planes, kernel_size=1, stride=stride, padding='same', bias=False),
180
+ nn.BatchNorm1d(planes)
181
+ )
182
+
183
+ def forward(self, x):
184
+ """Computes forward pass for the block.
185
+
186
+ Args:
187
+ x (torch.Tensor): Input tensor.
188
+
189
+ Returns:
190
+ torch.Tensor: Output tensor after passing through the basic block.
191
+ """
192
+ # Apply first convolution followed by ReLU activation
193
+ out = F.relu(self.bn1(self.conv1(x)))
194
+
195
+ # Apply second convolution
196
+ out = self.bn2(self.conv2(out))
197
+
198
+ # Add the output of the shortcut connection
199
+ out += self.shortcut(x)
200
+
201
+ # Apply activation (it's here for the embedding)
202
+ if self.do_activation:
203
+ out = F.relu(out)
204
+
205
+ return out
206
+
207
+
208
+
209
+ class PhaseHunter(pl.LightningModule):
210
+ """Implements PhaseHunter model for seismic phase picking.
211
+
212
+ Attributes:
213
+ n_masks (int): Number of masks for Masksembles operation.
214
+ n_outs (int): Number of output units.
215
+ """
216
+ def __init__(self, n_masks=128, n_outs=2):
217
+ super().__init__()
218
+
219
+ self.n_masks = 128
220
+ self.n_outs = n_outs
221
+
222
+ # Define sequential layers for block 1 to 9
223
+ # Each block consist of BasicBlock, GELU activation, BlurPool1D, and GroupNorm layers
224
+ # Blocks vary in the number of in and out features
225
+
226
+ self.block1 = nn.Sequential(
227
+ BasicBlock(3,8, kernel_size=7, groups=1),
228
+ nn.GELU(),
229
+ BlurPool1D(8, filt_size=3, stride=2),
230
+ nn.GroupNorm(2,8),
231
+ )
232
+
233
+ self.block2 = nn.Sequential(
234
+ BasicBlock(8, 16, kernel_size=7, groups=8),
235
+ nn.GELU(),
236
+ BlurPool1D(16, filt_size=3, stride=2),
237
+ nn.GroupNorm(2,16),
238
+ )
239
+
240
+ self.block3 = nn.Sequential(
241
+ BasicBlock(16,32, kernel_size=7, groups=16),
242
+ nn.GELU(),
243
+ BlurPool1D(32, filt_size=3, stride=2),
244
+ nn.GroupNorm(2,32),
245
+ )
246
+
247
+ self.block4 = nn.Sequential(
248
+ BasicBlock(32,64, kernel_size=7, groups=32),
249
+ nn.GELU(),
250
+ BlurPool1D(64, filt_size=3, stride=2),
251
+ nn.GroupNorm(2,64),
252
+ )
253
+
254
+ self.block5 = nn.Sequential(
255
+ BasicBlock(64,128, kernel_size=7, groups=64),
256
+ nn.GELU(),
257
+ BlurPool1D(128, filt_size=3, stride=2),
258
+ nn.GroupNorm(2,128),
259
+ )
260
+
261
+ self.block6 = nn.Sequential(
262
+ Masksembles1D(128, self.n_masks, 2.0),
263
+ BasicBlock(128,256, kernel_size=7, groups=128),
264
+ nn.GELU(),
265
+ BlurPool1D(256, filt_size=3, stride=2),
266
+ nn.GroupNorm(2,256),
267
+ )
268
+
269
+ self.block7 = nn.Sequential(
270
+ Masksembles1D(256, self.n_masks, 2.0),
271
+ BasicBlock(256,512, kernel_size=7, groups=256),
272
+ BlurPool1D(512, filt_size=3, stride=2),
273
+ nn.GELU(),
274
+ nn.GroupNorm(2,512),
275
+ )
276
+
277
+ self.block8 = nn.Sequential(
278
+ Masksembles1D(512, self.n_masks, 2.0),
279
+ BasicBlock(512,1024, kernel_size=7, groups=512),
280
+ BlurPool1D(1024, filt_size=3, stride=2),
281
+ nn.GELU(),
282
+ nn.GroupNorm(2,1024),
283
+ )
284
+
285
+ self.block9 = nn.Sequential(
286
+ Masksembles1D(1024, self.n_masks, 2.0),
287
+ BasicBlock(1024,128, kernel_size=7, groups=128, do_activation=False),
288
+
289
+ # Works better with those off on the last layer before regressor
290
+ # BlurPool1D(512, filt_size=3, stride=2),
291
+ # nn.GELU(),
292
+ # nn.GroupNorm(2,512),
293
+ )
294
+
295
+ # Final output layer with Sigmoid activation
296
+ self.out = nn.Sequential(
297
+ nn.LazyLinear(n_outs),
298
+ nn.Sigmoid()
299
+ )
300
+
301
+ # Save hyperparameters and initialize Mean Absolute Error loss
302
+ self.save_hyperparameters(ignore=['picker'])
303
+ self.mae = MeanAbsoluteError()
304
+
305
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
306
+ """Computes forward pass for the model."""
307
+ # Feature extraction
308
+ x = self.block1(x)
309
+ x = self.block2(x)
310
+
311
+ x = self.block3(x)
312
+ x = self.block4(x)
313
+
314
+ x = self.block5(x)
315
+ x = self.block6(x)
316
+
317
+ x = self.block7(x)
318
+ x = self.block8(x)
319
+
320
+ x = self.block9(x)
321
+
322
+ # Regressor
323
+ embedding = x.flatten(start_dim=1)
324
+ x = self.out(F.relu(embedding))
325
+
326
+ return x, embedding
327
+
328
+ def compute_loss(self, y: torch.Tensor, pick: torch.Tensor, mae_name: Optional[Union[str, bool]] = False) -> torch.Tensor:
329
+ """Computes loss for the predictions.
330
+
331
+ Args:
332
+ y (torch.Tensor): The ground truth tensor.
333
+ pick (torch.Tensor): The predicted tensor.
334
+ mae_name (Union[str, bool], optional): The name for the Mean Absolute Error (MAE) metric.
335
+ If provided, it logs the MAE metric with the name 'MAE/{mae_name}_val'. Default is False.
336
+
337
+ Returns:
338
+ torch.Tensor: The computed loss.
339
+ """
340
+ # Filter non-zero values
341
+ y_filt = y[y != 0]
342
+ pick_filt = pick[y != 0]
343
+
344
+ # Compute L1 loss if there are non-zero values
345
+ if len(y_filt) > 0:
346
+ loss = F.l1_loss(y_filt, pick_filt.flatten())
347
+
348
+ # If mae_name is provided, log the MAE metric
349
+ if mae_name != False:
350
+ mae_phase = self.mae(y_filt, pick_filt.flatten())*30
351
+ self.log(f'MAE/{mae_name}_val', mae_phase, on_step=False, on_epoch=True, prog_bar=False)
352
+ else:
353
+ loss = 0
354
+ return loss
355
+
356
+ def get_likely_val(self, array: np.ndarray) -> Tuple[np.ndarray, gaussian_kde, torch.Tensor, float]:
357
+ """Computes most likely value using Kernel Density Estimation.
358
+
359
+ Args:
360
+ array (np.ndarray): The input array for which to compute the most likely value.
361
+
362
+ Returns:
363
+ Tuple[np.ndarray, gaussian_kde, torch.Tensor, float]: A tuple containing
364
+ - the distribution space (dist_space),
365
+ - the Kernel Density Estimation (kde),
366
+ - the most likely value (val), and
367
+ - the uncertainty of the estimation.
368
+ """
369
+ # Compute KDE for the input array
370
+ kde = gaussian_kde(array)
371
+
372
+ # Define the distribution space
373
+ dist_space = np.linspace(min(array)-0.001, max(array)+0.001, 512)
374
+
375
+ # Compute the most likely value and the uncertainty
376
+ val = torch.tensor(dist_space[np.argmax(kde(dist_space))], dtype=torch.float32)
377
+ uncertainty = dist_space.ptp()/2
378
+
379
+ return dist_space, kde, val, uncertainty
380
+
381
+ def align_and_pad_chunk(self, chunk, expected_samples):
382
+ """
383
+ Align and pad seismic data in a chunk.
384
+
385
+ This function ensures that all traces in the chunk have the same start and end times
386
+ and are of the same length (as specified by expected_samples). If any trace is shorter than
387
+ expected_samples, it is padded with zeros.
388
+
389
+ Parameters:
390
+ - chunk (Stream): The seismic data chunk to be processed.
391
+ - expected_samples (int): The expected number of samples for each trace in the chunk.
392
+
393
+ Returns:
394
+ - Stream: The aligned and padded seismic data chunk.
395
+ """
396
+
397
+ # Get the latest start time and earliest end time among the traces
398
+ latest_start_time = max([trace.stats.starttime for trace in chunk])
399
+ earliest_end_time = min([trace.stats.endtime for trace in chunk])
400
+
401
+ for trace in chunk:
402
+ # Trim the trace to the new start and end times
403
+ trace.trim(starttime=latest_start_time, endtime=earliest_end_time, nearest_sample=True, pad=True, fill_value=0.0)
404
+
405
+ # Check the length of the trace data and pad with zeros if necessary
406
+ if len(trace.data) < expected_samples:
407
+ padding = expected_samples - len(trace.data)
408
+ trace.data = np.pad(trace.data, (0, padding), 'constant')
409
+
410
+ return chunk
411
+
412
+ def process_continuous_waveform(self, st: Stream) -> pd.DataFrame:
413
+ """
414
+ Processes a continuous seismic waveform and predicts P and S wave arrival times using PhaseHunter.
415
+
416
+ Parameters:
417
+ -----------
418
+ st : Stream
419
+ The input seismic data as an ObsPy Stream object with three components.
420
+
421
+ Returns:
422
+ --------
423
+ pd.DataFrame
424
+ A DataFrame containing the following columns:
425
+ - p_time: Predicted P-wave arrival time.
426
+ - s_time: Predicted S-wave arrival time.
427
+ - p_uncert: Uncertainty associated with the P-wave prediction.
428
+ - s_uncert: Uncertainty associated with the S-wave prediction.
429
+ - embedding: Embedding representation of the chunk.
430
+ - p_conf: Confidence level of the P-wave prediction.
431
+ - s_conf: Confidence level of the S-wave prediction.
432
+ - p_time_rel: Relative P-wave arrival time in seconds from the start of the input stream.
433
+ - s_time_rel: Relative S-wave arrival time in seconds from the start of the input stream.
434
+
435
+ Notes:
436
+ ------
437
+ The function assumes that the input Stream object has three components.
438
+ The neural network inference is performed on chunks of data of 30 seconds.
439
+ The output DataFrame is a result of aggregating predictions for each chunk and filtering duplicate rows.
440
+
441
+ Raises:
442
+ -------
443
+ AssertionError
444
+ If the input Stream object doesn't contain three components.
445
+
446
+ Examples:
447
+ ---------
448
+ >>> from obspy import read
449
+ >>> st = read('path_to_your_waveform_data')
450
+ >>> predictions = process_continuous_waveform(st)
451
+ >>> print(predictions)
452
+ """
453
+ assert len(st) == 3, 'For the moment, PhaseHunter works only with 3C input data'
454
+
455
+ start_time = st[0].stats.starttime
456
+ end_time = st[0].stats.endtime
457
+
458
+ chunk_size = 30
459
+ chunk_size_samples = int(chunk_size*st[0].stats.sampling_rate) + 1
460
+
461
+ chunks = []
462
+ predictions = pd.DataFrame()
463
+
464
+ for chunk_start in tqdm(np.arange(start_time, end_time, chunk_size)):
465
+ chunk_end = chunk_start + chunk_size
466
+
467
+ chunk = st.slice(chunk_start, chunk_end)
468
+ chunk = self.align_and_pad_chunk(chunk, expected_samples=chunk_size_samples)
469
+
470
+ # chunk_orig = np.vstack([x.data for x in chunk], dtype='float')[:,:-1]
471
+ chunk_orig = np.vstack([x.data for x in chunk])
472
+ chunk_orig = chunk_orig.astype('float')[:,:-1]
473
+
474
+ if chunk_orig.shape[-1] != chunk_size * 100:
475
+ continue
476
+
477
+ chunk = chunk_orig - chunk_orig.mean(axis=0)
478
+ max_val = np.max(np.abs(chunk))
479
+ chunk = chunk/max_val
480
+
481
+ chunk = torch.tensor(chunk, dtype=torch.float)
482
+
483
+ inference_sample = torch.stack([chunk]*128).to(self.device)
484
+
485
+ with torch.no_grad():
486
+ preds, embeddings = self(inference_sample)
487
+
488
+ p_pred = preds[:,0].detach().cpu()
489
+ s_pred = preds[:,1].detach().cpu()
490
+ embeddings = torch.mean(embeddings, axis=0).detach().cpu().numpy()
491
+
492
+ p_dist, p_kde, p_val, p_uncert = self.get_likely_val(p_pred)
493
+ s_dist, s_kde, s_val, s_uncert = self.get_likely_val(s_pred)
494
+
495
+ p_time = chunk_start+p_val.item()*chunk_size
496
+ s_time = chunk_start+s_val.item()*chunk_size
497
+
498
+ current_predictions = pd.DataFrame({'p_time': p_time, 's_time':s_time,
499
+ 'p_uncert' : p_uncert, 's_uncert' : s_uncert,
500
+ 'embedding' : [embeddings]})
501
+
502
+ predictions = pd.concat([predictions, current_predictions], ignore_index=True)
503
+
504
+ predictions = predictions.drop_duplicates(subset=['p_uncert', 's_uncert']).reset_index()
505
+
506
+ predictions['p_conf'] = 1/predictions['p_uncert']
507
+ predictions['s_conf'] = 1/predictions['s_uncert']
508
+
509
+ predictions['p_conf'] /= predictions['p_conf'].max()
510
+ predictions['s_conf'] /= predictions['s_conf'].max()
511
+
512
+ predictions['p_time_rel'] = predictions.p_time.apply(lambda x: pd.Timestamp(x.timestamp, unit='s') - pd.Timestamp(start_time.timestamp, unit='s')).dt.total_seconds()
513
+ predictions['s_time_rel'] = predictions.s_time.apply(lambda x: pd.Timestamp(x.timestamp, unit='s') - pd.Timestamp(start_time.timestamp, unit='s')).dt.total_seconds()
514
+
515
+ return predictions
516
+
517
+ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
518
+ """
519
+ Defines a single step in the training loop for PhaseHunter.
520
+
521
+ Args:
522
+ batch (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing an input batch (x),
523
+ and the corresponding P-wave (y_p) and S-wave (y_s) target tensors.
524
+ batch_idx (int): The index of the current batch.
525
+
526
+ Returns:
527
+ torch.Tensor: The computed loss for this training step.
528
+ """
529
+ # Unpack the batch
530
+ x, y_p, y_s = batch
531
+
532
+ # Perform forward pass and get predictions
533
+ picks, embedding = self(x)
534
+
535
+ # Extract P and S phase picks
536
+ p_pick = picks[:,0]
537
+ s_pick = picks[:,1]
538
+
539
+ # Compute losses for P and S phase picks
540
+ p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
541
+ s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
542
+
543
+ # Combine losses
544
+ loss = (p_loss+s_loss)/self.n_outs
545
+
546
+ # Log the loss
547
+ self.log('Loss/train', loss, on_step=True, on_epoch=False, prog_bar=True)
548
+
549
+ return loss
550
+
551
+ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
552
+ """
553
+ Defines a single step in the validation loop for PhaseHunter.
554
+
555
+ Args:
556
+ batch (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing an input batch (x),
557
+ and the corresponding P-wave (y_p) and S-wave (y_s) target tensors.
558
+ batch_idx (int): The index of the current batch.
559
+
560
+ Returns:
561
+ torch.Tensor: The computed loss for this validation step.
562
+ """
563
+ # Unpack the batch
564
+ x, y_p, y_s = batch
565
+
566
+ # Perform forward pass and get predictions
567
+ picks, embedding = self(x)
568
+
569
+ # Extract P and S phase picks
570
+ p_pick = picks[:,0]
571
+ s_pick = picks[:,1]
572
+
573
+ # Compute losses for P and S phase picks
574
+ p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
575
+ s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
576
+
577
+ # Combine losses
578
+ loss = (p_loss+s_loss)/self.n_outs
579
+
580
+ # Log the loss
581
+ self.log('Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False)
582
+
583
+ return loss
584
+
585
+ # def configure_optimizers(self) -> dict:
586
+ # """
587
+ # Defines the optimizer and scheduler for PhaseHunter.
588
+
589
+ # Returns:
590
+ # dict: A dictionary containing the optimizer, the learning rate scheduler, and the metric to monitor.
591
+ # """
592
+ # # Define the optimizer
593
+ # optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
594
+
595
+ # # Define the learning rate scheduler
596
+ # # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, cooldown=10, threshold=1e-6)
597
+
598
+ # # Define the metric to monitor
599
+ # # monitor = 'Loss/train'
600
+
601
+ # return {"optimizer": optimizer}#, "lr_scheduler": scheduler, 'monitor': monitor}
602
+
603
+ def configure_optimizers(self) -> dict:
604
+ """
605
+ Defines the optimizer and scheduler for PhaseHunter.
606
+
607
+ Returns:
608
+ dict: A dictionary containing the optimizer, the learning rate scheduler, and the metric to monitor.
609
+ """
610
+ # Define the optimizer
611
+ optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
612
+
613
+ # Total number of epochs for decay
614
+ decay_epochs = 100
615
+
616
+ # Total number of epochs including constant learning rate period
617
+ total_epochs = 200
618
+
619
+ # Final learning rate
620
+ final_lr = 1e-7
621
+
622
+ # Lambda function for learning rate schedule
623
+ def lambda_func(epoch):
624
+ if epoch < decay_epochs:
625
+ return 1.0 # constant learning rate
626
+ else:
627
+ epoch_adjusted = epoch - decay_epochs
628
+ return 1 - epoch_adjusted/decay_epochs + (final_lr/1e-3)*epoch_adjusted/decay_epochs
629
+
630
+ # Define the learning rate scheduler
631
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_func)
632
+
633
+ # Define the metric to monitor
634
+ # monitor = 'Loss/train'
635
+
636
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
637
+
638
+
639
+