File size: 9,702 Bytes
283a882 ea54e7b 283a882 1fe1a19 283a882 1fe1a19 283a882 e63fc4e 283a882 1fe1a19 283a882 1fe1a19 283a882 1fe1a19 283a882 ea54e7b 283a882 ea54e7b 283a882 ea54e7b 283a882 e63fc4e 283a882 e63fc4e 283a882 e63fc4e 283a882 e63fc4e 283a882 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | """
Key Rank Evaluation
===================
Computes the key rank metric used to evaluate side-channel attack models.
Matches the methodology from the official ASCAD evaluation scripts.
The key rank measures how many key candidates score higher than the correct
key when accumulating log-likelihood scores over increasing numbers of
attack traces. A rank of 0 means the correct key is the top candidate.
"""
import logging
from typing import Dict, Optional, Tuple
import numpy as np
import tensorflow as tf
from .constants import AES_SBOX
logger = logging.getLogger(__name__)
def bits_to_byte_probs(bit_predictions: np.ndarray) -> np.ndarray:
"""
Convert 8 independent bit probabilities into 256 joint byte probabilities.
For multi-bit models (Wu et al., TCHES 2024), each byte is predicted as
8 independent bits. To compute key rank, we need to reconstruct the
256-class probability distribution.
For byte value v with bits b_0..b_7:
P(v) = prod_j P(b_j=bit_j(v))
where P(b_j=1) = bit_predictions[j] and P(b_j=0) = 1 - bit_predictions[j].
Args:
bit_predictions: Model output, shape (N, 8), values in [0, 1].
Returns:
Byte probabilities, shape (N, 256), rows sum to 1.
"""
N = bit_predictions.shape[0]
byte_probs = np.ones((N, 256), dtype=np.float64)
# Pre-compute bit patterns for all 256 byte values
byte_values = np.arange(256, dtype=np.uint8)
# bit_patterns[v, j] = j-th bit of value v (MSB first to match np.unpackbits)
bit_patterns = np.unpackbits(byte_values[:, np.newaxis], axis=1) # (256, 8)
for j in range(8):
p_j = bit_predictions[:, j:j+1].astype(np.float64) # (N, 1)
# For each byte value, multiply by P(b_j=bit_j(v))
# bit_patterns[:, j] is 0 or 1 for each of 256 values
mask = bit_patterns[:, j].astype(np.float64) # (256,)
# P(b_j = mask[v]) = mask[v]*p_j + (1-mask[v])*(1-p_j)
byte_probs *= mask[np.newaxis, :] * p_j + (1 - mask[np.newaxis, :]) * (1 - p_j)
# Normalize to ensure rows sum to 1 (handle numerical precision)
row_sums = byte_probs.sum(axis=1, keepdims=True)
row_sums = np.maximum(row_sums, 1e-300) # prevent division by zero
byte_probs /= row_sums
return byte_probs
def compute_key_rank(
predictions: np.ndarray,
metadata: np.ndarray,
real_key: int,
target_byte: int,
num_traces: Optional[int] = None,
rank_step: int = 10,
) -> Tuple[np.ndarray, int]:
"""
Compute the key rank over an increasing number of attack traces.
For each key candidate k in [0..255], the log-likelihood score is:
score[k] = sum_i log( predictions[i][ Sbox(plaintext_i XOR k) ] )
The rank of the real key is its position when candidates are sorted
by descending score.
Args:
predictions: Model output probabilities, shape (N, 256).
metadata: Structured array with 'plaintext' and 'key' fields.
real_key: The true key byte value (0-255).
target_byte: Index of the target byte (0-15).
num_traces: Number of attack traces to use (default: all).
rank_step: Step size for computing intermediate ranks.
Returns:
Tuple of (ranks_array, final_rank).
ranks_array: shape (M, 2) with columns [num_traces, rank].
final_rank: The key rank using all traces.
"""
if num_traces is None:
num_traces = len(predictions)
num_traces = min(num_traces, len(predictions))
log_scores = np.zeros(256, dtype=np.float64)
ranks_list = []
for i in range(num_traces):
plaintext_byte = metadata[i]["plaintext"][target_byte]
for k in range(256):
sbox_out = AES_SBOX[plaintext_byte ^ k]
prob = predictions[i][sbox_out]
if prob > 0:
log_scores[k] += np.log(prob)
else:
# Back-off: use square of minimum non-zero probability
non_zero = predictions[i][predictions[i] > 0]
if len(non_zero) > 0:
log_scores[k] += np.log(np.min(non_zero) ** 2)
if (i + 1) % rank_step == 0:
rank = _rank_of_key(log_scores, real_key)
ranks_list.append([i + 1, rank])
ranks_array = np.array(ranks_list, dtype=np.uint32) if ranks_list else np.empty((0, 2), dtype=np.uint32)
final_rank = _rank_of_key(log_scores, real_key)
return ranks_array, final_rank
def _rank_of_key(log_scores: np.ndarray, real_key: int) -> int:
"""Compute the rank of the real key in the sorted score list."""
sorted_scores = np.sort(log_scores)[::-1]
real_score = log_scores[real_key]
rank = int(np.where(sorted_scores == real_score)[0][0])
return rank
def _get_rank_at_n(ranks_array: np.ndarray, n: int) -> int:
"""Get the rank at exactly n traces, or the closest available."""
if len(ranks_array) == 0:
return 256
idx = np.where(ranks_array[:, 0] == n)[0]
if len(idx) > 0:
return int(ranks_array[idx[0], 1])
before = ranks_array[ranks_array[:, 0] <= n]
if len(before) > 0:
return int(before[-1, 1])
return int(ranks_array[0, 1])
def evaluate_model(
model: tf.keras.Model,
attack_traces: np.ndarray,
attack_metadata: np.ndarray,
target_byte: int,
real_key: int,
model_type: str = "mlp",
num_traces: int = 2000,
rank_step: int = 10,
output_index: Optional[int] = None,
cached_predictions: Optional[Dict] = None,
) -> Dict:
"""
Run the full key rank evaluation on a trained model.
Args:
model: Trained Keras model.
attack_traces: Raw attack traces, shape (N, trace_length).
attack_metadata: Structured metadata array.
target_byte: Target key byte index (0-15).
real_key: True key byte value (0-255).
model_type: 'mlp', 'cnn', or 'mtan' (determines input reshaping).
num_traces: Number of attack traces to evaluate.
rank_step: Step size for intermediate rank computation.
output_index: For multi-output models (MTAN), the index of the
output to evaluate. If None, the model is assumed single-output.
cached_predictions: Optional pre-computed predictions dict from a
single model.predict() call. When provided, skips the forward
pass entirely. This avoids redundant forward passes when
evaluating all 16 bytes of a multi-output model.
Returns:
Dictionary with evaluation results:
'final_rank', 'ranks', 'pre_rank', 'min_rank', 'max_rank',
'rank_at_500', 'rank_at_1000'.
"""
if cached_predictions is not None:
raw_predictions = cached_predictions
else:
traces = attack_traces[:num_traces].copy()
# Reshape for CNN/MTAN (add channel dimension)
if model_type in ("cnn", "mtan"):
traces = traces.reshape((traces.shape[0], traces.shape[1], 1))
raw_predictions = model.predict(traces, batch_size=256, verbose=0)
# For multi-output models, extract the specific task's predictions.
# Keras returns a dict when model outputs are named (e.g. {'byte_0': ..., 'byte_1': ...}).
# It may also return a list for unnamed multi-output models.
if output_index is not None:
if isinstance(raw_predictions, dict):
key = f"byte_{output_index}"
if key in raw_predictions:
predictions = raw_predictions[key]
else:
# Fallback: try by position in dict values
predictions = list(raw_predictions.values())[output_index]
elif isinstance(raw_predictions, (list, tuple)):
predictions = raw_predictions[output_index]
else:
predictions = raw_predictions
else:
predictions = raw_predictions
# Ensure predictions is a numpy array
if not isinstance(predictions, np.ndarray):
predictions = np.array(predictions)
# Multi-bit mode: convert (N, 8) bit predictions to (N, 256) byte probs
if predictions.ndim == 2 and predictions.shape[1] == 8:
logger.info(
"Multi-bit predictions detected (shape %s), converting to byte probs",
predictions.shape,
)
predictions = bits_to_byte_probs(predictions)
if predictions.ndim != 2 or predictions.shape[1] != 256:
raise ValueError(
f"Expected predictions shape (N, 256) or (N, 8), got {predictions.shape}. "
f"raw_predictions type={type(raw_predictions).__name__}, "
f"output_index={output_index}"
)
ranks_array, final_rank = compute_key_rank(
predictions=predictions,
metadata=attack_metadata[:num_traces],
real_key=real_key,
target_byte=target_byte,
num_traces=num_traces,
rank_step=rank_step,
)
min_rank = int(np.min(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
max_rank = int(np.max(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
pre_rank = int(ranks_array[0, 1]) if len(ranks_array) > 0 else 256
rank_at_500 = _get_rank_at_n(ranks_array, 500)
rank_at_1000 = _get_rank_at_n(ranks_array, 1000)
result = {
"final_rank": final_rank,
"ranks": ranks_array,
"pre_rank": pre_rank,
"min_rank": min_rank,
"max_rank": max_rank,
"rank_at_500": rank_at_500,
"rank_at_1000": rank_at_1000,
}
logger.info(
"Byte %d: pre_rank=%d, final_rank=%d, min_rank=%d, max_rank=%d, "
"rank@500=%d, rank@1000=%d",
target_byte, pre_rank, final_rank, min_rank, max_rank,
rank_at_500, rank_at_1000,
)
return result
|