""" Data Augmentation for Side-Channel Analysis ============================================ Implements on-the-fly data augmentation strategies for profiling traces. Currently supported: - **Random shift**: Circular shift of each trace by a random integer in [-max_shift, +max_shift]. This simulates desynchronization and forces the CNN to learn shift-invariant features, acting as a powerful regularizer for protected AES implementations. References: - Li, H. & Perin, G. (2024). A Systematic Study of Data Augmentation for Protected AES Implementations. J. Cryptographic Engineering. - Wu, L. et al. (2023). Breaking Free: Leakage Model-free DLSCA. IACR ePrint 2023/1110. (Optimal shift = 5 samples.) """ import logging from typing import Dict, Optional, Tuple import numpy as np import tensorflow as tf logger = logging.getLogger(__name__) class RandomShiftAugmentor: """ Applies random circular shift augmentation to 1D traces. Each trace in a batch is independently shifted by a random integer drawn uniformly from [-max_shift, +max_shift]. The shift is circular (wraps around), preserving trace length and total energy. This augmentation is applied on-the-fly during training via a tf.data.Dataset map operation, so it does not increase memory usage. """ def __init__(self, max_shift: int = 5) -> None: """ Args: max_shift: Maximum shift in either direction (samples). Wu et al. (2023) found 5 to be optimal for ASCAD. """ if max_shift < 0: raise ValueError(f"max_shift must be non-negative, got {max_shift}") self.max_shift = max_shift logger.info( "RandomShiftAugmentor initialized: max_shift=%d samples", max_shift ) def augment_numpy( self, traces: np.ndarray, rng: Optional[np.random.Generator] = None ) -> np.ndarray: """ Apply random shift to a batch of traces (NumPy version). Args: traces: Array of shape (N, T) or (N, T, 1). rng: Optional NumPy random generator for reproducibility. Returns: Shifted traces with the same shape as input. """ if self.max_shift == 0: return traces if rng is None: rng = np.random.default_rng() squeeze = False if traces.ndim == 3 and traces.shape[2] == 1: traces = traces[:, :, 0] squeeze = True n = traces.shape[0] shifts = rng.integers(-self.max_shift, self.max_shift + 1, size=n) augmented = np.empty_like(traces) for i in range(n): augmented[i] = np.roll(traces[i], shifts[i]) if squeeze: augmented = augmented[:, :, np.newaxis] return augmented def make_tf_dataset( self, traces: np.ndarray, labels: Dict[str, np.ndarray], batch_size: int, seed: int = 42, ) -> tf.data.Dataset: """ Create a tf.data.Dataset with on-the-fly random shift augmentation. The dataset yields (augmented_traces, labels) tuples suitable for model.fit(). Augmentation is applied per-batch using a vectorized gather-based circular shift (no tf.map_fn). Args: traces: Training traces of shape (N, T, 1). labels: Dictionary mapping "byte_i" to one-hot label arrays. batch_size: Training batch size. seed: Random seed for the augmentation RNG. Returns: A tf.data.Dataset that yields (traces, labels) with augmentation. """ n_samples = traces.shape[0] trace_len = traces.shape[1] max_shift = self.max_shift # Create a dataset from the indices to allow shuffling indices_ds = tf.data.Dataset.from_tensor_slices( tf.range(n_samples, dtype=tf.int32) ) # Shuffle and batch the indices indices_ds = indices_ds.shuffle( buffer_size=min(n_samples, 50000), seed=seed ).batch(batch_size, drop_remainder=False) # Store references for the map function traces_tensor = tf.constant(traces, dtype=tf.float32) # Build label tensors dict label_keys = sorted(labels.keys()) label_tensors = {k: tf.constant(labels[k], dtype=tf.float32) for k in label_keys} # Pre-compute the base index range [0, 1, ..., T-1] for vectorized shift base_indices = tf.range(trace_len, dtype=tf.int32) # shape: (T,) def gather_and_augment(batch_indices): """Gather traces/labels for batch and apply vectorized random shift.""" batch_traces = tf.gather(traces_tensor, batch_indices) if max_shift > 0: batch_size_actual = tf.shape(batch_indices)[0] # Random shift per trace in the batch shifts = tf.random.uniform( shape=[batch_size_actual], minval=-max_shift, maxval=max_shift + 1, dtype=tf.int32, ) # Vectorized circular shift using tf.gather: # For each trace, compute shifted_indices = (base - shift) % T # shifts shape: (B,) → (B, 1) for broadcasting with (T,) shifts_expanded = tf.expand_dims(shifts, axis=1) # (B, 1) # (B, T) = broadcast of (T,) - (B, 1), then mod T shifted_indices = tf.math.floormod( base_indices - shifts_expanded, trace_len ) # shape: (B, T) # Gather along the time axis for each trace in the batch # batch_traces shape: (B, T, 1) # We need to gather along axis=1 with per-row indices batch_idx = tf.repeat( tf.range(batch_size_actual)[:, tf.newaxis], trace_len, axis=1 ) # (B, T) gather_indices = tf.stack( [batch_idx, shifted_indices], axis=-1 ) # (B, T, 2) batch_traces = tf.gather_nd(batch_traces, gather_indices) # Restore channel dimension: (B, T) → (B, T, 1) batch_traces = tf.expand_dims(batch_traces, axis=-1) batch_labels = {k: tf.gather(label_tensors[k], batch_indices) for k in label_keys} return batch_traces, batch_labels dataset = indices_ds.map( gather_and_augment, num_parallel_calls=tf.data.AUTOTUNE, ) dataset = dataset.prefetch(tf.data.AUTOTUNE) logger.info( "Created augmented tf.data.Dataset: %d samples, batch=%d, " "max_shift=%d, ~%d batches/epoch (vectorized shift)", n_samples, batch_size, max_shift, (n_samples + batch_size - 1) // batch_size, ) return dataset