ascad-training-pipeline / src /augmentation.py
lemousehunter
feat: training speed optimizations — mixed precision, vectorized augmentation, cached eval predictions
1fe1a19
"""
Data Augmentation for Side-Channel Analysis
============================================
Implements on-the-fly data augmentation strategies for profiling traces.
Currently supported:
- **Random shift**: Circular shift of each trace by a random integer
in [-max_shift, +max_shift]. This simulates desynchronization and
forces the CNN to learn shift-invariant features, acting as a
powerful regularizer for protected AES implementations.
References:
- Li, H. & Perin, G. (2024). A Systematic Study of Data Augmentation
for Protected AES Implementations. J. Cryptographic Engineering.
- Wu, L. et al. (2023). Breaking Free: Leakage Model-free DLSCA.
IACR ePrint 2023/1110. (Optimal shift = 5 samples.)
"""
import logging
from typing import Dict, Optional, Tuple
import numpy as np
import tensorflow as tf
logger = logging.getLogger(__name__)
class RandomShiftAugmentor:
"""
Applies random circular shift augmentation to 1D traces.
Each trace in a batch is independently shifted by a random integer
drawn uniformly from [-max_shift, +max_shift]. The shift is circular
(wraps around), preserving trace length and total energy.
This augmentation is applied on-the-fly during training via a
tf.data.Dataset map operation, so it does not increase memory usage.
"""
def __init__(self, max_shift: int = 5) -> None:
"""
Args:
max_shift: Maximum shift in either direction (samples).
Wu et al. (2023) found 5 to be optimal for ASCAD.
"""
if max_shift < 0:
raise ValueError(f"max_shift must be non-negative, got {max_shift}")
self.max_shift = max_shift
logger.info(
"RandomShiftAugmentor initialized: max_shift=%d samples", max_shift
)
def augment_numpy(
self, traces: np.ndarray, rng: Optional[np.random.Generator] = None
) -> np.ndarray:
"""
Apply random shift to a batch of traces (NumPy version).
Args:
traces: Array of shape (N, T) or (N, T, 1).
rng: Optional NumPy random generator for reproducibility.
Returns:
Shifted traces with the same shape as input.
"""
if self.max_shift == 0:
return traces
if rng is None:
rng = np.random.default_rng()
squeeze = False
if traces.ndim == 3 and traces.shape[2] == 1:
traces = traces[:, :, 0]
squeeze = True
n = traces.shape[0]
shifts = rng.integers(-self.max_shift, self.max_shift + 1, size=n)
augmented = np.empty_like(traces)
for i in range(n):
augmented[i] = np.roll(traces[i], shifts[i])
if squeeze:
augmented = augmented[:, :, np.newaxis]
return augmented
def make_tf_dataset(
self,
traces: np.ndarray,
labels: Dict[str, np.ndarray],
batch_size: int,
seed: int = 42,
) -> tf.data.Dataset:
"""
Create a tf.data.Dataset with on-the-fly random shift augmentation.
The dataset yields (augmented_traces, labels) tuples suitable for
model.fit(). Augmentation is applied per-batch using a vectorized
gather-based circular shift (no tf.map_fn).
Args:
traces: Training traces of shape (N, T, 1).
labels: Dictionary mapping "byte_i" to one-hot label arrays.
batch_size: Training batch size.
seed: Random seed for the augmentation RNG.
Returns:
A tf.data.Dataset that yields (traces, labels) with augmentation.
"""
n_samples = traces.shape[0]
trace_len = traces.shape[1]
max_shift = self.max_shift
# Create a dataset from the indices to allow shuffling
indices_ds = tf.data.Dataset.from_tensor_slices(
tf.range(n_samples, dtype=tf.int32)
)
# Shuffle and batch the indices
indices_ds = indices_ds.shuffle(
buffer_size=min(n_samples, 50000), seed=seed
).batch(batch_size, drop_remainder=False)
# Store references for the map function
traces_tensor = tf.constant(traces, dtype=tf.float32)
# Build label tensors dict
label_keys = sorted(labels.keys())
label_tensors = {k: tf.constant(labels[k], dtype=tf.float32) for k in label_keys}
# Pre-compute the base index range [0, 1, ..., T-1] for vectorized shift
base_indices = tf.range(trace_len, dtype=tf.int32) # shape: (T,)
def gather_and_augment(batch_indices):
"""Gather traces/labels for batch and apply vectorized random shift."""
batch_traces = tf.gather(traces_tensor, batch_indices)
if max_shift > 0:
batch_size_actual = tf.shape(batch_indices)[0]
# Random shift per trace in the batch
shifts = tf.random.uniform(
shape=[batch_size_actual],
minval=-max_shift,
maxval=max_shift + 1,
dtype=tf.int32,
)
# Vectorized circular shift using tf.gather:
# For each trace, compute shifted_indices = (base - shift) % T
# shifts shape: (B,) → (B, 1) for broadcasting with (T,)
shifts_expanded = tf.expand_dims(shifts, axis=1) # (B, 1)
# (B, T) = broadcast of (T,) - (B, 1), then mod T
shifted_indices = tf.math.floormod(
base_indices - shifts_expanded, trace_len
) # shape: (B, T)
# Gather along the time axis for each trace in the batch
# batch_traces shape: (B, T, 1)
# We need to gather along axis=1 with per-row indices
batch_idx = tf.repeat(
tf.range(batch_size_actual)[:, tf.newaxis],
trace_len, axis=1
) # (B, T)
gather_indices = tf.stack(
[batch_idx, shifted_indices], axis=-1
) # (B, T, 2)
batch_traces = tf.gather_nd(batch_traces, gather_indices)
# Restore channel dimension: (B, T) → (B, T, 1)
batch_traces = tf.expand_dims(batch_traces, axis=-1)
batch_labels = {k: tf.gather(label_tensors[k], batch_indices) for k in label_keys}
return batch_traces, batch_labels
dataset = indices_ds.map(
gather_and_augment,
num_parallel_calls=tf.data.AUTOTUNE,
)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
logger.info(
"Created augmented tf.data.Dataset: %d samples, batch=%d, "
"max_shift=%d, ~%d batches/epoch (vectorized shift)",
n_samples, batch_size, max_shift,
(n_samples + batch_size - 1) // batch_size,
)
return dataset