ascad-training-pipeline / src /evaluation.py
lemousehunter
feat: comprehensive W&B logging + retry=0 hardening
e63fc4e
"""
Key Rank Evaluation
===================
Computes the key rank metric used to evaluate side-channel attack models.
Matches the methodology from the official ASCAD evaluation scripts.
The key rank measures how many key candidates score higher than the correct
key when accumulating log-likelihood scores over increasing numbers of
attack traces. A rank of 0 means the correct key is the top candidate.
"""
import logging
from typing import Dict, Optional, Tuple
import numpy as np
import tensorflow as tf
from .constants import AES_SBOX
logger = logging.getLogger(__name__)
def bits_to_byte_probs(bit_predictions: np.ndarray) -> np.ndarray:
"""
Convert 8 independent bit probabilities into 256 joint byte probabilities.
For multi-bit models (Wu et al., TCHES 2024), each byte is predicted as
8 independent bits. To compute key rank, we need to reconstruct the
256-class probability distribution.
For byte value v with bits b_0..b_7:
P(v) = prod_j P(b_j=bit_j(v))
where P(b_j=1) = bit_predictions[j] and P(b_j=0) = 1 - bit_predictions[j].
Args:
bit_predictions: Model output, shape (N, 8), values in [0, 1].
Returns:
Byte probabilities, shape (N, 256), rows sum to 1.
"""
N = bit_predictions.shape[0]
byte_probs = np.ones((N, 256), dtype=np.float64)
# Pre-compute bit patterns for all 256 byte values
byte_values = np.arange(256, dtype=np.uint8)
# bit_patterns[v, j] = j-th bit of value v (MSB first to match np.unpackbits)
bit_patterns = np.unpackbits(byte_values[:, np.newaxis], axis=1) # (256, 8)
for j in range(8):
p_j = bit_predictions[:, j:j+1].astype(np.float64) # (N, 1)
# For each byte value, multiply by P(b_j=bit_j(v))
# bit_patterns[:, j] is 0 or 1 for each of 256 values
mask = bit_patterns[:, j].astype(np.float64) # (256,)
# P(b_j = mask[v]) = mask[v]*p_j + (1-mask[v])*(1-p_j)
byte_probs *= mask[np.newaxis, :] * p_j + (1 - mask[np.newaxis, :]) * (1 - p_j)
# Normalize to ensure rows sum to 1 (handle numerical precision)
row_sums = byte_probs.sum(axis=1, keepdims=True)
row_sums = np.maximum(row_sums, 1e-300) # prevent division by zero
byte_probs /= row_sums
return byte_probs
def compute_key_rank(
predictions: np.ndarray,
metadata: np.ndarray,
real_key: int,
target_byte: int,
num_traces: Optional[int] = None,
rank_step: int = 10,
) -> Tuple[np.ndarray, int]:
"""
Compute the key rank over an increasing number of attack traces.
For each key candidate k in [0..255], the log-likelihood score is:
score[k] = sum_i log( predictions[i][ Sbox(plaintext_i XOR k) ] )
The rank of the real key is its position when candidates are sorted
by descending score.
Args:
predictions: Model output probabilities, shape (N, 256).
metadata: Structured array with 'plaintext' and 'key' fields.
real_key: The true key byte value (0-255).
target_byte: Index of the target byte (0-15).
num_traces: Number of attack traces to use (default: all).
rank_step: Step size for computing intermediate ranks.
Returns:
Tuple of (ranks_array, final_rank).
ranks_array: shape (M, 2) with columns [num_traces, rank].
final_rank: The key rank using all traces.
"""
if num_traces is None:
num_traces = len(predictions)
num_traces = min(num_traces, len(predictions))
log_scores = np.zeros(256, dtype=np.float64)
ranks_list = []
for i in range(num_traces):
plaintext_byte = metadata[i]["plaintext"][target_byte]
for k in range(256):
sbox_out = AES_SBOX[plaintext_byte ^ k]
prob = predictions[i][sbox_out]
if prob > 0:
log_scores[k] += np.log(prob)
else:
# Back-off: use square of minimum non-zero probability
non_zero = predictions[i][predictions[i] > 0]
if len(non_zero) > 0:
log_scores[k] += np.log(np.min(non_zero) ** 2)
if (i + 1) % rank_step == 0:
rank = _rank_of_key(log_scores, real_key)
ranks_list.append([i + 1, rank])
ranks_array = np.array(ranks_list, dtype=np.uint32) if ranks_list else np.empty((0, 2), dtype=np.uint32)
final_rank = _rank_of_key(log_scores, real_key)
return ranks_array, final_rank
def _rank_of_key(log_scores: np.ndarray, real_key: int) -> int:
"""Compute the rank of the real key in the sorted score list."""
sorted_scores = np.sort(log_scores)[::-1]
real_score = log_scores[real_key]
rank = int(np.where(sorted_scores == real_score)[0][0])
return rank
def _get_rank_at_n(ranks_array: np.ndarray, n: int) -> int:
"""Get the rank at exactly n traces, or the closest available."""
if len(ranks_array) == 0:
return 256
idx = np.where(ranks_array[:, 0] == n)[0]
if len(idx) > 0:
return int(ranks_array[idx[0], 1])
before = ranks_array[ranks_array[:, 0] <= n]
if len(before) > 0:
return int(before[-1, 1])
return int(ranks_array[0, 1])
def evaluate_model(
model: tf.keras.Model,
attack_traces: np.ndarray,
attack_metadata: np.ndarray,
target_byte: int,
real_key: int,
model_type: str = "mlp",
num_traces: int = 2000,
rank_step: int = 10,
output_index: Optional[int] = None,
cached_predictions: Optional[Dict] = None,
) -> Dict:
"""
Run the full key rank evaluation on a trained model.
Args:
model: Trained Keras model.
attack_traces: Raw attack traces, shape (N, trace_length).
attack_metadata: Structured metadata array.
target_byte: Target key byte index (0-15).
real_key: True key byte value (0-255).
model_type: 'mlp', 'cnn', or 'mtan' (determines input reshaping).
num_traces: Number of attack traces to evaluate.
rank_step: Step size for intermediate rank computation.
output_index: For multi-output models (MTAN), the index of the
output to evaluate. If None, the model is assumed single-output.
cached_predictions: Optional pre-computed predictions dict from a
single model.predict() call. When provided, skips the forward
pass entirely. This avoids redundant forward passes when
evaluating all 16 bytes of a multi-output model.
Returns:
Dictionary with evaluation results:
'final_rank', 'ranks', 'pre_rank', 'min_rank', 'max_rank',
'rank_at_500', 'rank_at_1000'.
"""
if cached_predictions is not None:
raw_predictions = cached_predictions
else:
traces = attack_traces[:num_traces].copy()
# Reshape for CNN/MTAN (add channel dimension)
if model_type in ("cnn", "mtan"):
traces = traces.reshape((traces.shape[0], traces.shape[1], 1))
raw_predictions = model.predict(traces, batch_size=256, verbose=0)
# For multi-output models, extract the specific task's predictions.
# Keras returns a dict when model outputs are named (e.g. {'byte_0': ..., 'byte_1': ...}).
# It may also return a list for unnamed multi-output models.
if output_index is not None:
if isinstance(raw_predictions, dict):
key = f"byte_{output_index}"
if key in raw_predictions:
predictions = raw_predictions[key]
else:
# Fallback: try by position in dict values
predictions = list(raw_predictions.values())[output_index]
elif isinstance(raw_predictions, (list, tuple)):
predictions = raw_predictions[output_index]
else:
predictions = raw_predictions
else:
predictions = raw_predictions
# Ensure predictions is a numpy array
if not isinstance(predictions, np.ndarray):
predictions = np.array(predictions)
# Multi-bit mode: convert (N, 8) bit predictions to (N, 256) byte probs
if predictions.ndim == 2 and predictions.shape[1] == 8:
logger.info(
"Multi-bit predictions detected (shape %s), converting to byte probs",
predictions.shape,
)
predictions = bits_to_byte_probs(predictions)
if predictions.ndim != 2 or predictions.shape[1] != 256:
raise ValueError(
f"Expected predictions shape (N, 256) or (N, 8), got {predictions.shape}. "
f"raw_predictions type={type(raw_predictions).__name__}, "
f"output_index={output_index}"
)
ranks_array, final_rank = compute_key_rank(
predictions=predictions,
metadata=attack_metadata[:num_traces],
real_key=real_key,
target_byte=target_byte,
num_traces=num_traces,
rank_step=rank_step,
)
min_rank = int(np.min(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
max_rank = int(np.max(ranks_array[:, 1])) if len(ranks_array) > 0 else 256
pre_rank = int(ranks_array[0, 1]) if len(ranks_array) > 0 else 256
rank_at_500 = _get_rank_at_n(ranks_array, 500)
rank_at_1000 = _get_rank_at_n(ranks_array, 1000)
result = {
"final_rank": final_rank,
"ranks": ranks_array,
"pre_rank": pre_rank,
"min_rank": min_rank,
"max_rank": max_rank,
"rank_at_500": rank_at_500,
"rank_at_1000": rank_at_1000,
}
logger.info(
"Byte %d: pre_rank=%d, final_rank=%d, min_rank=%d, max_rank=%d, "
"rank@500=%d, rank@1000=%d",
target_byte, pre_rank, final_rank, min_rank, max_rank,
rank_at_500, rank_at_1000,
)
return result