File size: 9,702 Bytes
283a882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea54e7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283a882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fe1a19
283a882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fe1a19
 
 
 
283a882
 
 
e63fc4e
 
283a882
1fe1a19
 
 
 
283a882
1fe1a19
 
 
283a882
1fe1a19
283a882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea54e7b
283a882
 
ea54e7b
 
 
 
 
 
 
 
 
283a882
 
ea54e7b
283a882
 
 
 
 
 
 
 
 
 
 
 
 
 
e63fc4e
 
283a882
 
 
 
 
 
e63fc4e
283a882
e63fc4e
283a882
 
 
 
 
e63fc4e
 
 
 
283a882
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
"""
Key Rank Evaluation
===================
Computes the key rank metric used to evaluate side-channel attack models.
Matches the methodology from the official ASCAD evaluation scripts.

The key rank measures how many key candidates score higher than the correct
key when accumulating log-likelihood scores over increasing numbers of
attack traces. A rank of 0 means the correct key is the top candidate.
"""

import logging
from typing import Dict, Optional, Tuple

import numpy as np
import tensorflow as tf

from .constants import AES_SBOX

logger = logging.getLogger(__name__)


def bits_to_byte_probs(bit_predictions: np.ndarray) -> np.ndarray:
    """
    Convert 8 independent bit probabilities into 256 joint byte probabilities.

    For multi-bit models (Wu et al., TCHES 2024), each byte is predicted as
    8 independent bits. To compute key rank, we need to reconstruct the
    256-class probability distribution.

    For byte value v with bits b_0..b_7:
        P(v) = prod_j P(b_j=bit_j(v))

    where P(b_j=1) = bit_predictions[j] and P(b_j=0) = 1 - bit_predictions[j].

    Args:
        bit_predictions: Model output, shape (N, 8), values in [0, 1].

    Returns:
        Byte probabilities, shape (N, 256), rows sum to 1.
    """
    N = bit_predictions.shape[0]
    byte_probs = np.ones((N, 256), dtype=np.float64)

    # Pre-compute bit patterns for all 256 byte values
    byte_values = np.arange(256, dtype=np.uint8)
    # bit_patterns[v, j] = j-th bit of value v (MSB first to match np.unpackbits)
    bit_patterns = np.unpackbits(byte_values[:, np.newaxis], axis=1)  # (256, 8)

    for j in range(8):
        p_j = bit_predictions[:, j:j+1].astype(np.float64)  # (N, 1)
        # For each byte value, multiply by P(b_j=bit_j(v))
        # bit_patterns[:, j] is 0 or 1 for each of 256 values
        mask = bit_patterns[:, j].astype(np.float64)  # (256,)
        # P(b_j = mask[v]) = mask[v]*p_j + (1-mask[v])*(1-p_j)
        byte_probs *= mask[np.newaxis, :] * p_j + (1 - mask[np.newaxis, :]) * (1 - p_j)

    # Normalize to ensure rows sum to 1 (handle numerical precision)
    row_sums = byte_probs.sum(axis=1, keepdims=True)
    row_sums = np.maximum(row_sums, 1e-300)  # prevent division by zero
    byte_probs /= row_sums

    return byte_probs


def compute_key_rank(
    predictions: np.ndarray,
    metadata: np.ndarray,
    real_key: int,
    target_byte: int,
    num_traces: Optional[int] = None,
    rank_step: int = 10,
) -> Tuple[np.ndarray, int]:
    """
    Compute the key rank over an increasing number of attack traces.

    For each key candidate k in [0..255], the log-likelihood score is:
        score[k] = sum_i log( predictions[i][ Sbox(plaintext_i XOR k) ] )

    The rank of the real key is its position when candidates are sorted
    by descending score.

    Args:
        predictions: Model output probabilities, shape (N, 256).
        metadata: Structured array with 'plaintext' and 'key' fields.
        real_key: The true key byte value (0-255).
        target_byte: Index of the target byte (0-15).
        num_traces: Number of attack traces to use (default: all).
        rank_step: Step size for computing intermediate ranks.

    Returns:
        Tuple of (ranks_array, final_rank).
        ranks_array: shape (M, 2) with columns [num_traces, rank].
        final_rank: The key rank using all traces.
    """
    if num_traces is None:
        num_traces = len(predictions)
    num_traces = min(num_traces, len(predictions))

    log_scores = np.zeros(256, dtype=np.float64)
    ranks_list = []

    for i in range(num_traces):
        plaintext_byte = metadata[i]["plaintext"][target_byte]

        for k in range(256):
            sbox_out = AES_SBOX[plaintext_byte ^ k]
            prob = predictions[i][sbox_out]
            if prob > 0:
                log_scores[k] += np.log(prob)
            else:
                # Back-off: use square of minimum non-zero probability
                non_zero = predictions[i][predictions[i] > 0]
                if len(non_zero) > 0:
                    log_scores[k] += np.log(np.min(non_zero) ** 2)

        if (i + 1) % rank_step == 0:
            rank = _rank_of_key(log_scores, real_key)
            ranks_list.append([i + 1, rank])

    ranks_array = np.array(ranks_list, dtype=np.uint32) if ranks_list else np.empty((0, 2), dtype=np.uint32)
    final_rank = _rank_of_key(log_scores, real_key)

    return ranks_array, final_rank


