import datetime import itertools import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, DistributedSampler import torch.nn.functional as F import torch.distributed as dist import torch.multiprocessing as mp from torch.nn.parallel import DistributedDataParallel as DDP import random import numpy as np from typing import Tuple, List, Dict, Any, Union, Optional from dataclasses import dataclass from .dataset import ChatTSTimeRCDPretrainDataset from .ts_encoder_bi_bias import TimeSeriesEncoder from .time_rcd_config import TimeRCDConfig, default_config import warnings warnings.filterwarnings("ignore") @dataclass class PretrainBatch: """Batch structure for pretraining tasks.""" time_series: torch.Tensor labels: torch.Tensor masked_time_series: torch.Tensor mask_indices: torch.Tensor class TimeSeriesPretrainModel(nn.Module): """Model for time series pretraining with masked reconstruction and anomaly detection.""" def __init__(self, config: TimeRCDConfig): super().__init__() self.config = config # Extract TimeSeriesEncoder parameters from config ts_config = config.ts_config self.ts_encoder = TimeSeriesEncoder( d_model=ts_config.d_model, d_proj=ts_config.d_proj, patch_size=ts_config.patch_size, num_layers=ts_config.num_layers, num_heads=ts_config.num_heads, d_ff_dropout=ts_config.d_ff_dropout, use_rope=ts_config.use_rope, num_features=ts_config.num_features, activation=ts_config.activation ) # Masked reconstruction head self.reconstruction_head = nn.Sequential( nn.Linear(config.ts_config.d_proj, config.ts_config.d_proj * 4), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.ts_config.d_proj * 4, config.ts_config.d_proj * 4), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.ts_config.d_proj * 4, 1) # (B, seq_len, num_features, 1) ) # Anomaly detection head self.anomaly_head = nn.Sequential( nn.Linear(config.ts_config.d_proj, config.ts_config.d_proj // 2), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.ts_config.d_proj // 2, 2) # (B, seq_len, num_features, 2) for binary classification ) def forward(self, time_series: torch.Tensor, mask: Optional[torch.Tensor] = None): """Forward pass through the encoder.""" local_embeddings = self.ts_encoder(time_series, mask) return local_embeddings def masked_reconstruction_loss(self, local_embeddings: torch.Tensor, # (B, seq_len, num_features, d_proj) original_time_series: torch.Tensor, # (B, seq_len, num_features), mask: torch.Tensor # (B, seq_len) ) -> torch.Tensor: """Compute masked reconstruction loss.""" batch_size, seq_len, num_features = original_time_series.shape patch_size = self.config.ts_config.patch_size mask = mask.bool() # local_embeddings: [B, seq_len, num_features, d_proj] reconstructed = self.reconstruction_head(local_embeddings) # (B, seq_len, num_features, 1) reconstructed = reconstructed.view(batch_size, seq_len, num_features) mask_expanded = mask.unsqueeze(-1).expand(-1, -1, num_features) # (B, seq_len, num_features) reconstruction_loss = F.mse_loss( reconstructed[mask_expanded], original_time_series[mask_expanded] ) return reconstruction_loss def anomaly_detection_loss(self, local_embeddings: torch.Tensor, # (B, seq_len, num_features, d_proj) labels: torch.Tensor) -> torch.Tensor: # (B, seq_len) """Compute anomaly detection loss for each timestep.""" # Project local embeddings to anomaly scores logits = self.anomaly_head(local_embeddings) # (B, seq_len, num_features, 2) logits = torch.mean(logits, dim=-2) # Average over num_features to get (B, seq_len, 2) # Reshape for loss computation batch_size, seq_len, _ = logits.shape logits = logits.view(-1, 2) # (B*seq_len, 2) labels = labels.view(-1) # (B*seq_len) labels = (labels > 0.5).long() # Create mask for valid labels (not padding) valid_mask = (labels != -1) # Compute loss only on valid timesteps if valid_mask.sum() > 0: anomaly_loss = F.cross_entropy( logits[valid_mask], labels[valid_mask] ) else: anomaly_loss = torch.tensor(0.0, device=logits.device) return anomaly_loss def create_random_mask(time_series: torch.Tensor, # (B, max_seq_len, num_features) attention_mask: torch.Tensor, # (B, max_seq_len) mask_ratio: float = 0.15) -> Tuple[torch.Tensor, torch.Tensor]: """Create random mask for time series patches, only masking valid sequence parts.""" batch_size, seq_len, num_features = time_series.shape patch_size = default_config.ts_config.patch_size mask = torch.zeros(batch_size, seq_len) # (B, max_seq_len) for i in range(batch_size): # Get valid sequence length for this sample valid_length = attention_mask[i].sum().item() # Calculate number of patches in valid sequence num_valid_patches = (valid_length - 1) // patch_size + 1 num_masked = int(num_valid_patches * mask_ratio) if num_masked > 0: # Only select patches from valid sequence masked_patches = torch.randperm(num_valid_patches)[:num_masked] for j in masked_patches: start_idx = j * patch_size end_idx = min((j + 1) * patch_size, valid_length) # Don't exceed valid length mask[i, start_idx:end_idx] = 1 # Create masked time series - only mask valid parts masked_time_series = time_series.clone() mask_indices = mask.bool() & attention_mask # Only mask where both mask and attention_mask are True mask_expanded = mask_indices.unsqueeze(-1).expand(-1, -1, num_features) # (B, max_seq_len, num_features) masked_time_series[mask_expanded] = torch.randn_like(masked_time_series[mask_expanded]) * 0.1 # Update mask to only include valid parts mask = mask * attention_mask.float() return masked_time_series, mask # (B, max_seq_len, num_features), (B, max_seq_len) def collate_fn(batch): """Collate function for pretraining dataset.""" time_series_list, normal_time_series_list, labels_list, attribute_list = zip(*batch) # Convert to tensors and pad sequences if time_series_list[0].ndim == 1: time_series_tensors = [ts.unsqueeze(-1) for ts in time_series_list] # Add feature dimension normal_time_series_tensors = [nts.unsqueeze(-1) for nts in normal_time_series_list] else: time_series_tensors = [ts for ts in time_series_list] normal_time_series_tensors = [nts for nts in normal_time_series_list] # standardize time series concatenated = torch.cat(time_series_tensors, dim=0) # (total_length, num_features) mean = concatenated.mean(dim=0, keepdim=True) # (1, num_features) std = concatenated.std(dim=0, keepdim=True) # (1, num_features) std = std + 1e-4 time_series_tensors_std = [(ts - mean) / std for ts in time_series_tensors] normal_time_series_tensors_std = [(nts - mean) / std for nts in normal_time_series_tensors] time_series_tensors = time_series_tensors_std normal_time_series_tensors = normal_time_series_tensors_std # labels_tensor = torch.stack(labels_list) labels = [label for label in labels_list] # Pad time series to same length padded_time_series = torch.nn.utils.rnn.pad_sequence( time_series_tensors, batch_first=True, padding_value=0.0 ) # (B, max_seq_len, num_features) padded_normal_time_series = torch.nn.utils.rnn.pad_sequence( normal_time_series_tensors, batch_first=True, padding_value=0.0 ) # (B, max_seq_len, num_features) padded_labels = torch.nn.utils.rnn.pad_sequence( labels, batch_first=True, padding_value=-1 ) # (B, max_seq_len) sequence_lengths = [ts.size(0) for ts in time_series_tensors] B, max_seq_len, num_features = padded_time_series.shape attention_mask = torch.zeros(B, max_seq_len, dtype=torch.bool) # (B, max_seq_len) for i, length in enumerate(sequence_lengths): attention_mask[i, :length] = True # Create random masks for reconstruction task - only mask valid sequence parts masked_time_series, mask = create_random_mask(padded_time_series, attention_mask) return { 'time_series': padded_time_series, 'normal_time_series': padded_normal_time_series, 'masked_time_series': masked_time_series, 'mask': mask, # for reconstruction task 'labels': padded_labels, 'attention_mask': attention_mask, # for padding 'attribute': attribute_list } def test_collate_fn(batch): """Collate function for pretraining dataset.""" # Unpack the batch correctly - batch is a list of (time_series, mask) tuples time_series_list, mask_list = zip(*batch) # Stack into batch format instead of concatenating # This maintains the batch dimension: (B, seq_len, num_features) batched_time_series = torch.stack(time_series_list, dim=0) print(f"batched_time_series shape: {batched_time_series.shape}") # Stack masks into batch format: (B, seq_len) batched_mask = torch.stack(mask_list, dim=0) print(f"batched_mask shape: {batched_mask.shape}") return { 'time_series': batched_time_series, 'attention_mask': batched_mask, # for padding }