lemousehunter
feat: training speed optimizations — mixed precision, vectorized augmentation, cached eval predictions
1fe1a19 | """ | |
| Data Augmentation for Side-Channel Analysis | |
| ============================================ | |
| Implements on-the-fly data augmentation strategies for profiling traces. | |
| Currently supported: | |
| - **Random shift**: Circular shift of each trace by a random integer | |
| in [-max_shift, +max_shift]. This simulates desynchronization and | |
| forces the CNN to learn shift-invariant features, acting as a | |
| powerful regularizer for protected AES implementations. | |
| References: | |
| - Li, H. & Perin, G. (2024). A Systematic Study of Data Augmentation | |
| for Protected AES Implementations. J. Cryptographic Engineering. | |
| - Wu, L. et al. (2023). Breaking Free: Leakage Model-free DLSCA. | |
| IACR ePrint 2023/1110. (Optimal shift = 5 samples.) | |
| """ | |
| import logging | |
| from typing import Dict, Optional, Tuple | |
| import numpy as np | |
| import tensorflow as tf | |
| logger = logging.getLogger(__name__) | |
| class RandomShiftAugmentor: | |
| """ | |
| Applies random circular shift augmentation to 1D traces. | |
| Each trace in a batch is independently shifted by a random integer | |
| drawn uniformly from [-max_shift, +max_shift]. The shift is circular | |
| (wraps around), preserving trace length and total energy. | |
| This augmentation is applied on-the-fly during training via a | |
| tf.data.Dataset map operation, so it does not increase memory usage. | |
| """ | |
| def __init__(self, max_shift: int = 5) -> None: | |
| """ | |
| Args: | |
| max_shift: Maximum shift in either direction (samples). | |
| Wu et al. (2023) found 5 to be optimal for ASCAD. | |
| """ | |
| if max_shift < 0: | |
| raise ValueError(f"max_shift must be non-negative, got {max_shift}") | |
| self.max_shift = max_shift | |
| logger.info( | |
| "RandomShiftAugmentor initialized: max_shift=%d samples", max_shift | |
| ) | |
| def augment_numpy( | |
| self, traces: np.ndarray, rng: Optional[np.random.Generator] = None | |
| ) -> np.ndarray: | |
| """ | |
| Apply random shift to a batch of traces (NumPy version). | |
| Args: | |
| traces: Array of shape (N, T) or (N, T, 1). | |
| rng: Optional NumPy random generator for reproducibility. | |
| Returns: | |
| Shifted traces with the same shape as input. | |
| """ | |
| if self.max_shift == 0: | |
| return traces | |
| if rng is None: | |
| rng = np.random.default_rng() | |
| squeeze = False | |
| if traces.ndim == 3 and traces.shape[2] == 1: | |
| traces = traces[:, :, 0] | |
| squeeze = True | |
| n = traces.shape[0] | |
| shifts = rng.integers(-self.max_shift, self.max_shift + 1, size=n) | |
| augmented = np.empty_like(traces) | |
| for i in range(n): | |
| augmented[i] = np.roll(traces[i], shifts[i]) | |
| if squeeze: | |
| augmented = augmented[:, :, np.newaxis] | |
| return augmented | |
| def make_tf_dataset( | |
| self, | |
| traces: np.ndarray, | |
| labels: Dict[str, np.ndarray], | |
| batch_size: int, | |
| seed: int = 42, | |
| ) -> tf.data.Dataset: | |
| """ | |
| Create a tf.data.Dataset with on-the-fly random shift augmentation. | |
| The dataset yields (augmented_traces, labels) tuples suitable for | |
| model.fit(). Augmentation is applied per-batch using a vectorized | |
| gather-based circular shift (no tf.map_fn). | |
| Args: | |
| traces: Training traces of shape (N, T, 1). | |
| labels: Dictionary mapping "byte_i" to one-hot label arrays. | |
| batch_size: Training batch size. | |
| seed: Random seed for the augmentation RNG. | |
| Returns: | |
| A tf.data.Dataset that yields (traces, labels) with augmentation. | |
| """ | |
| n_samples = traces.shape[0] | |
| trace_len = traces.shape[1] | |
| max_shift = self.max_shift | |
| # Create a dataset from the indices to allow shuffling | |
| indices_ds = tf.data.Dataset.from_tensor_slices( | |
| tf.range(n_samples, dtype=tf.int32) | |
| ) | |
| # Shuffle and batch the indices | |
| indices_ds = indices_ds.shuffle( | |
| buffer_size=min(n_samples, 50000), seed=seed | |
| ).batch(batch_size, drop_remainder=False) | |
| # Store references for the map function | |
| traces_tensor = tf.constant(traces, dtype=tf.float32) | |
| # Build label tensors dict | |
| label_keys = sorted(labels.keys()) | |
| label_tensors = {k: tf.constant(labels[k], dtype=tf.float32) for k in label_keys} | |
| # Pre-compute the base index range [0, 1, ..., T-1] for vectorized shift | |
| base_indices = tf.range(trace_len, dtype=tf.int32) # shape: (T,) | |
| def gather_and_augment(batch_indices): | |
| """Gather traces/labels for batch and apply vectorized random shift.""" | |
| batch_traces = tf.gather(traces_tensor, batch_indices) | |
| if max_shift > 0: | |
| batch_size_actual = tf.shape(batch_indices)[0] | |
| # Random shift per trace in the batch | |
| shifts = tf.random.uniform( | |
| shape=[batch_size_actual], | |
| minval=-max_shift, | |
| maxval=max_shift + 1, | |
| dtype=tf.int32, | |
| ) | |
| # Vectorized circular shift using tf.gather: | |
| # For each trace, compute shifted_indices = (base - shift) % T | |
| # shifts shape: (B,) → (B, 1) for broadcasting with (T,) | |
| shifts_expanded = tf.expand_dims(shifts, axis=1) # (B, 1) | |
| # (B, T) = broadcast of (T,) - (B, 1), then mod T | |
| shifted_indices = tf.math.floormod( | |
| base_indices - shifts_expanded, trace_len | |
| ) # shape: (B, T) | |
| # Gather along the time axis for each trace in the batch | |
| # batch_traces shape: (B, T, 1) | |
| # We need to gather along axis=1 with per-row indices | |
| batch_idx = tf.repeat( | |
| tf.range(batch_size_actual)[:, tf.newaxis], | |
| trace_len, axis=1 | |
| ) # (B, T) | |
| gather_indices = tf.stack( | |
| [batch_idx, shifted_indices], axis=-1 | |
| ) # (B, T, 2) | |
| batch_traces = tf.gather_nd(batch_traces, gather_indices) | |
| # Restore channel dimension: (B, T) → (B, T, 1) | |
| batch_traces = tf.expand_dims(batch_traces, axis=-1) | |
| batch_labels = {k: tf.gather(label_tensors[k], batch_indices) for k in label_keys} | |
| return batch_traces, batch_labels | |
| dataset = indices_ds.map( | |
| gather_and_augment, | |
| num_parallel_calls=tf.data.AUTOTUNE, | |
| ) | |
| dataset = dataset.prefetch(tf.data.AUTOTUNE) | |
| logger.info( | |
| "Created augmented tf.data.Dataset: %d samples, batch=%d, " | |
| "max_shift=%d, ~%d batches/epoch (vectorized shift)", | |
| n_samples, batch_size, max_shift, | |
| (n_samples + batch_size - 1) // batch_size, | |
| ) | |
| return dataset | |