TPS / time_series_forecasting /utils /augmentations.py
bakhshaliyev's picture
added all the files
556d303 verified
# Adapted from dominant-shuffle (https://github.com/zuojie2024/dominant-shuffle) and FrAug implementations
# Original works: Dominant Shuffle by Kai Zhao et al., FrAug by Muxi Chen et al.
# Modified and extended by Jafar Bakhshaliyev (2025)
import torch
import torch.nn as nn
import numpy as np
import random
from pytorch_wavelets import DWT1DForward, DWT1DInverse
import torch.nn.functional as F
from typing import List, Tuple
from joblib import Parallel, delayed
from dtaidistance import dtw
from statsmodels.tsa.seasonal import STL
from arch.bootstrap import MovingBlockBootstrap
def augmentation(augment_time):
if augment_time == 'batch':
return BatchAugmentation()
elif augment_time == 'dataset':
return DatasetAugmentation()
class BatchAugmentation():
def __init__(self):
pass
def flipping(self,x,y,rate=0):
xy = torch.cat([x,y],dim=1)
idxs = np.arange(xy.shape[1])
idxs = list(idxs)[::-1]
xy = xy[:,idxs,:]
return xy
def warping(self,x,y,rate=0):
xy = torch.cat([x,y],dim=1)
new_xy = torch.zeros_like(xy)
for i in range(new_xy.shape[1]//2):
new_xy[:,i*2,:] = xy[:,i + xy.shape[1]//2,:]
new_xy[:,i*2 + 1,:] = xy[:,i + xy.shape[1]//2,:]
return xy
def noise(self,x,y,rate=0.05):
xy = torch.cat([x,y],dim=1)
noise_xy = (torch.rand(xy.shape)-0.5) * 0.1
xy = xy + noise_xy.cuda()
return xy
def noise_input(self,x,y,rate=0.05):
noise = (torch.rand(x.shape)-0.5) * 0.1
x = x + noise.cuda()
xy = torch.cat([x,y],dim=1)
return xy
def masking(self,x,y,rate=0.5):
xy = torch.cat([x,y],dim=1)
b_idx = np.arange(xy.shape[1])
np.random.shuffle(b_idx)
crop_num = int(xy.shape[1]*rate)
xy[:,b_idx[:crop_num],:] = 0
return xy
def masking_seg(self,x,y,rate=0.5):
xy = torch.cat([x,y],dim=1)
b_idx = int(np.random.rand(1)*xy.shape[1]//2)
xy[:,b_idx:b_idx+xy.shape[1]//2,:] = 0
return xy
def freq_mask(self,x, y, rate=0.5, dim=1):
xy = torch.cat([x,y],dim=1)
xy_f = torch.fft.rfft(xy,dim=dim)
m = torch.cuda.FloatTensor(xy_f.shape).uniform_() < rate
amp = abs(xy_f)
_, index = amp.sort(dim=dim, descending=True)
dominant_mask = index > 5
m = torch.bitwise_and(m,dominant_mask)
freal = xy_f.real.masked_fill(m,0)
fimag = xy_f.imag.masked_fill(m,0)
xy_f = torch.complex(freal,fimag)
xy = torch.fft.irfft(xy_f,dim=dim)
return xy
def freq_mix(self, x, y, rate=0.5, dim=1):
xy = torch.cat([x,y],dim=dim)
xy_f = torch.fft.rfft(xy,dim=dim)
m = torch.cuda.FloatTensor(xy_f.shape).uniform_() < rate
amp = abs(xy_f)
_,index = amp.sort(dim=dim, descending=True)
dominant_mask = index > 2
m = torch.bitwise_and(m,dominant_mask)
freal = xy_f.real.masked_fill(m,0)
fimag = xy_f.imag.masked_fill(m,0)
b_idx = np.arange(x.shape[0])
np.random.shuffle(b_idx)
x2, y2 = x[b_idx], y[b_idx]
xy2 = torch.cat([x2,y2],dim=dim)
xy2_f = torch.fft.rfft(xy2,dim=dim)
m = torch.bitwise_not(m)
freal2 = xy2_f.real.masked_fill(m,0)
fimag2 = xy2_f.imag.masked_fill(m,0)
freal += freal2
fimag += fimag2
xy_f = torch.complex(freal,fimag)
xy = torch.fft.irfft(xy_f,dim=dim)
return xy
def dom_shuffle(self, x, y, rate=4, dim=1):
xy = torch.cat([x,y],dim=1)
xy_f = torch.fft.rfft(xy,dim=dim)
magnitude = abs(xy_f)
topk_indices = torch.argsort(magnitude, dim=1, descending=True)[:, 1:int(rate+1)]
#minor_indices = torch.argsort(magnitude, dim=1, descending=True)[:, 10:]
new_xy_f = xy_f
for i in range(topk_indices.size(2)):
for j in range(xy.size(0)):
random_indices = torch.randperm(topk_indices.size(1))
shuffled_tensor1 = topk_indices[:,:,i][j][random_indices]
new_xy_f[:,:,i][j][topk_indices[:,:,i][j]] = new_xy_f[:,:,i][j][shuffled_tensor1]
xy = torch.fft.irfft(new_xy_f,dim=dim)
return xy
def wave_mask(self, x, y, rates, wavelet='db1', level=2, uniform = False, sampling_rate=0.5, dim=1):
xy = torch.cat([x, y], dim=1) # (B, T, C)
if uniform:
rates = [rates[0]] * (level + 1)
rate_tensor = torch.tensor(rates, device=x.device)
xy = xy.permute(0, 2, 1).to(x.device) # (B, C, T)
dwt = DWT1DForward(J=level, wave=wavelet, mode='symmetric').to(x.device)
cA, cDs = dwt(xy)
mask = torch.rand_like(cA).to(cA.device) < rate_tensor[0]
cA = cA.masked_fill(mask, 0)
masked_cDs = []
for i, cD in enumerate(cDs):
mask_cD = torch.rand_like(cD).to(cD.device) < rate_tensor[i+1] # Create mask
cD = cD.masked_fill(mask_cD, 0)
masked_cDs.append(cD)
idwt = DWT1DInverse(wave=wavelet, mode='symmetric').to(x.device)
xy = idwt((cA, masked_cDs))
xy = xy.permute(0, 2, 1) # (B, T, C)
# Sample a subset of batches
batch_size = xy.shape[0]
sampling_steps = int(batch_size * sampling_rate)
indices = torch.randperm(batch_size)[:sampling_steps]
xy = xy[indices]
return xy, indices
def wave_mix(self, x, y, rates, wavelet='db1', level=2, uniform = False, sampling_rate=0.5, dim=1):
xy = torch.cat([x, y], dim=1) # (B, T, C)
batch_size, _, _ = xy.shape
if uniform:
rates = [rates[0]] * (level + 1)
rate_tensor = torch.tensor(rates, device=x.device)
xy = xy.permute(0, 2, 1).to(x.device) # (B, C, T)
# Shuffle the batch for mixing
b_idx = torch.randperm(batch_size)
xy2 = xy[b_idx]
dwt = DWT1DForward(J=level, wave=wavelet, mode='symmetric').to(x.device)
cA1, cDs1 = dwt(xy)
cA2, cDs2 = dwt(xy2)
mask = torch.rand_like(cA1).to(cA1.device) < rate_tensor[0] # Create mask
cA_mixed = cA1.masked_fill(mask, 0) + cA2.masked_fill(~mask, 0)
mixed_cDs = []
for i, (cD1, cD2) in enumerate(zip(cDs1, cDs2)):
mask = torch.rand_like(cD1).to(cD1.device) < rate_tensor[i+1] # Create mask
cD_mixed = cD1.masked_fill(mask, 0) + cD2.masked_fill(~mask, 0)
mixed_cDs.append(cD_mixed)
idwt = DWT1DInverse(wave=wavelet, mode='symmetric').to(x.device)
xy = idwt((cA_mixed, mixed_cDs))
xy = xy.permute(0, 2, 1) # (B, T, C)
# Sample a subset of batches
batch_size = xy.shape[0]
sampling_steps = int(batch_size * sampling_rate)
indices = torch.randperm(batch_size)[:sampling_steps]
xy = xy[indices]
return xy, indices
def freq_add(self, x, y, rate = 0.5):
"""
Perturb a single randomly selected low-frequency component per channel by setting its magnitude
to half of the maximum magnitude per channel.
"""
xy = torch.cat([x, y], dim=1) # Shape: (B, T_total, D)
device = xy.device
X = torch.fft.rfft(xy, dim=1) # Shape: (B, num_freqs, D)
mag = torch.abs(X) # Shape: (B, num_freqs, D)
max_mag = mag.amax(dim=1, keepdim=True) # Shape: (B, 1, D)
num_freqs = X.size(1)
N_low = num_freqs // 2
low_freq_indices = torch.arange(1, N_low, device=device) # Indices from 1 to N_low - 1
num_low_freqs = len(low_freq_indices)
batch_size = x.size(0)
num_channels = x.size(2)
selected_indices = torch.randint(0, num_low_freqs, (batch_size, num_channels), device=device) # Shape: (B, D)
freq_indices = low_freq_indices[selected_indices] # Shape: (B, D)
batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, num_channels) # Shape: (B, D)
channel_indices = torch.arange(num_channels, device=device).unsqueeze(0).expand(batch_size, -1) # Shape: (B, D)
orig_phase = torch.angle(X[batch_indices, freq_indices, channel_indices]) # Shape: (B, D)
new_mag = max_mag.squeeze(1) * rate # Shape: (B, D)
X[batch_indices, freq_indices, channel_indices] = new_mag * torch.exp(1j * orig_phase)
xy = torch.fft.irfft(X, n=xy.size(1), dim=1) # Shape: (B, T_total, D)
return xy
def freq_pool(self, x, y, rate=0.5):
"""
Implements FreqPool augmentation from Chen et al. (2023b).
Applies maximum pooling to the entire frequency spectrum with given pool size.
"""
pool_size=int(rate)
xy = torch.cat([x, y], dim=1) # Shape: (B, T_total, D)
xy_f = torch.fft.rfft(xy, dim=1) # Shape: (B, num_freqs, D)
# Get real and imaginary components
real = xy_f.real
imag = xy_f.imag
B, F, D = real.shape
pad_size = (pool_size - (F % pool_size)) % pool_size
if pad_size > 0:
real = torch.nn.functional.pad(real, (0, 0, 0, pad_size))
imag = torch.nn.functional.pad(imag, (0, 0, 0, pad_size))
real = real.unsqueeze(1) # Shape: (B, 1, F, D)
imag = imag.unsqueeze(1) # Shape: (B, 1, F, D)
# max pooling layer
max_pool = torch.nn.MaxPool2d(kernel_size=(pool_size, 1), stride=(pool_size, 1))
# max pooling separately to real and imaginary parts
real_pooled = max_pool(real)
imag_pooled = max_pool(imag)
real_pooled = real_pooled.squeeze(1)[:, :F, :]
imag_pooled = imag_pooled.squeeze(1)[:, :F, :]
# Reconstruction
xy_f_pooled = torch.complex(real_pooled, imag_pooled)
xy = torch.fft.irfft(xy_f_pooled, n=xy.size(1), dim=1)
return xy
def temporal_patch_shuffle(self, x, y, patch_len=16, stride=5, rate=0.5):
xy = torch.cat([x, y], dim=1) # (B, T, C)
B, T, C = xy.shape
xy = xy.permute(0, 2, 1)
patches = xy.unfold(dimension=2, size=patch_len, step=stride)
B, C, num_patches, _ = patches.shape
patches = patches.permute(0, 2, 1, 3).contiguous() # (B, num_patches, C, patch_len)
# scores based on variance
variance_scores = patches.var(dim=(2, 3)) # (B, num_patches)
# Decide which patches to shuffle based on shuffle rate
num_to_shuffle = int(num_patches * rate)
if num_to_shuffle > 0:
for b in range(B):
scores = variance_scores[b]
sorted_indices = torch.argsort(scores)
shuffle_indices = sorted_indices[:num_to_shuffle]
# Shuffle selected patches
patches_to_shuffle = patches[b, shuffle_indices]
shuffled_order = torch.randperm(num_to_shuffle)
patches[b, shuffle_indices] = patches_to_shuffle[shuffled_order]
# Reconstruction
patches = patches.permute(0, 2, 1, 3)
reconstructed = torch.zeros(B, C, T, device=xy.device)
counts = torch.zeros(B, C, T, device=xy.device)
for i in range(num_patches):
start = i * stride
end = start + patch_len
reconstructed[:, :, start:end] += patches[:, :, i]
counts[:, :, start:end] += 1
counts[counts == 0] = 1
xy = reconstructed / counts
# Permute back to (B, T, C)
xy = xy.permute(0, 2, 1)
return xy
def robusttad_m(self, x, y, rate=0.5, m_K=3, segment_ratio=0.2):
"""
RobustTAD magnitude augmentation with multiple segments.
Args:
rate: qA parameter controlling perturbation
m_K: number of segments to perturb
"""
xy = torch.cat([x, y], dim=1)
xy_f = torch.fft.rfft(xy, dim=1) # N' = N//2 + 1 frequencies for real signal
magnitude = torch.abs(xy_f)
phase = torch.angle(xy_f)
# Calculate segment length K = rN'
B, N_prime, D = xy_f.shape
K = int(N_prime * segment_ratio) # segment length
# Ensure |ki - ki+1| >= K/2 for overlapping constraint
valid_starts = []
for _ in range(m_K):
if not valid_starts:
start = torch.randint(0, N_prime - K, (1,)).item()
else:
# Create mask for valid start positions
mask = torch.ones(N_prime - K, dtype=torch.bool)
for prev_start in valid_starts:
# minimum distance K/2 from previous segments
invalid_range_start = max(0, prev_start - K//2)
invalid_range_end = min(N_prime - K, prev_start + K//2)
mask[invalid_range_start:invalid_range_end] = False
# valid positions
valid_positions = torch.where(mask)[0]
if len(valid_positions) == 0:
break
# Randomly select from valid positions
start = valid_positions[torch.randint(0, len(valid_positions), (1,))].item()
valid_starts.append(start)
for start_idx in valid_starts:
end_idx = start_idx + K
magnitude_segment = magnitude[:, start_idx:end_idx, :]
# segment statistics
mu_bar_A = magnitude_segment.mean(dim=(1, 2), keepdim=True)
delta_bar_A_sq = magnitude_segment.var(dim=(1, 2), unbiased=False, keepdim=True)
mu = torch.zeros_like(mu_bar_A)
# Sampling
std = (rate * delta_bar_A_sq).sqrt()
v = torch.normal(mean=mu, std=std)
v = v.expand(B, K, D)
# Replacing magnitude values in segment
magnitude[:, start_idx:end_idx, :] = v
# Reconstruction
xy_f_aug = torch.polar(magnitude, phase)
xy = torch.fft.irfft(xy_f_aug, n=xy.size(1), dim=1)
return xy
def robusttad_p(self, x, y, rate=0.5, m_K=3, segment_ratio=0.2):
"""
RobustTAD phase augmentation with multiple segments.
Args:
rate: q_theta parameter controlling perturbation degree
m_K: number of segments to perturb
"""
xy = torch.cat([x, y], dim=1)
xy_f = torch.fft.rfft(xy, dim=1) # N' = N//2 + 1 frequencies for real signal
magnitude = torch.abs(xy_f)
phase = torch.angle(xy_f)
# Calculate segment length K = rN'
B, N_prime, D = xy_f.shape
K = int(N_prime * segment_ratio) # segment length
# Ensure |ki - ki+1| >= K/2 for overlapping constraint
valid_starts = []
for _ in range(m_K):
if not valid_starts:
start = torch.randint(0, N_prime - K, (1,)).item()
else:
mask = torch.ones(N_prime - K, dtype=torch.bool)
for prev_start in valid_starts:
# minimum distance K/2 from previous segments
invalid_range_start = max(0, prev_start - K//2)
invalid_range_end = min(N_prime - K, prev_start + K//2)
mask[invalid_range_start:invalid_range_end] = False
# valid positions
valid_positions = torch.where(mask)[0]
if len(valid_positions) == 0:
break
# Randomly select from valid positions
start = valid_positions[torch.randint(0, len(valid_positions), (1,))].item()
valid_starts.append(start)
for start_idx in valid_starts:
end_idx = start_idx + K
phase_segment = phase[:, start_idx:end_idx, :]
delta_theta_sq = phase_segment.var(dim=(1, 2), unbiased=False, keepdim=True)
# Sampling
std_theta = (rate * delta_theta_sq).sqrt()
theta = torch.normal(mean=torch.zeros_like(std_theta), std=std_theta)
theta = theta.expand(B, K, D)
# Add perturbation
phase[:, start_idx:end_idx, :] += theta
# Reconstruction
xy_f_aug = torch.polar(magnitude, phase)
xy = torch.fft.irfft(xy_f_aug, n=xy.size(1), dim=1)
return xy
def upsample(self, x, y, rate=0.1):
"""
Upsample method based on Semenoglou et al. (2023).
Parameters:
- x: Input data tensor of shape (B, T_x, D)
- y: Labels tensor of shape (B, T_y, D)
- K: Number of points to select (must be less than T_total)
- rate: If K is not provided, K = int(T_total * rate)
Returns:
- Augmented data tensors x_aug and y_aug of shapes matching x and y
"""
xy = torch.cat([x, y], dim=1) # Shape: (B, T_total, D)
B, T_total, D = xy.shape
K = int(T_total * rate)
# randomly select start indices for each sample
max_start = T_total - K
start_indices = torch.randint(0, max_start + 1, (B,), device=xy.device)
xy_aug_list = []
for b in range(B):
s = start_indices[b].item()
# Extract the subsequence of length K
subseq = xy[b, s:s+K, :] # Shape: (K, D)
subseq = subseq.unsqueeze(0).permute(0, 2, 1) # Shape: (1, D, K)
# linear interpolation
upsampled = F.interpolate(subseq, size=T_total, mode='linear', align_corners=False)
upsampled = upsampled.squeeze(0).permute(1, 0) # Shape: (T_total, D)
xy_aug_list.append(upsampled)
xy = torch.stack(xy_aug_list, dim=0) # Shape: (B, T_total, D)
return xy
def asd(self, x, y, rate = 0.5, top_k = 5, dtw_window_ratio = 0.1):
"""
Perform ASD-based data augmentation on a batch of time series.
x: (B, T_x, D) # look-back window
y: (B, T_y, D) # horizon
- both on the same device (CPU or GPU)
Returns
-------
xy_aug : torch.Tensor
Shape: (B, T_x+T_y, D)
The newly synthesized samples for each of the B original samples.
You can then concatenate these to your original data if desired.
"""
device = x.device
xy = torch.cat([x, y], dim=1)
B, T_total, D = xy.shape
# Move to CPU for dtw calculation (dtaidistance uses NumPy/C)
xy_np = xy.detach().cpu().numpy() # shape: (B, T_total, D)
# pairwise DTW distances
distances = np.zeros((B, B), dtype=np.float64)
# a simple approach: for each dimension, compute distance_matrix_fast, sum results.
for d_idx in range(D):
series_d = xy_np[:, :, d_idx]
series_d = np.ascontiguousarray(series_d, dtype=np.double)
window = None
if dtw_window_ratio is not None:
window = int(dtw_window_ratio * T_total)
try:
dist_d = dtw.distance_matrix_fast(series_d,
window=window,
parallel=True)
except Exception as e:
print(f"[ASD] Fallback (dimension={d_idx}): {e}, using distance_matrix()")
dist_d = dtw.distance_matrix(series_d, window=window)
distances += dist_d
if D > 1:
distances /= D
distances_t = torch.from_numpy(distances).float().to(device) # shape (B, B)
# for each sample i, get its top_k neighbors, build an exponentially weighted sum
xy_syn_list = []
for i in range(B):
dist_i = distances_t[i].clone()
dist_i[i] = float('inf')
# top_k neighbors
vals, idxs = torch.topk(dist_i, k=top_k, largest=False) # shape: (top_k,)
# Compute weights
d_nn = vals[0].clamp(min=1e-8)
factor = vals / d_nn # ratio
weights = torch.exp(torch.log(torch.tensor(rate, device=device)) * factor)
weights = weights / weights.sum() # normalize
# Weighted sum of neighbors
neighbors = xy[idxs]
weighted_sum = (neighbors * weights.view(top_k, 1, 1)).sum(dim=0) # (T_total, D)
new_sample = weighted_sum
xy_syn_list.append(new_sample)
xy = torch.stack(xy_syn_list, dim=0)
xy = xy.to(device, dtype=torch.float32)
return xy
def process_one_sample_cpu(self, sample_i_cpu, block_size=24, stl_period=7):
"""
Process a single sample using arch.bootstrap.MovingBlockBootstrap properly.
Args:
sample_i_cpu: numpy array of shape (T, D)
block_size: size of blocks for MBB
stl_period: period for STL decomposition
"""
T, D = sample_i_cpu.shape
trends_np = np.zeros_like(sample_i_cpu)
seasonals_np = np.zeros_like(sample_i_cpu)
remainders_np = np.zeros_like(sample_i_cpu)
# STL decomposition
for d in range(D):
series_d = sample_i_cpu[:, d]
stl = STL(series_d, period=stl_period, robust=True)
res = stl.fit()
trends_np[:, d] = res.trend
seasonals_np[:, d] = res.seasonal
remainders_np[:, d] = res.resid
# MBB to remainder series
new_remainders = np.zeros_like(remainders_np)
for d in range(D):
bs = MovingBlockBootstrap(block_size, remainders_np[:, d], seed=None)
for data in bs.bootstrap(1):
new_remainders[:, d] = data[0][0] # First bootstrap replicate
break # We only need one sample
# Reconstruction
new_sample = trends_np + seasonals_np + new_remainders
return new_sample
def mbb(self, x, y, block_size=24, stl_period=7):
"""
Parallel MBB-based STL augmentation using arch.bootstrap.
Args:
x: input tensor (B, T_x, D)
y: target tensor (B, T_y, D)
block_size: size of moving blocks
stl_period: period for STL decomposition
"""
xy = torch.cat([x, y], dim=1)
device = xy.device
dtype = xy.dtype
B, T, D = xy.shape
x_cpu = xy.detach().cpu().numpy()
results = Parallel(n_jobs=-1)(
delayed(self.process_one_sample_cpu)(
x_cpu[i], block_size, stl_period
)
for i in range(B)
)
results_np = np.stack(results, axis=0)
xy = torch.from_numpy(results_np).to(device=device, dtype=dtype)
return xy
class DatasetAugmentation():
def __init__(self):
pass
def freq_dropout(self, x, y, dropout_rate=0.2, dim=0, keep_dominant=True):
x, y = torch.from_numpy(x), torch.from_numpy(y)
xy = torch.cat([x,y],dim=0)
print('AAAA', x.shape, y.shape)
xy_f = torch.fft.rfft(xy,dim=0)
m = torch.FloatTensor(xy_f.shape).uniform_() < dropout_rate
freal = xy_f.real.masked_fill(m,0)
fimag = xy_f.imag.masked_fill(m,0)
xy_f = torch.complex(freal,fimag)
xy = torch.fft.irfft(xy_f,dim=dim)
x, y = xy[:x.shape[0],:].numpy(), xy[-y.shape[0]:,:].numpy()
return x, y
def freq_mix(self, x, y, x2, y2, dropout_rate=0.2):
x, y = torch.from_numpy(x), torch.from_numpy(y)
xy = torch.cat([x,y],dim=0)
xy_f = torch.fft.rfft(xy,dim=0)
m = torch.FloatTensor(xy_f.shape).uniform_() < dropout_rate
amp = abs(xy_f)
_,index = amp.sort(dim=0, descending=True)
dominant_mask = index > 2
m = torch.bitwise_and(m,dominant_mask)
freal = xy_f.real.masked_fill(m,0)
fimag = xy_f.imag.masked_fill(m,0)
x2, y2 = torch.from_numpy(x2), torch.from_numpy(y2)
xy2 = torch.cat([x2,y2],dim=0)
xy2_f = torch.fft.rfft(xy2,dim=0)
m = torch.bitwise_not(m)
freal2 = xy2_f.real.masked_fill(m,0)
fimag2 = xy2_f.imag.masked_fill(m,0)
freal += freal2
fimag += fimag2
xy_f = torch.complex(freal,fimag)
xy = torch.fft.irfft(xy_f,dim=0)
x, y = xy[:x.shape[0],:].numpy(), xy[-y.shape[0]:,:].numpy()
return x, y