TempoPFN / src /synthetic_generation /augmentations /offline_per_sample_iid_augmentations.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
import argparse
import logging
import sys
import time
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.feather as feather
import torch
from src.data.augmentations import (
CensorAugmenter,
DifferentialAugmenter,
MixUpAugmenter,
QuantizationAugmenter,
RandomConvAugmenter,
TimeFlipAugmenter,
YFlipAugmenter,
)
from src.data.constants import LENGTH_CHOICES
from src.data.datasets import CyclicalBatchDataset
from src.data.filter import is_low_quality
from src.data.frequency import Frequency, parse_frequency
from src.data.scalers import MeanScaler, MedianScaler, MinMaxScaler, RobustScaler
class TimeSeriesDatasetManager:
def __init__(self, output_path: str, batch_size: int = 2**13):
self.output_path = Path(output_path)
self.output_path.mkdir(parents=True, exist_ok=True)
self.batches_dir = self.output_path
self.batches_dir.mkdir(exist_ok=True)
self.batch_size = batch_size
self.batch_counter = 0
self.series_counter = 0
self.schema = pa.schema(
[
("series_id", pa.int64()),
("values", pa.list_(pa.list_(pa.float64()))),
("length", pa.int32()),
("num_channels", pa.int32()),
("generator_type", pa.string()),
("start", pa.timestamp("ns")),
("frequency", pa.string()),
("generation_timestamp", pa.timestamp("ns")),
]
)
self._initialize_state()
def _initialize_state(self) -> None:
existing_batches = sorted(self.batches_dir.glob("batch_*.arrow"))
if not existing_batches:
logging.info("No existing batches found. Starting from scratch.")
return
batch_numbers = []
total_series = 0
for batch_file in existing_batches:
try:
batch_num = int(batch_file.stem.split("_")[1])
batch_numbers.append(batch_num)
batch_table = feather.read_table(batch_file)
total_series += len(batch_table)
except Exception as e:
logging.warning(f"Error reading batch {batch_file}: {e}")
continue
if batch_numbers:
max_batch_num = max(batch_numbers)
self.batch_counter = max_batch_num + 1
self.series_counter = total_series
last_batch_file = self.batches_dir / f"batch_{max_batch_num:08d}.arrow"
if last_batch_file.exists():
try:
last_batch_table = feather.read_table(last_batch_file)
if len(last_batch_table) < self.batch_size:
self.batch_counter = max_batch_num
logging.info(f"Found incomplete last batch {max_batch_num} with {len(last_batch_table)} series")
except Exception as e:
logging.warning(f"Error checking last batch: {e}")
logging.info(f"Resuming from: batch_counter={self.batch_counter}, series_counter={self.series_counter}")
def append_batch(self, batch_data: list[dict[str, Any]]) -> None:
if not batch_data:
return
try:
arrays = []
for field in self.schema:
field_name = field.name
if field_name in ["start", "generation_timestamp"]:
timestamps = [row[field_name] for row in batch_data]
arrays.append(pa.array([ts.value for ts in timestamps], type=pa.timestamp("ns")))
else:
arrays.append(pa.array([row[field_name] for row in batch_data]))
new_table = pa.Table.from_arrays(arrays, schema=self.schema)
batch_filename = f"batch_{self.batch_counter:08d}.arrow"
batch_filepath = self.batches_dir / batch_filename
feather.write_feather(new_table, batch_filepath)
self.series_counter += len(batch_data)
self.batch_counter += 1
except Exception as e:
logging.error(f"Error writing batch: {e}")
raise
class UnivariateOfflineAugmentor:
def __init__(
self,
augmentations: dict[str, bool] | None = None,
augmentation_probabilities: dict[str, float] | None = None,
global_seed: int = 42,
):
self.global_seed = global_seed
np.random.seed(global_seed)
torch.manual_seed(global_seed)
self.rng = np.random.default_rng(global_seed)
self.augmentation_probabilities = augmentation_probabilities
self.augmentations = augmentations
self.apply_augmentations = any(self.augmentations.values())
self.time_flip_augmenter = None
if self.augmentations["time_flip_augmentation"]:
self.time_flip_augmenter = TimeFlipAugmenter(
p_flip=self.augmentation_probabilities["time_flip_augmentation"]
)
self.yflip_augmenter = None
if self.augmentations["yflip_augmentation"]:
self.yflip_augmenter = YFlipAugmenter(p_flip=self.augmentation_probabilities["yflip_augmentation"])
self.censor_augmenter = None
if self.augmentations["censor_augmentation"]:
self.censor_augmenter = CensorAugmenter()
self.quantization_augmenter = None
if self.augmentations["quantization_augmentation"]:
self.quantization_augmenter = QuantizationAugmenter(
p_quantize=self.augmentation_probabilities["censor_or_quantization_augmentation"],
level_range=(5, 15),
)
if self.augmentations["differential_augmentation"]:
self.differential_augmentor = DifferentialAugmenter(
p_transform=self.augmentation_probabilities["differential_augmentation"]
)
def apply(
self,
history_values: torch.Tensor,
starts: list[pd.Timestamp] | None = None,
frequencies: list[str] | None = None,
) -> torch.Tensor:
if not self.apply_augmentations:
return history_values
batch_size = int(history_values.shape[0])
# 0) Combination (MixUp) – handled early at batch level due to dependency on other series
if self.augmentations.get("mixup_augmentation", False) and self.mixup_augmenter is not None:
history_values = self.mixup_augmenter.transform(history_values)
# Per-series plan: sample categories and apply in fixed order per series
# Categories (max one op per category):
# invariances, structure, seasonality, artifacts, analytic, discrete
for b in range(batch_size):
series = history_values[b : b + 1].clone()
# Determine eligible categories and weights for this series
categories = [
"invariances",
"structure",
"seasonality",
"artifacts",
"analytic",
"discrete",
]
weights = {
"invariances": 0.6,
"structure": 0.6,
"seasonality": 0.5,
"artifacts": 0.3,
"analytic": 0.4,
"discrete": 0.6,
}
# Remove disabled categories
if not (
self.augmentations.get("time_flip_augmentation", False)
or self.augmentations.get("yflip_augmentation", False)
):
weights["invariances"] = 0.0
if not (
self.augmentations.get("regime_change_augmentation", False)
or self.augmentations.get("shock_recovery_augmentation", False)
):
weights["structure"] = 0.0
if not (
self.augmentations.get("calendar_augmentation", False)
or self.augmentations.get("amplitude_modulation_augmentation", False)
):
weights["seasonality"] = 0.0
if not self.augmentations.get("differential_augmentation", False):
weights["analytic"] = 0.0
if not (
self.augmentations.get("quantization_augmentation", False)
or self.augmentations.get("censor_augmentation", False)
):
weights["discrete"] = 0.0
# Sample number of operations in [2, 5]
num_ops = int(self.rng.integers(2, 6))
# Build candidate list and normalized probabilities
candidates = [c for c in categories if weights[c] > 0.0]
if not candidates:
# Nothing to do for this series
history_values[b : b + 1] = series
continue
num_ops = min(num_ops, len(candidates))
probs = np.array([weights[c] for c in candidates], dtype=float)
probs = probs / probs.sum()
chosen_categories = list(self.rng.choice(candidates, size=num_ops, replace=False, p=probs))
# Apply in the fixed global order, only if selected
# 1) Invariances
if "invariances" in chosen_categories:
# Choose one: time_flip or yflip
choices = []
if self.augmentations.get("time_flip_augmentation", False):
choices.append("time_flip")
if self.augmentations.get("yflip_augmentation", False):
choices.append("yflip")
if choices:
pick = str(self.rng.choice(choices))
if pick == "time_flip":
series = torch.flip(series, dims=[1])
elif pick == "yflip":
series = -series
# 2) Structural edits
if "structure" in chosen_categories:
choices = []
if self.augmentations.get("regime_change_augmentation", False):
choices.append("regime")
if self.augmentations.get("shock_recovery_augmentation", False):
choices.append("shock")
if choices:
pick = str(self.rng.choice(choices))
if pick == "regime":
series = self._apply_regime_change(series, p_apply=1.0)
else:
series = self._apply_shock_recovery(series, p_apply=1.0)
# 3) Seasonality/context
if "seasonality" in chosen_categories:
choices = []
if self.augmentations.get("calendar_augmentation", False):
choices.append("calendar")
if self.augmentations.get("amplitude_modulation_augmentation", False):
choices.append("amplitude")
if choices:
pick = str(self.rng.choice(choices))
if pick == "calendar":
series = self._apply_calendar_injections(
series,
[starts[b]] if (starts is not None and b < len(starts)) else None,
[frequencies[b]] if (frequencies is not None and b < len(frequencies)) else None,
p_apply=1.0,
)
else:
series = self._apply_seasonality_amplitude_modulation(series, p_apply=1.0)
# 4) Sampling artifacts
if "artifacts" in chosen_categories and self.augmentations.get("resample_artifacts_augmentation", False):
series = self._apply_resample_artifacts(series, p_apply=1.0)
# 5) Analytic transforms
if (
"analytic" in chosen_categories
and self.augmentations.get("differential_augmentation", False)
and hasattr(self, "differential_augmentor")
):
series = self.differential_augmentor.transform(series)
# 6) Discretization/clipping (mutually exclusive)
if "discrete" in chosen_categories:
can_quant = (
self.augmentations.get("quantization_augmentation", False)
and self.quantization_augmenter is not None
)
can_cens = self.augmentations.get("censor_augmentation", False) and self.censor_augmenter is not None
if can_quant and can_cens:
method = self.rng.choice(["quantize", "censor"], p=[0.6, 0.4])
if method == "quantize":
series = self.quantization_augmenter.transform(series)
else:
series = self.censor_augmenter.transform(series)
elif can_quant:
series = self.quantization_augmenter.transform(series)
elif can_cens:
series = self.censor_augmenter.transform(series)
# Write back series
history_values[b : b + 1] = series
# 7) Scaling then Noise (last, optional, batch-level)
if self.augmentations.get("scaling_augmentation", False):
if self.rng.random() < self.augmentation_probabilities.get("scaling_augmentation", 0.0):
scale_factor = float(self.rng.uniform(0.95, 1.05))
history_values = history_values * scale_factor
if self.augmentations.get("noise_augmentation", False):
if self.rng.random() < self.augmentation_probabilities.get("noise_augmentation", 0.0):
noise_std = 0.01 * torch.std(history_values)
if torch.isfinite(noise_std) and (noise_std > 0):
noise = torch.normal(0, noise_std, size=history_values.shape)
history_values = history_values + noise
return history_values
def apply_per_series_only(
self,
series: torch.Tensor,
start: pd.Timestamp | None = None,
frequency: str | None = None,
) -> torch.Tensor:
"""
Apply all per-series augmentations (excluding mixup) to a single series tensor,
preserving ordering and probabilities used in apply().
Args:
series: Tensor of shape [1, length, 1]
start: Optional pandas.Timestamp for calendar injections
frequency: Optional frequency string for calendar injections
"""
if not self.apply_augmentations:
return series
categories = [
"invariances",
"structure",
"seasonality",
"artifacts",
"analytic",
"discrete",
]
weights = {
"invariances": 0.6,
"structure": 0.6,
"seasonality": 0.5,
"artifacts": 0.3,
"analytic": 0.4,
"discrete": 0.6,
}
# Disable categories not enabled
if not (
self.augmentations.get("time_flip_augmentation", False)
or self.augmentations.get("yflip_augmentation", False)
):
weights["invariances"] = 0.0
if not (
self.augmentations.get("regime_change_augmentation", False)
or self.augmentations.get("shock_recovery_augmentation", False)
):
weights["structure"] = 0.0
if not (
self.augmentations.get("calendar_augmentation", False)
or self.augmentations.get("amplitude_modulation_augmentation", False)
):
weights["seasonality"] = 0.0
if not self.augmentations.get("differential_augmentation", False):
weights["analytic"] = 0.0
if not (
self.augmentations.get("quantization_augmentation", False)
or self.augmentations.get("censor_augmentation", False)
):
weights["discrete"] = 0.0
# Sample number of operations in [2, 5]
num_ops = int(self.rng.integers(2, 6))
candidates = [c for c in categories if weights[c] > 0.0]
if not candidates:
result = series
else:
num_ops = min(num_ops, len(candidates))
probs = np.array([weights[c] for c in candidates], dtype=float)
probs = probs / probs.sum()
chosen_categories = list(self.rng.choice(candidates, size=num_ops, replace=False, p=probs))
result = series.clone()
# 1) Invariances
if "invariances" in chosen_categories:
choices = []
if self.augmentations.get("time_flip_augmentation", False):
choices.append("time_flip")
if self.augmentations.get("yflip_augmentation", False):
choices.append("yflip")
if choices:
pick = str(self.rng.choice(choices))
if pick == "time_flip":
result = torch.flip(result, dims=[1])
elif pick == "yflip":
result = -result
# 2) Structural edits
if "structure" in chosen_categories:
choices = []
if self.augmentations.get("regime_change_augmentation", False):
choices.append("regime")
if self.augmentations.get("shock_recovery_augmentation", False):
choices.append("shock")
if choices:
pick = str(self.rng.choice(choices))
if pick == "regime":
result = self._apply_regime_change(result, p_apply=1.0)
else:
result = self._apply_shock_recovery(result, p_apply=1.0)
# 3) Seasonality/context
if "seasonality" in chosen_categories:
choices = []
if self.augmentations.get("calendar_augmentation", False):
choices.append("calendar")
if self.augmentations.get("amplitude_modulation_augmentation", False):
choices.append("amplitude")
if choices:
pick = str(self.rng.choice(choices))
if pick == "calendar":
result = self._apply_calendar_injections(
result,
[start] if start is not None else None,
[frequency] if frequency is not None else None,
p_apply=1.0,
)
else:
result = self._apply_seasonality_amplitude_modulation(result, p_apply=1.0)
# 4) Sampling artifacts
if "artifacts" in chosen_categories and self.augmentations.get("resample_artifacts_augmentation", False):
result = self._apply_resample_artifacts(result, p_apply=1.0)
# 5) Analytic transforms
if (
"analytic" in chosen_categories
and self.augmentations.get("differential_augmentation", False)
and hasattr(self, "differential_augmentor")
):
result = self.differential_augmentor.transform(result)
# 6) Discretization/clipping (mutually exclusive)
if "discrete" in chosen_categories:
can_quant = (
self.augmentations.get("quantization_augmentation", False)
and self.quantization_augmenter is not None
)
can_cens = self.augmentations.get("censor_augmentation", False) and self.censor_augmenter is not None
if can_quant and can_cens:
method = self.rng.choice(["quantize", "censor"], p=[0.6, 0.4])
if method == "quantize":
result = self.quantization_augmenter.transform(result)
else:
result = self.censor_augmenter.transform(result)
elif can_quant:
result = self.quantization_augmenter.transform(result)
elif can_cens:
result = self.censor_augmenter.transform(result)
# Optional scaling and noise (applied to this single series)
if self.augmentations.get("scaling_augmentation", False):
if self.rng.random() < self.augmentation_probabilities.get("scaling_augmentation", 0.0):
scale_factor = float(self.rng.uniform(0.95, 1.05))
result = result * scale_factor
if self.augmentations.get("noise_augmentation", False):
if self.rng.random() < self.augmentation_probabilities.get("noise_augmentation", 0.0):
noise_std = 0.01 * torch.std(result)
if torch.isfinite(noise_std) and (noise_std > 0):
noise = torch.normal(0, noise_std, size=result.shape)
result = result + noise
return result
@property
def mixup_augmenter(self) -> MixUpAugmenter | None:
if not hasattr(self, "_mixup_augmenter"):
self._mixup_augmenter = (
MixUpAugmenter(p_combine=self.augmentation_probabilities["mixup_augmentation"])
if self.augmentations["mixup_augmentation"]
else None
)
return self._mixup_augmenter
def _apply_regime_change(self, series: torch.Tensor, p_apply: float) -> torch.Tensor:
"""
Apply piecewise affine transforms with 1-3 change-points per series.
series shape: [batch, length, 1]
"""
if series.numel() == 0:
return series
batch_size, length, _ = series.shape
result = series.clone()
# Iterate per-series to allow different change-points
for b in range(batch_size):
if self.rng.random() >= p_apply:
continue
# sample number of change points and ensure minimum segment length
num_cp = int(self.rng.integers(1, 4))
min_seg = max(8, length // 32)
if length <= (num_cp + 1) * min_seg:
num_cp = max(1, length // (2 * min_seg) - 1)
if num_cp <= 0:
num_cp = 1
# pick change-point indices
valid_positions = np.arange(min_seg, length - min_seg)
if valid_positions.size == 0:
continue
cp = np.sort(self.rng.choice(valid_positions, size=num_cp, replace=False))
boundaries = np.concatenate([[0], cp, [length]])
# compute per-segment scale/shift
series_b = result[b, :, 0]
seg_scales = []
seg_shifts = []
overall_std = torch.std(series_b).item()
if not np.isfinite(overall_std) or overall_std == 0:
overall_std = 1.0
for _ in range(len(boundaries) - 1):
scale = float(self.rng.uniform(0.8, 1.25))
shift = float(self.rng.normal(0.0, 0.15 * overall_std))
seg_scales.append(scale)
seg_shifts.append(shift)
# apply per segment
for i in range(len(boundaries) - 1):
s, e = int(boundaries[i]), int(boundaries[i + 1])
if e <= s:
continue
segment = series_b[s:e]
# preserve segment mean roughly while scaling deviations
seg_mean = torch.mean(segment)
transformed = (segment - seg_mean) * seg_scales[i] + seg_mean + seg_shifts[i]
result[b, s:e, 0] = transformed
return result
def _apply_shock_recovery(self, series: torch.Tensor, p_apply: float) -> torch.Tensor:
"""
Add an impulse at a random time and exponentially decay to baseline.
series shape: [batch, length, 1]
"""
if series.numel() == 0:
return series
batch_size, length, _ = series.shape
device = series.device
result = series.clone()
time_idx = torch.arange(length, device=device).float()
for b in range(batch_size):
if self.rng.random() >= p_apply:
continue
# choose shock time away from edges
t0 = int(self.rng.integers(low=max(1, length // 16), high=max(2, length - length // 16)))
# magnitude relative to series std
s_b = result[b, :, 0]
std_b = torch.std(s_b).item()
if not np.isfinite(std_b) or std_b == 0:
std_b = 1.0
mag = float(self.rng.uniform(0.5, 2.0) * std_b)
if self.rng.random() < 0.5:
mag = -mag
# decay constant
half_life = float(self.rng.uniform(0.03, 0.25) * length)
decay = torch.exp(-(time_idx - t0).clamp(min=0) / max(1.0, half_life))
effect = mag * decay
result[b, :, 0] = s_b + effect
return result
def _apply_calendar_injections(
self,
series: torch.Tensor,
starts: list[pd.Timestamp] | None,
frequencies: list[str] | None,
p_apply: float,
) -> torch.Tensor:
if series.numel() == 0:
return series
if starts is None or frequencies is None:
return series
batch_size, length, _ = series.shape
result = series.clone()
for b in range(batch_size):
if b >= len(starts) or b >= len(frequencies):
continue
if self.rng.random() >= p_apply:
continue
start_ts = starts[b]
try:
freq_enum = parse_frequency(str(frequencies[b]))
freq_alias = freq_enum.to_pandas_freq(for_date_range=True)
except Exception:
freq_alias = "D"
try:
index = pd.date_range(start=start_ts, periods=length, freq=freq_alias)
except Exception:
index = pd.date_range(start=start_ts, periods=length, freq="D")
factors = np.ones(length, dtype=np.float32)
# Weekend dips (for daily/hourly-like)
try:
freq_enum_check = parse_frequency(str(frequencies[b]))
except Exception:
freq_enum_check = Frequency.D
if freq_enum_check in [
Frequency.H,
Frequency.D,
Frequency.S,
Frequency.T1,
Frequency.T5,
Frequency.T10,
Frequency.T15,
Frequency.T30,
]:
dow = index.dayofweek
if (dow >= 5).any():
dip = float(self.rng.uniform(0.7, 0.95))
factors[dow >= 5] *= dip
# Month-end bumps
if hasattr(index, "is_month_end"):
me = np.asarray(index.is_month_end, dtype=bool)
if me.any():
bump = float(self.rng.uniform(1.05, 1.3))
factors[me] *= bump
# Holiday-like one-off effects (1-2 random impulses)
n_imp = int(self.rng.integers(1, 3))
imp_positions = self.rng.integers(0, length, size=n_imp)
for pos in np.atleast_1d(imp_positions):
if 0 <= pos < length:
impulse = float(self.rng.uniform(0.8, 1.4))
factors[pos] *= impulse
# Apply multiplicatively around mean to avoid drift
s = result[b, :, 0].cpu().numpy()
mean_val = float(np.mean(s))
s_new = (s - mean_val) * factors + mean_val
result[b, :, 0] = torch.from_numpy(s_new).to(result.device)
return result
def _apply_seasonality_amplitude_modulation(self, series: torch.Tensor, p_apply: float) -> torch.Tensor:
if series.numel() == 0:
return series
batch_size, length, _ = series.shape
result = series.clone()
for b in range(batch_size):
if self.rng.random() >= p_apply:
continue
min_w = max(8, length // 16)
max_w = max(min_w + 1, length // 2)
win = int(self.rng.integers(min_w, max_w + 1))
start = int(self.rng.integers(0, max(1, length - win)))
end = start + win
seg = result[b, start:end, 0]
if seg.numel() == 0:
continue
seg_mean = torch.mean(seg)
amp = float(self.rng.uniform(0.5, 1.8))
result[b, start:end, 0] = (seg - seg_mean) * amp + seg_mean
return result
def _apply_resample_artifacts(
self,
series: torch.Tensor,
p_apply: float,
) -> torch.Tensor:
"""
Downsample then upsample with interpolation to introduce artifacts.
"""
if series.numel() == 0:
return series
batch_size, length, _ = series.shape
result = series.clone()
for b in range(batch_size):
if self.rng.random() >= p_apply:
continue
s_np = result[b, :, 0].cpu().numpy()
max_factor = max(2, min(8, length // 32))
if max_factor <= 1:
continue
factor = int(self.rng.integers(2, max_factor + 1))
offset = int(self.rng.integers(0, factor))
ds_idx = np.arange(offset, length, factor)
if ds_idx.size < 3:
continue
ds_vals = s_np[ds_idx]
base_idx = np.arange(length)
mode = self.rng.choice(["linear", "hold", "linear_smooth"], p=[0.5, 0.2, 0.3])
if mode == "linear":
us = np.interp(base_idx, ds_idx, ds_vals)
elif mode == "hold":
us = np.empty(length, dtype=s_np.dtype)
last = ds_vals[0]
j = 0
for i in range(length):
while j + 1 < ds_idx.size and i >= ds_idx[j + 1]:
j += 1
last = ds_vals[j]
us[i] = last
else:
us = np.interp(base_idx, ds_idx, ds_vals)
k = max(3, length // 128)
kernel = np.ones(k) / k
us = np.convolve(us, kernel, mode="same")
result[b, :, 0] = torch.from_numpy(us).to(result.device)
return result
class OfflinePerSampleAugmentedGenerator:
def __init__(
self,
base_data_dir: str,
output_dir: str,
length: int | None,
chunk_size: int = 2**13,
generator_proportions: dict[str, float] | None = None,
augmentations: dict[str, bool] | None = None,
augmentation_probabilities: dict[str, float] | None = None,
global_seed: int = 42,
mixup_position: str = "both",
change_threshold: float = 0.05,
max_tries: int = 3,
enable_quality_filter: bool = False,
rc_batch_size: int = 8,
):
self.base_data_dir = base_data_dir
self.length = length
self.chunk_size = chunk_size
self.global_seed = global_seed
np.random.seed(global_seed)
torch.manual_seed(global_seed)
self.rng = np.random.default_rng(global_seed)
self.mixup_position = mixup_position
self.change_threshold = float(change_threshold)
self.max_tries = int(max_tries)
self.enable_quality_filter = bool(enable_quality_filter)
self.rc_batch_size = int(rc_batch_size)
out_dir_name = f"augmented_per_sample_{length}" if length is not None else "augmented_per_sample"
self.dataset_manager = TimeSeriesDatasetManager(str(Path(output_dir) / out_dir_name), batch_size=chunk_size)
self.augmentor = UnivariateOfflineAugmentor(
augmentations=augmentations,
augmentation_probabilities=augmentation_probabilities,
global_seed=global_seed,
)
self.generator_proportions = self._setup_proportions(generator_proportions)
self.datasets = self._initialize_datasets()
# -------------------- Per-sample scaler utilities --------------------
def _choose_scaler(self) -> object | None:
"""Choose a scaler with 50% probability of None; else one of four scalers uniformly."""
if self.rng.random() < 0.5:
return None
pick = str(self.rng.choice(["robust", "minmax", "median", "mean"]))
if pick == "robust":
return RobustScaler()
if pick == "minmax":
return MinMaxScaler()
if pick == "median":
return MedianScaler()
return MeanScaler()
def _apply_scaler(self, values: torch.Tensor, scaler: object | None) -> torch.Tensor:
"""Apply the provided scaler to values of shape [1, length, channels]."""
if scaler is None:
return values
stats = scaler.compute_statistics(values)
return scaler.scale(values, stats)
# -------------------- Mixup utilities (per-sample) --------------------
def _mix_sources_static(self, source_tensor: torch.Tensor, alpha: float) -> torch.Tensor:
"""Static Dirichlet mix of k sources -> [1, L, C]."""
k = int(source_tensor.shape[0])
device = source_tensor.device
concentration = torch.full((k,), float(alpha), device=device)
weights = torch.distributions.Dirichlet(concentration).sample()
mixed = (source_tensor * weights.view(k, 1, 1)).sum(dim=0, keepdim=True)
return mixed
def _apply_mixup_to_series(
self,
base_series: torch.Tensor,
total_length_for_batch: int,
scaler: object | None,
) -> torch.Tensor:
"""Mix base with k-1 additional sources; returns [1, L, 1]."""
mixup = self.augmentor.mixup_augmenter
if mixup is None:
return base_series
# Decide k
current_k = mixup._sample_k() if not mixup.randomize_k else int(self.rng.integers(2, mixup.max_k + 1))
# Ensure at least 2 and include base in the set
current_k = max(2, int(current_k))
num_sources_needed = current_k - 1
chosen_gens = self._choose_generators_for_mixup(current_k)
# If we sampled k gens but need only k-1 external sources, trim
chosen_gens = chosen_gens[:num_sources_needed]
sources: list[torch.Tensor] = []
# Base (already possibly scaled) first
sources.append(base_series)
# Additional sources
for gen in chosen_gens:
src_values, _, _, _ = self._get_one_sample_from_generator(gen, total_length_for_batch)
if scaler is not None:
src_values = self._apply_scaler(src_values, scaler)
sources.append(src_values)
source_tensor = torch.cat(sources, dim=0)
alpha = mixup._sample_alpha()
mixed_series = self._mix_sources_static(source_tensor, alpha=alpha)
return mixed_series
# -------------------- RandomConv (temp batch) utilities --------------------
def _apply_random_conv_with_temp_batch(
self,
base_series: torch.Tensor,
total_length_for_batch: int,
scaler: object | None,
) -> torch.Tensor:
"""Apply RandomConvAugmenter by creating a small temp batch and taking the transformed base element."""
if not hasattr(self, "random_conv_augmenter"):
# Lazy init if not present but enabled in config
if self.augmentor.augmentations.get("random_conv_augmentation", False):
p_val = self.augmentor.augmentation_probabilities.get("random_conv_augmentation", 0.3)
self.random_conv_augmenter = RandomConvAugmenter(p_transform=p_val)
else:
return base_series
# Assemble temp batch: base + (rc_batch_size-1) sources
temp_series_list: list[torch.Tensor] = [base_series]
for _ in range(max(0, self.rc_batch_size - 1)):
try:
gen = self._sample_generator_name()
src_values, _, _, _ = self._get_one_sample_from_generator(gen, total_length_for_batch)
if scaler is not None:
src_values = self._apply_scaler(src_values, scaler)
temp_series_list.append(src_values)
except Exception:
break
temp_batch = torch.cat(temp_series_list, dim=0)
transformed = self.random_conv_augmenter.transform(temp_batch)
return transformed[0:1]
# -------------------- Selection and quality helpers --------------------
def _compute_change_score(self, original: torch.Tensor, augmented: torch.Tensor) -> float:
"""
Computes a normalized change score between original and augmented series.
The score is the Mean Absolute Error (MAE) normalized by a robust
measure of the original series' scale (its Interquartile Range).
This makes the score less sensitive to outliers and absolute scale.
"""
original_flat = original.flatten()
# Use the standard Interquartile Range (IQR) for robust scaling.
q25 = torch.quantile(original_flat, 0.25)
q75 = torch.quantile(original_flat, 0.75)
iqr = (q75 - q25).item()
# Use a robust epsilon to prevent division by zero for flat series.
series_range = (torch.max(original_flat) - torch.min(original_flat)).item()
scale = max(iqr, 1e-6 * series_range, 1e-8)
# Compute Mean Absolute Error
mae = torch.mean(torch.abs(augmented - original)).item()
return float(mae / scale)
# moved to src/synthetic_generation/augmentations/filter.py
def _setup_proportions(self, generator_proportions: dict[str, float] | None) -> dict[str, float]:
# Default uniform proportions across discovered generators
if generator_proportions is None:
# Discover generator directories
base = Path(self.base_data_dir)
discovered = [p.name for p in base.iterdir() if p.is_dir()]
proportions = dict.fromkeys(discovered, 1.0)
else:
proportions = dict(generator_proportions)
total = sum(proportions.values())
if total <= 0:
raise ValueError("Total generator proportions must be positive")
return {k: v / total for k, v in proportions.items()}
def _initialize_datasets(self) -> dict[str, CyclicalBatchDataset]:
datasets: dict[str, CyclicalBatchDataset] = {}
for generator_name, proportion in self.generator_proportions.items():
# Load batches only if the generator is explicitly listed and has positive proportion
if proportion <= 0:
continue
batches_dir = Path(self.base_data_dir) / generator_name
if not batches_dir.is_dir():
logging.warning(f"Skipping '{generator_name}' because directory does not exist: {batches_dir}")
continue
try:
dataset = CyclicalBatchDataset(
batches_dir=str(batches_dir),
generator_type=generator_name,
device=None,
prefetch_next=True,
prefetch_threshold=32,
)
datasets[generator_name] = dataset
logging.info(f"Loaded dataset for {generator_name}")
except Exception as e:
logging.warning(f"Failed to load dataset for {generator_name}: {e}")
if not datasets:
raise ValueError("No valid datasets loaded from base_data_dir")
return datasets
def _convert_sample_to_tensor(self, sample: dict) -> tuple[torch.Tensor, Any, str, int]:
num_channels = sample.get("num_channels", 1)
values_data = sample["values"]
if num_channels == 1:
if isinstance(values_data[0], list):
values = torch.tensor(values_data[0], dtype=torch.float32)
else:
values = torch.tensor(values_data, dtype=torch.float32)
values = values.unsqueeze(0).unsqueeze(-1)
else:
channel_tensors = []
for channel_values in values_data:
channel_tensor = torch.tensor(channel_values, dtype=torch.float32)
channel_tensors.append(channel_tensor)
values = torch.stack(channel_tensors, dim=-1).unsqueeze(0)
freq_str = sample["frequency"]
start_val = sample["start"]
# Keep start as pandas.Timestamp for Arrow writing later
if isinstance(start_val, pd.Timestamp):
start = start_val
else:
start = pd.Timestamp(start_val)
return values, start, freq_str, num_channels
def _maybe_resize(self, values: torch.Tensor, target_len: int) -> torch.Tensor:
if values.shape[1] == target_len:
return values
if values.shape[1] > target_len:
max_start_idx = values.shape[1] - target_len
start_idx = np.random.randint(0, max_start_idx + 1)
return values[:, start_idx : start_idx + target_len, :]
# Subsample evenly to reach target_len
indices = np.linspace(0, values.shape[1] - 1, target_len, dtype=int)
return values[:, indices, :]
def _sample_generator_name(self) -> str:
available = [g for g in self.generator_proportions.keys() if g in self.datasets]
probs = np.array([self.generator_proportions[g] for g in available], dtype=float)
probs = probs / probs.sum()
return str(np.random.choice(available, p=probs))
def _get_one_sample(self, total_length_for_batch: int) -> tuple[torch.Tensor, pd.Timestamp, str, int]:
attempts = 0
while attempts < 20:
attempts += 1
gen_name = self._sample_generator_name()
dataset = self.datasets[gen_name]
sample = dataset.get_samples(1)[0]
values, start, freq_str, num_channels = self._convert_sample_to_tensor(sample)
values = self._maybe_resize(values, total_length_for_batch)
if values.shape[2] != 1:
continue
return values, start, freq_str, num_channels
raise RuntimeError("Failed to sample a valid univariate series after multiple attempts")
def _get_one_sample_from_generator(
self, gen_name: str, total_length_for_batch: int
) -> tuple[torch.Tensor, pd.Timestamp, str, int]:
attempts = 0
dataset = self.datasets[gen_name]
while attempts < 20:
attempts += 1
sample = dataset.get_samples(1)[0]
values, start, freq_str, num_channels = self._convert_sample_to_tensor(sample)
values = self._maybe_resize(values, total_length_for_batch)
if values.shape[2] != 1:
continue
return values, start, freq_str, num_channels
raise RuntimeError(
f"Failed to sample a valid univariate series from generator '{gen_name}' after multiple attempts"
)
def _choose_generators_for_mixup(self, k: int) -> list[str]:
available = [g for g in self.generator_proportions.keys() if g in self.datasets]
if not available:
raise RuntimeError("No available generators to sample from for mixup")
k_eff = min(k, len(available))
# Weighted sampling without replacement by sequential renormalization
chosen: list[str] = []
remaining = available.copy()
while len(chosen) < k_eff:
weights = np.array([self.generator_proportions[g] for g in remaining], dtype=float)
if weights.sum() <= 0:
# fallback to uniform
probs = np.ones(len(remaining)) / len(remaining)
else:
probs = weights / weights.sum()
pick = str(np.random.choice(remaining, p=probs))
chosen.append(pick)
remaining.remove(pick)
return chosen
def _maybe_apply_mixup_to_single(self, base_series: torch.Tensor, total_length_for_batch: int) -> torch.Tensor:
do_mixup = self.augmentor.augmentations.get(
"mixup_augmentation", False
) and self.augmentor.rng.random() < self.augmentor.augmentation_probabilities.get("mixup_augmentation", 0.0)
if not do_mixup:
return base_series
# Use MixUpAugmenter to avoid duplication
mixup = self.augmentor.mixup_augmenter
if mixup is None:
return base_series
# Decide number of sources k consistent with MixUpAugmenter behavior
current_k = mixup._sample_k() if not mixup.randomize_k else int(self.augmentor.rng.integers(2, mixup.max_k + 1))
# Choose distinct generators for sources according to proportions
chosen_gens = self._choose_generators_for_mixup(current_k)
# Collect one source per chosen generator
sources: list[torch.Tensor] = []
for gen in chosen_gens:
src_values, _, _, _ = self._get_one_sample_from_generator(gen, total_length_for_batch)
sources.append(src_values)
source_tensor = torch.cat(sources, dim=0)
# Sample alpha via MixUpAugmenter, then mix
alpha = mixup._sample_alpha()
mixed_series = mixup.mix_sources(source_tensor, alpha=alpha)
return mixed_series
def _tensor_to_values_list(self, series_tensor: torch.Tensor) -> tuple[list[list[float]], int, int]:
# series_tensor shape: [1, seq_len, num_channels]
seq_len = int(series_tensor.shape[1])
num_channels = int(series_tensor.shape[2])
if num_channels == 1:
return [series_tensor.squeeze(0).squeeze(-1).tolist()], seq_len, 1
channels: list[list[float]] = []
for ch in range(num_channels):
channels.append(series_tensor[0, :, ch].tolist())
return channels, seq_len, num_channels
def run(self, num_batches: int) -> None:
logging.info(
f"Starting offline augmentation into {self.dataset_manager.batches_dir} | chunk_size={self.chunk_size}"
)
augmented_buffer: list[dict[str, Any]] = []
target_batches = num_batches
start_time = time.time()
try:
while self.dataset_manager.batch_counter < target_batches:
# Decide target length for this sample
total_length_for_batch = (
self.length if self.length is not None else int(np.random.choice(LENGTH_CHOICES))
)
for _ in range(max(1, self.max_tries)):
# Sample one base series
base_values, base_start, base_freq, _ = self._get_one_sample(total_length_for_batch)
original_base = base_values.clone()
# Per-sample scaler choice (50% none; else robust/minmax/median/mean)
per_sample_scaler = self._choose_scaler()
base_values = self._apply_scaler(base_values, per_sample_scaler)
# Early mixup (if enabled and position includes first)
do_mixup_early = (
self.augmentor.augmentations.get("mixup_augmentation", False)
and self.mixup_position in ["first", "both"]
and self.augmentor.rng.random()
< self.augmentor.augmentation_probabilities.get("mixup_augmentation", 0.0)
)
if do_mixup_early:
base_values = self._apply_mixup_to_series(
base_values, total_length_for_batch, per_sample_scaler
)
# Apply per-series augmentations
augmented_single = self.augmentor.apply_per_series_only(
base_values, start=base_start, frequency=base_freq
)
# Optional analytic: RandomConvAugmenter via temp batch (before late mixup)
if self.augmentor.augmentations.get("random_conv_augmentation", False):
if self.rng.random() < self.augmentor.augmentation_probabilities.get(
"random_conv_augmentation", 0.3
):
augmented_single = self._apply_random_conv_with_temp_batch(
augmented_single,
total_length_for_batch,
per_sample_scaler,
)
# Late mixup (if enabled and position includes last)
do_mixup_late = (
self.augmentor.augmentations.get("mixup_augmentation", False)
and self.mixup_position in ["last", "both"]
and self.augmentor.rng.random()
< self.augmentor.augmentation_probabilities.get("mixup_augmentation", 0.0)
)
if do_mixup_late:
augmented_single = self._apply_mixup_to_series(
augmented_single, total_length_for_batch, per_sample_scaler
)
# Compute change score and unchanged check
score = self._compute_change_score(original_base, augmented_single)
if score < self.change_threshold:
continue
# Optional quality filter
if self.enable_quality_filter and is_low_quality(augmented_single):
continue
# Accept first candidate that passes thresholds
values_list, seq_len, num_channels = self._tensor_to_values_list(augmented_single)
record = {
"series_id": self.dataset_manager.series_counter,
"values": values_list,
"length": int(seq_len),
"num_channels": int(num_channels),
"generator_type": "augmented",
"start": pd.Timestamp(base_start),
"frequency": base_freq,
"generation_timestamp": pd.Timestamp.now(),
}
augmented_buffer.append(record)
break
# Discard combined_values_augmented and loop
if len(augmented_buffer) >= self.chunk_size:
write_start = time.time()
self.dataset_manager.append_batch(augmented_buffer)
write_time = time.time() - write_start
elapsed = time.time() - start_time
series_per_sec = self.dataset_manager.series_counter / elapsed if elapsed > 0 else 0
print(
f"✓ Wrote batch {self.dataset_manager.batch_counter - 1}/{target_batches} | "
f"Series: {self.dataset_manager.series_counter:,} | "
f"Rate: {series_per_sec:.1f}/s | "
f"Write: {write_time:.2f}s"
)
augmented_buffer = []
except KeyboardInterrupt:
logging.info(
f"Interrupted. Generated {self.dataset_manager.series_counter} series, "
f"{self.dataset_manager.batch_counter} batches."
)
finally:
# Flush remaining buffer if any
if augmented_buffer:
self.dataset_manager.append_batch(augmented_buffer)
logging.info("Offline augmentation completed.")
def setup_logging(verbose: bool = False) -> None:
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
def main():
parser = argparse.ArgumentParser(
description="Offline augmentation script to precompute augmented series",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--base-data-dir",
type=str,
required=True,
help="Base directory with generator subdirectories (inputs)",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Base output directory for augmented datasets",
)
parser.add_argument(
"--length",
type=int,
default=None,
help="Fixed length of augmented series. If set, saves under augmented{length}",
)
parser.add_argument(
"--chunk-size",
type=int,
default=2**13, # 8192
help="Number of series per written Arrow batch",
)
parser.add_argument(
"--num-batches",
type=int,
default=1000,
help="Number of Arrow batches to write",
)
parser.add_argument(
"--mixup-position",
type=str,
default="both",
choices=["first", "last", "both"],
help="Where to apply mixup in the pipeline (first, last, or both)",
)
parser.add_argument(
"--change-threshold",
type=float,
default=0.05,
help="Minimum normalized change score (vs IQR) required to keep series",
)
parser.add_argument(
"--max-tries",
type=int,
default=3,
help="Max attempts to produce an acceptable augmented series per output",
)
parser.add_argument(
"--enable-quality-filter",
action="store_true",
help="Enable low-quality series filter (noise-like removal)",
)
# Quality filter thresholds moved to filter module defaults
parser.add_argument(
"--rc-batch-size",
type=int,
default=8,
help="Temporary batch size used for RandomConvAugmenter",
)
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
parser.add_argument("--global-seed", type=int, default=42, help="Global random seed")
args = parser.parse_args()
setup_logging(args.verbose)
generator_proportions = {
"forecast_pfn": 1.0,
"gp": 1.0,
"kernel": 1.0,
"sinewave": 1.0,
"sawtooth": 1.0,
"step": 0.1,
"anomaly": 1.0,
"spike": 1.0,
"cauker_univariate": 2.0,
"ou_process": 1.0,
"audio_financial_volatility": 0.1,
"audio_multi_scale_fractal": 0.1,
"audio_network_topology": 0.5,
"audio_stochastic_rhythm": 1.0,
}
augmentations = {
"censor_augmentation": True,
"quantization_augmentation": True,
"mixup_augmentation": True,
"time_flip_augmentation": True,
"yflip_augmentation": True,
"differential_augmentation": True,
"regime_change_augmentation": True,
"shock_recovery_augmentation": True,
"calendar_augmentation": True,
"amplitude_modulation_augmentation": True,
"resample_artifacts_augmentation": True,
"scaling_augmentation": True,
"noise_augmentation": True,
"random_conv_augmentation": True,
}
augmentation_probabilities = {
"censor_or_quantization_augmentation": 0.40,
"mixup_augmentation": 0.50,
"time_flip_augmentation": 0.30,
"yflip_augmentation": 0.30,
"differential_augmentation": 0.40,
"regime_change_augmentation": 0.40,
"shock_recovery_augmentation": 0.40,
"calendar_augmentation": 0.40,
"amplitude_modulation_augmentation": 0.35,
"resample_artifacts_augmentation": 0.40,
"scaling_augmentation": 0.50,
"noise_augmentation": 0.1,
"random_conv_augmentation": 0.30,
}
try:
generator = OfflinePerSampleAugmentedGenerator(
base_data_dir=args.base_data_dir,
output_dir=args.output_dir,
length=args.length,
chunk_size=args.chunk_size,
generator_proportions=generator_proportions,
augmentations=augmentations,
augmentation_probabilities=augmentation_probabilities,
global_seed=args.global_seed,
mixup_position=args.mixup_position,
change_threshold=args.change_threshold,
max_tries=args.max_tries,
enable_quality_filter=args.enable_quality_filter,
rc_batch_size=args.rc_batch_size,
)
generator.run(num_batches=args.num_batches)
except Exception as e:
logging.error(f"Fatal error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()