""" Key Rank Evaluation =================== Computes the key rank metric used to evaluate side-channel attack models. Matches the methodology from the official ASCAD evaluation scripts. The key rank measures how many key candidates score higher than the correct key when accumulating log-likelihood scores over increasing numbers of attack traces. A rank of 0 means the correct key is the top candidate. """ import logging from typing import Dict, Optional, Tuple import numpy as np import tensorflow as tf from .constants import AES_SBOX logger = logging.getLogger(__name__) def bits_to_byte_probs(bit_predictions: np.ndarray) -> np.ndarray: """ Convert 8 independent bit probabilities into 256 joint byte probabilities. For multi-bit models (Wu et al., TCHES 2024), each byte is predicted as 8 independent bits. To compute key rank, we need to reconstruct the 256-class probability distribution. For byte value v with bits b_0..b_7: P(v) = prod_j P(b_j=bit_j(v)) where P(b_j=1) = bit_predictions[j] and P(b_j=0) = 1 - bit_predictions[j]. Args: bit_predictions: Model output, shape (N, 8), values in [0, 1]. Returns: Byte probabilities, shape (N, 256), rows sum to 1. """ N = bit_predictions.shape[0] byte_probs = np.ones((N, 256), dtype=np.float64) # Pre-compute bit patterns for all 256 byte values byte_values = np.arange(256, dtype=np.uint8) # bit_patterns[v, j] = j-th bit of value v (MSB first to match np.unpackbits) bit_patterns = np.unpackbits(byte_values[:, np.newaxis], axis=1) # (256, 8) for j in range(8): p_j = bit_predictions[:, j:j+1].astype(np.float64) # (N, 1) # For each byte value, multiply by P(b_j=bit_j(v)) # bit_patterns[:, j] is 0 or 1 for each of 256 values mask = bit_patterns[:, j].astype(np.float64) # (256,) # P(b_j = mask[v]) = mask[v]*p_j + (1-mask[v])*(1-p_j) byte_probs *= mask[np.newaxis, :] * p_j + (1 - mask[np.newaxis, :]) * (1 - p_j) # Normalize to ensure rows sum to 1 (handle numerical precision) row_sums = byte_probs.sum(axis=1, keepdims=True) row_sums = np.maximum(row_sums, 1e-300) # prevent division by zero byte_probs /= row_sums return byte_probs def compute_key_rank( predictions: np.ndarray, metadata: np.ndarray, real_key: int, target_byte: int, num_traces: Optional[int] = None, rank_step: int = 10, ) -> Tuple[np.ndarray, int]: """ Compute the key rank over an increasing number of attack traces. For each key candidate k in [0..255], the log-likelihood score is: score[k] = sum_i log( predictions[i][ Sbox(plaintext_i XOR k) ] ) The rank of the real key is its position when candidates are sorted by descending score. Args: predictions: Model output probabilities, shape (N, 256). metadata: Structured array with 'plaintext' and 'key' fields. real_key: The true key byte value (0-255). target_byte: Index of the target byte (0-15). num_traces: Number of attack traces to use (default: all). rank_step: Step size for computing intermediate ranks. Returns: Tuple of (ranks_array, final_rank). ranks_array: shape (M, 2) with columns [num_traces, rank]. final_rank: The key rank using all traces. """ if num_traces is None: num_traces = len(predictions) num_traces = min(num_traces, len(predictions)) log_scores = np.zeros(256, dtype=np.float64) ranks_list = [] for i in range(num_traces): plaintext_byte = metadata[i]["plaintext"][target_byte] for k in range(256): sbox_out = AES_SBOX[plaintext_byte ^ k] prob = predictions[i][sbox_out] if prob > 0: log_scores[k] += np.log(prob) else: # Back-off: use square of minimum non-zero probability non_zero = predictions[i][predictions[i] > 0] if len(non_zero) > 0: log_scores[k] += np.log(np.min(non_zero) ** 2) if (i + 1) % rank_step == 0: rank = _rank_of_key(log_scores, real_key) ranks_list.append([i + 1, rank]) ranks_array = np.array(ranks_list, dtype=np.uint32) if ranks_list else np.empty((0, 2), dtype=np.uint32) final_rank = _rank_of_key(log_scores, real_key) return ranks_array, final_rank def _rank_of_key(log_scores: np.ndarray, real_key: int) -> int: """Compute the rank of the real key in the sorted score list.""" sorted_scores = np.sort(log_scores)[::-1] real_score = log_scores[real_key] rank = int(np.where(sorted_scores == real_score)[0][0]) return rank def _get_rank_at_n(ranks_array: np.ndarray, n: int) -> int: """Get the rank at exactly n traces, or the closest available.""" if len(ranks_array) == 0: return 256 idx = np.where(ranks_array[:, 0] == n)[0] if len(idx) > 0: return int(ranks_array[idx[0], 1]) before = ranks_array[ranks_array[:, 0] <= n] if len(before) > 0: return int(before[-1, 1]) return int(ranks_array[0, 1]) def evaluate_model( model: tf.keras.Model, attack_traces: np.ndarray, attack_metadata: np.ndarray, target_byte: int, real_key: int, model_type: str = "mlp", num_traces: int = 2000, rank_step: int = 10, output_index: Optional[int] = None, cached_predictions: Optional[Dict] = None, ) -> Dict: """ Run the full key rank evaluation on a trained model. Args: model: Trained Keras model. attack_traces: Raw attack traces, shape (N, trace_length). attack_metadata: Structured metadata array. target_byte: Target key byte index (0-15). real_key: True key byte value (0-255). model_type: 'mlp', 'cnn', or 'mtan' (determines input reshaping). num_traces: Number of attack traces to evaluate. rank_step: Step size for intermediate rank computation. output_index: For multi-output models (MTAN), the index of the output to evaluate. If None, the model is assumed single-output. cached_predictions: Optional pre-computed predictions dict from a single model.predict() call. When provided, skips the forward pass entirely. This avoids redundant forward passes when evaluating all 16 bytes of a multi-output model. Returns: Dictionary with evaluation results: 'final_rank', 'ranks', 'pre_rank', 'min_rank', 'max_rank', 'rank_at_500', 'rank_at_1000'. """ if cached_predictions is not None: raw_predictions = cached_predictions else: traces = attack_traces[:num_traces].copy() # Reshape for CNN/MTAN (add channel dimension) if model_type in ("cnn", "mtan"): traces = traces.reshape((traces.shape[0], traces.shape[1], 1)) raw_predictions = model.predict(traces, batch_size=256, verbose=0) # For multi-output models, extract the specific task's predictions. # Keras returns a dict when model outputs are named (e.g. {'byte_0': ..., 'byte_1': ...}). # It may also return a list for unnamed multi-output models. if output_index is not None: if isinstance(raw_predictions, dict): key = f"byte_{output_index}" if key in raw_predictions: predictions = raw_predictions[key] else: # Fallback: try by position in dict values predictions = list(raw_predictions.values())[output_index] elif isinstance(raw_predictions, (list, tuple)): predictions = raw_predictions[output_index] else: predictions = raw_predictions else: predictions = raw_predictions # Ensure predictions is a numpy array if not isinstance(predictions, np.ndarray): predictions = np.array(predictions) # Multi-bit mode: convert (N, 8) bit predictions to (N, 256) byte probs if predictions.ndim == 2 and predictions.shape[1] == 8: logger.info( "Multi-bit predictions detected (shape %s), converting to byte probs", predictions.shape, ) predictions = bits_to_byte_probs(predictions) if predictions.ndim != 2 or predictions.shape[1] != 256: raise ValueError( f"Expected predictions shape (N, 256) or (N, 8), got {predictions.shape}. " f"raw_predictions type={type(raw_predictions).__name__}, " f"output_index={output_index}" ) ranks_array, final_rank = compute_key_rank( predictions=predictions, metadata=attack_metadata[:num_traces], real_key=real_key, target_byte=target_byte, num_traces=num_traces, rank_step=rank_step, ) min_rank = int(np.min(ranks_array[:, 1])) if len(ranks_array) > 0 else 256 max_rank = int(np.max(ranks_array[:, 1])) if len(ranks_array) > 0 else 256 pre_rank = int(ranks_array[0, 1]) if len(ranks_array) > 0 else 256 rank_at_500 = _get_rank_at_n(ranks_array, 500) rank_at_1000 = _get_rank_at_n(ranks_array, 1000) result = { "final_rank": final_rank, "ranks": ranks_array, "pre_rank": pre_rank, "min_rank": min_rank, "max_rank": max_rank, "rank_at_500": rank_at_500, "rank_at_1000": rank_at_1000, } logger.info( "Byte %d: pre_rank=%d, final_rank=%d, min_rank=%d, max_rank=%d, " "rank@500=%d, rank@1000=%d", target_byte, pre_rank, final_rank, min_rank, max_rank, rank_at_500, rank_at_1000, ) return result