def _rank_of_key(log_scores: np.ndarray, real_key: int) -> int:
    """Compute the rank of the real key in the sorted score list."""
    sorted_scores = np.sort(log_scores)[::-1]
    real_score = log_scores[real_key]
    rank = int(np.where(sorted_scores == real_score)[0][0])
    return rank


def _get_rank_at_n(ranks_array: np.ndarray, n: int) -> int:
    """Get the rank at exactly n traces, or the closest available."""
    if len(ranks_array) == 0:
        return 256
    idx = np.where(ranks_array[:, 0] == n)[0]
    if len(idx) > 0:
        return int(ranks_array[idx[0], 1])
    before = ranks_array[ranks_array[:, 0] <= n]
    if len(before) > 0:
        return int(before[-1, 1])
    return int(ranks_array[0, 1])


def evaluate_model(
    model: tf.keras.Model,
    attack_traces: np.ndarray,
    attack_metadata: np.ndarray,
    target_byte: int,
    real_key: int,
    model_type: str = "mlp",
    num_traces: int = 2000,
    rank_step: int = 10,
    output_index: Optional[int] = None,
    cached_predictions: Optional[Dict] = None,
) -> Dict:
    """
    Run the full key rank evaluation on a trained model.

    Args:
        model: Trained Keras model.
        attack_traces: Raw attack traces, shape (N, trace_length).
        attack_metadata: Structured metadata array.
        target_byte: Target key byte index (0-15).
        real_key: True key byte value (0-255).
        model_type: 'mlp', 'cnn', or 'mtan' (determines input reshaping).
        num_traces: Number of attack traces to evaluate.
        rank_step: Step size for intermediate rank computation.
        output_index: For multi-output models (MTAN), the index of the
            output to evaluate. If None, the model is assumed single-output.
        cached_predictions: Optional pre-computed predictions dict from a
            single model.predict() call. When provided, skips the forward
            pass entirely. This avoids redundant forward passes when
            evaluating all 16 bytes of a multi-output model.

    Returns:
        Dictionary with evaluation results:
            'final_rank', 'ranks', 'pre_rank', 'min_rank', 'max_rank',
            'rank_at_500', 'rank_at_1000'.
    """
    if cached_predictions is not None:
        raw_predictions = cached_predictions
    else:
        traces = attack_traces[:num_traces].copy()

        # Reshape for CNN/MTAN (add channel dimension)
        if model_type in ("cnn", "mtan"):
            traces = traces.reshape((traces.shape[0], traces.shape[1], 1))

        raw_predictions = model.predict(traces, batch_size=256, verbose=0)

    # For multi-output models, extract the specific task's predictions.
    # Keras returns a dict when model outputs are named (e.g. {'byte_0': ..., 'byte_1': ...}).
    # It may also return a list for unnamed multi-output models.
    if output_index is not None:
        if isinstance(raw_predictions, dict):
            key = f"byte_{output_index}"
            if key in raw_predictions:
                predictions = raw_predictions[key]
            else:
                # Fallback: try by position in dict values
                predictions = list(raw_predictions.values())[output_index]
        elif isinstance(raw_predictions, (list, tuple)):
            predictions = raw_predictions[output_index]
        else:
            predictions = raw_predictions
    else:
        predictions = raw_predictions

    # Ensure predictions is a numpy array
    if not isinstance(predictions, np.ndarray):
        predictions = np.array(predictions)

    # Multi-bit mode: convert (N, 8) bit predictions to (N, 256) byte probs
    if predictions.ndim == 2 and predictions.shape[1] == 8:
        logger.info(
            "Multi-bit predictions detected (shape %s), converting to byte probs",
            predictions.shape,
        )
        predictions = bits_to_byte_probs(predictions)

    if predictions.ndim != 2 or predictions.shape[1] != 256:
        raise ValueError(
            f"Expected predictions shape (N, 256) or (N, 8), got {predictions.shape}. "
            f"raw_predictions type={type(raw_predictions).__name__}, "
            f"output_index={output_index}"
        )

    ranks_array, final_rank = compute_key_rank(
        predictions=predictions,
        metadata=attack_metadata[:num_traces],
        real_key=real_key,
        target_byte=target_byte,
        num_traces=num_traces,
        rank_step=rank_step,
    )

    min_rank = int(np.min(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
    max_rank = int(np.max(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
    pre_rank = int(ranks_array[0, 1]) if len(ranks_array) > 0 else 256
    rank_at_500 = _get_rank_at_n(ranks_array, 500)
    rank_at_1000 = _get_rank_at_n(ranks_array, 1000)

    result = {
        "final_rank": final_rank,
        "ranks": ranks_array,
        "pre_rank": pre_rank,
        "min_rank": min_rank,
        "max_rank": max_rank,
        "rank_at_500": rank_at_500,
        "rank_at_1000": rank_at_1000,
    }

    logger.info(
        "Byte %d: pre_rank=%d, final_rank=%d, min_rank=%d, max_rank=%d, "
        "rank@500=%d, rank@1000=%d",
        target_byte, pre_rank, final_rank, min_rank, max_rank,
        rank_at_500, rank_at_1000,
    )
    return result