ascad-training-pipeline / tools /generate_gradient_maps.py
lemousehunter's picture
Add gradient saliency map generation and visualization tools
d078926
#!/usr/bin/env python3
"""
Gradient Saliency Map Generator for ASCAD Models
=================================================
Computes input-space gradient saliency maps for all completed training runs.
For each model, this script:
1. Downloads the model from HuggingFace
2. Loads attack traces from the ASCAD dataset
3. Computes gradients of the model output w.r.t. input traces
4. Saves the raw gradient data (numpy) and visualization (PNG)
Approach:
- For models with softmax output: create a logit model (remove softmax) and
compute gradient of the correct-class logit w.r.t. input
- For multi-bit binary models: compute gradient of the sum of correct-bit
sigmoid outputs w.r.t. input
- Average absolute gradients over N attack traces for stable estimates
Output per model:
- gradient_map.npy: Raw averaged |gradient| array (shape: input_length,)
- gradient_map.png: Visualization with SNR overlay for comparison
References:
- Masure, Dumas, Prouff (2020) "Gradient Visualization for General
Characterization in Profiling Attacks" (TCHES 2020)
- Simonyan, Vedaldi, Zisserman (2014) "Deep Inside Convolutional Networks:
Visualising Image Classification Models and Saliency Maps"
"""
import os
import sys
import json
import logging
import argparse
import gc
import traceback
from pathlib import Path
from typing import Dict, List, Optional, Tuple
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import tensorflow as tf
from huggingface_hub import hf_hub_download
import h5py
import keras
try:
import tf_keras
HAS_TF_KERAS = True
except ImportError:
HAS_TF_KERAS = False
# ============================================================================
# Constants (from pipeline)
# ============================================================================
BYTE_POI_WINDOWS = {
0: (30838, 31538),
1: (24525, 25225),
2: (45400, 46100),
3: (32824, 33524),
4: (47508, 48208),
5: (41258, 41958),
6: (37094, 37794),
7: (35018, 35718),
8: (26631, 27331),
9: (39145, 39845),
10: (28766, 29466),
11: (43333, 44033),
12: (20418, 21118),
13: (22499, 23199),
14: (49571, 50271),
15: (18363, 19063),
}
GLOBAL_WINDOW_START = 18000
GLOBAL_WINDOW_END = 50272
GLOBAL_WINDOW_SIZE = GLOBAL_WINDOW_END - GLOBAL_WINDOW_START # 32,272
WINDOW_SIZE = 700
NUM_CLASSES = 256
AES_SBOX = np.array([
0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16,
], dtype=np.uint8)
# ============================================================================
# Custom layers for model loading
# ============================================================================
class SigmoidGate(tf.keras.layers.Layer):
"""Learnable per-filter sigmoid gate."""
def __init__(self, num_filters: int, init_value: float = 0.0, **kwargs):
super().__init__(**kwargs)
self.num_filters = num_filters
self.init_value = init_value
def build(self, input_shape):
self.gate = self.add_weight(
name="gate",
shape=(self.num_filters,),
initializer=tf.keras.initializers.Constant(self.init_value),
trainable=True,
)
super().build(input_shape)
def call(self, inputs):
return inputs * tf.sigmoid(self.gate)
def get_config(self):
config = super().get_config()
config.update({"num_filters": self.num_filters, "init_value": self.init_value})
return config
class FocalCategoricalCrossentropy(tf.keras.losses.Loss):
"""Focal loss for categorical classification."""
def __init__(self, gamma=2.0, label_smoothing=0.0, from_logits=False, **kwargs):
super().__init__(**kwargs)
self.gamma = gamma
self.label_smoothing = label_smoothing
self.from_logits = from_logits
def call(self, y_true, y_pred):
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
ce = -y_true * tf.math.log(y_pred)
weight = tf.pow(1.0 - y_pred, self.gamma) * y_true
return tf.reduce_sum(weight * ce, axis=-1)
def get_config(self):
config = super().get_config()
config.update({
"gamma": self.gamma,
"label_smoothing": self.label_smoothing,
"from_logits": self.from_logits,
})
return config
class SoftAttentionBlock(tf.keras.layers.Layer):
"""Soft attention block for MTAN-Lite (channels-based) and SNRMTAN (filters-based).
Handles both constructor signatures:
- MTAN-Lite: SoftAttentionBlock(channels=N, bottleneck_ratio=4)
- SNRMTAN: SoftAttentionBlock(filters=N, task_id=X, block_id=Y)
"""
def __init__(self, channels=None, filters=None, bottleneck_ratio=4, task_id=0, block_id=0, **kwargs):
super().__init__(**kwargs)
# Support both signatures
self.channels = channels or filters
self.filters = self.channels # alias for get_config compatibility
self.bottleneck_ratio = bottleneck_ratio
self.task_id = task_id
self.block_id = block_id
self.bottleneck = max(self.channels // bottleneck_ratio, 16)
def build(self, input_shape):
self.conv_down = tf.keras.layers.Conv1D(
self.bottleneck, 1, padding='same', activation='relu',
kernel_initializer='he_uniform',
)
self.conv_up = tf.keras.layers.Conv1D(
self.channels, 1, padding='same', activation='sigmoid',
kernel_initializer='glorot_uniform',
)
super().build(input_shape)
def call(self, inputs, training=None):
att = self.conv_down(inputs)
att = self.conv_up(att)
return inputs * att
def get_config(self):
config = super().get_config()
config.update({
"channels": self.channels,
"filters": self.filters,
"bottleneck_ratio": self.bottleneck_ratio,
"task_id": self.task_id,
"block_id": self.block_id,
})
return config
CUSTOM_OBJECTS = {
'SigmoidGate': SigmoidGate,
'FocalCategoricalCrossentropy': FocalCategoricalCrossentropy,
'SoftAttentionBlock': SoftAttentionBlock,
}
def load_model_smart(model_path: str, name: str):
"""
Smart model loader that detects the keras version from the h5 file
and uses the appropriate loading API.
- Models saved with keras >= 3.x: use keras.saving.load_model()
- Models saved with keras 2.x (tf.keras): use tf_keras.models.load_model()
"""
# Detect keras version from h5 file attributes
keras_version = "unknown"
try:
with h5py.File(model_path, 'r') as f:
if 'keras_version' in f.attrs:
kv = f.attrs['keras_version']
if isinstance(kv, bytes):
kv = kv.decode()
keras_version = kv
except Exception:
pass
is_keras2 = keras_version.startswith('2.')
logging.debug(f" Model {name} saved with keras {keras_version}, is_keras2={is_keras2}")
# Try loading with appropriate loader
if is_keras2 and HAS_TF_KERAS:
# Use tf_keras for legacy Keras 2.x models
try:
model = tf_keras.models.load_model(
model_path, custom_objects=CUSTOM_OBJECTS, compile=False
)
return model
except Exception as e:
logging.warning(f" tf_keras load failed for {name}: {e}")
# Fall through to try keras 3
# Try Keras 3 native loader
try:
model = keras.saving.load_model(
model_path, safe_mode=False, compile=False,
custom_objects=CUSTOM_OBJECTS,
)
return model
except Exception as e:
logging.warning(f" keras.saving load failed for {name}: {e}")
# Final fallback: tf.keras (may work for some models)
try:
model = tf.keras.models.load_model(
model_path, custom_objects=CUSTOM_OBJECTS, compile=False
)
return model
except Exception as e:
logging.error(f"Failed to load model for {name} (all loaders failed): {e}")
return None
# ============================================================================
# Data loading
# ============================================================================
def load_ascad_data(h5_path: str, desync: int = 0, n_traces: int = 1000):
"""Load attack traces and metadata from ASCAD HDF5 file.
Supports both the raw traces file (ATMega8515_raw_traces.h5) with
top-level 'traces' and 'metadata' datasets, and the windowed ASCAD.h5
file with 'Attack_traces/traces' and 'Attack_traces/metadata' groups.
For the raw file, attack traces are the last 10,000 (indices 50000:60000).
"""
with h5py.File(h5_path, 'r') as f:
if 'Attack_traces' in f:
# Windowed ASCAD.h5 format
traces = np.array(f['Attack_traces/traces'][:n_traces], dtype=np.float64)
metadata = f['Attack_traces/metadata'][:n_traces]
else:
# Raw traces format: attack set is last 10,000 traces
attack_start = 50000
traces = np.array(
f['traces'][attack_start:attack_start + n_traces], dtype=np.float64
)
metadata = f['metadata'][attack_start:attack_start + n_traces]
plaintexts = np.array([m[0] for m in metadata]) # plaintext bytes (16,)
keys = np.array([m[2] for m in metadata]) # key bytes (16,)
# Apply desync if needed
if desync > 0:
np.random.seed(42) # Deterministic for reproducibility
shifts = np.random.randint(0, desync + 1, size=n_traces)
shifted_traces = np.zeros_like(traces)
for i in range(n_traces):
s = shifts[i]
if s > 0:
shifted_traces[i, :-s] = traces[i, s:]
else:
shifted_traces[i] = traces[i]
traces = shifted_traces
return traces, plaintexts, keys
def extract_byte_window(traces: np.ndarray, byte_idx: int) -> np.ndarray:
"""Extract the 700-sample POI window for a specific byte."""
start, end = BYTE_POI_WINDOWS[byte_idx]
return traces[:, start:end]
def extract_global_window(traces: np.ndarray) -> np.ndarray:
"""Extract the global window for multi-task models."""
return traces[:, GLOBAL_WINDOW_START:GLOBAL_WINDOW_END]
def normalize_traces(traces: np.ndarray) -> np.ndarray:
"""Per-trace z-score normalization."""
mean = traces.mean(axis=1, keepdims=True)
std = traces.std(axis=1, keepdims=True)
std = np.where(std == 0, 1.0, std)
return (traces - mean) / std
def compute_labels(plaintexts: np.ndarray, keys: np.ndarray, byte_idx: int) -> np.ndarray:
"""Compute AES S-Box output labels for a specific byte."""
return AES_SBOX[plaintexts[:, byte_idx] ^ keys[:, byte_idx]]
# ============================================================================
# Gradient computation
# ============================================================================
def make_logit_model(model: tf.keras.Model) -> tf.keras.Model:
"""
Create a model that outputs pre-activation logits instead of softmax probs.
This avoids the vanishing gradient problem with softmax outputs.
"""
# Find the last dense layer (predictions)
last_layer = None
for layer in reversed(model.layers):
if isinstance(layer, tf.keras.layers.Dense):
last_layer = layer
break
if last_layer is None:
return model
# Check if it has softmax activation
activation = last_layer.get_config().get('activation', 'linear')
if activation == 'softmax':
# Create a new model that outputs the pre-softmax logits
# We need to rebuild with linear activation
# Approach: get the input to the last dense layer and apply it with linear activation
pre_last = last_layer.input
logits = tf.keras.layers.Dense(
last_layer.units,
activation='linear',
name='logits_output',
)(pre_last)
logit_model = tf.keras.Model(inputs=model.input, outputs=logits)
# Copy weights from original last layer
logit_model.get_layer('logits_output').set_weights(last_layer.get_weights())
return logit_model
elif activation == 'sigmoid':
# For multi-bit binary models, sigmoid is fine (gradients don't vanish as badly)
return model
else:
return model
def compute_saliency_single_byte(
model: tf.keras.Model,
traces: np.ndarray,
labels: np.ndarray,
batch_size: int = 128,
is_multibit: bool = False,
) -> np.ndarray:
"""
Compute input-space saliency map for a single-byte model.
Uses the gradient of the correct-class logit (or correct-bit outputs)
w.r.t. the input trace, averaged over multiple traces.
Args:
model: The trained model (or logit model)
traces: Input traces, shape (N, time_steps) or (N, time_steps, 1)
labels: Correct labels, shape (N,) for identity or (N, 8) for multibit
batch_size: Batch size for gradient computation
is_multibit: Whether the model uses multi-bit binary encoding
Returns:
Averaged absolute gradient, shape (time_steps,)
"""
n_traces = len(traces)
input_shape = model.input_shape[1:]
# Reshape traces if needed
if len(input_shape) == 2 and len(traces.shape) == 2:
# CNN: needs (N, time, 1)
traces = traces[..., np.newaxis]
elif len(input_shape) == 1 and len(traces.shape) == 3:
# MLP: needs (N, time)
traces = traces.squeeze(-1)
all_grads = []
for start in range(0, n_traces, batch_size):
end = min(start + batch_size, n_traces)
batch_traces = tf.constant(traces[start:end], dtype=tf.float32)
batch_labels = labels[start:end]
with tf.GradientTape() as tape:
tape.watch(batch_traces)
outputs = model(batch_traces, training=False)
if is_multibit:
# Multi-bit: gradient of correct bit predictions
# labels shape: (batch, 8), outputs shape: (batch, 8)
# Use sum of correct-bit log-odds
target = tf.reduce_sum(
outputs * tf.constant(batch_labels, dtype=tf.float32),
axis=-1
)
else:
# Identity encoding: gradient of correct-class logit
# labels shape: (batch,), outputs shape: (batch, 256)
batch_indices = tf.range(end - start)
indices = tf.stack([batch_indices, tf.constant(batch_labels, dtype=tf.int32)], axis=1)
target = tf.gather_nd(outputs, indices)
loss = tf.reduce_sum(target)
grad = tape.gradient(loss, batch_traces)
if grad is not None:
# Take absolute value and average over batch
abs_grad = tf.reduce_mean(tf.abs(grad), axis=0).numpy()
all_grads.append(abs_grad)
if not all_grads:
logging.warning("No gradients computed!")
return np.zeros(input_shape[0] if len(input_shape) == 1 else input_shape[0])
# Average across all batches
avg_grad = np.mean(all_grads, axis=0)
# Squeeze channel dimension if present
if len(avg_grad.shape) > 1:
avg_grad = avg_grad.squeeze(-1)
return avg_grad
def compute_saliency_multitask(
model: tf.keras.Model,
traces_dict: Dict[str, np.ndarray],
labels_dict: Dict[str, np.ndarray],
batch_size: int = 64,
is_multibit: bool = True,
) -> Dict[int, np.ndarray]:
"""
Compute per-byte saliency maps for multi-task (LMIC/LMIC-TSBN) models.
For LMIC models with 16 named inputs, computes the gradient of each
byte's output w.r.t. its corresponding input.
Returns:
Dict mapping byte_idx -> averaged absolute gradient (shape: 700,)
"""
n_traces = len(list(traces_dict.values())[0])
saliency_maps = {}
for byte_idx in range(16):
input_name = f"byte_{byte_idx}_input"
output_name = f"byte_{byte_idx}"
if input_name not in traces_dict:
continue
byte_traces = traces_dict[input_name]
byte_labels = labels_dict.get(output_name, labels_dict.get(f"byte_{byte_idx}", None))
if byte_labels is None:
continue
all_grads = []
for start in range(0, n_traces, batch_size):
end = min(start + batch_size, n_traces)
# Build input dict for this batch
batch_dict = {}
for key, val in traces_dict.items():
batch_dict[key] = tf.constant(val[start:end], dtype=tf.float32)
# We only watch the target byte's input
target_input = batch_dict[input_name]
with tf.GradientTape() as tape:
tape.watch(target_input)
batch_dict[input_name] = target_input
outputs = model(batch_dict, training=False)
# Get the output for this byte
if isinstance(outputs, dict):
byte_output = outputs[output_name]
elif isinstance(outputs, (list, tuple)):
byte_output = outputs[byte_idx]
else:
byte_output = outputs
batch_labels = byte_labels[start:end]
if is_multibit:
target = tf.reduce_sum(
byte_output * tf.constant(batch_labels, dtype=tf.float32),
axis=-1
)
else:
batch_indices = tf.range(end - start)
indices = tf.stack([batch_indices, tf.constant(batch_labels, dtype=tf.int32)], axis=1)
target = tf.gather_nd(byte_output, indices)
loss = tf.reduce_sum(target)
grad = tape.gradient(loss, target_input)
if grad is not None:
abs_grad = tf.reduce_mean(tf.abs(grad), axis=0).numpy()
all_grads.append(abs_grad)
if all_grads:
avg_grad = np.mean(all_grads, axis=0)
if len(avg_grad.shape) > 1:
avg_grad = avg_grad.squeeze(-1)
saliency_maps[byte_idx] = avg_grad
return saliency_maps
def compute_saliency_global_multitask(
model: tf.keras.Model,
traces: np.ndarray,
labels_dict: Dict[str, np.ndarray],
batch_size: int = 32,
is_multibit: bool = False,
) -> Dict[int, np.ndarray]:
"""
Compute per-byte saliency maps for global-window multi-task models
(HPS, MTAN-Lite) that take a single global input.
Returns:
Dict mapping byte_idx -> averaged absolute gradient (shape: global_window_size,)
"""
n_traces = len(traces)
if len(traces.shape) == 2:
traces = traces[..., np.newaxis]
saliency_maps = {}
for byte_idx in range(16):
output_name = f"byte_{byte_idx}"
byte_labels = labels_dict.get(output_name, None)
if byte_labels is None:
continue
all_grads = []
for start in range(0, n_traces, batch_size):
end = min(start + batch_size, n_traces)
batch_traces = tf.constant(traces[start:end], dtype=tf.float32)
batch_labels = byte_labels[start:end]
with tf.GradientTape() as tape:
tape.watch(batch_traces)
outputs = model(batch_traces, training=False)
if isinstance(outputs, dict):
byte_output = outputs[output_name]
elif isinstance(outputs, (list, tuple)):
byte_output = outputs[byte_idx]
else:
byte_output = outputs
if is_multibit:
target = tf.reduce_sum(
byte_output * tf.constant(batch_labels, dtype=tf.float32),
axis=-1
)
else:
batch_indices = tf.range(end - start)
indices = tf.stack([batch_indices, tf.constant(batch_labels, dtype=tf.int32)], axis=1)
target = tf.gather_nd(byte_output, indices)
loss = tf.reduce_sum(target)
grad = tape.gradient(loss, batch_traces)
if grad is not None:
abs_grad = tf.reduce_mean(tf.abs(grad), axis=0).numpy()
all_grads.append(abs_grad)
if all_grads:
avg_grad = np.mean(all_grads, axis=0)
if len(avg_grad.shape) > 1:
avg_grad = avg_grad.squeeze(-1)
saliency_maps[byte_idx] = avg_grad
return saliency_maps
# ============================================================================
# Job processing
# ============================================================================
def process_single_byte_job(
job: dict,
h5_path: str,
output_dir: str,
hf_token: str,
n_traces: int = 1000,
) -> dict:
"""Process a single-byte (MLP/CNN) job."""
name = job['name']
model_type = job['model_type']
desync = job['params']['desync']
hf_url = job['hf_url']
# Parse byte index from HF URL
# URL format: .../desync{N}/{model_type}/byte{X}
byte_idx = int(hf_url.split('/byte')[-1])
# Determine HF repo and path
if 'ascad-v1-models' in hf_url:
repo_id = 'lemousehunter/ascad-v1-models'
else:
repo_id = 'lemousehunter/ascad-mtan-rank0-models'
# Extract the path within the repo
hf_path_prefix = hf_url.split('/tree/main/')[-1]
logging.info(f"Processing {name}: {model_type} byte{byte_idx} desync={desync}")
# Download model
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=f"{hf_path_prefix}/model.h5",
token=hf_token,
)
except Exception as e:
logging.error(f"Failed to download model for {name}: {e}")
return {"name": name, "status": "download_failed", "error": str(e)}
# Load model - detect keras version from h5 file and use appropriate loader
model = load_model_smart(model_path, name)
if model is None:
return {"name": name, "status": "load_failed", "error": "all loaders failed"}
# Create logit model (removes softmax for better gradients)
logit_model = make_logit_model(model)
# Load data
traces, plaintexts, keys = load_ascad_data(h5_path, desync=desync, n_traces=n_traces)
# Extract byte window and normalize
byte_traces = extract_byte_window(traces, byte_idx)
byte_traces = normalize_traces(byte_traces)
# Compute labels
labels = compute_labels(plaintexts, keys, byte_idx)
# Compute saliency
saliency = compute_saliency_single_byte(
logit_model, byte_traces, labels,
batch_size=128, is_multibit=False,
)
# Save results
job_output_dir = os.path.join(output_dir, name)
os.makedirs(job_output_dir, exist_ok=True)
np.save(os.path.join(job_output_dir, 'gradient_map.npy'), saliency)
# Save metadata
meta = {
"name": name,
"model_type": model_type,
"byte_idx": byte_idx,
"desync": desync,
"n_traces": n_traces,
"saliency_max": float(np.max(saliency)),
"saliency_mean": float(np.mean(saliency)),
"saliency_std": float(np.std(saliency)),
"input_shape": list(model.input_shape[1:]),
"status": "success",
}
with open(os.path.join(job_output_dir, 'metadata.json'), 'w') as f:
json.dump(meta, f, indent=2)
# Cleanup
del model, logit_model
tf.keras.backend.clear_session()
gc.collect()
logging.info(f" Done: max={saliency.max():.6f}, mean={saliency.mean():.6f}")
return meta
def process_lmic_job(
job: dict,
h5_path: str,
output_dir: str,
hf_token: str,
n_traces: int = 500,
) -> dict:
"""Process an LMIC/LMIC-TSBN multi-task job."""
name = job['name']
model_type = job['model_type']
desync = job['params']['desync']
hf_url = job['hf_url']
# Determine HF repo and path
if 'ascad-v1-models' in hf_url:
repo_id = 'lemousehunter/ascad-v1-models'
elif 'ascad-mtan-rank0-models' in hf_url:
repo_id = 'lemousehunter/ascad-mtan-rank0-models'
else:
logging.error(f"Unknown HF repo for {name}: {hf_url}")
return {"name": name, "status": "unknown_repo", "error": hf_url}
hf_path_prefix = hf_url.split('/tree/main/')[-1]
logging.info(f"Processing {name}: {model_type} desync={desync}")
# Download model
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=f"{hf_path_prefix}/model.h5",
token=hf_token,
)
except Exception as e:
logging.error(f"Failed to download model for {name}: {e}")
return {"name": name, "status": "download_failed", "error": str(e)}
# Load model - detect keras version from h5 file and use appropriate loader
model = load_model_smart(model_path, name)
if model is None:
return {"name": name, "status": "load_failed", "error": "all loaders failed"}
# Determine if multi-bit based on output shape
# Multi-bit models have 8 outputs per byte, identity has 256
sample_output_name = "byte_0"
is_multibit = False
if hasattr(model, 'output_shape'):
if isinstance(model.output_shape, dict):
out_shape = model.output_shape.get(sample_output_name, (None, 256))
is_multibit = (out_shape[-1] == 8)
elif isinstance(model.output_shape, list):
is_multibit = (model.output_shape[0][-1] == 8)
# Load data
traces, plaintexts, keys = load_ascad_data(h5_path, desync=desync, n_traces=n_traces)
# Prepare LMIC multi-input dict
traces_dict = {}
labels_dict = {}
for byte_idx in range(16):
byte_traces = extract_byte_window(traces, byte_idx)
byte_traces = normalize_traces(byte_traces)
byte_traces = byte_traces[..., np.newaxis] # Add channel dim for CNN
traces_dict[f"byte_{byte_idx}_input"] = byte_traces
sbox_out = compute_labels(plaintexts, keys, byte_idx)
if is_multibit:
# Convert to 8-bit binary
bits = np.unpackbits(sbox_out.astype(np.uint8).reshape(-1, 1), axis=1)
labels_dict[f"byte_{byte_idx}"] = bits
else:
labels_dict[f"byte_{byte_idx}"] = sbox_out
# Compute saliency
saliency_maps = compute_saliency_multitask(
model, traces_dict, labels_dict,
batch_size=64, is_multibit=is_multibit,
)
# Save results
job_output_dir = os.path.join(output_dir, name)
os.makedirs(job_output_dir, exist_ok=True)
for byte_idx, saliency in saliency_maps.items():
np.save(os.path.join(job_output_dir, f'gradient_map_byte{byte_idx}.npy'), saliency)
# Save metadata
meta = {
"name": name,
"model_type": model_type,
"desync": desync,
"n_traces": n_traces,
"is_multibit": is_multibit,
"bytes_computed": list(saliency_maps.keys()),
"saliency_stats": {
str(b): {
"max": float(np.max(s)),
"mean": float(np.mean(s)),
"std": float(np.std(s)),
}
for b, s in saliency_maps.items()
},
"status": "success",
}
with open(os.path.join(job_output_dir, 'metadata.json'), 'w') as f:
json.dump(meta, f, indent=2)
# Cleanup
del model
tf.keras.backend.clear_session()
gc.collect()
logging.info(f" Done: {len(saliency_maps)} bytes computed")
return meta
def process_global_mtl_job(
job: dict,
h5_path: str,
output_dir: str,
hf_token: str,
n_traces: int = 300,
) -> dict:
"""Process a global-window multi-task job (HPS, MTAN-Lite)."""
name = job['name']
model_type = job['model_type']
desync = job['params']['desync']
hf_url = job['hf_url']
if not hf_url:
return {"name": name, "status": "no_hf_url"}
# Determine HF repo and path
if 'ascad-v1-models' in hf_url:
repo_id = 'lemousehunter/ascad-v1-models'
elif 'ascad-mtan-rank0-models' in hf_url:
repo_id = 'lemousehunter/ascad-mtan-rank0-models'
else:
return {"name": name, "status": "unknown_repo", "error": hf_url}
hf_path_prefix = hf_url.split('/tree/main/')[-1]
logging.info(f"Processing {name}: {model_type} desync={desync}")
# Download model
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=f"{hf_path_prefix}/model.h5",
token=hf_token,
)
except Exception as e:
logging.error(f"Failed to download model for {name}: {e}")
return {"name": name, "status": "download_failed", "error": str(e)}
# Load model
try:
model = tf.keras.models.load_model(
model_path, custom_objects=CUSTOM_OBJECTS, compile=False
)
except Exception as e:
logging.error(f"Failed to load model for {name}: {e}")
return {"name": name, "status": "load_failed", "error": str(e)}
# Load data
traces, plaintexts, keys = load_ascad_data(h5_path, desync=desync, n_traces=n_traces)
# Extract global window and normalize
global_traces = extract_global_window(traces)
global_traces = normalize_traces(global_traces)
# Prepare labels dict
labels_dict = {}
for byte_idx in range(16):
sbox_out = compute_labels(plaintexts, keys, byte_idx)
labels_dict[f"byte_{byte_idx}"] = sbox_out
# Compute saliency
saliency_maps = compute_saliency_global_multitask(
model, global_traces, labels_dict,
batch_size=32, is_multibit=False,
)
# Save results
job_output_dir = os.path.join(output_dir, name)
os.makedirs(job_output_dir, exist_ok=True)
for byte_idx, saliency in saliency_maps.items():
np.save(os.path.join(job_output_dir, f'gradient_map_byte{byte_idx}.npy'), saliency)
# Save metadata
meta = {
"name": name,
"model_type": model_type,
"desync": desync,
"n_traces": n_traces,
"bytes_computed": list(saliency_maps.keys()),
"saliency_stats": {
str(b): {
"max": float(np.max(s)),
"mean": float(np.mean(s)),
"std": float(np.std(s)),
}
for b, s in saliency_maps.items()
},
"status": "success",
}
with open(os.path.join(job_output_dir, 'metadata.json'), 'w') as f:
json.dump(meta, f, indent=2)
# Cleanup
del model
tf.keras.backend.clear_session()
gc.collect()
logging.info(f" Done: {len(saliency_maps)} bytes computed")
return meta
# ============================================================================
# Main
# ============================================================================
def main():
parser = argparse.ArgumentParser(description="Generate gradient saliency maps")
parser.add_argument("--jobs-yaml", required=True, help="Path to orchestrator_jobs_updated.yaml")
parser.add_argument("--h5-path", required=True, help="Path to ASCAD HDF5 dataset")
parser.add_argument("--output-dir", default="./gradient_maps", help="Output directory")
parser.add_argument("--hf-token", required=True, help="HuggingFace token")
parser.add_argument("--n-traces", type=int, default=1000, help="Number of attack traces")
parser.add_argument("--filter-type", default=None, help="Filter by model type")
parser.add_argument("--filter-name", default=None, help="Filter by job name (substring)")
parser.add_argument("--skip-existing", action="store_true", help="Skip jobs with existing output")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[
logging.FileHandler(os.path.join(args.output_dir, 'generation.log')),
logging.StreamHandler(),
]
)
import yaml
with open(args.jobs_yaml) as f:
data = yaml.safe_load(f)
jobs = data['jobs']
# Filter to completed jobs with HF URLs
jobs = [j for j in jobs if j.get('status') == 'completed' and j.get('hf_url')]
if args.filter_type:
jobs = [j for j in jobs if j['model_type'] == args.filter_type]
if args.filter_name:
jobs = [j for j in jobs if args.filter_name in j['name']]
os.makedirs(args.output_dir, exist_ok=True)
logging.info(f"Processing {len(jobs)} jobs")
logging.info(f"Output directory: {args.output_dir}")
logging.info(f"N traces: {args.n_traces}")
results = []
for i, job in enumerate(jobs):
name = job['name']
# Skip existing
if args.skip_existing:
job_dir = os.path.join(args.output_dir, name)
meta_path = os.path.join(job_dir, 'metadata.json')
if os.path.exists(meta_path):
with open(meta_path) as f:
existing = json.load(f)
if existing.get('status') == 'success':
logging.info(f"[{i+1}/{len(jobs)}] Skipping {name} (already done)")
results.append(existing)
continue
logging.info(f"[{i+1}/{len(jobs)}] {name}")
try:
if job['model_type'] in ('mlp', 'cnn'):
result = process_single_byte_job(
job, args.h5_path, args.output_dir, args.hf_token, args.n_traces
)
elif job['model_type'] in ('lmic', 'lmic_tsbn'):
result = process_lmic_job(
job, args.h5_path, args.output_dir, args.hf_token,
n_traces=min(args.n_traces, 500),
)
elif job['model_type'] in ('hps', 'mtan_lite'):
result = process_global_mtl_job(
job, args.h5_path, args.output_dir, args.hf_token,
n_traces=min(args.n_traces, 300),
)
else:
result = {"name": name, "status": "unknown_type", "model_type": job['model_type']}
except Exception as e:
logging.error(f" FAILED: {e}")
traceback.print_exc()
result = {"name": name, "status": "error", "error": str(e)}
tf.keras.backend.clear_session()
gc.collect()
results.append(result)
# Save summary
summary_path = os.path.join(args.output_dir, 'summary.json')
with open(summary_path, 'w') as f:
json.dump(results, f, indent=2)
# Print summary
success = sum(1 for r in results if r.get('status') == 'success')
failed = len(results) - success
logging.info(f"\nDone! Success: {success}/{len(results)}, Failed: {failed}")
if __name__ == "__main__":
main()