Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files- TimesNet.py +418 -0
- TimesNet_PointCloud.py +213 -0
- app.py +236 -0
- generate_samples_git.py +815 -0
TimesNet.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.fft
|
| 5 |
+
import numpy as np
|
| 6 |
+
# Basit embedding ve conv blocks - layers klasörü olmadan
|
| 7 |
+
class DataEmbedding(nn.Module):
|
| 8 |
+
def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1, seq_len=6000):
|
| 9 |
+
super(DataEmbedding, self).__init__()
|
| 10 |
+
self.c_in = c_in
|
| 11 |
+
self.d_model = d_model
|
| 12 |
+
self.embed_type = embed_type
|
| 13 |
+
self.freq = freq
|
| 14 |
+
self.seq_len = seq_len
|
| 15 |
+
|
| 16 |
+
# Basit linear embedding
|
| 17 |
+
self.value_embedding = nn.Linear(c_in, d_model)
|
| 18 |
+
# Position embedding'i seq_len'e göre oluştur
|
| 19 |
+
self.position_embedding = nn.Parameter(torch.randn(1, seq_len, d_model))
|
| 20 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 21 |
+
|
| 22 |
+
def forward(self, x, x_mark):
|
| 23 |
+
x = self.value_embedding(x)
|
| 24 |
+
|
| 25 |
+
# Position embedding'i input boyutuna göre crop et
|
| 26 |
+
# seq_len'e göre oluşturulduğu için genelde uyumlu olacak
|
| 27 |
+
if x.size(1) <= self.position_embedding.size(1):
|
| 28 |
+
x = x + self.position_embedding[:, :x.size(1), :]
|
| 29 |
+
else:
|
| 30 |
+
# Eğer input daha büyükse, position embedding'i extend et
|
| 31 |
+
x = x + self.position_embedding
|
| 32 |
+
remaining_length = x.size(1) - self.position_embedding.size(1)
|
| 33 |
+
if remaining_length > 0:
|
| 34 |
+
# Sinusoidal position encoding ekle
|
| 35 |
+
pos_encoding = self._get_sinusoidal_encoding(remaining_length, self.d_model)
|
| 36 |
+
pos_encoding = pos_encoding.unsqueeze(0).to(x.device)
|
| 37 |
+
x[:, self.position_embedding.size(1):, :] += pos_encoding
|
| 38 |
+
|
| 39 |
+
return self.dropout(x)
|
| 40 |
+
|
| 41 |
+
def _get_sinusoidal_encoding(self, length, d_model):
|
| 42 |
+
"""Sinusoidal position encoding oluştur"""
|
| 43 |
+
position = torch.arange(length).unsqueeze(1).float()
|
| 44 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
|
| 45 |
+
|
| 46 |
+
pos_encoding = torch.zeros(length, d_model)
|
| 47 |
+
pos_encoding[:, 0::2] = torch.sin(position * div_term)
|
| 48 |
+
pos_encoding[:, 1::2] = torch.cos(position * div_term)
|
| 49 |
+
|
| 50 |
+
return pos_encoding
|
| 51 |
+
|
| 52 |
+
class Inception_Block_V1(nn.Module):
|
| 53 |
+
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
|
| 54 |
+
super(Inception_Block_V1, self).__init__()
|
| 55 |
+
self.in_channels = in_channels
|
| 56 |
+
self.out_channels = out_channels
|
| 57 |
+
self.num_kernels = num_kernels
|
| 58 |
+
kernels = []
|
| 59 |
+
for i in range(self.num_kernels):
|
| 60 |
+
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
|
| 61 |
+
self.kernels = nn.ModuleList(kernels)
|
| 62 |
+
if init_weight:
|
| 63 |
+
self._initialize_weights()
|
| 64 |
+
|
| 65 |
+
def _initialize_weights(self):
|
| 66 |
+
for m in self.modules():
|
| 67 |
+
if isinstance(m, nn.Conv2d):
|
| 68 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 69 |
+
if m.bias is not None:
|
| 70 |
+
nn.init.constant_(m.bias, 0)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
res_list = []
|
| 74 |
+
for i, kernel in enumerate(self.kernels):
|
| 75 |
+
res_list.append(kernel(x))
|
| 76 |
+
res = torch.stack(res_list, dim=-1).mean(-1)
|
| 77 |
+
return res
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def FFT_for_Period(x, k=2):
|
| 81 |
+
# [B, T, C]
|
| 82 |
+
xf = torch.fft.rfft(x, dim=1)
|
| 83 |
+
# find period by amplitudes
|
| 84 |
+
frequency_list = abs(xf).mean(0).mean(-1)
|
| 85 |
+
frequency_list[0] = 0
|
| 86 |
+
_, top_list = torch.topk(frequency_list, k)
|
| 87 |
+
top_list = top_list.detach().cpu().numpy()
|
| 88 |
+
period = x.shape[1] // top_list
|
| 89 |
+
return period, abs(xf).mean(-1)[:, top_list]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class TimesBlock(nn.Module):
|
| 93 |
+
def __init__(self, configs):
|
| 94 |
+
super(TimesBlock, self).__init__()
|
| 95 |
+
self.seq_len = configs.seq_len
|
| 96 |
+
self.pred_len = configs.pred_len
|
| 97 |
+
self.k = configs.top_k
|
| 98 |
+
# parameter-efficient design
|
| 99 |
+
self.conv = nn.Sequential(
|
| 100 |
+
Inception_Block_V1(configs.d_model, configs.d_ff,
|
| 101 |
+
num_kernels=configs.num_kernels),
|
| 102 |
+
nn.GELU(),
|
| 103 |
+
Inception_Block_V1(configs.d_ff, configs.d_model,
|
| 104 |
+
num_kernels=configs.num_kernels)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
B, T, N = x.size() #B: batch size T: length of time series N:number of features
|
| 109 |
+
period_list, period_weight = FFT_for_Period(x, self.k)
|
| 110 |
+
|
| 111 |
+
res = []
|
| 112 |
+
for i in range(self.k):
|
| 113 |
+
period = period_list[i]
|
| 114 |
+
# padding
|
| 115 |
+
if (self.seq_len + self.pred_len) % period != 0:
|
| 116 |
+
length = (
|
| 117 |
+
((self.seq_len + self.pred_len) // period) + 1) * period
|
| 118 |
+
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
|
| 119 |
+
out = torch.cat([x, padding], dim=1)
|
| 120 |
+
else:
|
| 121 |
+
length = (self.seq_len + self.pred_len)
|
| 122 |
+
out = x
|
| 123 |
+
# reshape
|
| 124 |
+
out = out.reshape(B, length // period, period,
|
| 125 |
+
N).permute(0, 3, 1, 2).contiguous()
|
| 126 |
+
# 2D conv: from 1d Variation to 2d Variation
|
| 127 |
+
out = self.conv(out)
|
| 128 |
+
# reshape back
|
| 129 |
+
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
|
| 130 |
+
res.append(out[:, :(self.seq_len + self.pred_len), :])
|
| 131 |
+
res = torch.stack(res, dim=-1)
|
| 132 |
+
# adaptive aggregation
|
| 133 |
+
period_weight = F.softmax(period_weight, dim=1)
|
| 134 |
+
period_weight = period_weight.unsqueeze(
|
| 135 |
+
1).unsqueeze(1).repeat(1, T, N, 1)
|
| 136 |
+
res = torch.sum(res * period_weight, -1)
|
| 137 |
+
# residual connection
|
| 138 |
+
res = res + x
|
| 139 |
+
return res
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class Model(nn.Module):
|
| 143 |
+
"""
|
| 144 |
+
Paper link: https://openreview.net/pdf?id=ju_Uqw384Oq
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def __init__(self, configs):
|
| 148 |
+
super(Model, self).__init__()
|
| 149 |
+
self.configs = configs
|
| 150 |
+
self.task_name = configs.task_name
|
| 151 |
+
self.seq_len = configs.seq_len
|
| 152 |
+
self.label_len = configs.label_len
|
| 153 |
+
self.pred_len = configs.pred_len
|
| 154 |
+
self.model = nn.ModuleList([TimesBlock(configs)
|
| 155 |
+
for _ in range(configs.e_layers)])
|
| 156 |
+
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
|
| 157 |
+
configs.dropout, configs.seq_len)
|
| 158 |
+
self.layer = configs.e_layers
|
| 159 |
+
self.layer_norm = nn.LayerNorm(configs.d_model)
|
| 160 |
+
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 161 |
+
self.predict_linear = nn.Linear(
|
| 162 |
+
self.seq_len, self.pred_len + self.seq_len)
|
| 163 |
+
self.projection = nn.Linear(
|
| 164 |
+
configs.d_model, configs.c_out, bias=True)
|
| 165 |
+
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
|
| 166 |
+
self.projection = nn.Linear(
|
| 167 |
+
configs.d_model, configs.c_out, bias=True)
|
| 168 |
+
|
| 169 |
+
# Transfer learning için P-S prediction heads (sadece gerektiğinde eklenir)
|
| 170 |
+
if hasattr(configs, 'use_ps_heads') and configs.use_ps_heads:
|
| 171 |
+
# Skip attention for memory efficiency - use only pooling
|
| 172 |
+
|
| 173 |
+
# Multi-scale feature extraction (reduced sizes for memory)
|
| 174 |
+
self.multi_scale_pools = nn.ModuleList([
|
| 175 |
+
nn.AdaptiveAvgPool1d(16), # Local patterns (reduced)
|
| 176 |
+
nn.AdaptiveAvgPool1d(4), # Medium patterns
|
| 177 |
+
nn.AdaptiveAvgPool1d(1), # Global patterns
|
| 178 |
+
])
|
| 179 |
+
|
| 180 |
+
# Feature fusion - calculate exact dimension
|
| 181 |
+
# Pool sizes: 16 + 4 + 1 = 21, so total dim = d_model * 21
|
| 182 |
+
fusion_dim = configs.d_model * (16 + 4 + 1) # Exact calculation
|
| 183 |
+
self.feature_fusion = nn.Sequential(
|
| 184 |
+
nn.Linear(fusion_dim, configs.d_model),
|
| 185 |
+
nn.ReLU(),
|
| 186 |
+
nn.Dropout(configs.dropout)
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Separate P and S regression heads
|
| 190 |
+
self.p_regression_head = nn.Sequential(
|
| 191 |
+
nn.Linear(configs.d_model, 128),
|
| 192 |
+
nn.ReLU(),
|
| 193 |
+
nn.Dropout(configs.dropout),
|
| 194 |
+
nn.Linear(128, 64),
|
| 195 |
+
nn.ReLU(),
|
| 196 |
+
nn.Dropout(configs.dropout),
|
| 197 |
+
nn.Linear(64, 1) # P time only
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self.s_regression_head = nn.Sequential(
|
| 201 |
+
nn.Linear(configs.d_model, 128),
|
| 202 |
+
nn.ReLU(),
|
| 203 |
+
nn.Dropout(configs.dropout),
|
| 204 |
+
nn.Linear(128, 64),
|
| 205 |
+
nn.ReLU(),
|
| 206 |
+
nn.Dropout(configs.dropout),
|
| 207 |
+
nn.Linear(64, 1) # S time only
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
# Separate P and S classification heads
|
| 211 |
+
self.p_classification_head = nn.Sequential(
|
| 212 |
+
nn.Linear(configs.d_model, 64),
|
| 213 |
+
nn.ReLU(),
|
| 214 |
+
nn.Dropout(configs.dropout),
|
| 215 |
+
nn.Linear(64, 32),
|
| 216 |
+
nn.ReLU(),
|
| 217 |
+
nn.Dropout(configs.dropout),
|
| 218 |
+
nn.Linear(32, 1), # P exists/not
|
| 219 |
+
nn.Sigmoid()
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
self.s_classification_head = nn.Sequential(
|
| 223 |
+
nn.Linear(configs.d_model, 64),
|
| 224 |
+
nn.ReLU(),
|
| 225 |
+
nn.Dropout(configs.dropout),
|
| 226 |
+
nn.Linear(64, 32),
|
| 227 |
+
nn.ReLU(),
|
| 228 |
+
nn.Dropout(configs.dropout),
|
| 229 |
+
nn.Linear(32, 1), # S exists/not
|
| 230 |
+
nn.Sigmoid()
|
| 231 |
+
)
|
| 232 |
+
if self.task_name == 'classification':
|
| 233 |
+
self.act = F.gelu
|
| 234 |
+
self.dropout = nn.Dropout(configs.dropout)
|
| 235 |
+
self.projection = nn.Linear(
|
| 236 |
+
configs.d_model * configs.seq_len, configs.num_class)
|
| 237 |
+
|
| 238 |
+
def anomaly_detection(self, x_enc):
|
| 239 |
+
# Transfer learning için P-S heads varsa - SADECE ONLARI KULLAN
|
| 240 |
+
if hasattr(self, 'p_regression_head'):
|
| 241 |
+
# Normalization from Non-stationary Transformer
|
| 242 |
+
means = x_enc.mean(1, keepdim=True).detach()
|
| 243 |
+
x_enc = x_enc - means
|
| 244 |
+
stdev = torch.sqrt(
|
| 245 |
+
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 246 |
+
x_enc /= stdev
|
| 247 |
+
|
| 248 |
+
# embedding
|
| 249 |
+
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
|
| 250 |
+
# TimesNet
|
| 251 |
+
for i in range(self.layer):
|
| 252 |
+
enc_out = self.layer_norm(self.model[i](enc_out))
|
| 253 |
+
|
| 254 |
+
# Skip attention for memory - use direct multi-scale pooling
|
| 255 |
+
# Multi-scale feature extraction directly on TimesNet output
|
| 256 |
+
enc_out_transposed = enc_out.permute(0, 2, 1) # (B, d_model, T)
|
| 257 |
+
multi_scale_features = []
|
| 258 |
+
|
| 259 |
+
# Manual pooling for large sequences to avoid CUDA memory issues
|
| 260 |
+
pool_sizes = [16, 4, 1] # Target pool sizes
|
| 261 |
+
for i, target_size in enumerate(pool_sizes):
|
| 262 |
+
T = enc_out_transposed.size(2) # Sequence length
|
| 263 |
+
|
| 264 |
+
if T >= 8000: # Very large - use manual avg pooling
|
| 265 |
+
# Manual average pooling
|
| 266 |
+
window_size = T // target_size
|
| 267 |
+
if window_size > 0:
|
| 268 |
+
# Reshape and average
|
| 269 |
+
# (B, d_model, T) -> (B, d_model, target_size, window_size)
|
| 270 |
+
trimmed_T = (T // window_size) * window_size
|
| 271 |
+
trimmed = enc_out_transposed[:, :, :trimmed_T]
|
| 272 |
+
reshaped = trimmed.view(trimmed.size(0), trimmed.size(1), target_size, window_size)
|
| 273 |
+
pooled = reshaped.mean(dim=3) # Average over window
|
| 274 |
+
else:
|
| 275 |
+
# Fallback: simple reshape
|
| 276 |
+
pooled = enc_out_transposed[:, :, :target_size] if T >= target_size else enc_out_transposed
|
| 277 |
+
else:
|
| 278 |
+
# Use normal adaptive pooling for smaller sequences
|
| 279 |
+
pool = self.multi_scale_pools[i]
|
| 280 |
+
pooled = pool(enc_out_transposed) # (B, d_model, pool_size)
|
| 281 |
+
|
| 282 |
+
flattened = pooled.flatten(1) # (B, d_model * pool_size)
|
| 283 |
+
multi_scale_features.append(flattened)
|
| 284 |
+
|
| 285 |
+
# Concatenate multi-scale features
|
| 286 |
+
fused_features = torch.cat(multi_scale_features, dim=1) # (B, d_model * 3)
|
| 287 |
+
|
| 288 |
+
# Feature fusion
|
| 289 |
+
final_features = self.feature_fusion(fused_features) # (B, d_model)
|
| 290 |
+
|
| 291 |
+
# Separate P and S predictions
|
| 292 |
+
p_time = self.p_regression_head(final_features) # (B, 1)
|
| 293 |
+
s_time = self.s_regression_head(final_features) # (B, 1)
|
| 294 |
+
ps_times = torch.cat([p_time, s_time], dim=1) # (B, 2)
|
| 295 |
+
|
| 296 |
+
# Separate P and S classifications
|
| 297 |
+
p_class = self.p_classification_head(final_features) # (B, 1)
|
| 298 |
+
s_class = self.s_classification_head(final_features) # (B, 1)
|
| 299 |
+
ps_classification = torch.cat([p_class, s_class], dim=1) # (B, 2)
|
| 300 |
+
|
| 301 |
+
return ps_times, ps_classification
|
| 302 |
+
else:
|
| 303 |
+
# Orijinal anomaly detection (reconstruction)
|
| 304 |
+
# Normalization from Non-stationary Transformer
|
| 305 |
+
means = x_enc.mean(1, keepdim=True).detach()
|
| 306 |
+
x_enc = x_enc - means
|
| 307 |
+
stdev = torch.sqrt(
|
| 308 |
+
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 309 |
+
x_enc /= stdev
|
| 310 |
+
|
| 311 |
+
# embedding
|
| 312 |
+
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
|
| 313 |
+
# TimesNet
|
| 314 |
+
for i in range(self.layer):
|
| 315 |
+
enc_out = self.layer_norm(self.model[i](enc_out))
|
| 316 |
+
# porject back
|
| 317 |
+
dec_out = self.projection(enc_out)
|
| 318 |
+
|
| 319 |
+
# De-Normalization from Non-stationary Transformer
|
| 320 |
+
dec_out = dec_out * \
|
| 321 |
+
(stdev[:, 0, :].unsqueeze(1).repeat(
|
| 322 |
+
1, self.pred_len + self.seq_len, 1))
|
| 323 |
+
dec_out = dec_out + \
|
| 324 |
+
(means[:, 0, :].unsqueeze(1).repeat(
|
| 325 |
+
1, self.pred_len + self.seq_len, 1))
|
| 326 |
+
return dec_out
|
| 327 |
+
|
| 328 |
+
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
|
| 329 |
+
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
|
| 330 |
+
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
|
| 331 |
+
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
| 332 |
+
if self.task_name == 'imputation':
|
| 333 |
+
dec_out = self.imputation(
|
| 334 |
+
x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
|
| 335 |
+
return dec_out # [B, L, D]
|
| 336 |
+
if self.task_name == 'anomaly_detection':
|
| 337 |
+
result = self.anomaly_detection(x_enc)
|
| 338 |
+
return result # [B, L, D] veya [B, L, D], [B, 2], [B, 1]
|
| 339 |
+
if self.task_name == 'classification':
|
| 340 |
+
dec_out = self.classification(x_enc, x_mark_enc)
|
| 341 |
+
return dec_out # [B, N]
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
| 345 |
+
# Normalization from Non-stationary Transformer
|
| 346 |
+
means = x_enc.mean(1, keepdim=True).detach()
|
| 347 |
+
x_enc = x_enc - means
|
| 348 |
+
stdev = torch.sqrt(
|
| 349 |
+
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 350 |
+
x_enc /= stdev
|
| 351 |
+
|
| 352 |
+
# embedding
|
| 353 |
+
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
| 354 |
+
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(
|
| 355 |
+
0, 2, 1) # align temporal dimension
|
| 356 |
+
# TimesNet
|
| 357 |
+
for i in range(self.layer):
|
| 358 |
+
enc_out = self.layer_norm(self.model[i](enc_out))
|
| 359 |
+
# porject back
|
| 360 |
+
dec_out = self.projection(enc_out)
|
| 361 |
+
|
| 362 |
+
# De-Normalization from Non-stationary Transformer
|
| 363 |
+
dec_out = dec_out * \
|
| 364 |
+
(stdev[:, 0, :].unsqueeze(1).repeat(
|
| 365 |
+
1, self.pred_len + self.seq_len, 1))
|
| 366 |
+
dec_out = dec_out + \
|
| 367 |
+
(means[:, 0, :].unsqueeze(1).repeat(
|
| 368 |
+
1, self.pred_len + self.seq_len, 1))
|
| 369 |
+
return dec_out
|
| 370 |
+
|
| 371 |
+
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
|
| 372 |
+
# Normalization from Non-stationary Transformer
|
| 373 |
+
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
|
| 374 |
+
means = means.unsqueeze(1).detach()
|
| 375 |
+
x_enc = x_enc - means
|
| 376 |
+
x_enc = x_enc.masked_fill(mask == 0, 0)
|
| 377 |
+
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
|
| 378 |
+
torch.sum(mask == 1, dim=1) + 1e-5)
|
| 379 |
+
stdev = stdev.unsqueeze(1).detach()
|
| 380 |
+
x_enc /= stdev
|
| 381 |
+
|
| 382 |
+
# embedding
|
| 383 |
+
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
|
| 384 |
+
# TimesNet
|
| 385 |
+
for i in range(self.layer):
|
| 386 |
+
enc_out = self.layer_norm(self.model[i](enc_out))
|
| 387 |
+
# porject back
|
| 388 |
+
dec_out = self.projection(enc_out)
|
| 389 |
+
|
| 390 |
+
# De-Normalization from Non-stationary Transformer
|
| 391 |
+
dec_out = dec_out * \
|
| 392 |
+
(stdev[:, 0, :].unsqueeze(1).repeat(
|
| 393 |
+
1, self.pred_len + self.seq_len, 1))
|
| 394 |
+
dec_out = dec_out + \
|
| 395 |
+
(means[:, 0, :].unsqueeze(1).repeat(
|
| 396 |
+
1, self.pred_len + self.seq_len, 1))
|
| 397 |
+
return dec_out
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def classification(self, x_enc, x_mark_enc):
|
| 402 |
+
# embedding
|
| 403 |
+
enc_out = self.enc_embedding(x_enc, None) # [B,T,C]
|
| 404 |
+
# TimesNet
|
| 405 |
+
for i in range(self.layer):
|
| 406 |
+
enc_out = self.layer_norm(self.model[i](enc_out))
|
| 407 |
+
|
| 408 |
+
# Output
|
| 409 |
+
# the output transformer encoder/decoder embeddings don't include non-linearity
|
| 410 |
+
output = self.act(enc_out)
|
| 411 |
+
output = self.dropout(output)
|
| 412 |
+
# zero-out padding embeddings
|
| 413 |
+
output = output * x_mark_enc.unsqueeze(-1)
|
| 414 |
+
# (batch_size, seq_length * d_model)
|
| 415 |
+
output = output.reshape(output.shape[0], -1)
|
| 416 |
+
output = self.projection(output) # (batch_size, num_classes)
|
| 417 |
+
return output
|
| 418 |
+
|
TimesNet_PointCloud.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from .TimesNet import DataEmbedding
|
| 8 |
+
except Exception:
|
| 9 |
+
from TimesNet import DataEmbedding
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class _BlockConfig:
|
| 13 |
+
def __init__(self, seq_len: int, pred_len: int, d_model: int, d_ff: int, num_kernels: int, top_k: int = 2, num_stations: int = 0):
|
| 14 |
+
self.seq_len = seq_len
|
| 15 |
+
self.pred_len = pred_len
|
| 16 |
+
self.d_model = d_model
|
| 17 |
+
self.d_ff = d_ff
|
| 18 |
+
self.num_kernels = num_kernels
|
| 19 |
+
self.top_k = top_k
|
| 20 |
+
self.num_stations = num_stations
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Inception_Block_V1(nn.Module):
|
| 24 |
+
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
|
| 25 |
+
super(Inception_Block_V1, self).__init__()
|
| 26 |
+
self.in_channels = in_channels
|
| 27 |
+
self.out_channels = out_channels
|
| 28 |
+
self.num_kernels = num_kernels
|
| 29 |
+
kernels = []
|
| 30 |
+
for i in range(self.num_kernels):
|
| 31 |
+
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
|
| 32 |
+
self.kernels = nn.ModuleList(kernels)
|
| 33 |
+
if init_weight:
|
| 34 |
+
self._initialize_weights()
|
| 35 |
+
|
| 36 |
+
def _initialize_weights(self):
|
| 37 |
+
for m in self.modules():
|
| 38 |
+
if isinstance(m, nn.Conv2d):
|
| 39 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 40 |
+
if m.bias is not None:
|
| 41 |
+
nn.init.constant_(m.bias, 0)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
res_list = []
|
| 45 |
+
for i, kernel in enumerate(self.kernels):
|
| 46 |
+
res_list.append(kernel(x))
|
| 47 |
+
res = torch.stack(res_list, dim=-1).mean(-1)
|
| 48 |
+
return res
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def FFT_for_Period(x, k=2):
|
| 52 |
+
# [B, T, C]
|
| 53 |
+
xf = torch.fft.rfft(x, dim=1)
|
| 54 |
+
# find period by amplitudes
|
| 55 |
+
frequency_list = abs(xf).mean(0).mean(-1)
|
| 56 |
+
frequency_list[0] = 0
|
| 57 |
+
_, top_list = torch.topk(frequency_list, k)
|
| 58 |
+
top_list = top_list.detach().cpu().numpy()
|
| 59 |
+
period = x.shape[1] // top_list
|
| 60 |
+
return period, abs(xf).mean(-1)[:, top_list]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class TimesBlockStationCond(nn.Module):
|
| 64 |
+
"""TimesBlock with station ID conditioning (one-hot encoded as 1 channel)."""
|
| 65 |
+
def __init__(self, configs):
|
| 66 |
+
super(TimesBlockStationCond, self).__init__()
|
| 67 |
+
self.seq_len = configs.seq_len
|
| 68 |
+
self.pred_len = configs.pred_len
|
| 69 |
+
self.k = configs.top_k
|
| 70 |
+
self.num_stations = getattr(configs, 'num_stations', 0)
|
| 71 |
+
|
| 72 |
+
# Station ID embedding: maps station ID to d_model dimension
|
| 73 |
+
# This provides richer conditioning information than a single scalar
|
| 74 |
+
if self.num_stations > 0:
|
| 75 |
+
self.station_embedding = nn.Embedding(self.num_stations, configs.d_model)
|
| 76 |
+
# Initialize with small random values
|
| 77 |
+
nn.init.normal_(self.station_embedding.weight, mean=0.0, std=0.02)
|
| 78 |
+
|
| 79 |
+
# Inception blocks
|
| 80 |
+
self.conv = nn.Sequential(
|
| 81 |
+
Inception_Block_V1(configs.d_model, configs.d_ff,
|
| 82 |
+
num_kernels=configs.num_kernels),
|
| 83 |
+
nn.GELU(),
|
| 84 |
+
Inception_Block_V1(configs.d_ff, configs.d_model,
|
| 85 |
+
num_kernels=configs.num_kernels)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def forward(self, x, station_ids: torch.Tensor = None):
|
| 89 |
+
"""
|
| 90 |
+
Args:
|
| 91 |
+
x: (B, T, N) input features
|
| 92 |
+
station_ids: (B,) LongTensor of station IDs (0 to num_stations-1)
|
| 93 |
+
"""
|
| 94 |
+
B, T, N = x.size()
|
| 95 |
+
period_list, period_weight = FFT_for_Period(x, self.k)
|
| 96 |
+
|
| 97 |
+
res = []
|
| 98 |
+
for i in range(self.k):
|
| 99 |
+
period = period_list[i]
|
| 100 |
+
# padding
|
| 101 |
+
if (self.seq_len + self.pred_len) % period != 0:
|
| 102 |
+
length = (((self.seq_len + self.pred_len) // period) + 1) * period
|
| 103 |
+
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
|
| 104 |
+
out = torch.cat([x, padding], dim=1)
|
| 105 |
+
else:
|
| 106 |
+
length = (self.seq_len + self.pred_len)
|
| 107 |
+
out = x
|
| 108 |
+
|
| 109 |
+
# reshape to 2D: (B, N, H, W)
|
| 110 |
+
out = out.reshape(B, length // period, period, N).permute(0, 3, 1, 2).contiguous()
|
| 111 |
+
|
| 112 |
+
# Inject station ID conditioning via embedding addition
|
| 113 |
+
# This provides richer conditioning (d_model dimensions) compared to single scalar
|
| 114 |
+
if station_ids is not None and self.num_stations > 0:
|
| 115 |
+
# Get station embeddings: (B, d_model)
|
| 116 |
+
station_ids_flat = station_ids.view(B)
|
| 117 |
+
station_emb = self.station_embedding(station_ids_flat) # (B, d_model)
|
| 118 |
+
|
| 119 |
+
# out shape: (B, d_model, H, W)
|
| 120 |
+
# Expand station embedding to spatial dimensions: (B, d_model, H, W)
|
| 121 |
+
H = out.size(2)
|
| 122 |
+
W = out.size(3)
|
| 123 |
+
station_emb_spatial = station_emb.view(B, N, 1, 1).expand(-1, -1, H, W)
|
| 124 |
+
|
| 125 |
+
# Add station embedding to features (element-wise addition)
|
| 126 |
+
# This allows the model to learn station-specific feature modifications
|
| 127 |
+
out = out + station_emb_spatial
|
| 128 |
+
|
| 129 |
+
# 2D conv: from 1d Variation to 2d Variation
|
| 130 |
+
out = self.conv(out)
|
| 131 |
+
|
| 132 |
+
# reshape back
|
| 133 |
+
out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
|
| 134 |
+
res.append(out[:, :(self.seq_len + self.pred_len), :])
|
| 135 |
+
|
| 136 |
+
res = torch.stack(res, dim=-1)
|
| 137 |
+
# adaptive aggregation
|
| 138 |
+
period_weight = F.softmax(period_weight, dim=1)
|
| 139 |
+
period_weight = period_weight.unsqueeze(1).unsqueeze(1).repeat(1, T, N, 1)
|
| 140 |
+
res = torch.sum(res * period_weight, -1)
|
| 141 |
+
|
| 142 |
+
# residual connection
|
| 143 |
+
res = res + x
|
| 144 |
+
return res
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class TimesNetPointCloud(nn.Module):
|
| 148 |
+
"""TimesNet reconstruction with exposed encode/project methods for point-cloud mixing."""
|
| 149 |
+
def __init__(self, configs):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.configs = configs
|
| 152 |
+
self.seq_len = configs.seq_len
|
| 153 |
+
self.pred_len = getattr(configs, 'pred_len', 0)
|
| 154 |
+
self.top_k = configs.top_k
|
| 155 |
+
self.d_model = configs.d_model
|
| 156 |
+
self.d_ff = configs.d_ff
|
| 157 |
+
self.num_kernels = configs.num_kernels
|
| 158 |
+
self.e_layers = configs.e_layers
|
| 159 |
+
self.dropout = configs.dropout
|
| 160 |
+
self.c_out = configs.c_out
|
| 161 |
+
|
| 162 |
+
self.num_stations = getattr(configs, 'num_stations', 0)
|
| 163 |
+
|
| 164 |
+
self.enc_embedding = DataEmbedding(configs.enc_in, self.d_model, configs.embed, configs.freq,
|
| 165 |
+
configs.dropout, configs.seq_len)
|
| 166 |
+
self.model = nn.ModuleList([
|
| 167 |
+
TimesBlockStationCond(_BlockConfig(self.seq_len, 0, self.d_model, self.d_ff,
|
| 168 |
+
self.num_kernels, self.top_k, self.num_stations))
|
| 169 |
+
for _ in range(self.e_layers)
|
| 170 |
+
])
|
| 171 |
+
self.layer = self.e_layers
|
| 172 |
+
self.layer_norm = nn.LayerNorm(self.d_model)
|
| 173 |
+
self.projection = nn.Linear(self.d_model, self.c_out, bias=True)
|
| 174 |
+
|
| 175 |
+
def encode_features_for_reconstruction(self, x_enc: torch.Tensor, station_ids: torch.Tensor = None):
|
| 176 |
+
"""
|
| 177 |
+
Encode input with optional station ID conditioning.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
x_enc: (B, T, C) input signal
|
| 181 |
+
station_ids: (B,) LongTensor of station IDs (0 to num_stations-1), optional
|
| 182 |
+
"""
|
| 183 |
+
means = x_enc.mean(1, keepdim=True).detach()
|
| 184 |
+
x_norm = x_enc - means
|
| 185 |
+
stdev = torch.sqrt(torch.var(x_norm, dim=1, keepdim=True, unbiased=False) + 1e-5)
|
| 186 |
+
x_norm = x_norm / stdev
|
| 187 |
+
enc_out = self.enc_embedding(x_norm, None)
|
| 188 |
+
for i in range(self.layer):
|
| 189 |
+
enc_out = self.layer_norm(self.model[i](enc_out, station_ids))
|
| 190 |
+
return enc_out, means, stdev
|
| 191 |
+
|
| 192 |
+
def project_features_for_reconstruction(self, enc_out: torch.Tensor, means: torch.Tensor, stdev: torch.Tensor):
|
| 193 |
+
dec_out = self.projection(enc_out)
|
| 194 |
+
dec_out = dec_out * (stdev[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1))
|
| 195 |
+
dec_out = dec_out + (means[:, 0, :].unsqueeze(1).repeat(1, self.pred_len + self.seq_len, 1))
|
| 196 |
+
return dec_out
|
| 197 |
+
|
| 198 |
+
def anomaly_detection(self, x_enc: torch.Tensor, station_ids: torch.Tensor = None):
|
| 199 |
+
"""Full reconstruction pass with optional station ID conditioning."""
|
| 200 |
+
enc_out, means, stdev = self.encode_features_for_reconstruction(x_enc, station_ids)
|
| 201 |
+
return self.project_features_for_reconstruction(enc_out, means, stdev)
|
| 202 |
+
|
| 203 |
+
def forward(self, x_enc, station_ids=None, x_mark_enc=None, x_dec=None, x_mark_dec=None, mask=None):
|
| 204 |
+
"""
|
| 205 |
+
Forward pass compatible with anomaly_detection task.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
x_enc: (B, T, C) input signal
|
| 209 |
+
station_ids: (B,) LongTensor of station IDs, optional
|
| 210 |
+
"""
|
| 211 |
+
return self.anomaly_detection(x_enc, station_ids)
|
| 212 |
+
|
| 213 |
+
|
app.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib
|
| 4 |
+
matplotlib.use('Agg') # Use non-interactive backend
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import io
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
# Import model and generation functions
|
| 12 |
+
try:
|
| 13 |
+
from TimesNet_PointCloud import TimesNetPointCloud
|
| 14 |
+
from generate_samples_git import SimpleArgs, generate_samples_from_latent_bank
|
| 15 |
+
except ImportError:
|
| 16 |
+
# Fallback for local imports
|
| 17 |
+
import sys
|
| 18 |
+
sys.path.insert(0, '.')
|
| 19 |
+
from TimesNet_PointCloud import TimesNetPointCloud
|
| 20 |
+
try:
|
| 21 |
+
from generate_samples_git import SimpleArgs, generate_samples_from_latent_bank
|
| 22 |
+
except:
|
| 23 |
+
# If generate_samples_git doesn't exist, use generate_samples
|
| 24 |
+
from generate_samples import FineTuneArgs as SimpleArgs, generate_samples_from_latent_bank
|
| 25 |
+
|
| 26 |
+
def load_model(checkpoint_path, args):
|
| 27 |
+
"""Load pre-trained TimesNet-PointCloud model (matching generate_samples_git.py)."""
|
| 28 |
+
# Create model config
|
| 29 |
+
class ModelConfig:
|
| 30 |
+
def __init__(self, args):
|
| 31 |
+
self.seq_len = args.seq_len
|
| 32 |
+
self.pred_len = 0
|
| 33 |
+
self.enc_in = 3
|
| 34 |
+
self.c_out = 3
|
| 35 |
+
self.d_model = args.d_model
|
| 36 |
+
self.d_ff = args.d_ff
|
| 37 |
+
self.num_kernels = args.num_kernels
|
| 38 |
+
self.top_k = args.top_k
|
| 39 |
+
self.e_layers = args.e_layers
|
| 40 |
+
self.d_layers = args.d_layers
|
| 41 |
+
self.dropout = args.dropout
|
| 42 |
+
self.embed = 'timeF'
|
| 43 |
+
self.freq = 'h'
|
| 44 |
+
self.latent_dim = getattr(args, 'latent_dim', 256)
|
| 45 |
+
# num_stations is needed for station conditioning
|
| 46 |
+
# Try to get from checkpoint, or use default
|
| 47 |
+
self.num_stations = getattr(args, 'num_stations', 705)
|
| 48 |
+
|
| 49 |
+
config = ModelConfig(args)
|
| 50 |
+
model = TimesNetPointCloud(config)
|
| 51 |
+
|
| 52 |
+
# Load checkpoint
|
| 53 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 54 |
+
if 'model_state_dict' in checkpoint:
|
| 55 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 56 |
+
# Try to get num_stations from checkpoint
|
| 57 |
+
if 'num_stations' in checkpoint:
|
| 58 |
+
config.num_stations = checkpoint['num_stations']
|
| 59 |
+
else:
|
| 60 |
+
model.load_state_dict(checkpoint)
|
| 61 |
+
|
| 62 |
+
model.eval()
|
| 63 |
+
if args.use_gpu:
|
| 64 |
+
model = model.cuda()
|
| 65 |
+
|
| 66 |
+
print(f"[INFO] Model loaded successfully from {checkpoint_path}")
|
| 67 |
+
return model
|
| 68 |
+
|
| 69 |
+
# Configuration - can be set via environment variables for HF Space
|
| 70 |
+
PHASE1_MODEL_CHECKPOINT_PATH = os.getenv('PHASE1_MODEL_CHECKPOINT_PATH', './checkpoints/timesnet_pointcloud_phase1_final.pth')
|
| 71 |
+
LATENT_BANK_PATH = os.getenv('LATENT_BANK_PATH', './latent_bank_station_cond.npz')
|
| 72 |
+
ENCODER_STD_PATH = os.getenv('ENCODER_STD_PATH', './pcgen_stats/encoder_feature_std.npy')
|
| 73 |
+
|
| 74 |
+
# Test stations (5 unseen stations)
|
| 75 |
+
TEST_STATIONS = ['0205', '1716', '2020', '3130', '4628']
|
| 76 |
+
|
| 77 |
+
# Initialize model and args (loaded once at startup)
|
| 78 |
+
model = None
|
| 79 |
+
args = None
|
| 80 |
+
encoder_std = None
|
| 81 |
+
|
| 82 |
+
def initialize_model():
|
| 83 |
+
"""Load Phase 1 model and encoder_std once at startup."""
|
| 84 |
+
global model, args, encoder_std
|
| 85 |
+
|
| 86 |
+
if model is not None:
|
| 87 |
+
return True # Already initialized
|
| 88 |
+
|
| 89 |
+
print("[INFO] Initializing TimesNet-Gen Phase 1 model...")
|
| 90 |
+
|
| 91 |
+
# Create args (matching generate_samples_git.py)
|
| 92 |
+
args = SimpleArgs()
|
| 93 |
+
|
| 94 |
+
# Load Phase 1 model
|
| 95 |
+
if not os.path.exists(PHASE1_MODEL_CHECKPOINT_PATH):
|
| 96 |
+
print(f"[ERROR] Phase 1 model checkpoint not found: {PHASE1_MODEL_CHECKPOINT_PATH}")
|
| 97 |
+
print("[ERROR] Please set PHASE1_MODEL_CHECKPOINT_PATH environment variable")
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
model = load_model(PHASE1_MODEL_CHECKPOINT_PATH, args)
|
| 101 |
+
|
| 102 |
+
# Check if latent bank exists
|
| 103 |
+
if not os.path.exists(LATENT_BANK_PATH):
|
| 104 |
+
print(f"[ERROR] Latent bank not found: {LATENT_BANK_PATH}")
|
| 105 |
+
print("[ERROR] Please set LATENT_BANK_PATH environment variable or create latent bank first")
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
print(f"[INFO] Using latent bank: {LATENT_BANK_PATH}")
|
| 109 |
+
|
| 110 |
+
# Load encoder_std (optional, only for fine-tuning, not for generation)
|
| 111 |
+
if os.path.exists(ENCODER_STD_PATH):
|
| 112 |
+
encoder_std = np.load(ENCODER_STD_PATH)
|
| 113 |
+
print(f"[INFO] Loaded encoder_std from {ENCODER_STD_PATH} (shape: {encoder_std.shape})")
|
| 114 |
+
print(f"[INFO] encoder_std loaded (used only for fine-tuning, NOT for generation)")
|
| 115 |
+
else:
|
| 116 |
+
print(f"[INFO] No encoder_std found (not needed for generation, only for fine-tuning)")
|
| 117 |
+
encoder_std = None
|
| 118 |
+
|
| 119 |
+
print("[INFO] ✓ Phase 1 model initialized successfully!")
|
| 120 |
+
return True
|
| 121 |
+
|
| 122 |
+
# Initialize on import
|
| 123 |
+
try:
|
| 124 |
+
initialize_model()
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"[ERROR] Failed to initialize model: {e}")
|
| 127 |
+
import traceback
|
| 128 |
+
traceback.print_exc()
|
| 129 |
+
print("[WARN] App will run in dummy mode")
|
| 130 |
+
|
| 131 |
+
def generate_seismic_data(station_id_str, num_samples):
|
| 132 |
+
"""Generate seismic signals using Phase 1 model and pre-computed latent bank."""
|
| 133 |
+
global model, args, encoder_std
|
| 134 |
+
|
| 135 |
+
# Check if model is loaded
|
| 136 |
+
if model is None:
|
| 137 |
+
print("[ERROR] Model not initialized! Attempting to initialize...")
|
| 138 |
+
if not initialize_model():
|
| 139 |
+
# Fallback to dummy generation
|
| 140 |
+
print("[WARN] Using dummy generation as fallback")
|
| 141 |
+
generated_signals = np.random.randn(num_samples, 3, 6000) * 0.1 + np.sin(np.linspace(0, 100, 6000))
|
| 142 |
+
else:
|
| 143 |
+
# Retry generation with real model
|
| 144 |
+
return generate_seismic_data(station_id_str, num_samples)
|
| 145 |
+
|
| 146 |
+
# Generate samples from pre-computed latent bank (matching generate_samples_git.py)
|
| 147 |
+
try:
|
| 148 |
+
print(f"[INFO] Generating {num_samples} samples for station {station_id_str} from latent bank...")
|
| 149 |
+
|
| 150 |
+
# Note: encoder_std is passed but NOT used during generation in generate_samples_git.py
|
| 151 |
+
# (see line 313-317 in generate_samples_git.py - noise is NOT added during generation)
|
| 152 |
+
generated_signals, _ = generate_samples_from_latent_bank(
|
| 153 |
+
model, LATENT_BANK_PATH, station_id_str, num_samples, args, encoder_std
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if generated_signals is None:
|
| 157 |
+
print(f"[ERROR] Failed to generate samples for station {station_id_str}")
|
| 158 |
+
generated_signals = None
|
| 159 |
+
else:
|
| 160 |
+
print(f"[INFO] ✓ Generated {len(generated_signals)} samples successfully")
|
| 161 |
+
except Exception as e:
|
| 162 |
+
print(f"[ERROR] Generation failed: {e}")
|
| 163 |
+
import traceback
|
| 164 |
+
traceback.print_exc()
|
| 165 |
+
generated_signals = None
|
| 166 |
+
|
| 167 |
+
# Handle case where generation failed
|
| 168 |
+
if generated_signals is None:
|
| 169 |
+
# Create error message plot
|
| 170 |
+
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
|
| 171 |
+
ax.text(0.5, 0.5, f'Error: Could not generate samples\nfor station {station_id_str}',
|
| 172 |
+
ha='center', va='center', fontsize=14, transform=ax.transAxes)
|
| 173 |
+
ax.set_xticks([])
|
| 174 |
+
ax.set_yticks([])
|
| 175 |
+
plt.tight_layout()
|
| 176 |
+
else:
|
| 177 |
+
# generated_signals shape: (num_samples, 3, 6000)
|
| 178 |
+
num_plots = min(len(generated_signals), 5) # Show up to 5 samples
|
| 179 |
+
|
| 180 |
+
if num_plots == 1:
|
| 181 |
+
# Single plot
|
| 182 |
+
fig, axes = plt.subplots(3, 1, figsize=(10, 6), sharex=True)
|
| 183 |
+
channel_names = ['E-W', 'N-S', 'U-D']
|
| 184 |
+
for ch, ax in enumerate(axes):
|
| 185 |
+
ax.plot(generated_signals[0, ch, :], linewidth=0.8)
|
| 186 |
+
ax.set_ylabel(channel_names[ch], fontweight='bold')
|
| 187 |
+
ax.grid(True, alpha=0.3)
|
| 188 |
+
axes[-1].set_xlabel('Time Steps', fontweight='bold')
|
| 189 |
+
fig.suptitle(f'Generated Sample for Station {station_id_str}', fontsize=12, fontweight='bold')
|
| 190 |
+
plt.tight_layout()
|
| 191 |
+
else:
|
| 192 |
+
# Multiple plots in a grid
|
| 193 |
+
fig, axes = plt.subplots(num_plots, 3, figsize=(12, 2*num_plots), sharex=True)
|
| 194 |
+
channel_names = ['E-W', 'N-S', 'U-D']
|
| 195 |
+
|
| 196 |
+
for i in range(num_plots):
|
| 197 |
+
for ch in range(3):
|
| 198 |
+
ax = axes[i, ch] if num_plots > 1 else axes[ch]
|
| 199 |
+
ax.plot(generated_signals[i, ch, :], linewidth=0.8)
|
| 200 |
+
if i == 0:
|
| 201 |
+
ax.set_title(channel_names[ch], fontweight='bold')
|
| 202 |
+
if i == num_plots - 1:
|
| 203 |
+
ax.set_xlabel('Time Steps', fontweight='bold')
|
| 204 |
+
ax.set_ylabel('Amplitude', fontsize=9)
|
| 205 |
+
ax.grid(True, alpha=0.3)
|
| 206 |
+
|
| 207 |
+
fig.suptitle(f'Generated Samples for Station {station_id_str}', fontsize=12, fontweight='bold')
|
| 208 |
+
plt.tight_layout()
|
| 209 |
+
|
| 210 |
+
# Save to BytesIO buffer and convert to PIL Image, then numpy array
|
| 211 |
+
buf = io.BytesIO()
|
| 212 |
+
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
|
| 213 |
+
plt.close(fig)
|
| 214 |
+
buf.seek(0)
|
| 215 |
+
|
| 216 |
+
# Convert to PIL Image then numpy array (RGB format)
|
| 217 |
+
img = Image.open(buf)
|
| 218 |
+
img_array = np.array(img)
|
| 219 |
+
buf.close()
|
| 220 |
+
|
| 221 |
+
return img_array
|
| 222 |
+
|
| 223 |
+
# Gradio Interface
|
| 224 |
+
demo = gr.Interface(
|
| 225 |
+
fn=generate_seismic_data,
|
| 226 |
+
inputs=[
|
| 227 |
+
gr.Dropdown(choices=TEST_STATIONS, label="Station ID", value=TEST_STATIONS[0], info="Select one of the 5 test stations"),
|
| 228 |
+
gr.Slider(minimum=1, maximum=50, value=3, step=1, label="Number of Samples to Generate")
|
| 229 |
+
],
|
| 230 |
+
outputs=gr.Image(label="Generated Seismic Signals", type="numpy"),
|
| 231 |
+
title="TimesNet-Gen: Site-Specific Strong Motion Generation",
|
| 232 |
+
description="Generate synthetic seismic signals using Phase 1 model and pre-computed latent bank. Select a station ID and number of samples to generate. (Matching GitHub generate_samples_git.py workflow)"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
demo.launch()
|
generate_samples_git.py
ADDED
|
@@ -0,0 +1,815 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simplified inference script for TimesNet-Gen.
|
| 4 |
+
Only loads data for the 5 fine-tuned stations.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python generate_samples.py --num_samples 50
|
| 8 |
+
"""
|
| 9 |
+
import os
|
| 10 |
+
import argparse
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import glob
|
| 16 |
+
import scipy.io as sio
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SimpleArgs:
|
| 20 |
+
"""Configuration for generation."""
|
| 21 |
+
def __init__(self):
|
| 22 |
+
# Model architecture
|
| 23 |
+
self.seq_len = 6000
|
| 24 |
+
self.d_model = 128
|
| 25 |
+
self.d_ff = 256
|
| 26 |
+
self.e_layers = 2
|
| 27 |
+
self.d_layers = 2
|
| 28 |
+
self.num_kernels = 6
|
| 29 |
+
self.top_k = 2
|
| 30 |
+
self.dropout = 0.1
|
| 31 |
+
self.latent_dim = 256
|
| 32 |
+
|
| 33 |
+
# System
|
| 34 |
+
self.use_gpu = torch.cuda.is_available()
|
| 35 |
+
self.seed = 0
|
| 36 |
+
|
| 37 |
+
# Point-cloud generation
|
| 38 |
+
self.pcgen_k = 5
|
| 39 |
+
self.pcgen_jitter_std = 0.0
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _iter_np_arrays(obj):
|
| 43 |
+
"""Recursively iterate through numpy arrays in nested structures."""
|
| 44 |
+
if isinstance(obj, np.ndarray):
|
| 45 |
+
if obj.dtype == object:
|
| 46 |
+
for item in obj.flat:
|
| 47 |
+
yield from _iter_np_arrays(item)
|
| 48 |
+
else:
|
| 49 |
+
yield obj
|
| 50 |
+
elif isinstance(obj, dict):
|
| 51 |
+
for v in obj.values():
|
| 52 |
+
yield from _iter_np_arrays(v)
|
| 53 |
+
elif isinstance(obj, np.void):
|
| 54 |
+
if obj.dtype.names:
|
| 55 |
+
for name in obj.dtype.names:
|
| 56 |
+
yield from _iter_np_arrays(obj[name])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _find_3ch_from_arrays(arrays):
|
| 60 |
+
"""Find 3-channel array from list of arrays."""
|
| 61 |
+
# Prefer arrays that are 2D with a 3-channel dimension
|
| 62 |
+
for arr in arrays:
|
| 63 |
+
if isinstance(arr, np.ndarray) and arr.ndim == 2 and (arr.shape[0] == 3 or arr.shape[1] == 3):
|
| 64 |
+
return arr
|
| 65 |
+
# Otherwise, try to find three 1D arrays of same length
|
| 66 |
+
one_d = [a for a in arrays if isinstance(a, np.ndarray) and a.ndim == 1]
|
| 67 |
+
for i in range(len(one_d)):
|
| 68 |
+
for j in range(i + 1, len(one_d)):
|
| 69 |
+
for k in range(j + 1, len(one_d)):
|
| 70 |
+
if one_d[i].shape[0] == one_d[j].shape[0] == one_d[k].shape[0]:
|
| 71 |
+
return np.stack([one_d[i], one_d[j], one_d[k]], axis=0)
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def load_mat_file(filepath, seq_len=6000, debug=False):
|
| 76 |
+
"""Load and preprocess a .mat file (using data_loader_gen.py logic)."""
|
| 77 |
+
try:
|
| 78 |
+
if debug:
|
| 79 |
+
print(f"\n[DEBUG] Loading: {os.path.basename(filepath)}")
|
| 80 |
+
|
| 81 |
+
# Load with squeeze_me and struct_as_record like data_loader_gen.py
|
| 82 |
+
mat = sio.loadmat(filepath, squeeze_me=True, struct_as_record=False)
|
| 83 |
+
|
| 84 |
+
if debug:
|
| 85 |
+
print(f"[DEBUG] Keys in mat file: {[k for k in mat.keys() if not k.startswith('__')]}")
|
| 86 |
+
|
| 87 |
+
# Check if 'EQ' is a struct with nested 'anEQ' structure (like in data_loader_gen.py)
|
| 88 |
+
if 'EQ' in mat:
|
| 89 |
+
try:
|
| 90 |
+
eq_obj = mat['EQ']
|
| 91 |
+
|
| 92 |
+
if debug:
|
| 93 |
+
print(f"[DEBUG] EQ type: {type(eq_obj)}")
|
| 94 |
+
print(f"[DEBUG] EQ shape: {eq_obj.shape if hasattr(eq_obj, 'shape') else 'N/A'}")
|
| 95 |
+
|
| 96 |
+
# Since struct_as_record=False, EQ is a mat_struct object
|
| 97 |
+
# Access with attributes, not subscripts
|
| 98 |
+
if hasattr(eq_obj, 'anEQ'):
|
| 99 |
+
dataset = eq_obj.anEQ
|
| 100 |
+
if debug:
|
| 101 |
+
print(f"[DEBUG] Found anEQ, type: {type(dataset)}")
|
| 102 |
+
|
| 103 |
+
if hasattr(dataset, 'Accel'):
|
| 104 |
+
accel = dataset.Accel
|
| 105 |
+
|
| 106 |
+
if debug:
|
| 107 |
+
print(f"[DEBUG] Found Accel: type={type(accel)}, shape={accel.shape if hasattr(accel, 'shape') else 'N/A'}")
|
| 108 |
+
|
| 109 |
+
if isinstance(accel, np.ndarray):
|
| 110 |
+
# Transpose to (3, N) if needed
|
| 111 |
+
if accel.ndim == 2:
|
| 112 |
+
if accel.shape[1] == 3:
|
| 113 |
+
accel = accel.T
|
| 114 |
+
|
| 115 |
+
if accel.shape[0] == 3:
|
| 116 |
+
data = accel
|
| 117 |
+
if debug:
|
| 118 |
+
print(f"[DEBUG] ✅ Successfully extracted 3-channel data! Shape: {data.shape}")
|
| 119 |
+
|
| 120 |
+
# Resample if needed
|
| 121 |
+
if data.shape[1] != seq_len:
|
| 122 |
+
from scipy import signal as sp_signal
|
| 123 |
+
data_resampled = np.zeros((3, seq_len), dtype=np.float32)
|
| 124 |
+
for i in range(3):
|
| 125 |
+
data_resampled[i] = sp_signal.resample(data[i], seq_len)
|
| 126 |
+
data = data_resampled
|
| 127 |
+
if debug:
|
| 128 |
+
print(f"[DEBUG] Resampled to {seq_len} samples")
|
| 129 |
+
|
| 130 |
+
return torch.FloatTensor(data)
|
| 131 |
+
else:
|
| 132 |
+
if debug:
|
| 133 |
+
print(f"[DEBUG] Unexpected Accel shape[0]: {accel.shape[0]} (expected 3)")
|
| 134 |
+
else:
|
| 135 |
+
if debug:
|
| 136 |
+
print(f"[DEBUG] Accel is not 2D: ndim={accel.ndim}")
|
| 137 |
+
else:
|
| 138 |
+
if debug:
|
| 139 |
+
print(f"[DEBUG] anEQ has no 'Accel' attribute")
|
| 140 |
+
if hasattr(dataset, '__dict__'):
|
| 141 |
+
print(f"[DEBUG] anEQ attributes: {list(vars(dataset).keys())}")
|
| 142 |
+
else:
|
| 143 |
+
if debug:
|
| 144 |
+
print(f"[DEBUG] EQ has no 'anEQ' attribute")
|
| 145 |
+
if hasattr(eq_obj, '__dict__'):
|
| 146 |
+
print(f"[DEBUG] EQ attributes: {list(vars(eq_obj).keys())}")
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
if debug:
|
| 150 |
+
import traceback
|
| 151 |
+
print(f"[DEBUG] Could not parse EQ structure: {e}")
|
| 152 |
+
print(f"[DEBUG] Traceback: {traceback.format_exc()}")
|
| 153 |
+
|
| 154 |
+
arrays = list(_iter_np_arrays(mat))
|
| 155 |
+
|
| 156 |
+
if debug:
|
| 157 |
+
print(f"[DEBUG] Found {len(arrays)} arrays")
|
| 158 |
+
for i, arr in enumerate(arrays[:5]): # Show first 5
|
| 159 |
+
if isinstance(arr, np.ndarray):
|
| 160 |
+
print(f"[DEBUG] Array {i}: shape={arr.shape}, dtype={arr.dtype}")
|
| 161 |
+
|
| 162 |
+
# Common direct keys first
|
| 163 |
+
for key in ['signal', 'data', 'sig', 'x', 'X', 'signal3c', 'acc', 'NS', 'EW', 'UD']:
|
| 164 |
+
if key in mat and isinstance(mat[key], np.ndarray):
|
| 165 |
+
arrays.insert(0, mat[key])
|
| 166 |
+
if debug:
|
| 167 |
+
print(f"[DEBUG] Found key '{key}': shape={mat[key].shape}")
|
| 168 |
+
|
| 169 |
+
# Find 3-channel array
|
| 170 |
+
data = _find_3ch_from_arrays(arrays)
|
| 171 |
+
|
| 172 |
+
if data is None:
|
| 173 |
+
if debug:
|
| 174 |
+
print(f"[DEBUG] Could not find 3-channel array!")
|
| 175 |
+
return None
|
| 176 |
+
|
| 177 |
+
if debug:
|
| 178 |
+
print(f"[DEBUG] Found 3-channel data: shape={data.shape}")
|
| 179 |
+
|
| 180 |
+
# Ensure shape is (3, N)
|
| 181 |
+
if data.shape[0] != 3 and data.shape[1] == 3:
|
| 182 |
+
data = data.T
|
| 183 |
+
if debug:
|
| 184 |
+
print(f"[DEBUG] Transposed to: shape={data.shape}")
|
| 185 |
+
|
| 186 |
+
if data.shape[0] != 3:
|
| 187 |
+
if debug:
|
| 188 |
+
print(f"[DEBUG] Wrong number of channels: {data.shape[0]}")
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
# Resample to seq_len
|
| 192 |
+
if data.shape[1] != seq_len:
|
| 193 |
+
from scipy import signal as sp_signal
|
| 194 |
+
data_resampled = np.zeros((3, seq_len), dtype=np.float32)
|
| 195 |
+
for i in range(3):
|
| 196 |
+
data_resampled[i] = sp_signal.resample(data[i], seq_len)
|
| 197 |
+
data = data_resampled
|
| 198 |
+
if debug:
|
| 199 |
+
print(f"[DEBUG] Resampled to: shape={data.shape}")
|
| 200 |
+
|
| 201 |
+
if debug:
|
| 202 |
+
print(f"[DEBUG] ✅ Successfully loaded!")
|
| 203 |
+
|
| 204 |
+
return torch.FloatTensor(data)
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
if debug:
|
| 208 |
+
print(f"[DEBUG] ❌ Exception: {e}")
|
| 209 |
+
return None
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def load_model(checkpoint_path, args):
|
| 213 |
+
"""Load pre-trained TimesNet-PointCloud model."""
|
| 214 |
+
from models.TimesNet_PointCloud import TimesNetPointCloud
|
| 215 |
+
|
| 216 |
+
# Create model config
|
| 217 |
+
class ModelConfig:
|
| 218 |
+
def __init__(self, args):
|
| 219 |
+
self.seq_len = args.seq_len
|
| 220 |
+
self.pred_len = 0
|
| 221 |
+
self.enc_in = 3
|
| 222 |
+
self.c_out = 3
|
| 223 |
+
self.d_model = args.d_model
|
| 224 |
+
self.d_ff = args.d_ff
|
| 225 |
+
self.num_kernels = args.num_kernels
|
| 226 |
+
self.top_k = args.top_k
|
| 227 |
+
self.e_layers = args.e_layers
|
| 228 |
+
self.d_layers = args.d_layers
|
| 229 |
+
self.dropout = args.dropout
|
| 230 |
+
self.embed = 'timeF'
|
| 231 |
+
self.freq = 'h'
|
| 232 |
+
self.latent_dim = args.latent_dim
|
| 233 |
+
|
| 234 |
+
config = ModelConfig(args)
|
| 235 |
+
model = TimesNetPointCloud(config)
|
| 236 |
+
|
| 237 |
+
# Load checkpoint
|
| 238 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 239 |
+
if 'model_state_dict' in checkpoint:
|
| 240 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 241 |
+
else:
|
| 242 |
+
model.load_state_dict(checkpoint)
|
| 243 |
+
|
| 244 |
+
model.eval()
|
| 245 |
+
if args.use_gpu:
|
| 246 |
+
model = model.cuda()
|
| 247 |
+
|
| 248 |
+
print(f"[INFO] Model loaded successfully from {checkpoint_path}")
|
| 249 |
+
return model
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def generate_samples_from_latent_bank(model, latent_bank_path, station_id, num_samples, args, encoder_std=None):
|
| 253 |
+
"""
|
| 254 |
+
Generate samples directly from pre-computed latent bank.
|
| 255 |
+
NO REAL DATA NEEDED!
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
model: TimesNet model
|
| 259 |
+
latent_bank_path: Path to latent_bank_phase1.npz
|
| 260 |
+
station_id: Station ID (e.g., '0205')
|
| 261 |
+
num_samples: Number of samples to generate
|
| 262 |
+
args: Model arguments
|
| 263 |
+
encoder_std: Encoder std vector for noise injection
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
generated_signals: (num_samples, 3, seq_len) array
|
| 267 |
+
real_names_used: List of lists indicating which latent vectors were used
|
| 268 |
+
"""
|
| 269 |
+
print(f"[INFO] Loading latent bank from {latent_bank_path}...")
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
latent_data = np.load(latent_bank_path)
|
| 273 |
+
except Exception as e:
|
| 274 |
+
print(f"[ERROR] Could not load latent bank: {e}")
|
| 275 |
+
return None, None
|
| 276 |
+
|
| 277 |
+
# Load latent vectors for this station
|
| 278 |
+
latents_key = f'latents_{station_id}'
|
| 279 |
+
means_key = f'means_{station_id}'
|
| 280 |
+
stdev_key = f'stdev_{station_id}'
|
| 281 |
+
|
| 282 |
+
if latents_key not in latent_data:
|
| 283 |
+
print(f"[ERROR] Station {station_id} not found in latent bank!")
|
| 284 |
+
print(f"Available stations: {[k.replace('latents_', '') for k in latent_data.keys() if k.startswith('latents_')]}")
|
| 285 |
+
return None, None
|
| 286 |
+
|
| 287 |
+
latents = latent_data[latents_key] # (N_samples, seq_len, d_model)
|
| 288 |
+
means = latent_data[means_key] # (N_samples, seq_len, d_model)
|
| 289 |
+
stdevs = latent_data[stdev_key] # (N_samples, seq_len, d_model)
|
| 290 |
+
|
| 291 |
+
print(f"[INFO] Loaded {len(latents)} latent vectors for station {station_id}")
|
| 292 |
+
print(f"[INFO] Generating {num_samples} samples via bootstrap aggregation...")
|
| 293 |
+
|
| 294 |
+
generated_signals = []
|
| 295 |
+
real_names_used = []
|
| 296 |
+
|
| 297 |
+
model.eval()
|
| 298 |
+
with torch.no_grad():
|
| 299 |
+
for i in range(num_samples):
|
| 300 |
+
# Bootstrap: randomly select k latent vectors with replacement
|
| 301 |
+
k = min(args.pcgen_k, len(latents))
|
| 302 |
+
selected_indices = np.random.choice(len(latents), size=k, replace=True)
|
| 303 |
+
|
| 304 |
+
# Mix latent features (average)
|
| 305 |
+
selected_latents = latents[selected_indices] # (k, seq_len, d_model)
|
| 306 |
+
selected_means = means[selected_indices] # (k, seq_len, d_model)
|
| 307 |
+
selected_stdevs = stdevs[selected_indices] # (k, seq_len, d_model)
|
| 308 |
+
|
| 309 |
+
mixed_features = np.mean(selected_latents, axis=0) # (seq_len, d_model)
|
| 310 |
+
mixed_means = np.mean(selected_means, axis=0) # (seq_len, d_model)
|
| 311 |
+
mixed_stdevs = np.mean(selected_stdevs, axis=0) # (seq_len, d_model)
|
| 312 |
+
|
| 313 |
+
# NOTE: Do NOT add noise during generation (matching untitled1_gen.py)
|
| 314 |
+
# untitled1_gen.py only uses noise during TRAINING (Phase 1), not during generation
|
| 315 |
+
# if encoder_std is not None:
|
| 316 |
+
# noise = np.random.randn(*mixed_features.shape) * encoder_std
|
| 317 |
+
# mixed_features = mixed_features + noise
|
| 318 |
+
|
| 319 |
+
# Convert to torch tensors
|
| 320 |
+
mixed_features_torch = torch.from_numpy(mixed_features).float().unsqueeze(0) # (1, seq_len, d_model)
|
| 321 |
+
means_b = torch.from_numpy(mixed_means).float().unsqueeze(0) # (1, seq_len, d_model)
|
| 322 |
+
stdev_b = torch.from_numpy(mixed_stdevs).float().unsqueeze(0) # (1, seq_len, d_model)
|
| 323 |
+
|
| 324 |
+
if args.use_gpu:
|
| 325 |
+
mixed_features_torch = mixed_features_torch.cuda()
|
| 326 |
+
means_b = means_b.cuda()
|
| 327 |
+
stdev_b = stdev_b.cuda()
|
| 328 |
+
|
| 329 |
+
# Decode
|
| 330 |
+
xg = model.project_features_for_reconstruction(mixed_features_torch, means_b, stdev_b)
|
| 331 |
+
|
| 332 |
+
# Store - transpose to (3, 6000)
|
| 333 |
+
generated_np = xg.squeeze(0).cpu().numpy().T # (6000, 3) → (3, 6000)
|
| 334 |
+
generated_signals.append(generated_np)
|
| 335 |
+
|
| 336 |
+
# Track which latent indices were used
|
| 337 |
+
real_names_used.append([f"latent_{idx}" for idx in selected_indices])
|
| 338 |
+
|
| 339 |
+
if (i + 1) % 10 == 0:
|
| 340 |
+
print(f" Generated {i + 1}/{num_samples} samples...")
|
| 341 |
+
|
| 342 |
+
return np.array(generated_signals), real_names_used
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def _preprocess_component_boore(data: np.ndarray, fs: float, corner_freq: float, filter_order: int = 2) -> np.ndarray:
|
| 346 |
+
"""Boore (2005) style preprocessing: detrend (linear), zero-padding, high-pass Butterworth (zero-phase)."""
|
| 347 |
+
from scipy.signal import butter, filtfilt
|
| 348 |
+
x = np.asarray(data, dtype=np.float64)
|
| 349 |
+
n = x.shape[0]
|
| 350 |
+
# Linear detrend
|
| 351 |
+
t = np.arange(n, dtype=np.float64)
|
| 352 |
+
t_mean = t.mean()
|
| 353 |
+
x_mean = x.mean()
|
| 354 |
+
denom = np.sum((t - t_mean) ** 2)
|
| 355 |
+
slope = 0.0 if denom == 0 else float(np.sum((t - t_mean) * (x - x_mean)) / denom)
|
| 356 |
+
intercept = float(x_mean - slope * t_mean)
|
| 357 |
+
x_detr = x - (slope * t + intercept)
|
| 358 |
+
# Zero-padding
|
| 359 |
+
Tzpad = (1.5 * filter_order) / max(corner_freq, 1e-6)
|
| 360 |
+
pad_samples = int(round(Tzpad * fs))
|
| 361 |
+
x_pad = np.concatenate([np.zeros(pad_samples, dtype=np.float64), x_detr, np.zeros(pad_samples, dtype=np.float64)])
|
| 362 |
+
# High-pass filter (zero-phase)
|
| 363 |
+
normalized = corner_freq / (fs / 2.0)
|
| 364 |
+
normalized = min(max(normalized, 1e-6), 0.999999)
|
| 365 |
+
b, a = butter(filter_order, normalized, btype='high')
|
| 366 |
+
y = filtfilt(b, a, x_pad)
|
| 367 |
+
return y
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def _konno_ohmachi_smoothing(spectrum: np.ndarray, freq: np.ndarray, b: float = 40.0) -> np.ndarray:
|
| 371 |
+
"""Konno-Ohmachi smoothing as in MATLAB reference (O(n^2))."""
|
| 372 |
+
f = np.asarray(freq, dtype=np.float64).reshape(-1)
|
| 373 |
+
s = np.asarray(spectrum, dtype=np.float64).reshape(-1)
|
| 374 |
+
f = np.where(f == 0.0, 1e-12, f)
|
| 375 |
+
n = f.shape[0]
|
| 376 |
+
out = np.zeros_like(s)
|
| 377 |
+
for i in range(n):
|
| 378 |
+
w = np.exp(-b * (np.log(f / f[i])) ** 2)
|
| 379 |
+
w[~np.isfinite(w)] = 0.0
|
| 380 |
+
denom = np.sum(w)
|
| 381 |
+
out[i] = 0.0 if denom == 0 else float(np.sum(w * s) / denom)
|
| 382 |
+
return out
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def _compute_hvsr_simple(signal: np.ndarray, fs: float = 100.0):
|
| 386 |
+
"""Compute HVSR curve using MATLAB-style pipeline (Boore HP filter + FAS + Konno-Ohmachi)."""
|
| 387 |
+
try:
|
| 388 |
+
if signal.ndim != 2 or signal.shape[1] != 3:
|
| 389 |
+
return None, None
|
| 390 |
+
if np.any(np.isnan(signal)) or np.any(np.isinf(signal)):
|
| 391 |
+
return None, None
|
| 392 |
+
|
| 393 |
+
# Preprocess components (Boore 2005): detrend + zero-padding + high-pass (0.05 Hz)
|
| 394 |
+
ew = _preprocess_component_boore(signal[:, 0], fs, 0.05, 2)
|
| 395 |
+
ns = _preprocess_component_boore(signal[:, 1], fs, 0.05, 2)
|
| 396 |
+
ud = _preprocess_component_boore(signal[:, 2], fs, 0.05, 2)
|
| 397 |
+
|
| 398 |
+
n = int(min(len(ew), len(ns), len(ud)))
|
| 399 |
+
if n < 16:
|
| 400 |
+
return None, None
|
| 401 |
+
ew = ew[:n]; ns = ns[:n]; ud = ud[:n]
|
| 402 |
+
|
| 403 |
+
# FFT amplitudes and linear frequency grid
|
| 404 |
+
half = n // 2
|
| 405 |
+
if half <= 1:
|
| 406 |
+
return None, None
|
| 407 |
+
freq = (np.arange(0, half, dtype=np.float64)) * (fs / n)
|
| 408 |
+
amp_ew = np.abs(np.fft.fft(ew))[:half]
|
| 409 |
+
amp_ns = np.abs(np.fft.fft(ns))[:half]
|
| 410 |
+
amp_ud = np.abs(np.fft.fft(ud))[:half]
|
| 411 |
+
|
| 412 |
+
# Horizontal combination via geometric mean, then Konno-Ohmachi smoothing
|
| 413 |
+
combined_h = np.sqrt(np.maximum(amp_ew, 0.0) * np.maximum(amp_ns, 0.0))
|
| 414 |
+
sm_h = _konno_ohmachi_smoothing(combined_h, freq, 40.0)
|
| 415 |
+
sm_v = _konno_ohmachi_smoothing(amp_ud, freq, 40.0)
|
| 416 |
+
|
| 417 |
+
sm_v_safe = np.where(sm_v <= 0.0, 1e-12, sm_v)
|
| 418 |
+
sm_hvsr = sm_h / sm_v_safe
|
| 419 |
+
|
| 420 |
+
# Limit to 1-20 Hz band
|
| 421 |
+
mask = (freq >= 1.0) & (freq <= 20.0)
|
| 422 |
+
if not np.any(mask):
|
| 423 |
+
return None, None
|
| 424 |
+
return freq[mask], sm_hvsr[mask]
|
| 425 |
+
except Exception:
|
| 426 |
+
return None, None
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def save_generated_samples(generated_signals, real_names, station_id, output_dir):
|
| 430 |
+
"""Save generated samples to NPZ file with HVSR and f0 data."""
|
| 431 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 432 |
+
|
| 433 |
+
# Compute HVSR and f0 for all generated signals
|
| 434 |
+
f0_list = []
|
| 435 |
+
hvsr_curves = []
|
| 436 |
+
fs = 100.0
|
| 437 |
+
|
| 438 |
+
print(f"[INFO] Computing HVSR and f0 for {len(generated_signals)} generated samples...")
|
| 439 |
+
for idx, sig in enumerate(generated_signals):
|
| 440 |
+
# sig is (3, T), need to transpose to (T, 3)
|
| 441 |
+
sig_t = sig.T # (T, 3)
|
| 442 |
+
freq, hvsr = _compute_hvsr_simple(sig_t, fs)
|
| 443 |
+
if freq is not None and hvsr is not None:
|
| 444 |
+
hvsr_curves.append((freq, hvsr))
|
| 445 |
+
# f0 = frequency at max HVSR
|
| 446 |
+
max_idx = np.argmax(hvsr)
|
| 447 |
+
f0 = float(freq[max_idx])
|
| 448 |
+
f0_list.append(f0)
|
| 449 |
+
|
| 450 |
+
# Build median HVSR curve on a fixed frequency grid (1-20 Hz, 400 points for consistency)
|
| 451 |
+
hvsr_freq = None
|
| 452 |
+
hvsr_median = None
|
| 453 |
+
if hvsr_curves:
|
| 454 |
+
# Use a fixed frequency grid for consistency with other plots
|
| 455 |
+
hvsr_freq = np.linspace(1.0, 20.0, 400)
|
| 456 |
+
# Interpolate all curves to common grid
|
| 457 |
+
hvsr_matrix = []
|
| 458 |
+
for freq, hvsr in hvsr_curves:
|
| 459 |
+
hvsr_interp = np.interp(hvsr_freq, freq, hvsr, left=hvsr[0], right=hvsr[-1])
|
| 460 |
+
hvsr_matrix.append(hvsr_interp)
|
| 461 |
+
hvsr_median = np.median(np.vstack(hvsr_matrix), axis=0)
|
| 462 |
+
|
| 463 |
+
# Build f0 histogram (PDF)
|
| 464 |
+
f0_bins = np.linspace(1.0, 20.0, 21)
|
| 465 |
+
f0_array = np.array(f0_list)
|
| 466 |
+
f0_hist, _ = np.histogram(f0_array, bins=f0_bins)
|
| 467 |
+
f0_pdf = f0_hist.astype(float)
|
| 468 |
+
f0_sum = f0_pdf.sum()
|
| 469 |
+
if f0_sum > 0:
|
| 470 |
+
f0_pdf = f0_pdf / f0_sum
|
| 471 |
+
|
| 472 |
+
# Save timeseries NPZ with HVSR data
|
| 473 |
+
output_path = os.path.join(output_dir, f'station_{station_id}_generated_timeseries.npz')
|
| 474 |
+
np.savez_compressed(
|
| 475 |
+
output_path,
|
| 476 |
+
generated_signals=generated_signals,
|
| 477 |
+
signals_generated=generated_signals, # Alias for compatibility
|
| 478 |
+
real_names=real_names,
|
| 479 |
+
station_id=station_id,
|
| 480 |
+
station=station_id, # Alias for compatibility
|
| 481 |
+
f0_timesnet=f0_array,
|
| 482 |
+
f0_bins=f0_bins,
|
| 483 |
+
pdf_timesnet=f0_pdf,
|
| 484 |
+
hvsr_freq_timesnet=hvsr_freq if hvsr_freq is not None else np.array([]),
|
| 485 |
+
hvsr_median_timesnet=hvsr_median if hvsr_median is not None else np.array([]),
|
| 486 |
+
)
|
| 487 |
+
print(f"[INFO] Saved {len(generated_signals)} generated samples to {output_path}")
|
| 488 |
+
if len(f0_list) > 0:
|
| 489 |
+
print(f"[INFO] - f0 samples: {len(f0_list)}, median f0: {np.median(f0_array):.2f} Hz")
|
| 490 |
+
else:
|
| 491 |
+
print(f"[INFO] - No valid HVSR computed")
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def fine_tune_model(model, all_station_files, args, encoder_std, epochs=10, lr=1e-4):
|
| 495 |
+
"""
|
| 496 |
+
Fine-tune the model on 5 stations with noise injection.
|
| 497 |
+
Matches Phase 1 training in untitled1_gen.py exactly.
|
| 498 |
+
"""
|
| 499 |
+
print("\n" + "="*80)
|
| 500 |
+
print("Phase 1: Fine-Tuning with Noise Injection")
|
| 501 |
+
print("="*80)
|
| 502 |
+
|
| 503 |
+
# Prepare data loader
|
| 504 |
+
all_data = []
|
| 505 |
+
for station_id, files in all_station_files.items():
|
| 506 |
+
for fpath in files:
|
| 507 |
+
data = load_mat_file(fpath, args.seq_len, debug=False)
|
| 508 |
+
if data is not None:
|
| 509 |
+
all_data.append(data)
|
| 510 |
+
|
| 511 |
+
if len(all_data) == 0:
|
| 512 |
+
print("[WARN] No data loaded for fine-tuning!")
|
| 513 |
+
return model
|
| 514 |
+
|
| 515 |
+
print(f"[INFO] Loaded {len(all_data)} samples for fine-tuning")
|
| 516 |
+
|
| 517 |
+
# Create optimizer (matching untitled1_gen.py Phase 1)
|
| 518 |
+
batch_size = 32
|
| 519 |
+
weight_decay = 1e-4
|
| 520 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 521 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
| 522 |
+
|
| 523 |
+
# AMP scaler (matching untitled1_gen.py)
|
| 524 |
+
scaler = torch.cuda.amp.GradScaler(enabled=(args.use_gpu))
|
| 525 |
+
|
| 526 |
+
# Gradient clipping (matching untitled1_gen.py)
|
| 527 |
+
grad_clip = 1.0
|
| 528 |
+
|
| 529 |
+
train_losses_p1 = []
|
| 530 |
+
|
| 531 |
+
for epoch in range(epochs):
|
| 532 |
+
model.train()
|
| 533 |
+
total_loss = 0.0
|
| 534 |
+
total_rec = 0.0
|
| 535 |
+
num_batches = 0
|
| 536 |
+
|
| 537 |
+
# Shuffle data
|
| 538 |
+
np.random.shuffle(all_data)
|
| 539 |
+
|
| 540 |
+
for i in range(0, len(all_data), batch_size):
|
| 541 |
+
batch = all_data[i:i+batch_size]
|
| 542 |
+
if len(batch) == 0:
|
| 543 |
+
continue
|
| 544 |
+
|
| 545 |
+
# Stack batch
|
| 546 |
+
x_list = []
|
| 547 |
+
for sig in batch:
|
| 548 |
+
# sig is (3, 6000), transpose to (6000, 3)
|
| 549 |
+
x_list.append(sig.transpose(0, 1))
|
| 550 |
+
|
| 551 |
+
x = torch.stack(x_list, dim=0) # (batch, 6000, 3)
|
| 552 |
+
if args.use_gpu:
|
| 553 |
+
x = x.cuda()
|
| 554 |
+
|
| 555 |
+
# Zero gradients (matching untitled1_gen.py)
|
| 556 |
+
optimizer.zero_grad(set_to_none=True)
|
| 557 |
+
|
| 558 |
+
# Forward with AMP and noise injection (matching untitled1_gen.py Phase 1)
|
| 559 |
+
with torch.cuda.amp.autocast(enabled=(args.use_gpu)):
|
| 560 |
+
enc_out, means_b, stdev_b = model.encode_features_for_reconstruction(x)
|
| 561 |
+
|
| 562 |
+
# Add noise if encoder_std available (matching untitled1_gen.py line 945-948)
|
| 563 |
+
if encoder_std is not None:
|
| 564 |
+
std_vec = torch.from_numpy(encoder_std).to(enc_out.device).float()
|
| 565 |
+
noise = torch.randn_like(enc_out) * std_vec.view(1, 1, -1) * 1.0 # noise_std_scale=1.0
|
| 566 |
+
enc_out = enc_out + noise
|
| 567 |
+
|
| 568 |
+
# Decode
|
| 569 |
+
x_hat = model.project_features_for_reconstruction(enc_out, means_b, stdev_b)
|
| 570 |
+
|
| 571 |
+
# Reconstruction loss (MSE, matching untitled1_gen.py)
|
| 572 |
+
loss_rec = torch.nn.functional.mse_loss(x_hat, x)
|
| 573 |
+
loss = loss_rec
|
| 574 |
+
|
| 575 |
+
# Backward with gradient scaling (matching untitled1_gen.py)
|
| 576 |
+
scaler.scale(loss).backward()
|
| 577 |
+
|
| 578 |
+
# Gradient clipping (matching untitled1_gen.py)
|
| 579 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)
|
| 580 |
+
|
| 581 |
+
# Optimizer step with scaler (matching untitled1_gen.py)
|
| 582 |
+
scaler.step(optimizer)
|
| 583 |
+
scaler.update()
|
| 584 |
+
|
| 585 |
+
total_loss += float(loss.detach().cpu())
|
| 586 |
+
total_rec += float(loss_rec.detach().cpu())
|
| 587 |
+
num_batches += 1
|
| 588 |
+
|
| 589 |
+
# Scheduler step (matching untitled1_gen.py)
|
| 590 |
+
scheduler.step()
|
| 591 |
+
|
| 592 |
+
avg_loss = total_loss / max(1, num_batches)
|
| 593 |
+
avg_rec = total_rec / max(1, num_batches)
|
| 594 |
+
train_losses_p1.append(avg_loss)
|
| 595 |
+
print(f"[P1] epoch {epoch+1}/{epochs} loss={avg_loss:.4f} (rec={avg_rec:.4f})")
|
| 596 |
+
|
| 597 |
+
print("[INFO] Phase 1 fine-tuning complete!")
|
| 598 |
+
|
| 599 |
+
# Save fine-tuned model (matching untitled1_gen.py Phase 1 checkpoint)
|
| 600 |
+
checkpoint_dir = './checkpoints'
|
| 601 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 602 |
+
fine_tuned_path = os.path.join(checkpoint_dir, 'timesnet_pointcloud_phase1_finetuned.pth')
|
| 603 |
+
torch.save({
|
| 604 |
+
'epoch': epochs,
|
| 605 |
+
'model_state_dict': model.state_dict(),
|
| 606 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 607 |
+
'train_losses_phase1': train_losses_p1,
|
| 608 |
+
'phase': 'phase1'
|
| 609 |
+
}, fine_tuned_path)
|
| 610 |
+
print(f"[INFO] ✓ Fine-tuned model saved to: {fine_tuned_path}")
|
| 611 |
+
|
| 612 |
+
return model
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def plot_sample_preview(generated_signals, station_id, output_dir, num_preview=2):
|
| 616 |
+
"""Create preview plots."""
|
| 617 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 618 |
+
|
| 619 |
+
for i in range(min(num_preview, len(generated_signals))):
|
| 620 |
+
fig, axes = plt.subplots(3, 1, figsize=(12, 8))
|
| 621 |
+
signal = generated_signals[i]
|
| 622 |
+
|
| 623 |
+
channel_names = ['E-W', 'N-S', 'U-D']
|
| 624 |
+
for ch, (ax, name) in enumerate(zip(axes, channel_names)):
|
| 625 |
+
ax.plot(signal[ch], linewidth=0.8)
|
| 626 |
+
ax.set_ylabel(f'{name}\nAmplitude', fontsize=10, fontweight='bold')
|
| 627 |
+
ax.grid(True, alpha=0.3)
|
| 628 |
+
|
| 629 |
+
axes[-1].set_xlabel('Time Steps', fontsize=10, fontweight='bold')
|
| 630 |
+
fig.suptitle(f'Generated Sample - Station {station_id}', fontsize=12, fontweight='bold')
|
| 631 |
+
plt.tight_layout()
|
| 632 |
+
|
| 633 |
+
output_path = os.path.join(output_dir, f'station_{station_id}_preview_{i}.png')
|
| 634 |
+
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
| 635 |
+
plt.close()
|
| 636 |
+
|
| 637 |
+
print(f"[INFO] Saved {min(num_preview, len(generated_signals))} preview plots to {output_dir}")
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def main():
|
| 641 |
+
parser = argparse.ArgumentParser(description='Generate seismic samples (simplified version)')
|
| 642 |
+
parser.add_argument('--checkpoint', type=str,
|
| 643 |
+
default=r'D:\Baris\codes\Time-Series-Library-main\checkpoints\timesnet_pointcloud_phase1_final.pth',
|
| 644 |
+
help='Path to pre-trained model checkpoint')
|
| 645 |
+
parser.add_argument('--latent_bank', type=str,
|
| 646 |
+
default=r'D:\Baris\codes\Time-Series-Library-main\checkpoints\latent_bank_phase1.npz',
|
| 647 |
+
help='Path to latent bank NPZ file')
|
| 648 |
+
parser.add_argument('--num_samples', type=int, default=50,
|
| 649 |
+
help='Number of samples to generate per station')
|
| 650 |
+
parser.add_argument('--output_dir', type=str, default='./generated_samples',
|
| 651 |
+
help='Output directory')
|
| 652 |
+
parser.add_argument('--num_preview', type=int, default=2,
|
| 653 |
+
help='Number of preview plots per station')
|
| 654 |
+
parser.add_argument('--stations', type=str, nargs='+', default=['0205', '1716', '2020', '3130', '4628'],
|
| 655 |
+
help='Target station IDs')
|
| 656 |
+
parser.add_argument('--data_root', type=str, default=r"D:\Baris\5stats/",
|
| 657 |
+
help='Root path to seismic data (only needed if --fine_tune is used)')
|
| 658 |
+
parser.add_argument('--fine_tune', action='store_true',
|
| 659 |
+
help='Fine-tune the model before generation (use with Phase 0 checkpoint)')
|
| 660 |
+
parser.add_argument('--fine_tune_epochs', type=int, default=10,
|
| 661 |
+
help='Number of fine-tuning epochs')
|
| 662 |
+
parser.add_argument('--fine_tune_lr', type=float, default=1e-4,
|
| 663 |
+
help='Learning rate for fine-tuning')
|
| 664 |
+
|
| 665 |
+
args_cli = parser.parse_args()
|
| 666 |
+
|
| 667 |
+
# Check checkpoint
|
| 668 |
+
if not os.path.exists(args_cli.checkpoint):
|
| 669 |
+
print(f"\n{'='*80}")
|
| 670 |
+
print(f"❌ ERROR: Checkpoint not found!")
|
| 671 |
+
print(f"{'='*80}")
|
| 672 |
+
print(f"\nLooking for: {args_cli.checkpoint}")
|
| 673 |
+
return
|
| 674 |
+
|
| 675 |
+
# Create configuration
|
| 676 |
+
args = SimpleArgs()
|
| 677 |
+
|
| 678 |
+
print("="*80)
|
| 679 |
+
print("TimesNet-Gen Sample Generation (Simplified)")
|
| 680 |
+
print("="*80)
|
| 681 |
+
print(f"Checkpoint: {args_cli.checkpoint}")
|
| 682 |
+
print(f"Target stations: {args_cli.stations}")
|
| 683 |
+
print(f"Samples per station: {args_cli.num_samples}")
|
| 684 |
+
print(f"Output directory: {args_cli.output_dir}")
|
| 685 |
+
print("="*80)
|
| 686 |
+
|
| 687 |
+
# Set random seed
|
| 688 |
+
torch.manual_seed(args.seed)
|
| 689 |
+
np.random.seed(args.seed)
|
| 690 |
+
|
| 691 |
+
# Load model
|
| 692 |
+
model = load_model(args_cli.checkpoint, args)
|
| 693 |
+
|
| 694 |
+
# Try to load encoder_std from Phase 0 (only needed if fine-tuning)
|
| 695 |
+
encoder_std_path = './pcgen_stats/encoder_feature_std.npy'
|
| 696 |
+
encoder_std = None
|
| 697 |
+
if os.path.exists(encoder_std_path):
|
| 698 |
+
encoder_std = np.load(encoder_std_path)
|
| 699 |
+
print(f"[INFO] Loaded encoder_std from {encoder_std_path} (shape: {encoder_std.shape})")
|
| 700 |
+
print(f"[INFO] encoder_std loaded (used only for fine-tuning, NOT for generation)")
|
| 701 |
+
else:
|
| 702 |
+
print(f"[INFO] No encoder_std found (not needed for generation, only for fine-tuning)")
|
| 703 |
+
|
| 704 |
+
# Check if latent bank exists
|
| 705 |
+
if not os.path.exists(args_cli.latent_bank):
|
| 706 |
+
print(f"\n❌ ERROR: Latent bank not found!")
|
| 707 |
+
print(f"Looking for: {args_cli.latent_bank}")
|
| 708 |
+
print(f"\nPlease run untitled1_gen.py first to generate the latent bank.")
|
| 709 |
+
return
|
| 710 |
+
|
| 711 |
+
print(f"[INFO] Using latent bank: {args_cli.latent_bank}")
|
| 712 |
+
|
| 713 |
+
# Fine-tune if requested (requires real data)
|
| 714 |
+
if args_cli.fine_tune:
|
| 715 |
+
print("\n[INFO] Fine-tuning enabled! Loading real data...")
|
| 716 |
+
|
| 717 |
+
all_station_files = {}
|
| 718 |
+
for station_id in args_cli.stations:
|
| 719 |
+
# Find all .mat files for this station
|
| 720 |
+
pattern = os.path.join(args_cli.data_root, f"*{station_id}*.mat")
|
| 721 |
+
station_files = glob.glob(pattern)
|
| 722 |
+
|
| 723 |
+
if len(station_files) == 0:
|
| 724 |
+
print(f"[WARN] No files found for station {station_id}")
|
| 725 |
+
else:
|
| 726 |
+
print(f"[INFO] Found {len(station_files)} files for station {station_id}")
|
| 727 |
+
all_station_files[station_id] = station_files
|
| 728 |
+
|
| 729 |
+
if len(all_station_files) == 0:
|
| 730 |
+
print(f"\n❌ ERROR: No data files found in {args_cli.data_root}")
|
| 731 |
+
return
|
| 732 |
+
|
| 733 |
+
model = fine_tune_model(model, all_station_files, args, encoder_std,
|
| 734 |
+
epochs=args_cli.fine_tune_epochs,
|
| 735 |
+
lr=args_cli.fine_tune_lr)
|
| 736 |
+
|
| 737 |
+
# Create output directories
|
| 738 |
+
npz_output_dir = os.path.join(args_cli.output_dir, 'generated_timeseries_npz')
|
| 739 |
+
plot_output_dir = os.path.join(args_cli.output_dir, 'preview_plots')
|
| 740 |
+
|
| 741 |
+
# Generate samples for each station (from latent bank)
|
| 742 |
+
print("\n[INFO] Generating samples from latent bank...")
|
| 743 |
+
for station_id in args_cli.stations:
|
| 744 |
+
print(f"\n{'='*60}")
|
| 745 |
+
print(f"Processing Station: {station_id}")
|
| 746 |
+
print(f"{'='*60}")
|
| 747 |
+
|
| 748 |
+
generated_signals, real_names = generate_samples_from_latent_bank(
|
| 749 |
+
model, args_cli.latent_bank, station_id, args_cli.num_samples, args, encoder_std
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
if generated_signals is not None:
|
| 753 |
+
# Save to NPZ
|
| 754 |
+
save_generated_samples(generated_signals, real_names, station_id, npz_output_dir)
|
| 755 |
+
|
| 756 |
+
# Create preview plots
|
| 757 |
+
plot_sample_preview(generated_signals, station_id, plot_output_dir, args_cli.num_preview)
|
| 758 |
+
|
| 759 |
+
print("\n" + "="*80)
|
| 760 |
+
print("Generation Complete!")
|
| 761 |
+
print("="*80)
|
| 762 |
+
print(f"Generated samples saved to: {npz_output_dir}")
|
| 763 |
+
print(f"Preview plots saved to: {plot_output_dir}")
|
| 764 |
+
|
| 765 |
+
# Debug: Show how many samples were generated per station
|
| 766 |
+
print("\n[DEBUG] Generated samples per station:")
|
| 767 |
+
for station_id in args_cli.stations:
|
| 768 |
+
npz_path = os.path.join(npz_output_dir, f'station_{station_id}_generated_timeseries.npz')
|
| 769 |
+
if os.path.exists(npz_path):
|
| 770 |
+
try:
|
| 771 |
+
data = np.load(npz_path, allow_pickle=True)
|
| 772 |
+
if 'signals_generated' in data:
|
| 773 |
+
n_samples = data['signals_generated'].shape[0]
|
| 774 |
+
print(f" Station {station_id}: {n_samples} samples")
|
| 775 |
+
except Exception as e:
|
| 776 |
+
print(f" Station {station_id}: Error loading NPZ - {e}")
|
| 777 |
+
print("="*80)
|
| 778 |
+
|
| 779 |
+
# Create HVSR comparison plots (import plot_combined_hvsr_all_sources and call main)
|
| 780 |
+
print("\n[INFO] Creating HVSR comparison plots (matrices, HVSR curves, f0 distributions)...")
|
| 781 |
+
print("[INFO] Only plotting TimesNet-Gen vs Real (no Recon/VAE)")
|
| 782 |
+
try:
|
| 783 |
+
import sys
|
| 784 |
+
# Import the plotting module
|
| 785 |
+
import plot_combined_hvsr_all_sources as hvsr_plotter
|
| 786 |
+
|
| 787 |
+
# Override sys.argv to pass arguments to the plotter
|
| 788 |
+
# Only provide gen_dir and gen_ts_dir, explicitly disable others with empty strings
|
| 789 |
+
original_argv = sys.argv
|
| 790 |
+
sys.argv = [
|
| 791 |
+
'plot_combined_hvsr_all_sources.py',
|
| 792 |
+
'--gen_dir', npz_output_dir, # Use our generated NPZs as gen_dir (they now have HVSR/f0 data)
|
| 793 |
+
'--gen_ts_dir', npz_output_dir, # Also use for timeseries plots
|
| 794 |
+
'--out', os.path.join(args_cli.output_dir, 'hvsr_analysis'),
|
| 795 |
+
'--recon_dir', '', # Explicitly empty to disable auto-default
|
| 796 |
+
'--vae_dir', '', # Explicitly empty to disable auto-default
|
| 797 |
+
'--vae_gen_dir', '', # Explicitly empty to disable auto-default
|
| 798 |
+
]
|
| 799 |
+
|
| 800 |
+
# Call the main plotting function
|
| 801 |
+
hvsr_plotter.main()
|
| 802 |
+
|
| 803 |
+
# Restore original argv
|
| 804 |
+
sys.argv = original_argv
|
| 805 |
+
|
| 806 |
+
print(f"[INFO] ✅ HVSR analysis complete! Plots saved to: {os.path.join(args_cli.output_dir, 'hvsr_analysis')}")
|
| 807 |
+
except Exception as e:
|
| 808 |
+
import traceback
|
| 809 |
+
print(f"[WARN] Could not create HVSR plots: {e}")
|
| 810 |
+
traceback.print_exc()
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
if __name__ == "__main__":
|
| 814 |
+
main()
|
| 815 |
+
|