import argparse import logging import sys import time from pathlib import Path from typing import Any import numpy as np import pandas as pd import torch from src.data.augmentations import ( CensorAugmenter, DifferentialAugmenter, MixUpAugmenter, QuantizationAugmenter, RandomConvAugmenter, TimeFlipAugmenter, YFlipAugmenter, ) from src.data.constants import LENGTH_CHOICES from src.data.datasets import CyclicalBatchDataset from src.data.filter import is_low_quality from src.data.scalers import MeanScaler, MedianScaler, MinMaxScaler, RobustScaler from src.synthetic_generation.augmentations.offline_per_sample_iid_augmentations import ( TimeSeriesDatasetManager, UnivariateOfflineAugmentor, ) class OfflineTempBatchAugmentedGenerator: def __init__( self, base_data_dir: str, output_dir: str, length: int | None, mixed_batch_size: int = 10, chunk_size: int = 2**13, generator_proportions: dict[str, float] | None = None, augmentations: dict[str, bool] | None = None, augmentation_probabilities: dict[str, float] | None = None, global_seed: int = 42, mixup_position: str = "both", selection_strategy: str = "random", change_threshold: float = 0.05, enable_quality_filter: bool = False, temp_batch_retries: int = 3, ): self.base_data_dir = base_data_dir self.length = length self.mixed_batch_size = mixed_batch_size self.chunk_size = chunk_size self.global_seed = global_seed np.random.seed(global_seed) torch.manual_seed(global_seed) out_dir_name = f"augmented_temp_batch_{length}" if length is not None else "augmented_temp_batch" self.dataset_manager = TimeSeriesDatasetManager(str(Path(output_dir) / out_dir_name), batch_size=chunk_size) # Augmentation config self.augmentation_probabilities = augmentation_probabilities or {} self.augmentations = augmentations or {} self.apply_augmentations = any(self.augmentations.values()) # RNG for category choices and sampling self.rng = np.random.default_rng(global_seed) # Mixup placement and selection strategy self.mixup_position = mixup_position self.selection_strategy = selection_strategy self.change_threshold = float(change_threshold) self.enable_quality_filter = bool(enable_quality_filter) self.temp_batch_retries = int(temp_batch_retries) # Initialize augmenters as in old composer ordering self.flip_augmenter = None if self.augmentations.get("time_flip_augmentation", False): self.flip_augmenter = TimeFlipAugmenter( p_flip=self.augmentation_probabilities.get("time_flip_augmentation", 0.5) ) self.yflip_augmenter = None if self.augmentations.get("yflip_augmentation", False): self.yflip_augmenter = YFlipAugmenter(p_flip=self.augmentation_probabilities.get("yflip_augmentation", 0.5)) self.censor_augmenter = None if self.augmentations.get("censor_augmentation", False): self.censor_augmenter = CensorAugmenter() self.quantization_augmenter = None if self.augmentations.get("quantization_augmentation", False): self.quantization_augmenter = QuantizationAugmenter( p_quantize=self.augmentation_probabilities.get("censor_or_quantization_augmentation", 0.5), level_range=(5, 15), ) self.mixup_augmenter = None if self.augmentations.get("mixup_augmentation", False): self.mixup_augmenter = MixUpAugmenter( p_combine=self.augmentation_probabilities.get("mixup_augmentation", 0.5) ) self.differential_augmentor = None if self.augmentations.get("differential_augmentation", False): self.differential_augmentor = DifferentialAugmenter( p_transform=self.augmentation_probabilities.get("differential_augmentation", 0.5) ) self.random_conv_augmenter = None if self.augmentations.get("random_conv_augmentation", False): self.random_conv_augmenter = RandomConvAugmenter( p_transform=self.augmentation_probabilities.get("random_conv_augmentation", 0.3) ) self.generator_proportions = self._setup_proportions(generator_proportions) self.datasets = self._initialize_datasets() # Per-series augmentor from offline_augmentations.py (categories only) self.per_series_augmentor = UnivariateOfflineAugmentor( augmentations=self.augmentations, augmentation_probabilities=self.augmentation_probabilities, global_seed=global_seed, ) def _compute_change_scores(self, original_batch: torch.Tensor, augmented_batch: torch.Tensor) -> np.ndarray: # Normalized MAE vs IQR (q25-q75) per element bsz = augmented_batch.shape[0] scores: list[float] = [] for i in range(bsz): base_flat = original_batch[i].reshape(-1) q25 = torch.quantile(base_flat, 0.25) q75 = torch.quantile(base_flat, 0.75) iqr = (q75 - q25).item() iqr = iqr if iqr > 0 else 1e-6 mae = torch.mean(torch.abs(augmented_batch[i] - original_batch[i])).item() scores.append(mae / iqr) return np.asarray(scores, dtype=float) def _setup_proportions(self, generator_proportions: dict[str, float] | None) -> dict[str, float]: # Default uniform across discovered generators if generator_proportions is None: base = Path(self.base_data_dir) discovered = [p.name for p in base.iterdir() if p.is_dir()] proportions = dict.fromkeys(discovered, 1.0) else: proportions = dict(generator_proportions) total = sum(proportions.values()) if total <= 0: raise ValueError("Total generator proportions must be positive") return {k: v / total for k, v in proportions.items()} def _initialize_datasets(self) -> dict[str, CyclicalBatchDataset]: datasets: dict[str, CyclicalBatchDataset] = {} for generator_name, proportion in self.generator_proportions.items(): if proportion <= 0: continue batches_dir = Path(self.base_data_dir) / generator_name if not batches_dir.is_dir(): logging.warning(f"Skipping '{generator_name}' because directory does not exist: {batches_dir}") continue try: dataset = CyclicalBatchDataset( batches_dir=str(batches_dir), generator_type=generator_name, device=None, prefetch_next=True, prefetch_threshold=32, ) datasets[generator_name] = dataset logging.info(f"Loaded dataset for {generator_name}") except Exception as e: logging.warning(f"Failed to load dataset for {generator_name}: {e}") if not datasets: raise ValueError("No valid datasets loaded from base_data_dir") return datasets def _sample_generator_name(self) -> str: available = [g for g in self.generator_proportions.keys() if g in self.datasets] probs = np.array([self.generator_proportions[g] for g in available], dtype=float) probs = probs / probs.sum() return str(self.rng.choice(available, p=probs)) def _series_key(self, gen_name: str, sample: dict, values: torch.Tensor) -> str: series_id = sample.get("series_id", None) if series_id is not None: return f"{gen_name}:{series_id}" # Fallback: hash by values and metadata try: arr = values.detach().cpu().numpy() h = hash( ( gen_name, sample.get("start", None), sample.get("frequency", None), arr.shape, float(arr.mean()), float(arr.std()), ) ) return f"{gen_name}:hash:{h}" except Exception: return f"{gen_name}:rand:{self.rng.integers(0, 1 << 31)}" def _convert_sample_to_tensor(self, sample: dict) -> tuple[torch.Tensor, pd.Timestamp, str, int]: num_channels = sample.get("num_channels", 1) values_data = sample["values"] if num_channels == 1: if isinstance(values_data[0], list): values = torch.tensor(values_data[0], dtype=torch.float32) else: values = torch.tensor(values_data, dtype=torch.float32) values = values.unsqueeze(0).unsqueeze(-1) else: channel_tensors = [] for channel_values in values_data: channel_tensor = torch.tensor(channel_values, dtype=torch.float32) channel_tensors.append(channel_tensor) values = torch.stack(channel_tensors, dim=-1).unsqueeze(0) freq_str = sample["frequency"] start_val = sample["start"] start = start_val if isinstance(start_val, pd.Timestamp) else pd.Timestamp(start_val) return values, start, freq_str, num_channels def _shorten_like_batch_composer(self, values: torch.Tensor, target_len: int) -> torch.Tensor | None: # Only shorten if longer; if shorter than target_len, reject (to keep batch aligned) seq_len = int(values.shape[1]) if seq_len == target_len: return values if seq_len < target_len: return None # Randomly choose cut or subsample with equal probability strategy = str(self.rng.choice(["cut", "subsample"])) if strategy == "cut": max_start_idx = seq_len - target_len start_idx = int(self.rng.integers(0, max_start_idx + 1)) return values[:, start_idx : start_idx + target_len, :] # Subsample evenly spaced indices indices = np.linspace(0, seq_len - 1, target_len, dtype=int) return values[:, indices, :] def _maybe_apply_scaler(self, values: torch.Tensor) -> torch.Tensor: scaler_choice = str(self.rng.choice(["robust", "minmax", "median", "mean", "none"])) scaler = None if scaler_choice == "robust": scaler = RobustScaler() elif scaler_choice == "minmax": scaler = MinMaxScaler() elif scaler_choice == "median": scaler = MedianScaler() elif scaler_choice == "mean": scaler = MeanScaler() if scaler is not None: values = scaler.scale(values, scaler.compute_statistics(values)) return values def _apply_augmentations( self, batch_values: torch.Tensor, starts: list[pd.Timestamp], freqs: list[str], ) -> torch.Tensor: if not self.apply_augmentations: return batch_values # 1) Early mixup (batch-level) if ( self.mixup_position in ["first", "both"] and self.augmentations.get("mixup_augmentation", False) and self.mixup_augmenter is not None ): batch_values = self.mixup_augmenter.transform(batch_values) # 2) Per-series categories (apply to ALL series with starts/freqs) batch_size = int(batch_values.shape[0]) augmented_list = [] for i in range(batch_size): s = batch_values[i : i + 1] start_i = starts[i] if i < len(starts) else None freq_i = freqs[i] if i < len(freqs) else None s_aug = self.per_series_augmentor.apply_per_series_only(s, start=start_i, frequency=freq_i) augmented_list.append(s_aug) batch_values = torch.cat(augmented_list, dim=0) # 3) Noise augmentation (batch-level) if self.augmentations.get("noise_augmentation", False): if self.rng.random() < self.augmentation_probabilities.get("noise_augmentation", 0.5): noise_std = 0.01 * torch.std(batch_values) if torch.isfinite(noise_std) and (noise_std > 0): noise = torch.normal(0, noise_std, size=batch_values.shape) batch_values = batch_values + noise # 4) Scaling augmentation (batch-level) if self.augmentations.get("scaling_augmentation", False): if self.rng.random() < self.augmentation_probabilities.get("scaling_augmentation", 0.5): scale_factor = float(self.rng.uniform(0.95, 1.05)) batch_values = batch_values * scale_factor # 5) RandomConvAugmenter (batch-level) if self.augmentations.get("random_conv_augmentation", False) and self.random_conv_augmenter is not None: if self.rng.random() < self.augmentation_probabilities.get("random_conv_augmentation", 0.3): batch_values = self.random_conv_augmenter.transform(batch_values) # 6) Late mixup (batch-level) if ( self.mixup_position in ["last", "both"] and self.augmentations.get("mixup_augmentation", False) and self.mixup_augmenter is not None ): batch_values = self.mixup_augmenter.transform(batch_values) return batch_values def _get_one_source_sample( self, total_length_for_batch: int, used_source_keys: set ) -> tuple[torch.Tensor, pd.Timestamp, str, str] | None: # Returns (values, start, freq, source_key) or None if cannot fetch attempts = 0 while attempts < 50: attempts += 1 gen_name = self._sample_generator_name() dataset = self.datasets[gen_name] sample = dataset.get_samples(1)[0] values, start, freq_str, num_channels = self._convert_sample_to_tensor(sample) if num_channels != 1: continue # Reject NaNs if torch.isnan(values).any(): continue # Shorten to target_len; reject if too short shortened = self._shorten_like_batch_composer(values, total_length_for_batch) if shortened is None: continue values = shortened # Random scaler values = self._maybe_apply_scaler(values) # Uniqueness check key = self._series_key(gen_name, sample, values) if key in used_source_keys: continue # Reserve key immediately to avoid re-use in same temp batch used_source_keys.add(key) return values, start, freq_str, key return None def _tensor_to_values_list(self, series_tensor: torch.Tensor) -> tuple[list[list[float]], int, int]: seq_len = int(series_tensor.shape[1]) num_channels = int(series_tensor.shape[2]) if num_channels == 1: return [series_tensor.squeeze(0).squeeze(-1).tolist()], seq_len, 1 channels: list[list[float]] = [] for ch in range(num_channels): channels.append(series_tensor[0, :, ch].tolist()) return channels, seq_len, num_channels def run(self, num_batches: int) -> None: logging.info( f"Starting offline IID augmentation into {self.dataset_manager.batches_dir} | " f"chunk_size={self.chunk_size} | " f"mixed_batch_size={self.mixed_batch_size}" ) augmented_buffer: list[dict[str, Any]] = [] target_batches = num_batches start_time = time.time() try: while self.dataset_manager.batch_counter < target_batches: # Decide target length for this temp batch total_length_for_batch = ( self.length if self.length is not None else int(self.rng.choice(LENGTH_CHOICES)) ) selected_record: dict[str, Any] | None = None for _retry in range(max(1, self.temp_batch_retries + 1)): # Collect a temporary mixed batch without reusing sources temp_values_list: list[torch.Tensor] = [] temp_starts: list[pd.Timestamp] = [] temp_freqs: list[str] = [] temp_used_keys: set = set() attempts = 0 while len(temp_values_list) < self.mixed_batch_size and attempts < self.mixed_batch_size * 200: attempts += 1 fetched = self._get_one_source_sample(total_length_for_batch, temp_used_keys) if fetched is None: continue values, start, freq, _ = fetched temp_values_list.append(values) temp_starts.append(start) temp_freqs.append(freq) if len(temp_values_list) == 0: # If we could not fetch anything, rebuild next retry continue temp_batch = torch.cat(temp_values_list, dim=0) original_temp_batch = temp_batch.clone() # Apply augmentations sequentially augmented_temp_batch = self._apply_augmentations(temp_batch, temp_starts, temp_freqs) # Compute change scores scores = self._compute_change_scores(original_temp_batch, augmented_temp_batch) # Build eligible indices by threshold eligible = np.where(scores >= self.change_threshold)[0].tolist() # Apply quality filter if enabled if self.enable_quality_filter: eligible_q: list[int] = [] for idx in eligible: cand = augmented_temp_batch[idx : idx + 1] if not is_low_quality(cand): eligible_q.append(idx) eligible = eligible_q sel_idx: int | None = None if self.selection_strategy == "max_change": if eligible: sel_idx = int(max(eligible, key=lambda i: scores[i])) else: # Fallback to best by score (respect quality if possible) if self.enable_quality_filter: qual_idxs = [ i for i in range(augmented_temp_batch.shape[0]) if not is_low_quality(augmented_temp_batch[i : i + 1]) ] if qual_idxs: sel_idx = int(max(qual_idxs, key=lambda i: scores[i])) if sel_idx is None: sel_idx = int(np.argmax(scores)) else: # random selection among eligible, else fallback to best if eligible: sel_idx = int(self.rng.choice(np.asarray(eligible, dtype=int))) else: if self.enable_quality_filter: qual_idxs = [ i for i in range(augmented_temp_batch.shape[0]) if not is_low_quality(augmented_temp_batch[i : i + 1]) ] if qual_idxs: sel_idx = int(max(qual_idxs, key=lambda i: scores[i])) if sel_idx is None: sel_idx = int(np.argmax(scores)) # If still none (shouldn't happen), rebuild if sel_idx is None: continue selected_series = augmented_temp_batch[sel_idx : sel_idx + 1] values_list, seq_len, num_channels = self._tensor_to_values_list(selected_series) selected_record = { "series_id": self.dataset_manager.series_counter, "values": values_list, "length": int(seq_len), "num_channels": int(num_channels), "generator_type": "augmented", "start": pd.Timestamp(temp_starts[sel_idx]), "frequency": temp_freqs[sel_idx], "generation_timestamp": pd.Timestamp.now(), } break if selected_record is None: # Could not assemble a valid candidate after retries; skip iteration continue augmented_buffer.append(selected_record) if len(augmented_buffer) >= self.chunk_size: write_start = time.time() self.dataset_manager.append_batch(augmented_buffer) write_time = time.time() - write_start elapsed = time.time() - start_time series_per_sec = self.dataset_manager.series_counter / elapsed if elapsed > 0 else 0 print( f"✓ Wrote batch {self.dataset_manager.batch_counter - 1}/{target_batches} | " f"Series: {self.dataset_manager.series_counter:,} | " f"Rate: {series_per_sec:.1f}/s | " f"Write: {write_time:.2f}s" ) augmented_buffer = [] except KeyboardInterrupt: logging.info( f"Interrupted. Generated {self.dataset_manager.series_counter} series, " f"{self.dataset_manager.batch_counter} batches." ) finally: if augmented_buffer: self.dataset_manager.append_batch(augmented_buffer) logging.info("Offline IID augmentation completed.") def setup_logging(verbose: bool = False) -> None: level = logging.DEBUG if verbose else logging.INFO logging.basicConfig( level=level, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[logging.StreamHandler(sys.stdout)], ) def main(): parser = argparse.ArgumentParser( description="Offline IID augmentation script using temp mixed batches", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--base-data-dir", type=str, required=True, help="Base directory with generator subdirectories (inputs)", ) parser.add_argument( "--output-dir", type=str, required=True, help="Base output directory for augmented datasets", ) parser.add_argument( "--length", type=int, default=None, help="Fixed length of augmented series. If set, saves under augmented{length}", ) parser.add_argument( "--mixed-batch-size", type=int, default=64, help="Temporary mixed batch size before selecting a single element", ) parser.add_argument( "--chunk-size", type=int, default=2**13, help="Number of series per written Arrow batch", ) parser.add_argument( "--num-batches", type=int, default=1000, help="Number of Arrow batches to write", ) parser.add_argument( "--mixup-position", type=str, default="both", choices=["first", "last", "both"], help="Where to apply mixup in the pipeline (first, last, or both)", ) parser.add_argument( "--selection-strategy", type=str, default="random", choices=["random", "max_change"], help="How to select the final series from the temp batch", ) parser.add_argument( "--change-threshold", type=float, default=0.05, help="Minimum normalized change score (vs IQR) required for selection", ) parser.add_argument( "--enable-quality-filter", action="store_true", help="Enable low-quality filter using autocorr/SNR/complexity", ) parser.add_argument( "--temp-batch-retries", type=int, default=3, help="Number of times to rebuild temp batch if selection fails thresholds", ) parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") parser.add_argument("--global-seed", type=int, default=42, help="Global random seed") args = parser.parse_args() setup_logging(args.verbose) generator_proportions = { "forecast_pfn": 1.0, "gp": 1.0, "kernel": 1.0, "sinewave": 1.0, "sawtooth": 1.0, "step": 0.1, "anomaly": 1.0, "spike": 1.0, "cauker_univariate": 2.0, "ou_process": 1.0, "audio_financial_volatility": 0.1, "audio_multi_scale_fractal": 0.1, "audio_network_topology": 0.5, "audio_stochastic_rhythm": 1.0, } # Defaults reflecting configs/train.yaml from the prompt augmentations = { "censor_augmentation": True, "quantization_augmentation": False, "mixup_augmentation": True, "time_flip_augmentation": True, "yflip_augmentation": True, "differential_augmentation": True, "regime_change_augmentation": True, "shock_recovery_augmentation": True, "calendar_augmentation": False, "amplitude_modulation_augmentation": True, "resample_artifacts_augmentation": True, "scaling_augmentation": True, "noise_augmentation": True, "random_conv_augmentation": True, } augmentation_probabilities = { "censor_or_quantization_augmentation": 0.40, "mixup_augmentation": 0.50, "time_flip_augmentation": 0.30, "yflip_augmentation": 0.30, "differential_augmentation": 0.40, "regime_change_augmentation": 0.40, "shock_recovery_augmentation": 0.40, "calendar_augmentation": 0.40, "amplitude_modulation_augmentation": 0.35, "resample_artifacts_augmentation": 0.40, "scaling_augmentation": 0.50, "noise_augmentation": 0.10, "random_conv_augmentation": 0.30, } try: generator = OfflineTempBatchAugmentedGenerator( base_data_dir=args.base_data_dir, output_dir=args.output_dir, length=args.length, mixed_batch_size=args.mixed_batch_size, chunk_size=args.chunk_size, generator_proportions=generator_proportions, augmentations=augmentations, augmentation_probabilities=augmentation_probabilities, global_seed=args.global_seed, mixup_position=args.mixup_position, selection_strategy=args.selection_strategy, change_threshold=args.change_threshold, enable_quality_filter=args.enable_quality_filter, temp_batch_retries=args.temp_batch_retries, ) generator.run(num_batches=args.num_batches) except Exception as e: logging.error(f"Fatal error: {e}") sys.exit(1) if __name__ == "__main__": main()