| """ |
| Key Rank Evaluation |
| =================== |
| Computes the key rank metric used to evaluate side-channel attack models. |
| Matches the methodology from the official ASCAD evaluation scripts. |
| |
| The key rank measures how many key candidates score higher than the correct |
| key when accumulating log-likelihood scores over increasing numbers of |
| attack traces. A rank of 0 means the correct key is the top candidate. |
| """ |
|
|
| import logging |
| from typing import Dict, Optional, Tuple |
|
|
| import numpy as np |
| import tensorflow as tf |
|
|
| from .constants import AES_SBOX |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def bits_to_byte_probs(bit_predictions: np.ndarray) -> np.ndarray: |
| """ |
| Convert 8 independent bit probabilities into 256 joint byte probabilities. |
| |
| For multi-bit models (Wu et al., TCHES 2024), each byte is predicted as |
| 8 independent bits. To compute key rank, we need to reconstruct the |
| 256-class probability distribution. |
| |
| For byte value v with bits b_0..b_7: |
| P(v) = prod_j P(b_j=bit_j(v)) |
| |
| where P(b_j=1) = bit_predictions[j] and P(b_j=0) = 1 - bit_predictions[j]. |
| |
| Args: |
| bit_predictions: Model output, shape (N, 8), values in [0, 1]. |
| |
| Returns: |
| Byte probabilities, shape (N, 256), rows sum to 1. |
| """ |
| N = bit_predictions.shape[0] |
| byte_probs = np.ones((N, 256), dtype=np.float64) |
|
|
| |
| byte_values = np.arange(256, dtype=np.uint8) |
| |
| bit_patterns = np.unpackbits(byte_values[:, np.newaxis], axis=1) |
|
|
| for j in range(8): |
| p_j = bit_predictions[:, j:j+1].astype(np.float64) |
| |
| |
| mask = bit_patterns[:, j].astype(np.float64) |
| |
| byte_probs *= mask[np.newaxis, :] * p_j + (1 - mask[np.newaxis, :]) * (1 - p_j) |
|
|
| |
| row_sums = byte_probs.sum(axis=1, keepdims=True) |
| row_sums = np.maximum(row_sums, 1e-300) |
| byte_probs /= row_sums |
|
|
| return byte_probs |
|
|
|
|
| def compute_key_rank( |
| predictions: np.ndarray, |
| metadata: np.ndarray, |
| real_key: int, |
| target_byte: int, |
| num_traces: Optional[int] = None, |
| rank_step: int = 10, |
| ) -> Tuple[np.ndarray, int]: |
| """ |
| Compute the key rank over an increasing number of attack traces. |
| |
| For each key candidate k in [0..255], the log-likelihood score is: |
| score[k] = sum_i log( predictions[i][ Sbox(plaintext_i XOR k) ] ) |
| |
| The rank of the real key is its position when candidates are sorted |
| by descending score. |
| |
| Args: |
| predictions: Model output probabilities, shape (N, 256). |
| metadata: Structured array with 'plaintext' and 'key' fields. |
| real_key: The true key byte value (0-255). |
| target_byte: Index of the target byte (0-15). |
| num_traces: Number of attack traces to use (default: all). |
| rank_step: Step size for computing intermediate ranks. |
| |
| Returns: |
| Tuple of (ranks_array, final_rank). |
| ranks_array: shape (M, 2) with columns [num_traces, rank]. |
| final_rank: The key rank using all traces. |
| """ |
| if num_traces is None: |
| num_traces = len(predictions) |
| num_traces = min(num_traces, len(predictions)) |
|
|
| log_scores = np.zeros(256, dtype=np.float64) |
| ranks_list = [] |
|
|
| for i in range(num_traces): |
| plaintext_byte = metadata[i]["plaintext"][target_byte] |
|
|
| for k in range(256): |
| sbox_out = AES_SBOX[plaintext_byte ^ k] |
| prob = predictions[i][sbox_out] |
| if prob > 0: |
| log_scores[k] += np.log(prob) |
| else: |
| |
| non_zero = predictions[i][predictions[i] > 0] |
| if len(non_zero) > 0: |
| log_scores[k] += np.log(np.min(non_zero) ** 2) |
|
|
| if (i + 1) % rank_step == 0: |
| rank = _rank_of_key(log_scores, real_key) |
| ranks_list.append([i + 1, rank]) |
|
|
| ranks_array = np.array(ranks_list, dtype=np.uint32) if ranks_list else np.empty((0, 2), dtype=np.uint32) |
| final_rank = _rank_of_key(log_scores, real_key) |
|
|
| return ranks_array, final_rank |
|
|
|
|
| def _rank_of_key(log_scores: np.ndarray, real_key: int) -> int: |
| """Compute the rank of the real key in the sorted score list.""" |
| sorted_scores = np.sort(log_scores)[::-1] |
| real_score = log_scores[real_key] |
| rank = int(np.where(sorted_scores == real_score)[0][0]) |
| return rank |
|
|
|
|
| def _get_rank_at_n(ranks_array: np.ndarray, n: int) -> int: |
| """Get the rank at exactly n traces, or the closest available.""" |
| if len(ranks_array) == 0: |
| return 256 |
| idx = np.where(ranks_array[:, 0] == n)[0] |
| if len(idx) > 0: |
| return int(ranks_array[idx[0], 1]) |
| before = ranks_array[ranks_array[:, 0] <= n] |
| if len(before) > 0: |
| return int(before[-1, 1]) |
| return int(ranks_array[0, 1]) |
|
|
|
|
| def evaluate_model( |
| model: tf.keras.Model, |
| attack_traces: np.ndarray, |
| attack_metadata: np.ndarray, |
| target_byte: int, |
| real_key: int, |
| model_type: str = "mlp", |
| num_traces: int = 2000, |
| rank_step: int = 10, |
| output_index: Optional[int] = None, |
| cached_predictions: Optional[Dict] = None, |
| ) -> Dict: |
| """ |
| Run the full key rank evaluation on a trained model. |
| |
| Args: |
| model: Trained Keras model. |
| attack_traces: Raw attack traces, shape (N, trace_length). |
| attack_metadata: Structured metadata array. |
| target_byte: Target key byte index (0-15). |
| real_key: True key byte value (0-255). |
| model_type: 'mlp', 'cnn', or 'mtan' (determines input reshaping). |
| num_traces: Number of attack traces to evaluate. |
| rank_step: Step size for intermediate rank computation. |
| output_index: For multi-output models (MTAN), the index of the |
| output to evaluate. If None, the model is assumed single-output. |
| cached_predictions: Optional pre-computed predictions dict from a |
| single model.predict() call. When provided, skips the forward |
| pass entirely. This avoids redundant forward passes when |
| evaluating all 16 bytes of a multi-output model. |
| |
| Returns: |
| Dictionary with evaluation results: |
| 'final_rank', 'ranks', 'pre_rank', 'min_rank', 'max_rank', |
| 'rank_at_500', 'rank_at_1000'. |
| """ |
| if cached_predictions is not None: |
| raw_predictions = cached_predictions |
| else: |
| traces = attack_traces[:num_traces].copy() |
|
|
| |
| if model_type in ("cnn", "mtan"): |
| traces = traces.reshape((traces.shape[0], traces.shape[1], 1)) |
|
|
| raw_predictions = model.predict(traces, batch_size=256, verbose=0) |
|
|
| |
| |
| |
| if output_index is not None: |
| if isinstance(raw_predictions, dict): |
| key = f"byte_{output_index}" |
| if key in raw_predictions: |
| predictions = raw_predictions[key] |
| else: |
| |
| predictions = list(raw_predictions.values())[output_index] |
| elif isinstance(raw_predictions, (list, tuple)): |
| predictions = raw_predictions[output_index] |
| else: |
| predictions = raw_predictions |
| else: |
| predictions = raw_predictions |
|
|
| |
| if not isinstance(predictions, np.ndarray): |
| predictions = np.array(predictions) |
|
|
| |
| if predictions.ndim == 2 and predictions.shape[1] == 8: |
| logger.info( |
| "Multi-bit predictions detected (shape %s), converting to byte probs", |
| predictions.shape, |
| ) |
| predictions = bits_to_byte_probs(predictions) |
|
|
| if predictions.ndim != 2 or predictions.shape[1] != 256: |
| raise ValueError( |
| f"Expected predictions shape (N, 256) or (N, 8), got {predictions.shape}. " |
| f"raw_predictions type={type(raw_predictions).__name__}, " |
| f"output_index={output_index}" |
| ) |
|
|
| ranks_array, final_rank = compute_key_rank( |
| predictions=predictions, |
| metadata=attack_metadata[:num_traces], |
| real_key=real_key, |
| target_byte=target_byte, |
| num_traces=num_traces, |
| rank_step=rank_step, |
| ) |
|
|
| min_rank = int(np.min(ranks_array[:, 1])) if len(ranks_array) > 0 else 256 |
| max_rank = int(np.max(ranks_array[:, 1])) if len(ranks_array) > 0 else 256 |
| pre_rank = int(ranks_array[0, 1]) if len(ranks_array) > 0 else 256 |
| rank_at_500 = _get_rank_at_n(ranks_array, 500) |
| rank_at_1000 = _get_rank_at_n(ranks_array, 1000) |
|
|
| result = { |
| "final_rank": final_rank, |
| "ranks": ranks_array, |
| "pre_rank": pre_rank, |
| "min_rank": min_rank, |
| "max_rank": max_rank, |
| "rank_at_500": rank_at_500, |
| "rank_at_1000": rank_at_1000, |
| } |
|
|
| logger.info( |
| "Byte %d: pre_rank=%d, final_rank=%d, min_rank=%d, max_rank=%d, " |
| "rank@500=%d, rank@1000=%d", |
| target_byte, pre_rank, final_rank, min_rank, max_rank, |
| rank_at_500, rank_at_1000, |
| ) |
| return result |
|
|