File size: 6,985 Bytes
283a882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fe1a19
 
283a882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fe1a19
 
 
283a882
1fe1a19
283a882
 
 
1fe1a19
 
283a882
 
1fe1a19
283a882
 
 
 
1fe1a19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283a882
 
 
 
 
 
 
 
 
 
 
 
1fe1a19
283a882
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
Data Augmentation for Side-Channel Analysis
============================================
Implements on-the-fly data augmentation strategies for profiling traces.

Currently supported:
    - **Random shift**: Circular shift of each trace by a random integer
      in [-max_shift, +max_shift]. This simulates desynchronization and
      forces the CNN to learn shift-invariant features, acting as a
      powerful regularizer for protected AES implementations.

References:
    - Li, H. & Perin, G. (2024). A Systematic Study of Data Augmentation
      for Protected AES Implementations. J. Cryptographic Engineering.
    - Wu, L. et al. (2023). Breaking Free: Leakage Model-free DLSCA.
      IACR ePrint 2023/1110. (Optimal shift = 5 samples.)
"""

import logging
from typing import Dict, Optional, Tuple

import numpy as np
import tensorflow as tf

logger = logging.getLogger(__name__)


class RandomShiftAugmentor:
    """
    Applies random circular shift augmentation to 1D traces.

    Each trace in a batch is independently shifted by a random integer
    drawn uniformly from [-max_shift, +max_shift]. The shift is circular
    (wraps around), preserving trace length and total energy.

    This augmentation is applied on-the-fly during training via a
    tf.data.Dataset map operation, so it does not increase memory usage.
    """

    def __init__(self, max_shift: int = 5) -> None:
        """
        Args:
            max_shift: Maximum shift in either direction (samples).
                       Wu et al. (2023) found 5 to be optimal for ASCAD.
        """
        if max_shift < 0:
            raise ValueError(f"max_shift must be non-negative, got {max_shift}")
        self.max_shift = max_shift
        logger.info(
            "RandomShiftAugmentor initialized: max_shift=%d samples", max_shift
        )

    def augment_numpy(
        self, traces: np.ndarray, rng: Optional[np.random.Generator] = None
    ) -> np.ndarray:
        """
        Apply random shift to a batch of traces (NumPy version).

        Args:
            traces: Array of shape (N, T) or (N, T, 1).
            rng: Optional NumPy random generator for reproducibility.

        Returns:
            Shifted traces with the same shape as input.
        """
        if self.max_shift == 0:
            return traces

        if rng is None:
            rng = np.random.default_rng()

        squeeze = False
        if traces.ndim == 3 and traces.shape[2] == 1:
            traces = traces[:, :, 0]
            squeeze = True

        n = traces.shape[0]
        shifts = rng.integers(-self.max_shift, self.max_shift + 1, size=n)
        augmented = np.empty_like(traces)
        for i in range(n):
            augmented[i] = np.roll(traces[i], shifts[i])

        if squeeze:
            augmented = augmented[:, :, np.newaxis]

        return augmented

    def make_tf_dataset(
        self,
        traces: np.ndarray,
        labels: Dict[str, np.ndarray],
        batch_size: int,
        seed: int = 42,
    ) -> tf.data.Dataset:
        """
        Create a tf.data.Dataset with on-the-fly random shift augmentation.

        The dataset yields (augmented_traces, labels) tuples suitable for
        model.fit(). Augmentation is applied per-batch using a vectorized
        gather-based circular shift (no tf.map_fn).

        Args:
            traces: Training traces of shape (N, T, 1).
            labels: Dictionary mapping "byte_i" to one-hot label arrays.
            batch_size: Training batch size.
            seed: Random seed for the augmentation RNG.

        Returns:
            A tf.data.Dataset that yields (traces, labels) with augmentation.
        """
        n_samples = traces.shape[0]
        trace_len = traces.shape[1]
        max_shift = self.max_shift

        # Create a dataset from the indices to allow shuffling
        indices_ds = tf.data.Dataset.from_tensor_slices(
            tf.range(n_samples, dtype=tf.int32)
        )

        # Shuffle and batch the indices
        indices_ds = indices_ds.shuffle(
            buffer_size=min(n_samples, 50000), seed=seed
        ).batch(batch_size, drop_remainder=False)

        # Store references for the map function
        traces_tensor = tf.constant(traces, dtype=tf.float32)

        # Build label tensors dict
        label_keys = sorted(labels.keys())
        label_tensors = {k: tf.constant(labels[k], dtype=tf.float32) for k in label_keys}

        # Pre-compute the base index range [0, 1, ..., T-1] for vectorized shift
        base_indices = tf.range(trace_len, dtype=tf.int32)  # shape: (T,)

        def gather_and_augment(batch_indices):
            """Gather traces/labels for batch and apply vectorized random shift."""
            batch_traces = tf.gather(traces_tensor, batch_indices)

            if max_shift > 0:
                batch_size_actual = tf.shape(batch_indices)[0]

                # Random shift per trace in the batch
                shifts = tf.random.uniform(
                    shape=[batch_size_actual],
                    minval=-max_shift,
                    maxval=max_shift + 1,
                    dtype=tf.int32,
                )

                # Vectorized circular shift using tf.gather:
                # For each trace, compute shifted_indices = (base - shift) % T
                # shifts shape: (B,) → (B, 1) for broadcasting with (T,)
                shifts_expanded = tf.expand_dims(shifts, axis=1)  # (B, 1)
                # (B, T) = broadcast of (T,) - (B, 1), then mod T
                shifted_indices = tf.math.floormod(
                    base_indices - shifts_expanded, trace_len
                )  # shape: (B, T)

                # Gather along the time axis for each trace in the batch
                # batch_traces shape: (B, T, 1)
                # We need to gather along axis=1 with per-row indices
                batch_idx = tf.repeat(
                    tf.range(batch_size_actual)[:, tf.newaxis],
                    trace_len, axis=1
                )  # (B, T)
                gather_indices = tf.stack(
                    [batch_idx, shifted_indices], axis=-1
                )  # (B, T, 2)
                batch_traces = tf.gather_nd(batch_traces, gather_indices)
                # Restore channel dimension: (B, T) → (B, T, 1)
                batch_traces = tf.expand_dims(batch_traces, axis=-1)

            batch_labels = {k: tf.gather(label_tensors[k], batch_indices) for k in label_keys}
            return batch_traces, batch_labels

        dataset = indices_ds.map(
            gather_and_augment,
            num_parallel_calls=tf.data.AUTOTUNE,
        )
        dataset = dataset.prefetch(tf.data.AUTOTUNE)

        logger.info(
            "Created augmented tf.data.Dataset: %d samples, batch=%d, "
            "max_shift=%d, ~%d batches/epoch (vectorized shift)",
            n_samples, batch_size, max_shift,
            (n_samples + batch_size - 1) // batch_size,
        )
        return dataset