Barisylmz commited on
Commit
0dfdc08
·
verified ·
1 Parent(s): 6a2e592

Upload 4 files

Browse files
Files changed (4) hide show
  1. TimesNet.py +418 -0
  2. TimesNet_PointCloud.py +213 -0
  3. app.py +236 -0
  4. generate_samples_git.py +815 -0
TimesNet.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.fft
5
+ import numpy as np
6
+ # Basit embedding ve conv blocks - layers klasörü olmadan
7
+ class DataEmbedding(nn.Module):
8
+ def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1, seq_len=6000):
9
+ super(DataEmbedding, self).__init__()
10
+ self.c_in = c_in
11
+ self.d_model = d_model
12
+ self.embed_type = embed_type
13
+ self.freq = freq
14
+ self.seq_len = seq_len
15
+
16
+ # Basit linear embedding
17
+ self.value_embedding = nn.Linear(c_in, d_model)
18
+ # Position embedding'i seq_len'e göre oluştur
19
+ self.position_embedding = nn.Parameter(torch.randn(1, seq_len, d_model))
20
+ self.dropout = nn.Dropout(p=dropout)
21
+
22
+ def forward(self, x, x_mark):
23
+ x = self.value_embedding(x)
24
+
25
+ # Position embedding'i input boyutuna göre crop et
26
+ # seq_len'e göre oluşturulduğu için genelde uyumlu olacak
27
+ if x.size(1) <= self.position_embedding.size(1):
28
+ x = x + self.position_embedding[:, :x.size(1), :]
29
+ else:
30
+ # Eğer input daha büyükse, position embedding'i extend et
31
+ x = x + self.position_embedding
32
+ remaining_length = x.size(1) - self.position_embedding.size(1)
33
+ if remaining_length > 0:
34
+ # Sinusoidal position encoding ekle
35
+ pos_encoding = self._get_sinusoidal_encoding(remaining_length, self.d_model)
36
+ pos_encoding = pos_encoding.unsqueeze(0).to(x.device)
37
+ x[:, self.position_embedding.size(1):, :] += pos_encoding
38
+
39
+ return self.dropout(x)
40
+
41
+ def _get_sinusoidal_encoding(self, length, d_model):
42
+ """Sinusoidal position encoding oluştur"""
43
+ position = torch.arange(length).unsqueeze(1).float()
44
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
45
+
46
+ pos_encoding = torch.zeros(length, d_model)
47
+ pos_encoding[:, 0::2] = torch.sin(position * div_term)
48
+ pos_encoding[:, 1::2] = torch.cos(position * div_term)
49
+
50
+ return pos_encoding
51
+
52
+ class Inception_Block_V1(nn.Module):
53
+ def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
54
+ super(Inception_Block_V1, self).__init__()
55
+ self.in_channels = in_channels
56
+ self.out_channels = out_channels
57
+ self.num_kernels = num_kernels
58
+ kernels = []
59
+ for i in range(self.num_kernels):
60
+ kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
61
+ self.kernels = nn.ModuleList(kernels)
62
+ if init_weight:
63
+ self._initialize_weights()
64
+
65
+ def _initialize_weights(self):
66
+ for m in self.modules():
67
+ if isinstance(m, nn.Conv2d):
68
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
69
+ if m.bias is not None:
70
+ nn.init.constant_(m.bias, 0)
71
+
72
+ def forward(self, x):
73
+ res_list = []
74
+ for i, kernel in enumerate(self.kernels):
75
+ res_list.append(kernel(x))
76
+ res = torch.stack(res_list, dim=-1).mean(-1)
77
+ return res
78
+
79
+
80
+ def FFT_for_Period(x, k=2):
81
+ # [B, T, C]
82
+ xf = torch.fft.rfft(x, dim=1)
83
+ # find period by amplitudes
84
+ frequency_list = abs(xf).mean(0).mean(-1)
85
+ frequency_list[0] = 0
86
+ _, top_list = torch.topk(frequency_list, k)
87
+ top_list = top_list.detach().cpu().numpy()
88
+ period = x.shape[1] // top_list
89
+ return period, abs(xf).mean(-1)[:, top_list]
90
+
91
+
92
+ class TimesBlock(nn.Module):
93
+ def __init__(self, configs):
94
+ super(TimesBlock, self).__init__()
95
+ self.seq_len = configs.seq_len
96
+ self.pred_len = configs.pred_len
97
+ self.k = configs.top_k
98
+ # parameter-efficient design
99
+ self.conv = nn.Sequential(
100
+ Inception_Block_V1(configs.d_model, configs.d_ff,
101
+ num_kernels=configs.num_kernels),
102
+ nn.GELU(),
103
+ Inception_Block_V1(configs.d_ff, configs.d_model,
104
+ num_kernels=configs.num_kernels)
105
+ )
106
+
107
+ def forward(self, x):
108
+ B, T, N = x.size() #B: batch size T: length of time series N:number of features
109
+ period_list, period_weight = FFT_for_Period(x, self.k)
110
+
111
+ res = []
112
+ for i in range(self.k):
113
+ period = period_list[i]
114
+ # padding
115
+ if (self.seq_len + self.pred_len) % period != 0:
116
+ length = (
117
+ ((self.seq_len + self.pred_len) // period) + 1) * period
118
+ padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
119
+ out = torch.cat([x, padding], dim=1)
120
+ else:
121
+ length = (self.seq_len + self.pred_len)
122
+ out = x
123
+ # reshape
124
+ out = out.reshape(B, length // period, period,
125
+ N).permute(0, 3, 1, 2).contiguous()
126
+ # 2D conv: from 1d Variation to 2d Variation
127
+ out = self.conv(out)
128
+ # reshape back
129
+ out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
130
+ res.append(out[:, :(self.seq_len + self.pred_len), :])
131
+ res = torch.stack(res, dim=-1)
132
+ # adaptive aggregation
133
+ period_weight = F.softmax(period_weight, dim=1)
134
+ period_weight = period_weight.unsqueeze(
135
+ 1).unsqueeze(1).repeat(1, T, N, 1)
136
+ res = torch.sum(res * period_weight, -1)
137
+ # residual connection
138
+ res = res + x
139
+ return res
140
+
141
+
142
+ class Model(nn.Module):
143
+ """
144
+ Paper link: https://openreview.net/pdf?id=ju_Uqw384Oq
145
+ """
146
+
147
+ def __init__(self, configs):
148
+ super(Model, self).__init__()
149
+ self.configs = configs
150
+ self.task_name = configs.task_name
151
+ self.seq_len = configs.seq_len
152
+ self.label_len = configs.label_len
153
+ self.pred_len = configs.pred_len
154
+ self.model = nn.ModuleList([TimesBlock(configs)
155
+ for _ in range(configs.e_layers)])
156
+ self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
157
+ configs.dropout, configs.seq_len)
158
+ self.layer = configs.e_layers
159
+ self.layer_norm = nn.LayerNorm(configs.d_model)
160
+ if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
161
+ self.predict_linear = nn.Linear(
162
+ self.seq_len, self.pred_len + self.seq_len)
163
+ self.projection = nn.Linear(
164
+ configs.d_model, configs.c_out, bias=True)
165
+ if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
166
+ self.projection = nn.Linear(
167
+ configs.d_model, configs.c_out, bias=True)
168
+
169
+ # Transfer learning için P-S prediction heads (sadece gerektiğinde eklenir)
170
+ if hasattr(configs, 'use_ps_heads') and configs.use_ps_heads:
171
+ # Skip attention for memory efficiency - use only pooling
172
+
173
+ # Multi-scale feature extraction (reduced sizes for memory)
174
+ self.multi_scale_pools = nn.ModuleList([
175
+ nn.AdaptiveAvgPool1d(16), # Local patterns (reduced)
176
+ nn.AdaptiveAvgPool1d(4), # Medium patterns
177
+ nn.AdaptiveAvgPool1d(1), # Global patterns
178
+ ])
179
+
180
+ # Feature fusion - calculate exact dimension
181
+ # Pool sizes: 16 + 4 + 1 = 21, so total dim = d_model * 21
182
+ fusion_dim = configs.d_model * (16 + 4 + 1) # Exact calculation
183
+ self.feature_fusion = nn.Sequential(
184
+ nn.Linear(fusion_dim, configs.d_model),
185
+ nn.ReLU(),
186
+ nn.Dropout(configs.dropout)
187
+ )
188
+
189
+ # Separate P and S regression heads
190
+ self.p_regression_head = nn.Sequential(
191
+ nn.Linear(configs.d_model, 128),
192
+ nn.ReLU(),
193
+ nn.Dropout(configs.dropout),
194
+ nn.Linear(128, 64),
195
+ nn.ReLU(),
196
+ nn.Dropout(configs.dropout),
197
+ nn.Linear(64, 1) # P time only
198
+ )
199
+
200
+ self.s_regression_head = nn.Sequential(
201
+ nn.Linear(configs.d_model, 128),
202
+ nn.ReLU(),
203
+ nn.Dropout(configs.dropout),
204
+ nn.Linear(128, 64),
205
+ nn.ReLU(),
206
+ nn.Dropout(configs.dropout),
207
+ nn.Linear(64, 1) # S time only
208
+ )
209
+
210
+ # Separate P and S classification heads
211
+ self.p_classification_head = nn.Sequential(
212
+ nn.Linear(configs.d_model, 64),
213
+ nn.ReLU(),
214
+ nn.Dropout(configs.dropout),
215
+ nn.Linear(64, 32),
216
+ nn.ReLU(),
217
+ nn.Dropout(configs.dropout),
218
+ nn.Linear(32, 1), # P exists/not
219
+ nn.Sigmoid()
220
+ )
221
+
222
+ self.s_classification_head = nn.Sequential(
223
+ nn.Linear(configs.d_model, 64),
224
+ nn.ReLU(),
225
+ nn.Dropout(configs.dropout),
226
+ nn.Linear(64, 32),
227
+ nn.ReLU(),
228
+ nn.Dropout(configs.dropout),
229
+ nn.Linear(32, 1), # S exists/not
230
+ nn.Sigmoid()
231
+ )
232
+ if self.task_name == 'classification':
233
+ self.act = F.gelu
234
+ self.dropout = nn.Dropout(configs.dropout)
235
+ self.projection = nn.Linear(
236
+ configs.d_model * configs.seq_len, configs.num_class)
237
+
238
+ def anomaly_detection(self, x_enc):
239
+ # Transfer learning için P-S heads varsa - SADECE ONLARI KULLAN
240
+ if hasattr(self, 'p_regression_head'):
241
+ # Normalization from Non-stationary Transformer
242
+ means = x_enc.mean(1, keepdim=True).detach()
243
+ x_enc = x_enc - means
244
+ stdev = torch.sqrt(
245
+ torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
246
+ x_enc /= stdev
247
+
248
+ # embedding
249
+ enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
250
+ # TimesNet
251
+ for i in range(self.layer):
252
+ enc_out = self.layer_norm(self.model[i](enc_out))
253
+
254
+ # Skip attention for memory - use direct multi-scale pooling
255
+ # Multi-scale feature extraction directly on TimesNet output
256
+ enc_out_transposed = enc_out.permute(0, 2, 1) # (B, d_model, T)
257
+ multi_scale_features = []
258
+
259
+ # Manual pooling for large sequences to avoid CUDA memory issues
260
+ pool_sizes = [16, 4, 1] # Target pool sizes
261
+ for i, target_size in enumerate(pool_sizes):
262
+ T = enc_out_transposed.size(2) # Sequence length
263
+
264
+ if T >= 8000: # Very large - use manual avg pooling
265
+ # Manual average pooling
266
+ window_size = T // target_size
267
+ if window_size > 0:
268
+ # Reshape and average
269
+ # (B, d_model, T) -> (B, d_model, target_size, window_size)
270
+ trimmed_T = (T // window_size) * window_size
271
+ trimmed = enc_out_transposed[:, :, :trimmed_T]
272
+ reshaped = trimmed.view(trimmed.size(0), trimmed.size(1), target_size, window_size)
273
+ pooled = reshaped.mean(dim=3) # Average over window
274
+ else:
275
+ # Fallback: simple reshape
276
+ pooled = enc_out_transposed[:, :, :target_size] if T >= target_size else enc_out_transposed
277
+ else:
278
+ # Use normal adaptive pooling for smaller sequences
279
+ pool = self.multi_scale_pools[i]
280
+ pooled = pool(enc_out_transposed) # (B, d_model, pool_size)
281
+
282
+ flattened = pooled.flatten(1) # (B, d_model * pool_size)
283
+ multi_scale_features.append(flattened)
284
+
285
+ # Concatenate multi-scale features
286
+ fused_features = torch.cat(multi_scale_features, dim=1) # (B, d_model * 3)
287
+
288
+ # Feature fusion
289
+ final_features = self.feature_fusion(fused_features) # (B, d_model)
290
+
291
+ # Separate P and S predictions
292
+ p_time = self.p_regression_head(final_features) # (B, 1)
293
+ s_time = self.s_regression_head(final_features) # (B, 1)
294
+ ps_times = torch.cat([p_time, s_time], dim=1) # (B, 2)
295
+
296
+ # Separate P and S classifications
297
+ p_class = self.p_classification_head(final_features) # (B, 1)
298
+ s_class = self.s_classification_head(final_features) # (B, 1)
299
+ ps_classification = torch.cat([p_class, s_class], dim=1) # (B, 2)
300
+
301
+ return ps_times, ps_classification
302
+ else:
303
+ # Orijinal anomaly detection (reconstruction)
304
+ # Normalization from Non-stationary Transformer
305
+ means = x_enc.mean(1, keepdim=True).detach()
306
+ x_enc = x_enc - means
307
+ stdev = torch.sqrt(
308
+ torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
309
+ x_enc /= stdev
310
+
311
+ # embedding
312
+ enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
313
+ # TimesNet
314
+ for i in range(self.layer):
315
+ enc_out = self.layer_norm(self.model[i](enc_out))
316
+ # porject back
317
+ dec_out = self.projection(enc_out)
318
+
319
+ # De-Normalization from Non-stationary Transformer
320
+ dec_out = dec_out * \
321
+ (stdev[:, 0, :].unsqueeze(1).repeat(
322
+ 1, self.pred_len + self.seq_len, 1))
323
+ dec_out = dec_out + \
324
+ (means[:, 0, :].unsqueeze(1).repeat(
325
+ 1, self.pred_len + self.seq_len, 1))
326
+ return dec_out
327
+
328
+ def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
329
+ if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
330
+ dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
331
+ return dec_out[:, -self.pred_len:, :] # [B, L, D]
332
+ if self.task_name == 'imputation':
333
+ dec_out = self.imputation(
334
+ x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
335
+ return dec_out # [B, L, D]
336
+ if self.task_name == 'anomaly_detection':
337
+ result = self.anomaly_detection(x_enc)
338
+ return result # [B, L, D] veya [B, L, D], [B, 2], [B, 1]
339
+ if self.task_name == 'classification':
340
+ dec_out = self.classification(x_enc, x_mark_enc)
341
+ return dec_out # [B, N]
342
+ return None
343
+
344
+ def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
345
+ # Normalization from Non-stationary Transformer
346
+ means = x_enc.mean(1, keepdim=True).detach()
347
+ x_enc = x_enc - means
348
+ stdev = torch.sqrt(
349
+ torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
350
+ x_enc /= stdev
351
+
352
+ # embedding
353
+ enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
354
+ enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(
355
+ 0, 2, 1) # align temporal dimension
356
+ # TimesNet
357
+ for i in range(self.layer):
358
+ enc_out = self.layer_norm(self.model[i](enc_out))
359
+ # porject back
360
+ dec_out = self.projection(enc_out)
361
+
362
+ # De-Normalization from Non-stationary Transformer
363
+ dec_out = dec_out * \
364
+ (stdev[:, 0, :].unsqueeze(1).repeat(
365
+ 1, self.pred_len + self.seq_len, 1))
366
+ dec_out = dec_out + \
367
+ (means[:, 0, :].unsqueeze(1).repeat(
368
+ 1, self.pred_len + self.seq_len, 1))
369
+ return dec_out
370
+
371
+ def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
372
+ # Normalization from Non-stationary Transformer
373
+ means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
374
+ means = means.unsqueeze(1).detach()
375
+ x_enc = x_enc - means
376
+ x_enc = x_enc.masked_fill(mask == 0, 0)
377
+ stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
378
+ torch.sum(mask == 1, dim=1) + 1e-5)
379
+ stdev = stdev.unsqueeze(1).detach()
380
+ x_enc /= stdev
381
+
382
+ # embedding
383
+ enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
384
+ # TimesNet
385
+ for i in range(self.layer):
386
+ enc_out = self.layer_norm(self.model[i](enc_out))
387
+ # porject back
388
+ dec_out = self.projection(enc_out)
389
+
390
+ # De-Normalization from Non-stationary Transformer
391
+ dec_out = dec_out * \
392
+ (stdev[:, 0, :].unsqueeze(1).repeat(
393
+ 1, self.pred_len + self.seq_len, 1))
394
+ dec_out = dec_out + \
395
+ (means[:, 0, :].unsqueeze(1).repeat(
396
+ 1, self.pred_len + self.seq_len, 1))
397
+ return dec_out
398
+
399
+
400
+
401
+ def classification(self, x_enc, x_mark_enc):
402
+ # embedding
403
+ enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
404
+ # TimesNet
405
+ for i in range(self.layer):
406
+ enc_out = self.layer_norm(self.model[i](enc_out))
407
+
408
+ # Output
409
+ # the output transformer encoder/decoder embeddings don't include non-linearity
410
+ output = self.act(enc_out)
411
+ output = self.dropout(output)
412
+ # zero-out padding embeddings
413
+ output = output * x_mark_enc.unsqueeze(-1)
414
+ # (batch_size, seq_length * d_model)
415
+ output = output.reshape(output.shape[0], -1)
416
+ output = self.projection(output) # (batch_size, num_classes)
417
+ return output
418
+
TimesNet_PointCloud.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ try:
7
+ from .TimesNet import DataEmbedding
8
+ except Exception:
9
+ from TimesNet import DataEmbedding
10
+
11
+
12
+ class _BlockConfig:
13
+ def __init__(self, seq_len: int, pred_len: int, d_model: int, d_ff: int, num_kernels: int, top_k: int = 2, num_stations: int = 0):
14
+ self.seq_len = seq_len
15
+ self.pred_len = pred_len
16
+ self.d_model = d_model
17
+ self.d_ff = d_ff
18
+ self.num_kernels = num_kernels
19
+ self.top_k = top_k
20
+ self.num_stations = num_stations
21
+
22
+
23
+ class Inception_Block_V1(nn.Module):
24
+ def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
25
+ super(Inception_Block_V1, self).__init__()
26
+ self.in_channels = in_channels
27
+ self.out_channels = out_channels
28
+ self.num_kernels = num_kernels
29
+ kernels = []
30
+ for i in range(self.num_kernels):
31
+ kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
32
+ self.kernels = nn.ModuleList(kernels)
33
+ if init_weight:
34
+ self._initialize_weights()
35
+
36
+ def _initialize_weights(self):
37
+ for m in self.modules():
38
+ if isinstance(m, nn.Conv2d):
39
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
40
+ if m.bias is not None:
41
+ nn.init.constant_(m.bias, 0)
42
+
43
+ def forward(self, x):
44
+ res_list = []
45
+ for i, kernel in enumerate(self.kernels):
46
+ res_list.append(kernel(x))
47
+ res = torch.stack(res_list, dim=-1).mean(-1)
48
+ return res
49
+
50
+
51
+ def FFT_for_Period(x, k=2):
52
+ # [B, T, C]
53
+ xf = torch.fft.rfft(x, dim=1)
54
+ # find period by amplitudes
55
+ frequency_list = abs(xf).mean(0).mean(-1)
56
+ frequency_list[0] = 0
57
+ _, top_list = torch.topk(frequency_list, k)
58
+ top_list = top_list.detach().cpu().numpy()
59
+ period = x.shape[1] // top_list
60
+ return period, abs(xf).mean(-1)[:, top_list]
61
+
62
+
63
+ class TimesBlockStationCond(nn.Module):
64
+ """TimesBlock with station ID conditioning (one-hot encoded as 1 channel)."""
65
+ def __init__(self, configs):
66
+ super(TimesBlockStationCond, self).__init__()
67
+ self.seq_len = configs.seq_len
68
+ self.pred_len = configs.pred_len
69
+ self.k = configs.top_k
70
+ self.num_stations = getattr(configs, 'num_stations', 0)
71
+
72
+ # Station ID embedding: maps station ID to d_model dimension
73
+ # This provides richer conditioning information than a single scalar
74
+ if self.num_stations > 0:
75
+ self.station_embedding = nn.Embedding(self.num_stations, configs.d_model)
76
+ # Initialize with small random values
77
+ nn.init.normal_(self.station_embedding.weight, mean=0.0, std=0.02)
78
+
79
+ # Inception blocks
80
+ self.conv = nn.Sequential(
81
+ Inception_Block_V1(configs.d_model, configs.d_ff,
82
+ num_kernels=configs.num_kernels),
83
+ nn.GELU(),
84
+ Inception_Block_V1(configs.d_ff, configs.d_model,
85
+ num_kernels=configs.num_kernels)
86
+ )
87
+
88
+ def forward(self, x, station_ids: torch.Tensor = None):
89
+ """
90
+ Args:
91
+ x: (B, T, N) input features
92
+ station_ids: (B,) LongTensor of station IDs (0 to num_stations-1)
93
+ """
94
+ B, T, N = x.size()
95
+ period_list, period_weight = FFT_for_Period(x, self.k)
96
+
97
+ res = []
98
+ for i in range(self.k):
99
+ period = period_list[i]
100
+ # padding
101
+ if (self.seq_len + self.pred_len) % period != 0:
102
+ length = (((self.seq_len + self.pred_len) // period) + 1) * period
103
+ padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
104
+ out = torch.cat([x, padding], dim=1)
105
+ else:
106
+ length = (self.seq_len + self.pred_len)
107
+ out = x
108
+
109
+ # reshape to 2D: (B, N, H, W)
110
+ out = out.reshape(B, length // period, period, N).permute(0, 3, 1, 2).contiguous()
111
+
112
+ # Inject station ID conditioning via embedding addition
113
+ # This provides richer conditioning (d_model dimensions) compared to single scalar
114
+ if station_ids is not None and self.num_stations > 0:
115
+ # Get station embeddings: (B, d_model)
116
+ station_ids_flat = station_ids.view(B)
117
+ station_emb = self.station_embedding(station_ids_flat) # (B, d_model)
118
+
119
+ # out shape: (B, d_model, H, W)
120
+ # Expand station embedding to spatial dimensions: (B, d_model, H, W)
121
+ H = out.size(2)
122
+ W = out.size(3)
123
+ station_emb_spatial = station_emb.view(B, N, 1, 1).expand(-1, -1, H, W)
124
+
125
+ # Add station embedding to features (element-wise addition)
126
+ # This allows the model to learn station-specific feature modifications
127
+ out = out + station_emb_spatial
128
+
129
+ # 2D conv: from 1d Variation to 2d Variation
130
+ out = self.conv(out)
131
+
132
+ # reshape back
133
+ out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
134
+ res.append(out[:, :(self.seq_len + self.pred_len), :])
135
+
136
+ res = torch.stack(res, dim=-1)
137
+ # adaptive aggregation
138
+ period_weight = F.softmax(period_weight, dim=1)
139
+ period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1)
140
+ res = torch.sum(res * period_weight, -1)
141
+
142
+ # residual connection
143
+ res = res + x
144
+ return res
145
+
146
+
147
+ class TimesNetPointCloud(nn.Module):
148
+ """TimesNet reconstruction with exposed encode/project methods for point-cloud mixing."""
149
+ def __init__(self, configs):
150
+ super().__init__()
151
+ self.configs = configs
152
+ self.seq_len = configs.seq_len
153
+ self.pred_len = getattr(configs, 'pred_len', 0)
154
+ self.top_k = configs.top_k
155
+ self.d_model = configs.d_model
156
+ self.d_ff = configs.d_ff
157
+ self.num_kernels = configs.num_kernels
158
+ self.e_layers = configs.e_layers
159
+ self.dropout = configs.dropout
160
+ self.c_out = configs.c_out
161
+
162
+ self.num_stations = getattr(configs, 'num_stations', 0)
163
+
164
+ self.enc_embedding = DataEmbedding(configs.enc_in, self.d_model, configs.embed, configs.freq,
165
+ configs.dropout, configs.seq_len)
166
+ self.model = nn.ModuleList([
167
+ TimesBlockStationCond(_BlockConfig(self.seq_len, 0, self.d_model, self.d_ff,
168
+ self.num_kernels, self.top_k, self.num_stations))
169
+ for _ in range(self.e_layers)
170
+ ])
171
+ self.layer = self.e_layers
172
+ self.layer_norm = nn.LayerNorm(self.d_model)
173
+ self.projection = nn.Linear(self.d_model, self.c_out, bias=True)
174
+
175
+ def encode_features_for_reconstruction(self, x_enc: torch.Tensor, station_ids: torch.Tensor = None):
176
+ """
177
+ Encode input with optional station ID conditioning.
178
+
179
+ Args:
180
+ x_enc: (B, T, C) input signal
181
+ station_ids: (B,) LongTensor of station IDs (0 to num_stations-1), optional
182
+ """
183
+ means = x_enc.mean(1, keepdim=True).detach()
184
+ x_norm = x_enc - means
185
+ stdev = torch.sqrt(torch.var(x_norm, dim=1, keepdim=True, unbiased=False) + 1e-5)
186
+ x_norm = x_norm / stdev
187
+ enc_out = self.enc_embedding(x_norm, None)
188
+ for i in range(self.layer):
189
+ enc_out = self.layer_norm(self.model[i](enc_out, station_ids))
190
+ return enc_out, means, stdev
191
+
192
+ def project_features_for_reconstruction(self, enc_out: torch.Tensor, means: torch.Tensor, stdev: torch.Tensor):
193
+ dec_out = self.projection(enc_out)
194
+ dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1))
195
+ dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1))
196
+ return dec_out
197
+
198
+ def anomaly_detection(self, x_enc: torch.Tensor, station_ids: torch.Tensor = None):
199
+ """Full reconstruction pass with optional station ID conditioning."""
200
+ enc_out, means, stdev = self.encode_features_for_reconstruction(x_enc, station_ids)
201
+ return self.project_features_for_reconstruction(enc_out, means, stdev)
202
+
203
+ def forward(self, x_enc, station_ids=None, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
204
+ """
205
+ Forward pass compatible with anomaly_detection task.
206
+
207
+ Args:
208
+ x_enc: (B, T, C) input signal
209
+ station_ids: (B,) LongTensor of station IDs, optional
210
+ """
211
+ return self.anomaly_detection(x_enc, station_ids)
212
+
213
+
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib
4
+ matplotlib.use('Agg') # Use non-interactive backend
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import io
8
+ import os
9
+ import torch
10
+
11
+ # Import model and generation functions
12
+ try:
13
+ from TimesNet_PointCloud import TimesNetPointCloud
14
+ from generate_samples_git import SimpleArgs, generate_samples_from_latent_bank
15
+ except ImportError:
16
+ # Fallback for local imports
17
+ import sys
18
+ sys.path.insert(0, '.')
19
+ from TimesNet_PointCloud import TimesNetPointCloud
20
+ try:
21
+ from generate_samples_git import SimpleArgs, generate_samples_from_latent_bank
22
+ except:
23
+ # If generate_samples_git doesn't exist, use generate_samples
24
+ from generate_samples import FineTuneArgs as SimpleArgs, generate_samples_from_latent_bank
25
+
26
+ def load_model(checkpoint_path, args):
27
+ """Load pre-trained TimesNet-PointCloud model (matching generate_samples_git.py)."""
28
+ # Create model config
29
+ class ModelConfig:
30
+ def __init__(self, args):
31
+ self.seq_len = args.seq_len
32
+ self.pred_len = 0
33
+ self.enc_in = 3
34
+ self.c_out = 3
35
+ self.d_model = args.d_model
36
+ self.d_ff = args.d_ff
37
+ self.num_kernels = args.num_kernels
38
+ self.top_k = args.top_k
39
+ self.e_layers = args.e_layers
40
+ self.d_layers = args.d_layers
41
+ self.dropout = args.dropout
42
+ self.embed = 'timeF'
43
+ self.freq = 'h'
44
+ self.latent_dim = getattr(args, 'latent_dim', 256)
45
+ # num_stations is needed for station conditioning
46
+ # Try to get from checkpoint, or use default
47
+ self.num_stations = getattr(args, 'num_stations', 705)
48
+
49
+ config = ModelConfig(args)
50
+ model = TimesNetPointCloud(config)
51
+
52
+ # Load checkpoint
53
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
54
+ if 'model_state_dict' in checkpoint:
55
+ model.load_state_dict(checkpoint['model_state_dict'])
56
+ # Try to get num_stations from checkpoint
57
+ if 'num_stations' in checkpoint:
58
+ config.num_stations = checkpoint['num_stations']
59
+ else:
60
+ model.load_state_dict(checkpoint)
61
+
62
+ model.eval()
63
+ if args.use_gpu:
64
+ model = model.cuda()
65
+
66
+ print(f"[INFO] Model loaded successfully from {checkpoint_path}")
67
+ return model
68
+
69
+ # Configuration - can be set via environment variables for HF Space
70
+ PHASE1_MODEL_CHECKPOINT_PATH = os.getenv('PHASE1_MODEL_CHECKPOINT_PATH', './checkpoints/timesnet_pointcloud_phase1_final.pth')
71
+ LATENT_BANK_PATH = os.getenv('LATENT_BANK_PATH', './latent_bank_station_cond.npz')
72
+ ENCODER_STD_PATH = os.getenv('ENCODER_STD_PATH', './pcgen_stats/encoder_feature_std.npy')
73
+
74
+ # Test stations (5 unseen stations)
75
+ TEST_STATIONS = ['0205', '1716', '2020', '3130', '4628']
76
+
77
+ # Initialize model and args (loaded once at startup)
78
+ model = None
79
+ args = None
80
+ encoder_std = None
81
+
82
+ def initialize_model():
83
+ """Load Phase 1 model and encoder_std once at startup."""
84
+ global model, args, encoder_std
85
+
86
+ if model is not None:
87
+ return True # Already initialized
88
+
89
+ print("[INFO] Initializing TimesNet-Gen Phase 1 model...")
90
+
91
+ # Create args (matching generate_samples_git.py)
92
+ args = SimpleArgs()
93
+
94
+ # Load Phase 1 model
95
+ if not os.path.exists(PHASE1_MODEL_CHECKPOINT_PATH):
96
+ print(f"[ERROR] Phase 1 model checkpoint not found: {PHASE1_MODEL_CHECKPOINT_PATH}")
97
+ print("[ERROR] Please set PHASE1_MODEL_CHECKPOINT_PATH environment variable")
98
+ return False
99
+
100
+ model = load_model(PHASE1_MODEL_CHECKPOINT_PATH, args)
101
+
102
+ # Check if latent bank exists
103
+ if not os.path.exists(LATENT_BANK_PATH):
104
+ print(f"[ERROR] Latent bank not found: {LATENT_BANK_PATH}")
105
+ print("[ERROR] Please set LATENT_BANK_PATH environment variable or create latent bank first")
106
+ return False
107
+
108
+ print(f"[INFO] Using latent bank: {LATENT_BANK_PATH}")
109
+
110
+ # Load encoder_std (optional, only for fine-tuning, not for generation)
111
+ if os.path.exists(ENCODER_STD_PATH):
112
+ encoder_std = np.load(ENCODER_STD_PATH)
113
+ print(f"[INFO] Loaded encoder_std from {ENCODER_STD_PATH} (shape: {encoder_std.shape})")
114
+ print(f"[INFO] encoder_std loaded (used only for fine-tuning, NOT for generation)")
115
+ else:
116
+ print(f"[INFO] No encoder_std found (not needed for generation, only for fine-tuning)")
117
+ encoder_std = None
118
+
119
+ print("[INFO] ✓ Phase 1 model initialized successfully!")
120
+ return True
121
+
122
+ # Initialize on import
123
+ try:
124
+ initialize_model()
125
+ except Exception as e:
126
+ print(f"[ERROR] Failed to initialize model: {e}")
127
+ import traceback
128
+ traceback.print_exc()
129
+ print("[WARN] App will run in dummy mode")
130
+
131
+ def generate_seismic_data(station_id_str, num_samples):
132
+ """Generate seismic signals using Phase 1 model and pre-computed latent bank."""
133
+ global model, args, encoder_std
134
+
135
+ # Check if model is loaded
136
+ if model is None:
137
+ print("[ERROR] Model not initialized! Attempting to initialize...")
138
+ if not initialize_model():
139
+ # Fallback to dummy generation
140
+ print("[WARN] Using dummy generation as fallback")
141
+ generated_signals = np.random.randn(num_samples, 3, 6000) * 0.1 + np.sin(np.linspace(0, 100, 6000))
142
+ else:
143
+ # Retry generation with real model
144
+ return generate_seismic_data(station_id_str, num_samples)
145
+
146
+ # Generate samples from pre-computed latent bank (matching generate_samples_git.py)
147
+ try:
148
+ print(f"[INFO] Generating {num_samples} samples for station {station_id_str} from latent bank...")
149
+
150
+ # Note: encoder_std is passed but NOT used during generation in generate_samples_git.py
151
+ # (see line 313-317 in generate_samples_git.py - noise is NOT added during generation)
152
+ generated_signals, _ = generate_samples_from_latent_bank(
153
+ model, LATENT_BANK_PATH, station_id_str, num_samples, args, encoder_std
154
+ )
155
+
156
+ if generated_signals is None:
157
+ print(f"[ERROR] Failed to generate samples for station {station_id_str}")
158
+ generated_signals = None
159
+ else:
160
+ print(f"[INFO] ✓ Generated {len(generated_signals)} samples successfully")
161
+ except Exception as e:
162
+ print(f"[ERROR] Generation failed: {e}")
163
+ import traceback
164
+ traceback.print_exc()
165
+ generated_signals = None
166
+
167
+ # Handle case where generation failed
168
+ if generated_signals is None:
169
+ # Create error message plot
170
+ fig, ax = plt.subplots(1, 1, figsize=(8, 4))
171
+ ax.text(0.5, 0.5, f'Error: Could not generate samples\nfor station {station_id_str}',
172
+ ha='center', va='center', fontsize=14, transform=ax.transAxes)
173
+ ax.set_xticks([])
174
+ ax.set_yticks([])
175
+ plt.tight_layout()
176
+ else:
177
+ # generated_signals shape: (num_samples, 3, 6000)
178
+ num_plots = min(len(generated_signals), 5) # Show up to 5 samples
179
+
180
+ if num_plots == 1:
181
+ # Single plot
182
+ fig, axes = plt.subplots(3, 1, figsize=(10, 6), sharex=True)
183
+ channel_names = ['E-W', 'N-S', 'U-D']
184
+ for ch, ax in enumerate(axes):
185
+ ax.plot(generated_signals[0, ch, :], linewidth=0.8)
186
+ ax.set_ylabel(channel_names[ch], fontweight='bold')
187
+ ax.grid(True, alpha=0.3)
188
+ axes[-1].set_xlabel('Time Steps', fontweight='bold')
189
+ fig.suptitle(f'Generated Sample for Station {station_id_str}', fontsize=12, fontweight='bold')
190
+ plt.tight_layout()
191
+ else:
192
+ # Multiple plots in a grid
193
+ fig, axes = plt.subplots(num_plots, 3, figsize=(12, 2*num_plots), sharex=True)
194
+ channel_names = ['E-W', 'N-S', 'U-D']
195
+
196
+ for i in range(num_plots):
197
+ for ch in range(3):
198
+ ax = axes[i, ch] if num_plots > 1 else axes[ch]
199
+ ax.plot(generated_signals[i, ch, :], linewidth=0.8)
200
+ if i == 0:
201
+ ax.set_title(channel_names[ch], fontweight='bold')
202
+ if i == num_plots - 1:
203
+ ax.set_xlabel('Time Steps', fontweight='bold')
204
+ ax.set_ylabel('Amplitude', fontsize=9)
205
+ ax.grid(True, alpha=0.3)
206
+
207
+ fig.suptitle(f'Generated Samples for Station {station_id_str}', fontsize=12, fontweight='bold')
208
+ plt.tight_layout()
209
+
210
+ # Save to BytesIO buffer and convert to PIL Image, then numpy array
211
+ buf = io.BytesIO()
212
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
213
+ plt.close(fig)
214
+ buf.seek(0)
215
+
216
+ # Convert to PIL Image then numpy array (RGB format)
217
+ img = Image.open(buf)
218
+ img_array = np.array(img)
219
+ buf.close()
220
+
221
+ return img_array
222
+
223
+ # Gradio Interface
224
+ demo = gr.Interface(
225
+ fn=generate_seismic_data,
226
+ inputs=[
227
+ gr.Dropdown(choices=TEST_STATIONS, label="Station ID", value=TEST_STATIONS[0], info="Select one of the 5 test stations"),
228
+ gr.Slider(minimum=1, maximum=50, value=3, step=1, label="Number of Samples to Generate")
229
+ ],
230
+ outputs=gr.Image(label="Generated Seismic Signals", type="numpy"),
231
+ title="TimesNet-Gen: Site-Specific Strong Motion Generation",
232
+ description="Generate synthetic seismic signals using Phase 1 model and pre-computed latent bank. Select a station ID and number of samples to generate. (Matching GitHub generate_samples_git.py workflow)"
233
+ )
234
+
235
+ if __name__ == "__main__":
236
+ demo.launch()
generate_samples_git.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simplified inference script for TimesNet-Gen.
4
+ Only loads data for the 5 fine-tuned stations.
5
+
6
+ Usage:
7
+ python generate_samples.py --num_samples 50
8
+ """
9
+ import os
10
+ import argparse
11
+ import torch
12
+ import numpy as np
13
+ from datetime import datetime
14
+ import matplotlib.pyplot as plt
15
+ import glob
16
+ import scipy.io as sio
17
+
18
+
19
+ class SimpleArgs:
20
+ """Configuration for generation."""
21
+ def __init__(self):
22
+ # Model architecture
23
+ self.seq_len = 6000
24
+ self.d_model = 128
25
+ self.d_ff = 256
26
+ self.e_layers = 2
27
+ self.d_layers = 2
28
+ self.num_kernels = 6
29
+ self.top_k = 2
30
+ self.dropout = 0.1
31
+ self.latent_dim = 256
32
+
33
+ # System
34
+ self.use_gpu = torch.cuda.is_available()
35
+ self.seed = 0
36
+
37
+ # Point-cloud generation
38
+ self.pcgen_k = 5
39
+ self.pcgen_jitter_std = 0.0
40
+
41
+
42
+ def _iter_np_arrays(obj):
43
+ """Recursively iterate through numpy arrays in nested structures."""
44
+ if isinstance(obj, np.ndarray):
45
+ if obj.dtype == object:
46
+ for item in obj.flat:
47
+ yield from _iter_np_arrays(item)
48
+ else:
49
+ yield obj
50
+ elif isinstance(obj, dict):
51
+ for v in obj.values():
52
+ yield from _iter_np_arrays(v)
53
+ elif isinstance(obj, np.void):
54
+ if obj.dtype.names:
55
+ for name in obj.dtype.names:
56
+ yield from _iter_np_arrays(obj[name])
57
+
58
+
59
+ def _find_3ch_from_arrays(arrays):
60
+ """Find 3-channel array from list of arrays."""
61
+ # Prefer arrays that are 2D with a 3-channel dimension
62
+ for arr in arrays:
63
+ if isinstance(arr, np.ndarray) and arr.ndim == 2 and (arr.shape[0] == 3 or arr.shape[1] == 3):
64
+ return arr
65
+ # Otherwise, try to find three 1D arrays of same length
66
+ one_d = [a for a in arrays if isinstance(a, np.ndarray) and a.ndim == 1]
67
+ for i in range(len(one_d)):
68
+ for j in range(i + 1, len(one_d)):
69
+ for k in range(j + 1, len(one_d)):
70
+ if one_d[i].shape[0] == one_d[j].shape[0] == one_d[k].shape[0]:
71
+ return np.stack([one_d[i], one_d[j], one_d[k]], axis=0)
72
+ return None
73
+
74
+
75
+ def load_mat_file(filepath, seq_len=6000, debug=False):
76
+ """Load and preprocess a .mat file (using data_loader_gen.py logic)."""
77
+ try:
78
+ if debug:
79
+ print(f"\n[DEBUG] Loading: {os.path.basename(filepath)}")
80
+
81
+ # Load with squeeze_me and struct_as_record like data_loader_gen.py
82
+ mat = sio.loadmat(filepath, squeeze_me=True, struct_as_record=False)
83
+
84
+ if debug:
85
+ print(f"[DEBUG] Keys in mat file: {[k for k in mat.keys() if not k.startswith('__')]}")
86
+
87
+ # Check if 'EQ' is a struct with nested 'anEQ' structure (like in data_loader_gen.py)
88
+ if 'EQ' in mat:
89
+ try:
90
+ eq_obj = mat['EQ']
91
+
92
+ if debug:
93
+ print(f"[DEBUG] EQ type: {type(eq_obj)}")
94
+ print(f"[DEBUG] EQ shape: {eq_obj.shape if hasattr(eq_obj, 'shape') else 'N/A'}")
95
+
96
+ # Since struct_as_record=False, EQ is a mat_struct object
97
+ # Access with attributes, not subscripts
98
+ if hasattr(eq_obj, 'anEQ'):
99
+ dataset = eq_obj.anEQ
100
+ if debug:
101
+ print(f"[DEBUG] Found anEQ, type: {type(dataset)}")
102
+
103
+ if hasattr(dataset, 'Accel'):
104
+ accel = dataset.Accel
105
+
106
+ if debug:
107
+ print(f"[DEBUG] Found Accel: type={type(accel)}, shape={accel.shape if hasattr(accel, 'shape') else 'N/A'}")
108
+
109
+ if isinstance(accel, np.ndarray):
110
+ # Transpose to (3, N) if needed
111
+ if accel.ndim == 2:
112
+ if accel.shape[1] == 3:
113
+ accel = accel.T
114
+
115
+ if accel.shape[0] == 3:
116
+ data = accel
117
+ if debug:
118
+ print(f"[DEBUG] ✅ Successfully extracted 3-channel data! Shape: {data.shape}")
119
+
120
+ # Resample if needed
121
+ if data.shape[1] != seq_len:
122
+ from scipy import signal as sp_signal
123
+ data_resampled = np.zeros((3, seq_len), dtype=np.float32)
124
+ for i in range(3):
125
+ data_resampled[i] = sp_signal.resample(data[i], seq_len)
126
+ data = data_resampled
127
+ if debug:
128
+ print(f"[DEBUG] Resampled to {seq_len} samples")
129
+
130
+ return torch.FloatTensor(data)
131
+ else:
132
+ if debug:
133
+ print(f"[DEBUG] Unexpected Accel shape[0]: {accel.shape[0]} (expected 3)")
134
+ else:
135
+ if debug:
136
+ print(f"[DEBUG] Accel is not 2D: ndim={accel.ndim}")
137
+ else:
138
+ if debug:
139
+ print(f"[DEBUG] anEQ has no 'Accel' attribute")
140
+ if hasattr(dataset, '__dict__'):
141
+ print(f"[DEBUG] anEQ attributes: {list(vars(dataset).keys())}")
142
+ else:
143
+ if debug:
144
+ print(f"[DEBUG] EQ has no 'anEQ' attribute")
145
+ if hasattr(eq_obj, '__dict__'):
146
+ print(f"[DEBUG] EQ attributes: {list(vars(eq_obj).keys())}")
147
+
148
+ except Exception as e:
149
+ if debug:
150
+ import traceback
151
+ print(f"[DEBUG] Could not parse EQ structure: {e}")
152
+ print(f"[DEBUG] Traceback: {traceback.format_exc()}")
153
+
154
+ arrays = list(_iter_np_arrays(mat))
155
+
156
+ if debug:
157
+ print(f"[DEBUG] Found {len(arrays)} arrays")
158
+ for i, arr in enumerate(arrays[:5]): # Show first 5
159
+ if isinstance(arr, np.ndarray):
160
+ print(f"[DEBUG] Array {i}: shape={arr.shape}, dtype={arr.dtype}")
161
+
162
+ # Common direct keys first
163
+ for key in ['signal', 'data', 'sig', 'x', 'X', 'signal3c', 'acc', 'NS', 'EW', 'UD']:
164
+ if key in mat and isinstance(mat[key], np.ndarray):
165
+ arrays.insert(0, mat[key])
166
+ if debug:
167
+ print(f"[DEBUG] Found key '{key}': shape={mat[key].shape}")
168
+
169
+ # Find 3-channel array
170
+ data = _find_3ch_from_arrays(arrays)
171
+
172
+ if data is None:
173
+ if debug:
174
+ print(f"[DEBUG] Could not find 3-channel array!")
175
+ return None
176
+
177
+ if debug:
178
+ print(f"[DEBUG] Found 3-channel data: shape={data.shape}")
179
+
180
+ # Ensure shape is (3, N)
181
+ if data.shape[0] != 3 and data.shape[1] == 3:
182
+ data = data.T
183
+ if debug:
184
+ print(f"[DEBUG] Transposed to: shape={data.shape}")
185
+
186
+ if data.shape[0] != 3:
187
+ if debug:
188
+ print(f"[DEBUG] Wrong number of channels: {data.shape[0]}")
189
+ return None
190
+
191
+ # Resample to seq_len
192
+ if data.shape[1] != seq_len:
193
+ from scipy import signal as sp_signal
194
+ data_resampled = np.zeros((3, seq_len), dtype=np.float32)
195
+ for i in range(3):
196
+ data_resampled[i] = sp_signal.resample(data[i], seq_len)
197
+ data = data_resampled
198
+ if debug:
199
+ print(f"[DEBUG] Resampled to: shape={data.shape}")
200
+
201
+ if debug:
202
+ print(f"[DEBUG] ✅ Successfully loaded!")
203
+
204
+ return torch.FloatTensor(data)
205
+
206
+ except Exception as e:
207
+ if debug:
208
+ print(f"[DEBUG] ❌ Exception: {e}")
209
+ return None
210
+
211
+
212
+ def load_model(checkpoint_path, args):
213
+ """Load pre-trained TimesNet-PointCloud model."""
214
+ from models.TimesNet_PointCloud import TimesNetPointCloud
215
+
216
+ # Create model config
217
+ class ModelConfig:
218
+ def __init__(self, args):
219
+ self.seq_len = args.seq_len
220
+ self.pred_len = 0
221
+ self.enc_in = 3
222
+ self.c_out = 3
223
+ self.d_model = args.d_model
224
+ self.d_ff = args.d_ff
225
+ self.num_kernels = args.num_kernels
226
+ self.top_k = args.top_k
227
+ self.e_layers = args.e_layers
228
+ self.d_layers = args.d_layers
229
+ self.dropout = args.dropout
230
+ self.embed = 'timeF'
231
+ self.freq = 'h'
232
+ self.latent_dim = args.latent_dim
233
+
234
+ config = ModelConfig(args)
235
+ model = TimesNetPointCloud(config)
236
+
237
+ # Load checkpoint
238
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
239
+ if 'model_state_dict' in checkpoint:
240
+ model.load_state_dict(checkpoint['model_state_dict'])
241
+ else:
242
+ model.load_state_dict(checkpoint)
243
+
244
+ model.eval()
245
+ if args.use_gpu:
246
+ model = model.cuda()
247
+
248
+ print(f"[INFO] Model loaded successfully from {checkpoint_path}")
249
+ return model
250
+
251
+
252
+ def generate_samples_from_latent_bank(model, latent_bank_path, station_id, num_samples, args, encoder_std=None):
253
+ """
254
+ Generate samples directly from pre-computed latent bank.
255
+ NO REAL DATA NEEDED!
256
+
257
+ Args:
258
+ model: TimesNet model
259
+ latent_bank_path: Path to latent_bank_phase1.npz
260
+ station_id: Station ID (e.g., '0205')
261
+ num_samples: Number of samples to generate
262
+ args: Model arguments
263
+ encoder_std: Encoder std vector for noise injection
264
+
265
+ Returns:
266
+ generated_signals: (num_samples, 3, seq_len) array
267
+ real_names_used: List of lists indicating which latent vectors were used
268
+ """
269
+ print(f"[INFO] Loading latent bank from {latent_bank_path}...")
270
+
271
+ try:
272
+ latent_data = np.load(latent_bank_path)
273
+ except Exception as e:
274
+ print(f"[ERROR] Could not load latent bank: {e}")
275
+ return None, None
276
+
277
+ # Load latent vectors for this station
278
+ latents_key = f'latents_{station_id}'
279
+ means_key = f'means_{station_id}'
280
+ stdev_key = f'stdev_{station_id}'
281
+
282
+ if latents_key not in latent_data:
283
+ print(f"[ERROR] Station {station_id} not found in latent bank!")
284
+ print(f"Available stations: {[k.replace('latents_', '') for k in latent_data.keys() if k.startswith('latents_')]}")
285
+ return None, None
286
+
287
+ latents = latent_data[latents_key] # (N_samples, seq_len, d_model)
288
+ means = latent_data[means_key] # (N_samples, seq_len, d_model)
289
+ stdevs = latent_data[stdev_key] # (N_samples, seq_len, d_model)
290
+
291
+ print(f"[INFO] Loaded {len(latents)} latent vectors for station {station_id}")
292
+ print(f"[INFO] Generating {num_samples} samples via bootstrap aggregation...")
293
+
294
+ generated_signals = []
295
+ real_names_used = []
296
+
297
+ model.eval()
298
+ with torch.no_grad():
299
+ for i in range(num_samples):
300
+ # Bootstrap: randomly select k latent vectors with replacement
301
+ k = min(args.pcgen_k, len(latents))
302
+ selected_indices = np.random.choice(len(latents), size=k, replace=True)
303
+
304
+ # Mix latent features (average)
305
+ selected_latents = latents[selected_indices] # (k, seq_len, d_model)
306
+ selected_means = means[selected_indices] # (k, seq_len, d_model)
307
+ selected_stdevs = stdevs[selected_indices] # (k, seq_len, d_model)
308
+
309
+ mixed_features = np.mean(selected_latents, axis=0) # (seq_len, d_model)
310
+ mixed_means = np.mean(selected_means, axis=0) # (seq_len, d_model)
311
+ mixed_stdevs = np.mean(selected_stdevs, axis=0) # (seq_len, d_model)
312
+
313
+ # NOTE: Do NOT add noise during generation (matching untitled1_gen.py)
314
+ # untitled1_gen.py only uses noise during TRAINING (Phase 1), not during generation
315
+ # if encoder_std is not None:
316
+ # noise = np.random.randn(*mixed_features.shape) * encoder_std
317
+ # mixed_features = mixed_features + noise
318
+
319
+ # Convert to torch tensors
320
+ mixed_features_torch = torch.from_numpy(mixed_features).float().unsqueeze(0) # (1, seq_len, d_model)
321
+ means_b = torch.from_numpy(mixed_means).float().unsqueeze(0) # (1, seq_len, d_model)
322
+ stdev_b = torch.from_numpy(mixed_stdevs).float().unsqueeze(0) # (1, seq_len, d_model)
323
+
324
+ if args.use_gpu:
325
+ mixed_features_torch = mixed_features_torch.cuda()
326
+ means_b = means_b.cuda()
327
+ stdev_b = stdev_b.cuda()
328
+
329
+ # Decode
330
+ xg = model.project_features_for_reconstruction(mixed_features_torch, means_b, stdev_b)
331
+
332
+ # Store - transpose to (3, 6000)
333
+ generated_np = xg.squeeze(0).cpu().numpy().T # (6000, 3) → (3, 6000)
334
+ generated_signals.append(generated_np)
335
+
336
+ # Track which latent indices were used
337
+ real_names_used.append([f"latent_{idx}" for idx in selected_indices])
338
+
339
+ if (i + 1) % 10 == 0:
340
+ print(f" Generated {i + 1}/{num_samples} samples...")
341
+
342
+ return np.array(generated_signals), real_names_used
343
+
344
+
345
+ def _preprocess_component_boore(data: np.ndarray, fs: float, corner_freq: float, filter_order: int = 2) -> np.ndarray:
346
+ """Boore (2005) style preprocessing: detrend (linear), zero-padding, high-pass Butterworth (zero-phase)."""
347
+ from scipy.signal import butter, filtfilt
348
+ x = np.asarray(data, dtype=np.float64)
349
+ n = x.shape[0]
350
+ # Linear detrend
351
+ t = np.arange(n, dtype=np.float64)
352
+ t_mean = t.mean()
353
+ x_mean = x.mean()
354
+ denom = np.sum((t - t_mean) ** 2)
355
+ slope = 0.0 if denom == 0 else float(np.sum((t - t_mean) * (x - x_mean)) / denom)
356
+ intercept = float(x_mean - slope * t_mean)
357
+ x_detr = x - (slope * t + intercept)
358
+ # Zero-padding
359
+ Tzpad = (1.5 * filter_order) / max(corner_freq, 1e-6)
360
+ pad_samples = int(round(Tzpad * fs))
361
+ x_pad = np.concatenate([np.zeros(pad_samples, dtype=np.float64), x_detr, np.zeros(pad_samples, dtype=np.float64)])
362
+ # High-pass filter (zero-phase)
363
+ normalized = corner_freq / (fs / 2.0)
364
+ normalized = min(max(normalized, 1e-6), 0.999999)
365
+ b, a = butter(filter_order, normalized, btype='high')
366
+ y = filtfilt(b, a, x_pad)
367
+ return y
368
+
369
+
370
+ def _konno_ohmachi_smoothing(spectrum: np.ndarray, freq: np.ndarray, b: float = 40.0) -> np.ndarray:
371
+ """Konno-Ohmachi smoothing as in MATLAB reference (O(n^2))."""
372
+ f = np.asarray(freq, dtype=np.float64).reshape(-1)
373
+ s = np.asarray(spectrum, dtype=np.float64).reshape(-1)
374
+ f = np.where(f == 0.0, 1e-12, f)
375
+ n = f.shape[0]
376
+ out = np.zeros_like(s)
377
+ for i in range(n):
378
+ w = np.exp(-b * (np.log(f / f[i])) ** 2)
379
+ w[~np.isfinite(w)] = 0.0
380
+ denom = np.sum(w)
381
+ out[i] = 0.0 if denom == 0 else float(np.sum(w * s) / denom)
382
+ return out
383
+
384
+
385
+ def _compute_hvsr_simple(signal: np.ndarray, fs: float = 100.0):
386
+ """Compute HVSR curve using MATLAB-style pipeline (Boore HP filter + FAS + Konno-Ohmachi)."""
387
+ try:
388
+ if signal.ndim != 2 or signal.shape[1] != 3:
389
+ return None, None
390
+ if np.any(np.isnan(signal)) or np.any(np.isinf(signal)):
391
+ return None, None
392
+
393
+ # Preprocess components (Boore 2005): detrend + zero-padding + high-pass (0.05 Hz)
394
+ ew = _preprocess_component_boore(signal[:, 0], fs, 0.05, 2)
395
+ ns = _preprocess_component_boore(signal[:, 1], fs, 0.05, 2)
396
+ ud = _preprocess_component_boore(signal[:, 2], fs, 0.05, 2)
397
+
398
+ n = int(min(len(ew), len(ns), len(ud)))
399
+ if n < 16:
400
+ return None, None
401
+ ew = ew[:n]; ns = ns[:n]; ud = ud[:n]
402
+
403
+ # FFT amplitudes and linear frequency grid
404
+ half = n // 2
405
+ if half <= 1:
406
+ return None, None
407
+ freq = (np.arange(0, half, dtype=np.float64)) * (fs / n)
408
+ amp_ew = np.abs(np.fft.fft(ew))[:half]
409
+ amp_ns = np.abs(np.fft.fft(ns))[:half]
410
+ amp_ud = np.abs(np.fft.fft(ud))[:half]
411
+
412
+ # Horizontal combination via geometric mean, then Konno-Ohmachi smoothing
413
+ combined_h = np.sqrt(np.maximum(amp_ew, 0.0) * np.maximum(amp_ns, 0.0))
414
+ sm_h = _konno_ohmachi_smoothing(combined_h, freq, 40.0)
415
+ sm_v = _konno_ohmachi_smoothing(amp_ud, freq, 40.0)
416
+
417
+ sm_v_safe = np.where(sm_v <= 0.0, 1e-12, sm_v)
418
+ sm_hvsr = sm_h / sm_v_safe
419
+
420
+ # Limit to 1-20 Hz band
421
+ mask = (freq >= 1.0) & (freq <= 20.0)
422
+ if not np.any(mask):
423
+ return None, None
424
+ return freq[mask], sm_hvsr[mask]
425
+ except Exception:
426
+ return None, None
427
+
428
+
429
+ def save_generated_samples(generated_signals, real_names, station_id, output_dir):
430
+ """Save generated samples to NPZ file with HVSR and f0 data."""
431
+ os.makedirs(output_dir, exist_ok=True)
432
+
433
+ # Compute HVSR and f0 for all generated signals
434
+ f0_list = []
435
+ hvsr_curves = []
436
+ fs = 100.0
437
+
438
+ print(f"[INFO] Computing HVSR and f0 for {len(generated_signals)} generated samples...")
439
+ for idx, sig in enumerate(generated_signals):
440
+ # sig is (3, T), need to transpose to (T, 3)
441
+ sig_t = sig.T # (T, 3)
442
+ freq, hvsr = _compute_hvsr_simple(sig_t, fs)
443
+ if freq is not None and hvsr is not None:
444
+ hvsr_curves.append((freq, hvsr))
445
+ # f0 = frequency at max HVSR
446
+ max_idx = np.argmax(hvsr)
447
+ f0 = float(freq[max_idx])
448
+ f0_list.append(f0)
449
+
450
+ # Build median HVSR curve on a fixed frequency grid (1-20 Hz, 400 points for consistency)
451
+ hvsr_freq = None
452
+ hvsr_median = None
453
+ if hvsr_curves:
454
+ # Use a fixed frequency grid for consistency with other plots
455
+ hvsr_freq = np.linspace(1.0, 20.0, 400)
456
+ # Interpolate all curves to common grid
457
+ hvsr_matrix = []
458
+ for freq, hvsr in hvsr_curves:
459
+ hvsr_interp = np.interp(hvsr_freq, freq, hvsr, left=hvsr[0], right=hvsr[-1])
460
+ hvsr_matrix.append(hvsr_interp)
461
+ hvsr_median = np.median(np.vstack(hvsr_matrix), axis=0)
462
+
463
+ # Build f0 histogram (PDF)
464
+ f0_bins = np.linspace(1.0, 20.0, 21)
465
+ f0_array = np.array(f0_list)
466
+ f0_hist, _ = np.histogram(f0_array, bins=f0_bins)
467
+ f0_pdf = f0_hist.astype(float)
468
+ f0_sum = f0_pdf.sum()
469
+ if f0_sum > 0:
470
+ f0_pdf = f0_pdf / f0_sum
471
+
472
+ # Save timeseries NPZ with HVSR data
473
+ output_path = os.path.join(output_dir, f'station_{station_id}_generated_timeseries.npz')
474
+ np.savez_compressed(
475
+ output_path,
476
+ generated_signals=generated_signals,
477
+ signals_generated=generated_signals, # Alias for compatibility
478
+ real_names=real_names,
479
+ station_id=station_id,
480
+ station=station_id, # Alias for compatibility
481
+ f0_timesnet=f0_array,
482
+ f0_bins=f0_bins,
483
+ pdf_timesnet=f0_pdf,
484
+ hvsr_freq_timesnet=hvsr_freq if hvsr_freq is not None else np.array([]),
485
+ hvsr_median_timesnet=hvsr_median if hvsr_median is not None else np.array([]),
486
+ )
487
+ print(f"[INFO] Saved {len(generated_signals)} generated samples to {output_path}")
488
+ if len(f0_list) > 0:
489
+ print(f"[INFO] - f0 samples: {len(f0_list)}, median f0: {np.median(f0_array):.2f} Hz")
490
+ else:
491
+ print(f"[INFO] - No valid HVSR computed")
492
+
493
+
494
+ def fine_tune_model(model, all_station_files, args, encoder_std, epochs=10, lr=1e-4):
495
+ """
496
+ Fine-tune the model on 5 stations with noise injection.
497
+ Matches Phase 1 training in untitled1_gen.py exactly.
498
+ """
499
+ print("\n" + "="*80)
500
+ print("Phase 1: Fine-Tuning with Noise Injection")
501
+ print("="*80)
502
+
503
+ # Prepare data loader
504
+ all_data = []
505
+ for station_id, files in all_station_files.items():
506
+ for fpath in files:
507
+ data = load_mat_file(fpath, args.seq_len, debug=False)
508
+ if data is not None:
509
+ all_data.append(data)
510
+
511
+ if len(all_data) == 0:
512
+ print("[WARN] No data loaded for fine-tuning!")
513
+ return model
514
+
515
+ print(f"[INFO] Loaded {len(all_data)} samples for fine-tuning")
516
+
517
+ # Create optimizer (matching untitled1_gen.py Phase 1)
518
+ batch_size = 32
519
+ weight_decay = 1e-4
520
+ optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
521
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
522
+
523
+ # AMP scaler (matching untitled1_gen.py)
524
+ scaler = torch.cuda.amp.GradScaler(enabled=(args.use_gpu))
525
+
526
+ # Gradient clipping (matching untitled1_gen.py)
527
+ grad_clip = 1.0
528
+
529
+ train_losses_p1 = []
530
+
531
+ for epoch in range(epochs):
532
+ model.train()
533
+ total_loss = 0.0
534
+ total_rec = 0.0
535
+ num_batches = 0
536
+
537
+ # Shuffle data
538
+ np.random.shuffle(all_data)
539
+
540
+ for i in range(0, len(all_data), batch_size):
541
+ batch = all_data[i:i+batch_size]
542
+ if len(batch) == 0:
543
+ continue
544
+
545
+ # Stack batch
546
+ x_list = []
547
+ for sig in batch:
548
+ # sig is (3, 6000), transpose to (6000, 3)
549
+ x_list.append(sig.transpose(0, 1))
550
+
551
+ x = torch.stack(x_list, dim=0) # (batch, 6000, 3)
552
+ if args.use_gpu:
553
+ x = x.cuda()
554
+
555
+ # Zero gradients (matching untitled1_gen.py)
556
+ optimizer.zero_grad(set_to_none=True)
557
+
558
+ # Forward with AMP and noise injection (matching untitled1_gen.py Phase 1)
559
+ with torch.cuda.amp.autocast(enabled=(args.use_gpu)):
560
+ enc_out, means_b, stdev_b = model.encode_features_for_reconstruction(x)
561
+
562
+ # Add noise if encoder_std available (matching untitled1_gen.py line 945-948)
563
+ if encoder_std is not None:
564
+ std_vec = torch.from_numpy(encoder_std).to(enc_out.device).float()
565
+ noise = torch.randn_like(enc_out) * std_vec.view(1, 1, -1) * 1.0 # noise_std_scale=1.0
566
+ enc_out = enc_out + noise
567
+
568
+ # Decode
569
+ x_hat = model.project_features_for_reconstruction(enc_out, means_b, stdev_b)
570
+
571
+ # Reconstruction loss (MSE, matching untitled1_gen.py)
572
+ loss_rec = torch.nn.functional.mse_loss(x_hat, x)
573
+ loss = loss_rec
574
+
575
+ # Backward with gradient scaling (matching untitled1_gen.py)
576
+ scaler.scale(loss).backward()
577
+
578
+ # Gradient clipping (matching untitled1_gen.py)
579
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
580
+
581
+ # Optimizer step with scaler (matching untitled1_gen.py)
582
+ scaler.step(optimizer)
583
+ scaler.update()
584
+
585
+ total_loss += float(loss.detach().cpu())
586
+ total_rec += float(loss_rec.detach().cpu())
587
+ num_batches += 1
588
+
589
+ # Scheduler step (matching untitled1_gen.py)
590
+ scheduler.step()
591
+
592
+ avg_loss = total_loss / max(1, num_batches)
593
+ avg_rec = total_rec / max(1, num_batches)
594
+ train_losses_p1.append(avg_loss)
595
+ print(f"[P1] epoch {epoch+1}/{epochs} loss={avg_loss:.4f} (rec={avg_rec:.4f})")
596
+
597
+ print("[INFO] Phase 1 fine-tuning complete!")
598
+
599
+ # Save fine-tuned model (matching untitled1_gen.py Phase 1 checkpoint)
600
+ checkpoint_dir = './checkpoints'
601
+ os.makedirs(checkpoint_dir, exist_ok=True)
602
+ fine_tuned_path = os.path.join(checkpoint_dir, 'timesnet_pointcloud_phase1_finetuned.pth')
603
+ torch.save({
604
+ 'epoch': epochs,
605
+ 'model_state_dict': model.state_dict(),
606
+ 'optimizer_state_dict': optimizer.state_dict(),
607
+ 'train_losses_phase1': train_losses_p1,
608
+ 'phase': 'phase1'
609
+ }, fine_tuned_path)
610
+ print(f"[INFO] ✓ Fine-tuned model saved to: {fine_tuned_path}")
611
+
612
+ return model
613
+
614
+
615
+ def plot_sample_preview(generated_signals, station_id, output_dir, num_preview=2):
616
+ """Create preview plots."""
617
+ os.makedirs(output_dir, exist_ok=True)
618
+
619
+ for i in range(min(num_preview, len(generated_signals))):
620
+ fig, axes = plt.subplots(3, 1, figsize=(12, 8))
621
+ signal = generated_signals[i]
622
+
623
+ channel_names = ['E-W', 'N-S', 'U-D']
624
+ for ch, (ax, name) in enumerate(zip(axes, channel_names)):
625
+ ax.plot(signal[ch], linewidth=0.8)
626
+ ax.set_ylabel(f'{name}\nAmplitude', fontsize=10, fontweight='bold')
627
+ ax.grid(True, alpha=0.3)
628
+
629
+ axes[-1].set_xlabel('Time Steps', fontsize=10, fontweight='bold')
630
+ fig.suptitle(f'Generated Sample - Station {station_id}', fontsize=12, fontweight='bold')
631
+ plt.tight_layout()
632
+
633
+ output_path = os.path.join(output_dir, f'station_{station_id}_preview_{i}.png')
634
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
635
+ plt.close()
636
+
637
+ print(f"[INFO] Saved {min(num_preview, len(generated_signals))} preview plots to {output_dir}")
638
+
639
+
640
+ def main():
641
+ parser = argparse.ArgumentParser(description='Generate seismic samples (simplified version)')
642
+ parser.add_argument('--checkpoint', type=str,
643
+ default=r'D:\Baris\codes\Time-Series-Library-main\checkpoints\timesnet_pointcloud_phase1_final.pth',
644
+ help='Path to pre-trained model checkpoint')
645
+ parser.add_argument('--latent_bank', type=str,
646
+ default=r'D:\Baris\codes\Time-Series-Library-main\checkpoints\latent_bank_phase1.npz',
647
+ help='Path to latent bank NPZ file')
648
+ parser.add_argument('--num_samples', type=int, default=50,
649
+ help='Number of samples to generate per station')
650
+ parser.add_argument('--output_dir', type=str, default='./generated_samples',
651
+ help='Output directory')
652
+ parser.add_argument('--num_preview', type=int, default=2,
653
+ help='Number of preview plots per station')
654
+ parser.add_argument('--stations', type=str, nargs='+', default=['0205', '1716', '2020', '3130', '4628'],
655
+ help='Target station IDs')
656
+ parser.add_argument('--data_root', type=str, default=r"D:\Baris\5stats/",
657
+ help='Root path to seismic data (only needed if --fine_tune is used)')
658
+ parser.add_argument('--fine_tune', action='store_true',
659
+ help='Fine-tune the model before generation (use with Phase 0 checkpoint)')
660
+ parser.add_argument('--fine_tune_epochs', type=int, default=10,
661
+ help='Number of fine-tuning epochs')
662
+ parser.add_argument('--fine_tune_lr', type=float, default=1e-4,
663
+ help='Learning rate for fine-tuning')
664
+
665
+ args_cli = parser.parse_args()
666
+
667
+ # Check checkpoint
668
+ if not os.path.exists(args_cli.checkpoint):
669
+ print(f"\n{'='*80}")
670
+ print(f"❌ ERROR: Checkpoint not found!")
671
+ print(f"{'='*80}")
672
+ print(f"\nLooking for: {args_cli.checkpoint}")
673
+ return
674
+
675
+ # Create configuration
676
+ args = SimpleArgs()
677
+
678
+ print("="*80)
679
+ print("TimesNet-Gen Sample Generation (Simplified)")
680
+ print("="*80)
681
+ print(f"Checkpoint: {args_cli.checkpoint}")
682
+ print(f"Target stations: {args_cli.stations}")
683
+ print(f"Samples per station: {args_cli.num_samples}")
684
+ print(f"Output directory: {args_cli.output_dir}")
685
+ print("="*80)
686
+
687
+ # Set random seed
688
+ torch.manual_seed(args.seed)
689
+ np.random.seed(args.seed)
690
+
691
+ # Load model
692
+ model = load_model(args_cli.checkpoint, args)
693
+
694
+ # Try to load encoder_std from Phase 0 (only needed if fine-tuning)
695
+ encoder_std_path = './pcgen_stats/encoder_feature_std.npy'
696
+ encoder_std = None
697
+ if os.path.exists(encoder_std_path):
698
+ encoder_std = np.load(encoder_std_path)
699
+ print(f"[INFO] Loaded encoder_std from {encoder_std_path} (shape: {encoder_std.shape})")
700
+ print(f"[INFO] encoder_std loaded (used only for fine-tuning, NOT for generation)")
701
+ else:
702
+ print(f"[INFO] No encoder_std found (not needed for generation, only for fine-tuning)")
703
+
704
+ # Check if latent bank exists
705
+ if not os.path.exists(args_cli.latent_bank):
706
+ print(f"\n❌ ERROR: Latent bank not found!")
707
+ print(f"Looking for: {args_cli.latent_bank}")
708
+ print(f"\nPlease run untitled1_gen.py first to generate the latent bank.")
709
+ return
710
+
711
+ print(f"[INFO] Using latent bank: {args_cli.latent_bank}")
712
+
713
+ # Fine-tune if requested (requires real data)
714
+ if args_cli.fine_tune:
715
+ print("\n[INFO] Fine-tuning enabled! Loading real data...")
716
+
717
+ all_station_files = {}
718
+ for station_id in args_cli.stations:
719
+ # Find all .mat files for this station
720
+ pattern = os.path.join(args_cli.data_root, f"*{station_id}*.mat")
721
+ station_files = glob.glob(pattern)
722
+
723
+ if len(station_files) == 0:
724
+ print(f"[WARN] No files found for station {station_id}")
725
+ else:
726
+ print(f"[INFO] Found {len(station_files)} files for station {station_id}")
727
+ all_station_files[station_id] = station_files
728
+
729
+ if len(all_station_files) == 0:
730
+ print(f"\n❌ ERROR: No data files found in {args_cli.data_root}")
731
+ return
732
+
733
+ model = fine_tune_model(model, all_station_files, args, encoder_std,
734
+ epochs=args_cli.fine_tune_epochs,
735
+ lr=args_cli.fine_tune_lr)
736
+
737
+ # Create output directories
738
+ npz_output_dir = os.path.join(args_cli.output_dir, 'generated_timeseries_npz')
739
+ plot_output_dir = os.path.join(args_cli.output_dir, 'preview_plots')
740
+
741
+ # Generate samples for each station (from latent bank)
742
+ print("\n[INFO] Generating samples from latent bank...")
743
+ for station_id in args_cli.stations:
744
+ print(f"\n{'='*60}")
745
+ print(f"Processing Station: {station_id}")
746
+ print(f"{'='*60}")
747
+
748
+ generated_signals, real_names = generate_samples_from_latent_bank(
749
+ model, args_cli.latent_bank, station_id, args_cli.num_samples, args, encoder_std
750
+ )
751
+
752
+ if generated_signals is not None:
753
+ # Save to NPZ
754
+ save_generated_samples(generated_signals, real_names, station_id, npz_output_dir)
755
+
756
+ # Create preview plots
757
+ plot_sample_preview(generated_signals, station_id, plot_output_dir, args_cli.num_preview)
758
+
759
+ print("\n" + "="*80)
760
+ print("Generation Complete!")
761
+ print("="*80)
762
+ print(f"Generated samples saved to: {npz_output_dir}")
763
+ print(f"Preview plots saved to: {plot_output_dir}")
764
+
765
+ # Debug: Show how many samples were generated per station
766
+ print("\n[DEBUG] Generated samples per station:")
767
+ for station_id in args_cli.stations:
768
+ npz_path = os.path.join(npz_output_dir, f'station_{station_id}_generated_timeseries.npz')
769
+ if os.path.exists(npz_path):
770
+ try:
771
+ data = np.load(npz_path, allow_pickle=True)
772
+ if 'signals_generated' in data:
773
+ n_samples = data['signals_generated'].shape[0]
774
+ print(f" Station {station_id}: {n_samples} samples")
775
+ except Exception as e:
776
+ print(f" Station {station_id}: Error loading NPZ - {e}")
777
+ print("="*80)
778
+
779
+ # Create HVSR comparison plots (import plot_combined_hvsr_all_sources and call main)
780
+ print("\n[INFO] Creating HVSR comparison plots (matrices, HVSR curves, f0 distributions)...")
781
+ print("[INFO] Only plotting TimesNet-Gen vs Real (no Recon/VAE)")
782
+ try:
783
+ import sys
784
+ # Import the plotting module
785
+ import plot_combined_hvsr_all_sources as hvsr_plotter
786
+
787
+ # Override sys.argv to pass arguments to the plotter
788
+ # Only provide gen_dir and gen_ts_dir, explicitly disable others with empty strings
789
+ original_argv = sys.argv
790
+ sys.argv = [
791
+ 'plot_combined_hvsr_all_sources.py',
792
+ '--gen_dir', npz_output_dir, # Use our generated NPZs as gen_dir (they now have HVSR/f0 data)
793
+ '--gen_ts_dir', npz_output_dir, # Also use for timeseries plots
794
+ '--out', os.path.join(args_cli.output_dir, 'hvsr_analysis'),
795
+ '--recon_dir', '', # Explicitly empty to disable auto-default
796
+ '--vae_dir', '', # Explicitly empty to disable auto-default
797
+ '--vae_gen_dir', '', # Explicitly empty to disable auto-default
798
+ ]
799
+
800
+ # Call the main plotting function
801
+ hvsr_plotter.main()
802
+
803
+ # Restore original argv
804
+ sys.argv = original_argv
805
+
806
+ print(f"[INFO] ✅ HVSR analysis complete! Plots saved to: {os.path.join(args_cli.output_dir, 'hvsr_analysis')}")
807
+ except Exception as e:
808
+ import traceback
809
+ print(f"[WARN] Could not create HVSR plots: {e}")
810
+ traceback.print_exc()
811
+
812
+
813
+ if __name__ == "__main__":
814
+ main()
815
+