#!/usr/bin/env python3 """ Gradient Saliency Map Generator for ASCAD Models ================================================= Computes input-space gradient saliency maps for all completed training runs. For each model, this script: 1. Downloads the model from HuggingFace 2. Loads attack traces from the ASCAD dataset 3. Computes gradients of the model output w.r.t. input traces 4. Saves the raw gradient data (numpy) and visualization (PNG) Approach: - For models with softmax output: create a logit model (remove softmax) and compute gradient of the correct-class logit w.r.t. input - For multi-bit binary models: compute gradient of the sum of correct-bit sigmoid outputs w.r.t. input - Average absolute gradients over N attack traces for stable estimates Output per model: - gradient_map.npy: Raw averaged |gradient| array (shape: input_length,) - gradient_map.png: Visualization with SNR overlay for comparison References: - Masure, Dumas, Prouff (2020) "Gradient Visualization for General Characterization in Profiling Attacks" (TCHES 2020) - Simonyan, Vedaldi, Zisserman (2014) "Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps" """ import os import sys import json import logging import argparse import gc import traceback from pathlib import Path from typing import Dict, List, Optional, Tuple os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import numpy as np import tensorflow as tf from huggingface_hub import hf_hub_download import h5py import keras try: import tf_keras HAS_TF_KERAS = True except ImportError: HAS_TF_KERAS = False # ============================================================================ # Constants (from pipeline) # ============================================================================ BYTE_POI_WINDOWS = { 0: (30838, 31538), 1: (24525, 25225), 2: (45400, 46100), 3: (32824, 33524), 4: (47508, 48208), 5: (41258, 41958), 6: (37094, 37794), 7: (35018, 35718), 8: (26631, 27331), 9: (39145, 39845), 10: (28766, 29466), 11: (43333, 44033), 12: (20418, 21118), 13: (22499, 23199), 14: (49571, 50271), 15: (18363, 19063), } GLOBAL_WINDOW_START = 18000 GLOBAL_WINDOW_END = 50272 GLOBAL_WINDOW_SIZE = GLOBAL_WINDOW_END - GLOBAL_WINDOW_START # 32,272 WINDOW_SIZE = 700 NUM_CLASSES = 256 AES_SBOX = np.array([ 0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76, 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0, 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15, 0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75, 0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84, 0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF, 0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8, 0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2, 0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73, 0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB, 0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79, 0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08, 0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A, 0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, 0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF, 0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16, ], dtype=np.uint8) # ============================================================================ # Custom layers for model loading # ============================================================================ class SigmoidGate(tf.keras.layers.Layer): """Learnable per-filter sigmoid gate.""" def __init__(self, num_filters: int, init_value: float = 0.0, **kwargs): super().__init__(**kwargs) self.num_filters = num_filters self.init_value = init_value def build(self, input_shape): self.gate = self.add_weight( name="gate", shape=(self.num_filters,), initializer=tf.keras.initializers.Constant(self.init_value), trainable=True, ) super().build(input_shape) def call(self, inputs): return inputs * tf.sigmoid(self.gate) def get_config(self): config = super().get_config() config.update({"num_filters": self.num_filters, "init_value": self.init_value}) return config class FocalCategoricalCrossentropy(tf.keras.losses.Loss): """Focal loss for categorical classification.""" def __init__(self, gamma=2.0, label_smoothing=0.0, from_logits=False, **kwargs): super().__init__(**kwargs) self.gamma = gamma self.label_smoothing = label_smoothing self.from_logits = from_logits def call(self, y_true, y_pred): y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7) ce = -y_true * tf.math.log(y_pred) weight = tf.pow(1.0 - y_pred, self.gamma) * y_true return tf.reduce_sum(weight * ce, axis=-1) def get_config(self): config = super().get_config() config.update({ "gamma": self.gamma, "label_smoothing": self.label_smoothing, "from_logits": self.from_logits, }) return config class SoftAttentionBlock(tf.keras.layers.Layer): """Soft attention block for MTAN-Lite (channels-based) and SNRMTAN (filters-based). Handles both constructor signatures: - MTAN-Lite: SoftAttentionBlock(channels=N, bottleneck_ratio=4) - SNRMTAN: SoftAttentionBlock(filters=N, task_id=X, block_id=Y) """ def __init__(self, channels=None, filters=None, bottleneck_ratio=4, task_id=0, block_id=0, **kwargs): super().__init__(**kwargs) # Support both signatures self.channels = channels or filters self.filters = self.channels # alias for get_config compatibility self.bottleneck_ratio = bottleneck_ratio self.task_id = task_id self.block_id = block_id self.bottleneck = max(self.channels // bottleneck_ratio, 16) def build(self, input_shape): self.conv_down = tf.keras.layers.Conv1D( self.bottleneck, 1, padding='same', activation='relu', kernel_initializer='he_uniform', ) self.conv_up = tf.keras.layers.Conv1D( self.channels, 1, padding='same', activation='sigmoid', kernel_initializer='glorot_uniform', ) super().build(input_shape) def call(self, inputs, training=None): att = self.conv_down(inputs) att = self.conv_up(att) return inputs * att def get_config(self): config = super().get_config() config.update({ "channels": self.channels, "filters": self.filters, "bottleneck_ratio": self.bottleneck_ratio, "task_id": self.task_id, "block_id": self.block_id, }) return config CUSTOM_OBJECTS = { 'SigmoidGate': SigmoidGate, 'FocalCategoricalCrossentropy': FocalCategoricalCrossentropy, 'SoftAttentionBlock': SoftAttentionBlock, } def load_model_smart(model_path: str, name: str): """ Smart model loader that detects the keras version from the h5 file and uses the appropriate loading API. - Models saved with keras >= 3.x: use keras.saving.load_model() - Models saved with keras 2.x (tf.keras): use tf_keras.models.load_model() """ # Detect keras version from h5 file attributes keras_version = "unknown" try: with h5py.File(model_path, 'r') as f: if 'keras_version' in f.attrs: kv = f.attrs['keras_version'] if isinstance(kv, bytes): kv = kv.decode() keras_version = kv except Exception: pass is_keras2 = keras_version.startswith('2.') logging.debug(f" Model {name} saved with keras {keras_version}, is_keras2={is_keras2}") # Try loading with appropriate loader if is_keras2 and HAS_TF_KERAS: # Use tf_keras for legacy Keras 2.x models try: model = tf_keras.models.load_model( model_path, custom_objects=CUSTOM_OBJECTS, compile=False ) return model except Exception as e: logging.warning(f" tf_keras load failed for {name}: {e}") # Fall through to try keras 3 # Try Keras 3 native loader try: model = keras.saving.load_model( model_path, safe_mode=False, compile=False, custom_objects=CUSTOM_OBJECTS, ) return model except Exception as e: logging.warning(f" keras.saving load failed for {name}: {e}") # Final fallback: tf.keras (may work for some models) try: model = tf.keras.models.load_model( model_path, custom_objects=CUSTOM_OBJECTS, compile=False ) return model except Exception as e: logging.error(f"Failed to load model for {name} (all loaders failed): {e}") return None # ============================================================================ # Data loading # ============================================================================ def load_ascad_data(h5_path: str, desync: int = 0, n_traces: int = 1000): """Load attack traces and metadata from ASCAD HDF5 file. Supports both the raw traces file (ATMega8515_raw_traces.h5) with top-level 'traces' and 'metadata' datasets, and the windowed ASCAD.h5 file with 'Attack_traces/traces' and 'Attack_traces/metadata' groups. For the raw file, attack traces are the last 10,000 (indices 50000:60000). """ with h5py.File(h5_path, 'r') as f: if 'Attack_traces' in f: # Windowed ASCAD.h5 format traces = np.array(f['Attack_traces/traces'][:n_traces], dtype=np.float64) metadata = f['Attack_traces/metadata'][:n_traces] else: # Raw traces format: attack set is last 10,000 traces attack_start = 50000 traces = np.array( f['traces'][attack_start:attack_start + n_traces], dtype=np.float64 ) metadata = f['metadata'][attack_start:attack_start + n_traces] plaintexts = np.array([m[0] for m in metadata]) # plaintext bytes (16,) keys = np.array([m[2] for m in metadata]) # key bytes (16,) # Apply desync if needed if desync > 0: np.random.seed(42) # Deterministic for reproducibility shifts = np.random.randint(0, desync + 1, size=n_traces) shifted_traces = np.zeros_like(traces) for i in range(n_traces): s = shifts[i] if s > 0: shifted_traces[i, :-s] = traces[i, s:] else: shifted_traces[i] = traces[i] traces = shifted_traces return traces, plaintexts, keys def extract_byte_window(traces: np.ndarray, byte_idx: int) -> np.ndarray: """Extract the 700-sample POI window for a specific byte.""" start, end = BYTE_POI_WINDOWS[byte_idx] return traces[:, start:end] def extract_global_window(traces: np.ndarray) -> np.ndarray: """Extract the global window for multi-task models.""" return traces[:, GLOBAL_WINDOW_START:GLOBAL_WINDOW_END] def normalize_traces(traces: np.ndarray) -> np.ndarray: """Per-trace z-score normalization.""" mean = traces.mean(axis=1, keepdims=True) std = traces.std(axis=1, keepdims=True) std = np.where(std == 0, 1.0, std) return (traces - mean) / std def compute_labels(plaintexts: np.ndarray, keys: np.ndarray, byte_idx: int) -> np.ndarray: """Compute AES S-Box output labels for a specific byte.""" return AES_SBOX[plaintexts[:, byte_idx] ^ keys[:, byte_idx]] # ============================================================================ # Gradient computation # ============================================================================ def make_logit_model(model: tf.keras.Model) -> tf.keras.Model: """ Create a model that outputs pre-activation logits instead of softmax probs. This avoids the vanishing gradient problem with softmax outputs. """ # Find the last dense layer (predictions) last_layer = None for layer in reversed(model.layers): if isinstance(layer, tf.keras.layers.Dense): last_layer = layer break if last_layer is None: return model # Check if it has softmax activation activation = last_layer.get_config().get('activation', 'linear') if activation == 'softmax': # Create a new model that outputs the pre-softmax logits # We need to rebuild with linear activation # Approach: get the input to the last dense layer and apply it with linear activation pre_last = last_layer.input logits = tf.keras.layers.Dense( last_layer.units, activation='linear', name='logits_output', )(pre_last) logit_model = tf.keras.Model(inputs=model.input, outputs=logits) # Copy weights from original last layer logit_model.get_layer('logits_output').set_weights(last_layer.get_weights()) return logit_model elif activation == 'sigmoid': # For multi-bit binary models, sigmoid is fine (gradients don't vanish as badly) return model else: return model def compute_saliency_single_byte( model: tf.keras.Model, traces: np.ndarray, labels: np.ndarray, batch_size: int = 128, is_multibit: bool = False, ) -> np.ndarray: """ Compute input-space saliency map for a single-byte model. Uses the gradient of the correct-class logit (or correct-bit outputs) w.r.t. the input trace, averaged over multiple traces. Args: model: The trained model (or logit model) traces: Input traces, shape (N, time_steps) or (N, time_steps, 1) labels: Correct labels, shape (N,) for identity or (N, 8) for multibit batch_size: Batch size for gradient computation is_multibit: Whether the model uses multi-bit binary encoding Returns: Averaged absolute gradient, shape (time_steps,) """ n_traces = len(traces) input_shape = model.input_shape[1:] # Reshape traces if needed if len(input_shape) == 2 and len(traces.shape) == 2: # CNN: needs (N, time, 1) traces = traces[..., np.newaxis] elif len(input_shape) == 1 and len(traces.shape) == 3: # MLP: needs (N, time) traces = traces.squeeze(-1) all_grads = [] for start in range(0, n_traces, batch_size): end = min(start + batch_size, n_traces) batch_traces = tf.constant(traces[start:end], dtype=tf.float32) batch_labels = labels[start:end] with tf.GradientTape() as tape: tape.watch(batch_traces) outputs = model(batch_traces, training=False) if is_multibit: # Multi-bit: gradient of correct bit predictions # labels shape: (batch, 8), outputs shape: (batch, 8) # Use sum of correct-bit log-odds target = tf.reduce_sum( outputs * tf.constant(batch_labels, dtype=tf.float32), axis=-1 ) else: # Identity encoding: gradient of correct-class logit # labels shape: (batch,), outputs shape: (batch, 256) batch_indices = tf.range(end - start) indices = tf.stack([batch_indices, tf.constant(batch_labels, dtype=tf.int32)], axis=1) target = tf.gather_nd(outputs, indices) loss = tf.reduce_sum(target) grad = tape.gradient(loss, batch_traces) if grad is not None: # Take absolute value and average over batch abs_grad = tf.reduce_mean(tf.abs(grad), axis=0).numpy() all_grads.append(abs_grad) if not all_grads: logging.warning("No gradients computed!") return np.zeros(input_shape[0] if len(input_shape) == 1 else input_shape[0]) # Average across all batches avg_grad = np.mean(all_grads, axis=0) # Squeeze channel dimension if present if len(avg_grad.shape) > 1: avg_grad = avg_grad.squeeze(-1) return avg_grad def compute_saliency_multitask( model: tf.keras.Model, traces_dict: Dict[str, np.ndarray], labels_dict: Dict[str, np.ndarray], batch_size: int = 64, is_multibit: bool = True, ) -> Dict[int, np.ndarray]: """ Compute per-byte saliency maps for multi-task (LMIC/LMIC-TSBN) models. For LMIC models with 16 named inputs, computes the gradient of each byte's output w.r.t. its corresponding input. Returns: Dict mapping byte_idx -> averaged absolute gradient (shape: 700,) """ n_traces = len(list(traces_dict.values())[0]) saliency_maps = {} for byte_idx in range(16): input_name = f"byte_{byte_idx}_input" output_name = f"byte_{byte_idx}" if input_name not in traces_dict: continue byte_traces = traces_dict[input_name] byte_labels = labels_dict.get(output_name, labels_dict.get(f"byte_{byte_idx}", None)) if byte_labels is None: continue all_grads = [] for start in range(0, n_traces, batch_size): end = min(start + batch_size, n_traces) # Build input dict for this batch batch_dict = {} for key, val in traces_dict.items(): batch_dict[key] = tf.constant(val[start:end], dtype=tf.float32) # We only watch the target byte's input target_input = batch_dict[input_name] with tf.GradientTape() as tape: tape.watch(target_input) batch_dict[input_name] = target_input outputs = model(batch_dict, training=False) # Get the output for this byte if isinstance(outputs, dict): byte_output = outputs[output_name] elif isinstance(outputs, (list, tuple)): byte_output = outputs[byte_idx] else: byte_output = outputs batch_labels = byte_labels[start:end] if is_multibit: target = tf.reduce_sum( byte_output * tf.constant(batch_labels, dtype=tf.float32), axis=-1 ) else: batch_indices = tf.range(end - start) indices = tf.stack([batch_indices, tf.constant(batch_labels, dtype=tf.int32)], axis=1) target = tf.gather_nd(byte_output, indices) loss = tf.reduce_sum(target) grad = tape.gradient(loss, target_input) if grad is not None: abs_grad = tf.reduce_mean(tf.abs(grad), axis=0).numpy() all_grads.append(abs_grad) if all_grads: avg_grad = np.mean(all_grads, axis=0) if len(avg_grad.shape) > 1: avg_grad = avg_grad.squeeze(-1) saliency_maps[byte_idx] = avg_grad return saliency_maps def compute_saliency_global_multitask( model: tf.keras.Model, traces: np.ndarray, labels_dict: Dict[str, np.ndarray], batch_size: int = 32, is_multibit: bool = False, ) -> Dict[int, np.ndarray]: """ Compute per-byte saliency maps for global-window multi-task models (HPS, MTAN-Lite) that take a single global input. Returns: Dict mapping byte_idx -> averaged absolute gradient (shape: global_window_size,) """ n_traces = len(traces) if len(traces.shape) == 2: traces = traces[..., np.newaxis] saliency_maps = {} for byte_idx in range(16): output_name = f"byte_{byte_idx}" byte_labels = labels_dict.get(output_name, None) if byte_labels is None: continue all_grads = [] for start in range(0, n_traces, batch_size): end = min(start + batch_size, n_traces) batch_traces = tf.constant(traces[start:end], dtype=tf.float32) batch_labels = byte_labels[start:end] with tf.GradientTape() as tape: tape.watch(batch_traces) outputs = model(batch_traces, training=False) if isinstance(outputs, dict): byte_output = outputs[output_name] elif isinstance(outputs, (list, tuple)): byte_output = outputs[byte_idx] else: byte_output = outputs if is_multibit: target = tf.reduce_sum( byte_output * tf.constant(batch_labels, dtype=tf.float32), axis=-1 ) else: batch_indices = tf.range(end - start) indices = tf.stack([batch_indices, tf.constant(batch_labels, dtype=tf.int32)], axis=1) target = tf.gather_nd(byte_output, indices) loss = tf.reduce_sum(target) grad = tape.gradient(loss, batch_traces) if grad is not None: abs_grad = tf.reduce_mean(tf.abs(grad), axis=0).numpy() all_grads.append(abs_grad) if all_grads: avg_grad = np.mean(all_grads, axis=0) if len(avg_grad.shape) > 1: avg_grad = avg_grad.squeeze(-1) saliency_maps[byte_idx] = avg_grad return saliency_maps # ============================================================================ # Job processing # ============================================================================ def process_single_byte_job( job: dict, h5_path: str, output_dir: str, hf_token: str, n_traces: int = 1000, ) -> dict: """Process a single-byte (MLP/CNN) job.""" name = job['name'] model_type = job['model_type'] desync = job['params']['desync'] hf_url = job['hf_url'] # Parse byte index from HF URL # URL format: .../desync{N}/{model_type}/byte{X} byte_idx = int(hf_url.split('/byte')[-1]) # Determine HF repo and path if 'ascad-v1-models' in hf_url: repo_id = 'lemousehunter/ascad-v1-models' else: repo_id = 'lemousehunter/ascad-mtan-rank0-models' # Extract the path within the repo hf_path_prefix = hf_url.split('/tree/main/')[-1] logging.info(f"Processing {name}: {model_type} byte{byte_idx} desync={desync}") # Download model try: model_path = hf_hub_download( repo_id=repo_id, filename=f"{hf_path_prefix}/model.h5", token=hf_token, ) except Exception as e: logging.error(f"Failed to download model for {name}: {e}") return {"name": name, "status": "download_failed", "error": str(e)} # Load model - detect keras version from h5 file and use appropriate loader model = load_model_smart(model_path, name) if model is None: return {"name": name, "status": "load_failed", "error": "all loaders failed"} # Create logit model (removes softmax for better gradients) logit_model = make_logit_model(model) # Load data traces, plaintexts, keys = load_ascad_data(h5_path, desync=desync, n_traces=n_traces) # Extract byte window and normalize byte_traces = extract_byte_window(traces, byte_idx) byte_traces = normalize_traces(byte_traces) # Compute labels labels = compute_labels(plaintexts, keys, byte_idx) # Compute saliency saliency = compute_saliency_single_byte( logit_model, byte_traces, labels, batch_size=128, is_multibit=False, ) # Save results job_output_dir = os.path.join(output_dir, name) os.makedirs(job_output_dir, exist_ok=True) np.save(os.path.join(job_output_dir, 'gradient_map.npy'), saliency) # Save metadata meta = { "name": name, "model_type": model_type, "byte_idx": byte_idx, "desync": desync, "n_traces": n_traces, "saliency_max": float(np.max(saliency)), "saliency_mean": float(np.mean(saliency)), "saliency_std": float(np.std(saliency)), "input_shape": list(model.input_shape[1:]), "status": "success", } with open(os.path.join(job_output_dir, 'metadata.json'), 'w') as f: json.dump(meta, f, indent=2) # Cleanup del model, logit_model tf.keras.backend.clear_session() gc.collect() logging.info(f" Done: max={saliency.max():.6f}, mean={saliency.mean():.6f}") return meta def process_lmic_job( job: dict, h5_path: str, output_dir: str, hf_token: str, n_traces: int = 500, ) -> dict: """Process an LMIC/LMIC-TSBN multi-task job.""" name = job['name'] model_type = job['model_type'] desync = job['params']['desync'] hf_url = job['hf_url'] # Determine HF repo and path if 'ascad-v1-models' in hf_url: repo_id = 'lemousehunter/ascad-v1-models' elif 'ascad-mtan-rank0-models' in hf_url: repo_id = 'lemousehunter/ascad-mtan-rank0-models' else: logging.error(f"Unknown HF repo for {name}: {hf_url}") return {"name": name, "status": "unknown_repo", "error": hf_url} hf_path_prefix = hf_url.split('/tree/main/')[-1] logging.info(f"Processing {name}: {model_type} desync={desync}") # Download model try: model_path = hf_hub_download( repo_id=repo_id, filename=f"{hf_path_prefix}/model.h5", token=hf_token, ) except Exception as e: logging.error(f"Failed to download model for {name}: {e}") return {"name": name, "status": "download_failed", "error": str(e)} # Load model - detect keras version from h5 file and use appropriate loader model = load_model_smart(model_path, name) if model is None: return {"name": name, "status": "load_failed", "error": "all loaders failed"} # Determine if multi-bit based on output shape # Multi-bit models have 8 outputs per byte, identity has 256 sample_output_name = "byte_0" is_multibit = False if hasattr(model, 'output_shape'): if isinstance(model.output_shape, dict): out_shape = model.output_shape.get(sample_output_name, (None, 256)) is_multibit = (out_shape[-1] == 8) elif isinstance(model.output_shape, list): is_multibit = (model.output_shape[0][-1] == 8) # Load data traces, plaintexts, keys = load_ascad_data(h5_path, desync=desync, n_traces=n_traces) # Prepare LMIC multi-input dict traces_dict = {} labels_dict = {} for byte_idx in range(16): byte_traces = extract_byte_window(traces, byte_idx) byte_traces = normalize_traces(byte_traces) byte_traces = byte_traces[..., np.newaxis] # Add channel dim for CNN traces_dict[f"byte_{byte_idx}_input"] = byte_traces sbox_out = compute_labels(plaintexts, keys, byte_idx) if is_multibit: # Convert to 8-bit binary bits = np.unpackbits(sbox_out.astype(np.uint8).reshape(-1, 1), axis=1) labels_dict[f"byte_{byte_idx}"] = bits else: labels_dict[f"byte_{byte_idx}"] = sbox_out # Compute saliency saliency_maps = compute_saliency_multitask( model, traces_dict, labels_dict, batch_size=64, is_multibit=is_multibit, ) # Save results job_output_dir = os.path.join(output_dir, name) os.makedirs(job_output_dir, exist_ok=True) for byte_idx, saliency in saliency_maps.items(): np.save(os.path.join(job_output_dir, f'gradient_map_byte{byte_idx}.npy'), saliency) # Save metadata meta = { "name": name, "model_type": model_type, "desync": desync, "n_traces": n_traces, "is_multibit": is_multibit, "bytes_computed": list(saliency_maps.keys()), "saliency_stats": { str(b): { "max": float(np.max(s)), "mean": float(np.mean(s)), "std": float(np.std(s)), } for b, s in saliency_maps.items() }, "status": "success", } with open(os.path.join(job_output_dir, 'metadata.json'), 'w') as f: json.dump(meta, f, indent=2) # Cleanup del model tf.keras.backend.clear_session() gc.collect() logging.info(f" Done: {len(saliency_maps)} bytes computed") return meta def process_global_mtl_job( job: dict, h5_path: str, output_dir: str, hf_token: str, n_traces: int = 300, ) -> dict: """Process a global-window multi-task job (HPS, MTAN-Lite).""" name = job['name'] model_type = job['model_type'] desync = job['params']['desync'] hf_url = job['hf_url'] if not hf_url: return {"name": name, "status": "no_hf_url"} # Determine HF repo and path if 'ascad-v1-models' in hf_url: repo_id = 'lemousehunter/ascad-v1-models' elif 'ascad-mtan-rank0-models' in hf_url: repo_id = 'lemousehunter/ascad-mtan-rank0-models' else: return {"name": name, "status": "unknown_repo", "error": hf_url} hf_path_prefix = hf_url.split('/tree/main/')[-1] logging.info(f"Processing {name}: {model_type} desync={desync}") # Download model try: model_path = hf_hub_download( repo_id=repo_id, filename=f"{hf_path_prefix}/model.h5", token=hf_token, ) except Exception as e: logging.error(f"Failed to download model for {name}: {e}") return {"name": name, "status": "download_failed", "error": str(e)} # Load model try: model = tf.keras.models.load_model( model_path, custom_objects=CUSTOM_OBJECTS, compile=False ) except Exception as e: logging.error(f"Failed to load model for {name}: {e}") return {"name": name, "status": "load_failed", "error": str(e)} # Load data traces, plaintexts, keys = load_ascad_data(h5_path, desync=desync, n_traces=n_traces) # Extract global window and normalize global_traces = extract_global_window(traces) global_traces = normalize_traces(global_traces) # Prepare labels dict labels_dict = {} for byte_idx in range(16): sbox_out = compute_labels(plaintexts, keys, byte_idx) labels_dict[f"byte_{byte_idx}"] = sbox_out # Compute saliency saliency_maps = compute_saliency_global_multitask( model, global_traces, labels_dict, batch_size=32, is_multibit=False, ) # Save results job_output_dir = os.path.join(output_dir, name) os.makedirs(job_output_dir, exist_ok=True) for byte_idx, saliency in saliency_maps.items(): np.save(os.path.join(job_output_dir, f'gradient_map_byte{byte_idx}.npy'), saliency) # Save metadata meta = { "name": name, "model_type": model_type, "desync": desync, "n_traces": n_traces, "bytes_computed": list(saliency_maps.keys()), "saliency_stats": { str(b): { "max": float(np.max(s)), "mean": float(np.mean(s)), "std": float(np.std(s)), } for b, s in saliency_maps.items() }, "status": "success", } with open(os.path.join(job_output_dir, 'metadata.json'), 'w') as f: json.dump(meta, f, indent=2) # Cleanup del model tf.keras.backend.clear_session() gc.collect() logging.info(f" Done: {len(saliency_maps)} bytes computed") return meta # ============================================================================ # Main # ============================================================================ def main(): parser = argparse.ArgumentParser(description="Generate gradient saliency maps") parser.add_argument("--jobs-yaml", required=True, help="Path to orchestrator_jobs_updated.yaml") parser.add_argument("--h5-path", required=True, help="Path to ASCAD HDF5 dataset") parser.add_argument("--output-dir", default="./gradient_maps", help="Output directory") parser.add_argument("--hf-token", required=True, help="HuggingFace token") parser.add_argument("--n-traces", type=int, default=1000, help="Number of attack traces") parser.add_argument("--filter-type", default=None, help="Filter by model type") parser.add_argument("--filter-name", default=None, help="Filter by job name (substring)") parser.add_argument("--skip-existing", action="store_true", help="Skip jobs with existing output") args = parser.parse_args() logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', handlers=[ logging.FileHandler(os.path.join(args.output_dir, 'generation.log')), logging.StreamHandler(), ] ) import yaml with open(args.jobs_yaml) as f: data = yaml.safe_load(f) jobs = data['jobs'] # Filter to completed jobs with HF URLs jobs = [j for j in jobs if j.get('status') == 'completed' and j.get('hf_url')] if args.filter_type: jobs = [j for j in jobs if j['model_type'] == args.filter_type] if args.filter_name: jobs = [j for j in jobs if args.filter_name in j['name']] os.makedirs(args.output_dir, exist_ok=True) logging.info(f"Processing {len(jobs)} jobs") logging.info(f"Output directory: {args.output_dir}") logging.info(f"N traces: {args.n_traces}") results = [] for i, job in enumerate(jobs): name = job['name'] # Skip existing if args.skip_existing: job_dir = os.path.join(args.output_dir, name) meta_path = os.path.join(job_dir, 'metadata.json') if os.path.exists(meta_path): with open(meta_path) as f: existing = json.load(f) if existing.get('status') == 'success': logging.info(f"[{i+1}/{len(jobs)}] Skipping {name} (already done)") results.append(existing) continue logging.info(f"[{i+1}/{len(jobs)}] {name}") try: if job['model_type'] in ('mlp', 'cnn'): result = process_single_byte_job( job, args.h5_path, args.output_dir, args.hf_token, args.n_traces ) elif job['model_type'] in ('lmic', 'lmic_tsbn'): result = process_lmic_job( job, args.h5_path, args.output_dir, args.hf_token, n_traces=min(args.n_traces, 500), ) elif job['model_type'] in ('hps', 'mtan_lite'): result = process_global_mtl_job( job, args.h5_path, args.output_dir, args.hf_token, n_traces=min(args.n_traces, 300), ) else: result = {"name": name, "status": "unknown_type", "model_type": job['model_type']} except Exception as e: logging.error(f" FAILED: {e}") traceback.print_exc() result = {"name": name, "status": "error", "error": str(e)} tf.keras.backend.clear_session() gc.collect() results.append(result) # Save summary summary_path = os.path.join(args.output_dir, 'summary.json') with open(summary_path, 'w') as f: json.dump(results, f, indent=2) # Print summary success = sum(1 for r in results if r.get('status') == 'success') failed = len(results) - success logging.info(f"\nDone! Success: {success}/{len(results)}, Failed: {failed}") if __name__ == "__main__": main()