File size: 6,985 Bytes
283a882 1fe1a19 283a882 1fe1a19 283a882 1fe1a19 283a882 1fe1a19 283a882 1fe1a19 283a882 1fe1a19 283a882 1fe1a19 283a882 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """
Data Augmentation for Side-Channel Analysis
============================================
Implements on-the-fly data augmentation strategies for profiling traces.
Currently supported:
- **Random shift**: Circular shift of each trace by a random integer
in [-max_shift, +max_shift]. This simulates desynchronization and
forces the CNN to learn shift-invariant features, acting as a
powerful regularizer for protected AES implementations.
References:
- Li, H. & Perin, G. (2024). A Systematic Study of Data Augmentation
for Protected AES Implementations. J. Cryptographic Engineering.
- Wu, L. et al. (2023). Breaking Free: Leakage Model-free DLSCA.
IACR ePrint 2023/1110. (Optimal shift = 5 samples.)
"""
import logging
from typing import Dict, Optional, Tuple
import numpy as np
import tensorflow as tf
logger = logging.getLogger(__name__)
class RandomShiftAugmentor:
"""
Applies random circular shift augmentation to 1D traces.
Each trace in a batch is independently shifted by a random integer
drawn uniformly from [-max_shift, +max_shift]. The shift is circular
(wraps around), preserving trace length and total energy.
This augmentation is applied on-the-fly during training via a
tf.data.Dataset map operation, so it does not increase memory usage.
"""
def __init__(self, max_shift: int = 5) -> None:
"""
Args:
max_shift: Maximum shift in either direction (samples).
Wu et al. (2023) found 5 to be optimal for ASCAD.
"""
if max_shift < 0:
raise ValueError(f"max_shift must be non-negative, got {max_shift}")
self.max_shift = max_shift
logger.info(
"RandomShiftAugmentor initialized: max_shift=%d samples", max_shift
)
def augment_numpy(
self, traces: np.ndarray, rng: Optional[np.random.Generator] = None
) -> np.ndarray:
"""
Apply random shift to a batch of traces (NumPy version).
Args:
traces: Array of shape (N, T) or (N, T, 1).
rng: Optional NumPy random generator for reproducibility.
Returns:
Shifted traces with the same shape as input.
"""
if self.max_shift == 0:
return traces
if rng is None:
rng = np.random.default_rng()
squeeze = False
if traces.ndim == 3 and traces.shape[2] == 1:
traces = traces[:, :, 0]
squeeze = True
n = traces.shape[0]
shifts = rng.integers(-self.max_shift, self.max_shift + 1, size=n)
augmented = np.empty_like(traces)
for i in range(n):
augmented[i] = np.roll(traces[i], shifts[i])
if squeeze:
augmented = augmented[:, :, np.newaxis]
return augmented
def make_tf_dataset(
self,
traces: np.ndarray,
labels: Dict[str, np.ndarray],
batch_size: int,
seed: int = 42,
) -> tf.data.Dataset:
"""
Create a tf.data.Dataset with on-the-fly random shift augmentation.
The dataset yields (augmented_traces, labels) tuples suitable for
model.fit(). Augmentation is applied per-batch using a vectorized
gather-based circular shift (no tf.map_fn).
Args:
traces: Training traces of shape (N, T, 1).
labels: Dictionary mapping "byte_i" to one-hot label arrays.
batch_size: Training batch size.
seed: Random seed for the augmentation RNG.
Returns:
A tf.data.Dataset that yields (traces, labels) with augmentation.
"""
n_samples = traces.shape[0]
trace_len = traces.shape[1]
max_shift = self.max_shift
# Create a dataset from the indices to allow shuffling
indices_ds = tf.data.Dataset.from_tensor_slices(
tf.range(n_samples, dtype=tf.int32)
)
# Shuffle and batch the indices
indices_ds = indices_ds.shuffle(
buffer_size=min(n_samples, 50000), seed=seed
).batch(batch_size, drop_remainder=False)
# Store references for the map function
traces_tensor = tf.constant(traces, dtype=tf.float32)
# Build label tensors dict
label_keys = sorted(labels.keys())
label_tensors = {k: tf.constant(labels[k], dtype=tf.float32) for k in label_keys}
# Pre-compute the base index range [0, 1, ..., T-1] for vectorized shift
base_indices = tf.range(trace_len, dtype=tf.int32) # shape: (T,)
def gather_and_augment(batch_indices):
"""Gather traces/labels for batch and apply vectorized random shift."""
batch_traces = tf.gather(traces_tensor, batch_indices)
if max_shift > 0:
batch_size_actual = tf.shape(batch_indices)[0]
# Random shift per trace in the batch
shifts = tf.random.uniform(
shape=[batch_size_actual],
minval=-max_shift,
maxval=max_shift + 1,
dtype=tf.int32,
)
# Vectorized circular shift using tf.gather:
# For each trace, compute shifted_indices = (base - shift) % T
# shifts shape: (B,) → (B, 1) for broadcasting with (T,)
shifts_expanded = tf.expand_dims(shifts, axis=1) # (B, 1)
# (B, T) = broadcast of (T,) - (B, 1), then mod T
shifted_indices = tf.math.floormod(
base_indices - shifts_expanded, trace_len
) # shape: (B, T)
# Gather along the time axis for each trace in the batch
# batch_traces shape: (B, T, 1)
# We need to gather along axis=1 with per-row indices
batch_idx = tf.repeat(
tf.range(batch_size_actual)[:, tf.newaxis],
trace_len, axis=1
) # (B, T)
gather_indices = tf.stack(
[batch_idx, shifted_indices], axis=-1
) # (B, T, 2)
batch_traces = tf.gather_nd(batch_traces, gather_indices)
# Restore channel dimension: (B, T) → (B, T, 1)
batch_traces = tf.expand_dims(batch_traces, axis=-1)
batch_labels = {k: tf.gather(label_tensors[k], batch_indices) for k in label_keys}
return batch_traces, batch_labels
dataset = indices_ds.map(
gather_and_augment,
num_parallel_calls=tf.data.AUTOTUNE,
)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
logger.info(
"Created augmented tf.data.Dataset: %d samples, batch=%d, "
"max_shift=%d, ~%d batches/epoch (vectorized shift)",
n_samples, batch_size, max_shift,
(n_samples + batch_size - 1) // batch_size,
)
return dataset
|