| |
| """ |
| Gradient Saliency Map Generator for ASCAD Models |
| ================================================= |
| Computes input-space gradient saliency maps for all completed training runs. |
| For each model, this script: |
| 1. Downloads the model from HuggingFace |
| 2. Loads attack traces from the ASCAD dataset |
| 3. Computes gradients of the model output w.r.t. input traces |
| 4. Saves the raw gradient data (numpy) and visualization (PNG) |
| |
| Approach: |
| - For models with softmax output: create a logit model (remove softmax) and |
| compute gradient of the correct-class logit w.r.t. input |
| - For multi-bit binary models: compute gradient of the sum of correct-bit |
| sigmoid outputs w.r.t. input |
| - Average absolute gradients over N attack traces for stable estimates |
| |
| Output per model: |
| - gradient_map.npy: Raw averaged |gradient| array (shape: input_length,) |
| - gradient_map.png: Visualization with SNR overlay for comparison |
| |
| References: |
| - Masure, Dumas, Prouff (2020) "Gradient Visualization for General |
| Characterization in Profiling Attacks" (TCHES 2020) |
| - Simonyan, Vedaldi, Zisserman (2014) "Deep Inside Convolutional Networks: |
| Visualising Image Classification Models and Saliency Maps" |
| """ |
|
|
| import os |
| import sys |
| import json |
| import logging |
| import argparse |
| import gc |
| import traceback |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' |
|
|
| import numpy as np |
| import tensorflow as tf |
| from huggingface_hub import hf_hub_download |
| import h5py |
| import keras |
|
|
| try: |
| import tf_keras |
| HAS_TF_KERAS = True |
| except ImportError: |
| HAS_TF_KERAS = False |
|
|
| |
| |
| |
|
|
| BYTE_POI_WINDOWS = { |
| 0: (30838, 31538), |
| 1: (24525, 25225), |
| 2: (45400, 46100), |
| 3: (32824, 33524), |
| 4: (47508, 48208), |
| 5: (41258, 41958), |
| 6: (37094, 37794), |
| 7: (35018, 35718), |
| 8: (26631, 27331), |
| 9: (39145, 39845), |
| 10: (28766, 29466), |
| 11: (43333, 44033), |
| 12: (20418, 21118), |
| 13: (22499, 23199), |
| 14: (49571, 50271), |
| 15: (18363, 19063), |
| } |
|
|
| GLOBAL_WINDOW_START = 18000 |
| GLOBAL_WINDOW_END = 50272 |
| GLOBAL_WINDOW_SIZE = GLOBAL_WINDOW_END - GLOBAL_WINDOW_START |
| WINDOW_SIZE = 700 |
| NUM_CLASSES = 256 |
|
|
| AES_SBOX = np.array([ |
| 0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76, |
| 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0, |
| 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15, |
| 0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75, |
| 0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84, |
| 0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF, |
| 0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8, |
| 0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2, |
| 0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73, |
| 0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB, |
| 0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79, |
| 0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08, |
| 0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A, |
| 0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, |
| 0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF, |
| 0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16, |
| ], dtype=np.uint8) |
|
|
| |
| |
| |
|
|
| class SigmoidGate(tf.keras.layers.Layer): |
| """Learnable per-filter sigmoid gate.""" |
| def __init__(self, num_filters: int, init_value: float = 0.0, **kwargs): |
| super().__init__(**kwargs) |
| self.num_filters = num_filters |
| self.init_value = init_value |
|
|
| def build(self, input_shape): |
| self.gate = self.add_weight( |
| name="gate", |
| shape=(self.num_filters,), |
| initializer=tf.keras.initializers.Constant(self.init_value), |
| trainable=True, |
| ) |
| super().build(input_shape) |
|
|
| def call(self, inputs): |
| return inputs * tf.sigmoid(self.gate) |
|
|
| def get_config(self): |
| config = super().get_config() |
| config.update({"num_filters": self.num_filters, "init_value": self.init_value}) |
| return config |
|
|
|
|
| class FocalCategoricalCrossentropy(tf.keras.losses.Loss): |
| """Focal loss for categorical classification.""" |
| def __init__(self, gamma=2.0, label_smoothing=0.0, from_logits=False, **kwargs): |
| super().__init__(**kwargs) |
| self.gamma = gamma |
| self.label_smoothing = label_smoothing |
| self.from_logits = from_logits |
|
|
| def call(self, y_true, y_pred): |
| y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7) |
| ce = -y_true * tf.math.log(y_pred) |
| weight = tf.pow(1.0 - y_pred, self.gamma) * y_true |
| return tf.reduce_sum(weight * ce, axis=-1) |
|
|
| def get_config(self): |
| config = super().get_config() |
| config.update({ |
| "gamma": self.gamma, |
| "label_smoothing": self.label_smoothing, |
| "from_logits": self.from_logits, |
| }) |
| return config |
|
|
|
|
| class SoftAttentionBlock(tf.keras.layers.Layer): |
| """Soft attention block for MTAN-Lite (channels-based) and SNRMTAN (filters-based). |
| |
| Handles both constructor signatures: |
| - MTAN-Lite: SoftAttentionBlock(channels=N, bottleneck_ratio=4) |
| - SNRMTAN: SoftAttentionBlock(filters=N, task_id=X, block_id=Y) |
| """ |
| def __init__(self, channels=None, filters=None, bottleneck_ratio=4, task_id=0, block_id=0, **kwargs): |
| super().__init__(**kwargs) |
| |
| self.channels = channels or filters |
| self.filters = self.channels |
| self.bottleneck_ratio = bottleneck_ratio |
| self.task_id = task_id |
| self.block_id = block_id |
| self.bottleneck = max(self.channels // bottleneck_ratio, 16) |
|
|
| def build(self, input_shape): |
| self.conv_down = tf.keras.layers.Conv1D( |
| self.bottleneck, 1, padding='same', activation='relu', |
| kernel_initializer='he_uniform', |
| ) |
| self.conv_up = tf.keras.layers.Conv1D( |
| self.channels, 1, padding='same', activation='sigmoid', |
| kernel_initializer='glorot_uniform', |
| ) |
| super().build(input_shape) |
|
|
| def call(self, inputs, training=None): |
| att = self.conv_down(inputs) |
| att = self.conv_up(att) |
| return inputs * att |
|
|
| def get_config(self): |
| config = super().get_config() |
| config.update({ |
| "channels": self.channels, |
| "filters": self.filters, |
| "bottleneck_ratio": self.bottleneck_ratio, |
| "task_id": self.task_id, |
| "block_id": self.block_id, |
| }) |
| return config |
|
|
|
|
| CUSTOM_OBJECTS = { |
| 'SigmoidGate': SigmoidGate, |
| 'FocalCategoricalCrossentropy': FocalCategoricalCrossentropy, |
| 'SoftAttentionBlock': SoftAttentionBlock, |
| } |
|
|
|
|
| def load_model_smart(model_path: str, name: str): |
| """ |
| Smart model loader that detects the keras version from the h5 file |
| and uses the appropriate loading API. |
| |
| - Models saved with keras >= 3.x: use keras.saving.load_model() |
| - Models saved with keras 2.x (tf.keras): use tf_keras.models.load_model() |
| """ |
| |
| keras_version = "unknown" |
| try: |
| with h5py.File(model_path, 'r') as f: |
| if 'keras_version' in f.attrs: |
| kv = f.attrs['keras_version'] |
| if isinstance(kv, bytes): |
| kv = kv.decode() |
| keras_version = kv |
| except Exception: |
| pass |
|
|
| is_keras2 = keras_version.startswith('2.') |
| logging.debug(f" Model {name} saved with keras {keras_version}, is_keras2={is_keras2}") |
|
|
| |
| if is_keras2 and HAS_TF_KERAS: |
| |
| try: |
| model = tf_keras.models.load_model( |
| model_path, custom_objects=CUSTOM_OBJECTS, compile=False |
| ) |
| return model |
| except Exception as e: |
| logging.warning(f" tf_keras load failed for {name}: {e}") |
| |
|
|
| |
| try: |
| model = keras.saving.load_model( |
| model_path, safe_mode=False, compile=False, |
| custom_objects=CUSTOM_OBJECTS, |
| ) |
| return model |
| except Exception as e: |
| logging.warning(f" keras.saving load failed for {name}: {e}") |
|
|
| |
| try: |
| model = tf.keras.models.load_model( |
| model_path, custom_objects=CUSTOM_OBJECTS, compile=False |
| ) |
| return model |
| except Exception as e: |
| logging.error(f"Failed to load model for {name} (all loaders failed): {e}") |
| return None |
|
|
|
|
| |
| |
| |
|
|
| def load_ascad_data(h5_path: str, desync: int = 0, n_traces: int = 1000): |
| """Load attack traces and metadata from ASCAD HDF5 file. |
| |
| Supports both the raw traces file (ATMega8515_raw_traces.h5) with |
| top-level 'traces' and 'metadata' datasets, and the windowed ASCAD.h5 |
| file with 'Attack_traces/traces' and 'Attack_traces/metadata' groups. |
| |
| For the raw file, attack traces are the last 10,000 (indices 50000:60000). |
| """ |
| with h5py.File(h5_path, 'r') as f: |
| if 'Attack_traces' in f: |
| |
| traces = np.array(f['Attack_traces/traces'][:n_traces], dtype=np.float64) |
| metadata = f['Attack_traces/metadata'][:n_traces] |
| else: |
| |
| attack_start = 50000 |
| traces = np.array( |
| f['traces'][attack_start:attack_start + n_traces], dtype=np.float64 |
| ) |
| metadata = f['metadata'][attack_start:attack_start + n_traces] |
| |
| plaintexts = np.array([m[0] for m in metadata]) |
| keys = np.array([m[2] for m in metadata]) |
|
|
| |
| if desync > 0: |
| np.random.seed(42) |
| shifts = np.random.randint(0, desync + 1, size=n_traces) |
| shifted_traces = np.zeros_like(traces) |
| for i in range(n_traces): |
| s = shifts[i] |
| if s > 0: |
| shifted_traces[i, :-s] = traces[i, s:] |
| else: |
| shifted_traces[i] = traces[i] |
| traces = shifted_traces |
|
|
| return traces, plaintexts, keys |
|
|
|
|
| def extract_byte_window(traces: np.ndarray, byte_idx: int) -> np.ndarray: |
| """Extract the 700-sample POI window for a specific byte.""" |
| start, end = BYTE_POI_WINDOWS[byte_idx] |
| return traces[:, start:end] |
|
|
|
|
| def extract_global_window(traces: np.ndarray) -> np.ndarray: |
| """Extract the global window for multi-task models.""" |
| return traces[:, GLOBAL_WINDOW_START:GLOBAL_WINDOW_END] |
|
|
|
|
| def normalize_traces(traces: np.ndarray) -> np.ndarray: |
| """Per-trace z-score normalization.""" |
| mean = traces.mean(axis=1, keepdims=True) |
| std = traces.std(axis=1, keepdims=True) |
| std = np.where(std == 0, 1.0, std) |
| return (traces - mean) / std |
|
|
|
|
| def compute_labels(plaintexts: np.ndarray, keys: np.ndarray, byte_idx: int) -> np.ndarray: |
| """Compute AES S-Box output labels for a specific byte.""" |
| return AES_SBOX[plaintexts[:, byte_idx] ^ keys[:, byte_idx]] |
|
|
|
|
| |
| |
| |
|
|
| def make_logit_model(model: tf.keras.Model) -> tf.keras.Model: |
| """ |
| Create a model that outputs pre-activation logits instead of softmax probs. |
| This avoids the vanishing gradient problem with softmax outputs. |
| """ |
| |
| last_layer = None |
| for layer in reversed(model.layers): |
| if isinstance(layer, tf.keras.layers.Dense): |
| last_layer = layer |
| break |
|
|
| if last_layer is None: |
| return model |
|
|
| |
| activation = last_layer.get_config().get('activation', 'linear') |
| if activation == 'softmax': |
| |
| |
| |
| pre_last = last_layer.input |
| logits = tf.keras.layers.Dense( |
| last_layer.units, |
| activation='linear', |
| name='logits_output', |
| )(pre_last) |
|
|
| logit_model = tf.keras.Model(inputs=model.input, outputs=logits) |
|
|
| |
| logit_model.get_layer('logits_output').set_weights(last_layer.get_weights()) |
| return logit_model |
| elif activation == 'sigmoid': |
| |
| return model |
| else: |
| return model |
|
|
|
|
| def compute_saliency_single_byte( |
| model: tf.keras.Model, |
| traces: np.ndarray, |
| labels: np.ndarray, |
| batch_size: int = 128, |
| is_multibit: bool = False, |
| ) -> np.ndarray: |
| """ |
| Compute input-space saliency map for a single-byte model. |
| |
| Uses the gradient of the correct-class logit (or correct-bit outputs) |
| w.r.t. the input trace, averaged over multiple traces. |
| |
| Args: |
| model: The trained model (or logit model) |
| traces: Input traces, shape (N, time_steps) or (N, time_steps, 1) |
| labels: Correct labels, shape (N,) for identity or (N, 8) for multibit |
| batch_size: Batch size for gradient computation |
| is_multibit: Whether the model uses multi-bit binary encoding |
| |
| Returns: |
| Averaged absolute gradient, shape (time_steps,) |
| """ |
| n_traces = len(traces) |
| input_shape = model.input_shape[1:] |
|
|
| |
| if len(input_shape) == 2 and len(traces.shape) == 2: |
| |
| traces = traces[..., np.newaxis] |
| elif len(input_shape) == 1 and len(traces.shape) == 3: |
| |
| traces = traces.squeeze(-1) |
|
|
| all_grads = [] |
|
|
| for start in range(0, n_traces, batch_size): |
| end = min(start + batch_size, n_traces) |
| batch_traces = tf.constant(traces[start:end], dtype=tf.float32) |
| batch_labels = labels[start:end] |
|
|
| with tf.GradientTape() as tape: |
| tape.watch(batch_traces) |
| outputs = model(batch_traces, training=False) |
|
|
| if is_multibit: |
| |
| |
| |
| target = tf.reduce_sum( |
| outputs * tf.constant(batch_labels, dtype=tf.float32), |
| axis=-1 |
| ) |
| else: |
| |
| |
| batch_indices = tf.range(end - start) |
| indices = tf.stack([batch_indices, tf.constant(batch_labels, dtype=tf.int32)], axis=1) |
| target = tf.gather_nd(outputs, indices) |
|
|
| loss = tf.reduce_sum(target) |
|
|
| grad = tape.gradient(loss, batch_traces) |
| if grad is not None: |
| |
| abs_grad = tf.reduce_mean(tf.abs(grad), axis=0).numpy() |
| all_grads.append(abs_grad) |
|
|
| if not all_grads: |
| logging.warning("No gradients computed!") |
| return np.zeros(input_shape[0] if len(input_shape) == 1 else input_shape[0]) |
|
|
| |
| avg_grad = np.mean(all_grads, axis=0) |
|
|
| |
| if len(avg_grad.shape) > 1: |
| avg_grad = avg_grad.squeeze(-1) |
|
|
| return avg_grad |
|
|
|
|
| def compute_saliency_multitask( |
| model: tf.keras.Model, |
| traces_dict: Dict[str, np.ndarray], |
| labels_dict: Dict[str, np.ndarray], |
| batch_size: int = 64, |
| is_multibit: bool = True, |
| ) -> Dict[int, np.ndarray]: |
| """ |
| Compute per-byte saliency maps for multi-task (LMIC/LMIC-TSBN) models. |
| |
| For LMIC models with 16 named inputs, computes the gradient of each |
| byte's output w.r.t. its corresponding input. |
| |
| Returns: |
| Dict mapping byte_idx -> averaged absolute gradient (shape: 700,) |
| """ |
| n_traces = len(list(traces_dict.values())[0]) |
| saliency_maps = {} |
|
|
| for byte_idx in range(16): |
| input_name = f"byte_{byte_idx}_input" |
| output_name = f"byte_{byte_idx}" |
|
|
| if input_name not in traces_dict: |
| continue |
|
|
| byte_traces = traces_dict[input_name] |
| byte_labels = labels_dict.get(output_name, labels_dict.get(f"byte_{byte_idx}", None)) |
| if byte_labels is None: |
| continue |
|
|
| all_grads = [] |
|
|
| for start in range(0, n_traces, batch_size): |
| end = min(start + batch_size, n_traces) |
|
|
| |
| batch_dict = {} |
| for key, val in traces_dict.items(): |
| batch_dict[key] = tf.constant(val[start:end], dtype=tf.float32) |
|
|
| |
| target_input = batch_dict[input_name] |
|
|
| with tf.GradientTape() as tape: |
| tape.watch(target_input) |
| batch_dict[input_name] = target_input |
| outputs = model(batch_dict, training=False) |
|
|
| |
| if isinstance(outputs, dict): |
| byte_output = outputs[output_name] |
| elif isinstance(outputs, (list, tuple)): |
| byte_output = outputs[byte_idx] |
| else: |
| byte_output = outputs |
|
|
| batch_labels = byte_labels[start:end] |
|
|
| if is_multibit: |
| target = tf.reduce_sum( |
| byte_output * tf.constant(batch_labels, dtype=tf.float32), |
| axis=-1 |
| ) |
| else: |
| batch_indices = tf.range(end - start) |
| indices = tf.stack([batch_indices, tf.constant(batch_labels, dtype=tf.int32)], axis=1) |
| target = tf.gather_nd(byte_output, indices) |
|
|
| loss = tf.reduce_sum(target) |
|
|
| grad = tape.gradient(loss, target_input) |
| if grad is not None: |
| abs_grad = tf.reduce_mean(tf.abs(grad), axis=0).numpy() |
| all_grads.append(abs_grad) |
|
|
| if all_grads: |
| avg_grad = np.mean(all_grads, axis=0) |
| if len(avg_grad.shape) > 1: |
| avg_grad = avg_grad.squeeze(-1) |
| saliency_maps[byte_idx] = avg_grad |
|
|
| return saliency_maps |
|
|
|
|
| def compute_saliency_global_multitask( |
| model: tf.keras.Model, |
| traces: np.ndarray, |
| labels_dict: Dict[str, np.ndarray], |
| batch_size: int = 32, |
| is_multibit: bool = False, |
| ) -> Dict[int, np.ndarray]: |
| """ |
| Compute per-byte saliency maps for global-window multi-task models |
| (HPS, MTAN-Lite) that take a single global input. |
| |
| Returns: |
| Dict mapping byte_idx -> averaged absolute gradient (shape: global_window_size,) |
| """ |
| n_traces = len(traces) |
| if len(traces.shape) == 2: |
| traces = traces[..., np.newaxis] |
|
|
| saliency_maps = {} |
|
|
| for byte_idx in range(16): |
| output_name = f"byte_{byte_idx}" |
| byte_labels = labels_dict.get(output_name, None) |
| if byte_labels is None: |
| continue |
|
|
| all_grads = [] |
|
|
| for start in range(0, n_traces, batch_size): |
| end = min(start + batch_size, n_traces) |
| batch_traces = tf.constant(traces[start:end], dtype=tf.float32) |
| batch_labels = byte_labels[start:end] |
|
|
| with tf.GradientTape() as tape: |
| tape.watch(batch_traces) |
| outputs = model(batch_traces, training=False) |
|
|
| if isinstance(outputs, dict): |
| byte_output = outputs[output_name] |
| elif isinstance(outputs, (list, tuple)): |
| byte_output = outputs[byte_idx] |
| else: |
| byte_output = outputs |
|
|
| if is_multibit: |
| target = tf.reduce_sum( |
| byte_output * tf.constant(batch_labels, dtype=tf.float32), |
| axis=-1 |
| ) |
| else: |
| batch_indices = tf.range(end - start) |
| indices = tf.stack([batch_indices, tf.constant(batch_labels, dtype=tf.int32)], axis=1) |
| target = tf.gather_nd(byte_output, indices) |
|
|
| loss = tf.reduce_sum(target) |
|
|
| grad = tape.gradient(loss, batch_traces) |
| if grad is not None: |
| abs_grad = tf.reduce_mean(tf.abs(grad), axis=0).numpy() |
| all_grads.append(abs_grad) |
|
|
| if all_grads: |
| avg_grad = np.mean(all_grads, axis=0) |
| if len(avg_grad.shape) > 1: |
| avg_grad = avg_grad.squeeze(-1) |
| saliency_maps[byte_idx] = avg_grad |
|
|
| return saliency_maps |
|
|
|
|
| |
| |
| |
|
|
| def process_single_byte_job( |
| job: dict, |
| h5_path: str, |
| output_dir: str, |
| hf_token: str, |
| n_traces: int = 1000, |
| ) -> dict: |
| """Process a single-byte (MLP/CNN) job.""" |
| name = job['name'] |
| model_type = job['model_type'] |
| desync = job['params']['desync'] |
| hf_url = job['hf_url'] |
|
|
| |
| |
| byte_idx = int(hf_url.split('/byte')[-1]) |
|
|
| |
| if 'ascad-v1-models' in hf_url: |
| repo_id = 'lemousehunter/ascad-v1-models' |
| else: |
| repo_id = 'lemousehunter/ascad-mtan-rank0-models' |
|
|
| |
| hf_path_prefix = hf_url.split('/tree/main/')[-1] |
|
|
| logging.info(f"Processing {name}: {model_type} byte{byte_idx} desync={desync}") |
|
|
| |
| try: |
| model_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=f"{hf_path_prefix}/model.h5", |
| token=hf_token, |
| ) |
| except Exception as e: |
| logging.error(f"Failed to download model for {name}: {e}") |
| return {"name": name, "status": "download_failed", "error": str(e)} |
|
|
| |
| model = load_model_smart(model_path, name) |
| if model is None: |
| return {"name": name, "status": "load_failed", "error": "all loaders failed"} |
|
|
| |
| logit_model = make_logit_model(model) |
|
|
| |
| traces, plaintexts, keys = load_ascad_data(h5_path, desync=desync, n_traces=n_traces) |
|
|
| |
| byte_traces = extract_byte_window(traces, byte_idx) |
| byte_traces = normalize_traces(byte_traces) |
|
|
| |
| labels = compute_labels(plaintexts, keys, byte_idx) |
|
|
| |
| saliency = compute_saliency_single_byte( |
| logit_model, byte_traces, labels, |
| batch_size=128, is_multibit=False, |
| ) |
|
|
| |
| job_output_dir = os.path.join(output_dir, name) |
| os.makedirs(job_output_dir, exist_ok=True) |
|
|
| np.save(os.path.join(job_output_dir, 'gradient_map.npy'), saliency) |
|
|
| |
| meta = { |
| "name": name, |
| "model_type": model_type, |
| "byte_idx": byte_idx, |
| "desync": desync, |
| "n_traces": n_traces, |
| "saliency_max": float(np.max(saliency)), |
| "saliency_mean": float(np.mean(saliency)), |
| "saliency_std": float(np.std(saliency)), |
| "input_shape": list(model.input_shape[1:]), |
| "status": "success", |
| } |
| with open(os.path.join(job_output_dir, 'metadata.json'), 'w') as f: |
| json.dump(meta, f, indent=2) |
|
|
| |
| del model, logit_model |
| tf.keras.backend.clear_session() |
| gc.collect() |
|
|
| logging.info(f" Done: max={saliency.max():.6f}, mean={saliency.mean():.6f}") |
| return meta |
|
|
|
|
| def process_lmic_job( |
| job: dict, |
| h5_path: str, |
| output_dir: str, |
| hf_token: str, |
| n_traces: int = 500, |
| ) -> dict: |
| """Process an LMIC/LMIC-TSBN multi-task job.""" |
| name = job['name'] |
| model_type = job['model_type'] |
| desync = job['params']['desync'] |
| hf_url = job['hf_url'] |
|
|
| |
| if 'ascad-v1-models' in hf_url: |
| repo_id = 'lemousehunter/ascad-v1-models' |
| elif 'ascad-mtan-rank0-models' in hf_url: |
| repo_id = 'lemousehunter/ascad-mtan-rank0-models' |
| else: |
| logging.error(f"Unknown HF repo for {name}: {hf_url}") |
| return {"name": name, "status": "unknown_repo", "error": hf_url} |
|
|
| hf_path_prefix = hf_url.split('/tree/main/')[-1] |
|
|
| logging.info(f"Processing {name}: {model_type} desync={desync}") |
|
|
| |
| try: |
| model_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=f"{hf_path_prefix}/model.h5", |
| token=hf_token, |
| ) |
| except Exception as e: |
| logging.error(f"Failed to download model for {name}: {e}") |
| return {"name": name, "status": "download_failed", "error": str(e)} |
|
|
| |
| model = load_model_smart(model_path, name) |
| if model is None: |
| return {"name": name, "status": "load_failed", "error": "all loaders failed"} |
|
|
| |
| |
| sample_output_name = "byte_0" |
| is_multibit = False |
| if hasattr(model, 'output_shape'): |
| if isinstance(model.output_shape, dict): |
| out_shape = model.output_shape.get(sample_output_name, (None, 256)) |
| is_multibit = (out_shape[-1] == 8) |
| elif isinstance(model.output_shape, list): |
| is_multibit = (model.output_shape[0][-1] == 8) |
|
|
| |
| traces, plaintexts, keys = load_ascad_data(h5_path, desync=desync, n_traces=n_traces) |
|
|
| |
| traces_dict = {} |
| labels_dict = {} |
| for byte_idx in range(16): |
| byte_traces = extract_byte_window(traces, byte_idx) |
| byte_traces = normalize_traces(byte_traces) |
| byte_traces = byte_traces[..., np.newaxis] |
| traces_dict[f"byte_{byte_idx}_input"] = byte_traces |
|
|
| sbox_out = compute_labels(plaintexts, keys, byte_idx) |
| if is_multibit: |
| |
| bits = np.unpackbits(sbox_out.astype(np.uint8).reshape(-1, 1), axis=1) |
| labels_dict[f"byte_{byte_idx}"] = bits |
| else: |
| labels_dict[f"byte_{byte_idx}"] = sbox_out |
|
|
| |
| saliency_maps = compute_saliency_multitask( |
| model, traces_dict, labels_dict, |
| batch_size=64, is_multibit=is_multibit, |
| ) |
|
|
| |
| job_output_dir = os.path.join(output_dir, name) |
| os.makedirs(job_output_dir, exist_ok=True) |
|
|
| for byte_idx, saliency in saliency_maps.items(): |
| np.save(os.path.join(job_output_dir, f'gradient_map_byte{byte_idx}.npy'), saliency) |
|
|
| |
| meta = { |
| "name": name, |
| "model_type": model_type, |
| "desync": desync, |
| "n_traces": n_traces, |
| "is_multibit": is_multibit, |
| "bytes_computed": list(saliency_maps.keys()), |
| "saliency_stats": { |
| str(b): { |
| "max": float(np.max(s)), |
| "mean": float(np.mean(s)), |
| "std": float(np.std(s)), |
| } |
| for b, s in saliency_maps.items() |
| }, |
| "status": "success", |
| } |
| with open(os.path.join(job_output_dir, 'metadata.json'), 'w') as f: |
| json.dump(meta, f, indent=2) |
|
|
| |
| del model |
| tf.keras.backend.clear_session() |
| gc.collect() |
|
|
| logging.info(f" Done: {len(saliency_maps)} bytes computed") |
| return meta |
|
|
|
|
| def process_global_mtl_job( |
| job: dict, |
| h5_path: str, |
| output_dir: str, |
| hf_token: str, |
| n_traces: int = 300, |
| ) -> dict: |
| """Process a global-window multi-task job (HPS, MTAN-Lite).""" |
| name = job['name'] |
| model_type = job['model_type'] |
| desync = job['params']['desync'] |
| hf_url = job['hf_url'] |
|
|
| if not hf_url: |
| return {"name": name, "status": "no_hf_url"} |
|
|
| |
| if 'ascad-v1-models' in hf_url: |
| repo_id = 'lemousehunter/ascad-v1-models' |
| elif 'ascad-mtan-rank0-models' in hf_url: |
| repo_id = 'lemousehunter/ascad-mtan-rank0-models' |
| else: |
| return {"name": name, "status": "unknown_repo", "error": hf_url} |
|
|
| hf_path_prefix = hf_url.split('/tree/main/')[-1] |
|
|
| logging.info(f"Processing {name}: {model_type} desync={desync}") |
|
|
| |
| try: |
| model_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=f"{hf_path_prefix}/model.h5", |
| token=hf_token, |
| ) |
| except Exception as e: |
| logging.error(f"Failed to download model for {name}: {e}") |
| return {"name": name, "status": "download_failed", "error": str(e)} |
|
|
| |
| try: |
| model = tf.keras.models.load_model( |
| model_path, custom_objects=CUSTOM_OBJECTS, compile=False |
| ) |
| except Exception as e: |
| logging.error(f"Failed to load model for {name}: {e}") |
| return {"name": name, "status": "load_failed", "error": str(e)} |
|
|
| |
| traces, plaintexts, keys = load_ascad_data(h5_path, desync=desync, n_traces=n_traces) |
|
|
| |
| global_traces = extract_global_window(traces) |
| global_traces = normalize_traces(global_traces) |
|
|
| |
| labels_dict = {} |
| for byte_idx in range(16): |
| sbox_out = compute_labels(plaintexts, keys, byte_idx) |
| labels_dict[f"byte_{byte_idx}"] = sbox_out |
|
|
| |
| saliency_maps = compute_saliency_global_multitask( |
| model, global_traces, labels_dict, |
| batch_size=32, is_multibit=False, |
| ) |
|
|
| |
| job_output_dir = os.path.join(output_dir, name) |
| os.makedirs(job_output_dir, exist_ok=True) |
|
|
| for byte_idx, saliency in saliency_maps.items(): |
| np.save(os.path.join(job_output_dir, f'gradient_map_byte{byte_idx}.npy'), saliency) |
|
|
| |
| meta = { |
| "name": name, |
| "model_type": model_type, |
| "desync": desync, |
| "n_traces": n_traces, |
| "bytes_computed": list(saliency_maps.keys()), |
| "saliency_stats": { |
| str(b): { |
| "max": float(np.max(s)), |
| "mean": float(np.mean(s)), |
| "std": float(np.std(s)), |
| } |
| for b, s in saliency_maps.items() |
| }, |
| "status": "success", |
| } |
| with open(os.path.join(job_output_dir, 'metadata.json'), 'w') as f: |
| json.dump(meta, f, indent=2) |
|
|
| |
| del model |
| tf.keras.backend.clear_session() |
| gc.collect() |
|
|
| logging.info(f" Done: {len(saliency_maps)} bytes computed") |
| return meta |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate gradient saliency maps") |
| parser.add_argument("--jobs-yaml", required=True, help="Path to orchestrator_jobs_updated.yaml") |
| parser.add_argument("--h5-path", required=True, help="Path to ASCAD HDF5 dataset") |
| parser.add_argument("--output-dir", default="./gradient_maps", help="Output directory") |
| parser.add_argument("--hf-token", required=True, help="HuggingFace token") |
| parser.add_argument("--n-traces", type=int, default=1000, help="Number of attack traces") |
| parser.add_argument("--filter-type", default=None, help="Filter by model type") |
| parser.add_argument("--filter-name", default=None, help="Filter by job name (substring)") |
| parser.add_argument("--skip-existing", action="store_true", help="Skip jobs with existing output") |
| args = parser.parse_args() |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s [%(levelname)s] %(message)s', |
| handlers=[ |
| logging.FileHandler(os.path.join(args.output_dir, 'generation.log')), |
| logging.StreamHandler(), |
| ] |
| ) |
|
|
| import yaml |
| with open(args.jobs_yaml) as f: |
| data = yaml.safe_load(f) |
| jobs = data['jobs'] |
|
|
| |
| jobs = [j for j in jobs if j.get('status') == 'completed' and j.get('hf_url')] |
|
|
| if args.filter_type: |
| jobs = [j for j in jobs if j['model_type'] == args.filter_type] |
| if args.filter_name: |
| jobs = [j for j in jobs if args.filter_name in j['name']] |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| logging.info(f"Processing {len(jobs)} jobs") |
| logging.info(f"Output directory: {args.output_dir}") |
| logging.info(f"N traces: {args.n_traces}") |
|
|
| results = [] |
| for i, job in enumerate(jobs): |
| name = job['name'] |
|
|
| |
| if args.skip_existing: |
| job_dir = os.path.join(args.output_dir, name) |
| meta_path = os.path.join(job_dir, 'metadata.json') |
| if os.path.exists(meta_path): |
| with open(meta_path) as f: |
| existing = json.load(f) |
| if existing.get('status') == 'success': |
| logging.info(f"[{i+1}/{len(jobs)}] Skipping {name} (already done)") |
| results.append(existing) |
| continue |
|
|
| logging.info(f"[{i+1}/{len(jobs)}] {name}") |
|
|
| try: |
| if job['model_type'] in ('mlp', 'cnn'): |
| result = process_single_byte_job( |
| job, args.h5_path, args.output_dir, args.hf_token, args.n_traces |
| ) |
| elif job['model_type'] in ('lmic', 'lmic_tsbn'): |
| result = process_lmic_job( |
| job, args.h5_path, args.output_dir, args.hf_token, |
| n_traces=min(args.n_traces, 500), |
| ) |
| elif job['model_type'] in ('hps', 'mtan_lite'): |
| result = process_global_mtl_job( |
| job, args.h5_path, args.output_dir, args.hf_token, |
| n_traces=min(args.n_traces, 300), |
| ) |
| else: |
| result = {"name": name, "status": "unknown_type", "model_type": job['model_type']} |
| except Exception as e: |
| logging.error(f" FAILED: {e}") |
| traceback.print_exc() |
| result = {"name": name, "status": "error", "error": str(e)} |
| tf.keras.backend.clear_session() |
| gc.collect() |
|
|
| results.append(result) |
|
|
| |
| summary_path = os.path.join(args.output_dir, 'summary.json') |
| with open(summary_path, 'w') as f: |
| json.dump(results, f, indent=2) |
|
|
| |
| success = sum(1 for r in results if r.get('status') == 'success') |
| failed = len(results) - success |
| logging.info(f"\nDone! Success: {success}/{len(results)}, Failed: {failed}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|