| import datetime |
| import itertools |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, DistributedSampler |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| import random |
| import numpy as np |
| from typing import Tuple, List, Dict, Any, Union, Optional |
| from dataclasses import dataclass |
|
|
| from .dataset import ChatTSTimeRCDPretrainDataset |
| from .ts_encoder_bi_bias import TimeSeriesEncoder |
| from .time_rcd_config import TimeRCDConfig, default_config |
|
|
| import warnings |
|
|
| warnings.filterwarnings("ignore") |
|
|
| @dataclass |
| class PretrainBatch: |
| """Batch structure for pretraining tasks.""" |
| time_series: torch.Tensor |
| labels: torch.Tensor |
| masked_time_series: torch.Tensor |
| mask_indices: torch.Tensor |
|
|
|
|
| class TimeSeriesPretrainModel(nn.Module): |
| """Model for time series pretraining with masked reconstruction and anomaly detection.""" |
|
|
| def __init__(self, config: TimeRCDConfig): |
| super().__init__() |
| self.config = config |
|
|
| |
| ts_config = config.ts_config |
| self.ts_encoder = TimeSeriesEncoder( |
| d_model=ts_config.d_model, |
| d_proj=ts_config.d_proj, |
| patch_size=ts_config.patch_size, |
| num_layers=ts_config.num_layers, |
| num_heads=ts_config.num_heads, |
| d_ff_dropout=ts_config.d_ff_dropout, |
| use_rope=ts_config.use_rope, |
| num_features=ts_config.num_features, |
| activation=ts_config.activation |
| ) |
|
|
| |
| self.reconstruction_head = nn.Sequential( |
| nn.Linear(config.ts_config.d_proj, config.ts_config.d_proj * 4), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.ts_config.d_proj * 4, config.ts_config.d_proj * 4), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.ts_config.d_proj * 4, 1) |
| ) |
|
|
| |
| self.anomaly_head = nn.Sequential( |
| nn.Linear(config.ts_config.d_proj, config.ts_config.d_proj // 2), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.ts_config.d_proj // 2, 2) |
| ) |
|
|
| def forward(self, time_series: torch.Tensor, mask: Optional[torch.Tensor] = None): |
| """Forward pass through the encoder.""" |
| local_embeddings = self.ts_encoder(time_series, mask) |
| return local_embeddings |
|
|
| def masked_reconstruction_loss(self, |
| local_embeddings: torch.Tensor, |
| original_time_series: torch.Tensor, |
| mask: torch.Tensor |
| ) -> torch.Tensor: |
| """Compute masked reconstruction loss.""" |
| batch_size, seq_len, num_features = original_time_series.shape |
| patch_size = self.config.ts_config.patch_size |
|
|
| mask = mask.bool() |
|
|
| |
| reconstructed = self.reconstruction_head(local_embeddings) |
| reconstructed = reconstructed.view(batch_size, seq_len, num_features) |
|
|
| mask_expanded = mask.unsqueeze(-1).expand(-1, -1, num_features) |
| reconstruction_loss = F.mse_loss( |
| reconstructed[mask_expanded], |
| original_time_series[mask_expanded] |
| ) |
| return reconstruction_loss |
|
|
| def anomaly_detection_loss(self, |
| local_embeddings: torch.Tensor, |
| labels: torch.Tensor) -> torch.Tensor: |
| """Compute anomaly detection loss for each timestep.""" |
| |
| logits = self.anomaly_head(local_embeddings) |
| logits = torch.mean(logits, dim=-2) |
|
|
| |
| batch_size, seq_len, _ = logits.shape |
| logits = logits.view(-1, 2) |
| labels = labels.view(-1) |
| labels = (labels > 0.5).long() |
| |
| valid_mask = (labels != -1) |
|
|
| |
| if valid_mask.sum() > 0: |
| anomaly_loss = F.cross_entropy( |
| logits[valid_mask], |
| labels[valid_mask] |
| ) |
| else: |
| anomaly_loss = torch.tensor(0.0, device=logits.device) |
|
|
| return anomaly_loss |
|
|
|
|
| def create_random_mask(time_series: torch.Tensor, |
| attention_mask: torch.Tensor, |
| mask_ratio: float = 0.15) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Create random mask for time series patches, only masking valid sequence parts.""" |
| batch_size, seq_len, num_features = time_series.shape |
| patch_size = default_config.ts_config.patch_size |
|
|
| mask = torch.zeros(batch_size, seq_len) |
|
|
| for i in range(batch_size): |
| |
| valid_length = attention_mask[i].sum().item() |
|
|
| |
| num_valid_patches = (valid_length - 1) // patch_size + 1 |
| num_masked = int(num_valid_patches * mask_ratio) |
|
|
| if num_masked > 0: |
| |
| masked_patches = torch.randperm(num_valid_patches)[:num_masked] |
| for j in masked_patches: |
| start_idx = j * patch_size |
| end_idx = min((j + 1) * patch_size, valid_length) |
| mask[i, start_idx:end_idx] = 1 |
|
|
| |
| masked_time_series = time_series.clone() |
| mask_indices = mask.bool() & attention_mask |
| mask_expanded = mask_indices.unsqueeze(-1).expand(-1, -1, num_features) |
| masked_time_series[mask_expanded] = torch.randn_like(masked_time_series[mask_expanded]) * 0.1 |
|
|
| |
| mask = mask * attention_mask.float() |
|
|
| return masked_time_series, mask |
|
|
|
|
| def collate_fn(batch): |
| """Collate function for pretraining dataset.""" |
| time_series_list, normal_time_series_list, labels_list, attribute_list = zip(*batch) |
|
|
| |
| if time_series_list[0].ndim == 1: |
| time_series_tensors = [ts.unsqueeze(-1) for ts in time_series_list] |
| normal_time_series_tensors = [nts.unsqueeze(-1) for nts in normal_time_series_list] |
| else: |
| time_series_tensors = [ts for ts in time_series_list] |
| normal_time_series_tensors = [nts for nts in normal_time_series_list] |
|
|
| |
| concatenated = torch.cat(time_series_tensors, dim=0) |
| mean = concatenated.mean(dim=0, keepdim=True) |
| std = concatenated.std(dim=0, keepdim=True) |
| std = std + 1e-4 |
| time_series_tensors_std = [(ts - mean) / std for ts in time_series_tensors] |
| normal_time_series_tensors_std = [(nts - mean) / std for nts in normal_time_series_tensors] |
| time_series_tensors = time_series_tensors_std |
| normal_time_series_tensors = normal_time_series_tensors_std |
|
|
| |
| labels = [label for label in labels_list] |
| |
| padded_time_series = torch.nn.utils.rnn.pad_sequence( |
| time_series_tensors, batch_first=True, padding_value=0.0 |
| ) |
| padded_normal_time_series = torch.nn.utils.rnn.pad_sequence( |
| normal_time_series_tensors, batch_first=True, padding_value=0.0 |
| ) |
| padded_labels = torch.nn.utils.rnn.pad_sequence( |
| labels, batch_first=True, padding_value=-1 |
| ) |
|
|
| sequence_lengths = [ts.size(0) for ts in time_series_tensors] |
| B, max_seq_len, num_features = padded_time_series.shape |
| attention_mask = torch.zeros(B, max_seq_len, dtype=torch.bool) |
| for i, length in enumerate(sequence_lengths): |
| attention_mask[i, :length] = True |
|
|
| |
| masked_time_series, mask = create_random_mask(padded_time_series, attention_mask) |
|
|
| return { |
| 'time_series': padded_time_series, |
| 'normal_time_series': padded_normal_time_series, |
| 'masked_time_series': masked_time_series, |
| 'mask': mask, |
| 'labels': padded_labels, |
| 'attention_mask': attention_mask, |
| 'attribute': attribute_list |
| } |
|
|
| def test_collate_fn(batch): |
| """Collate function for pretraining dataset.""" |
| |
| time_series_list, mask_list = zip(*batch) |
|
|
| |
| |
| batched_time_series = torch.stack(time_series_list, dim=0) |
| print(f"batched_time_series shape: {batched_time_series.shape}") |
| |
| batched_mask = torch.stack(mask_list, dim=0) |
| print(f"batched_mask shape: {batched_mask.shape}") |
|
|
| return { |
| 'time_series': batched_time_series, |
| 'attention_mask': batched_mask, |
| } |