lemousehunter's picture
Section 43-45: dataset.py per-trace shift fix, orchestrator DELETE FK fix, queue rebuild, clean MLP/CNN resubmission
8c414b1
"""
ASCAD Dataset Module
====================
Handles downloading, loading, and preprocessing of the ASCAD database.
Supports per-byte POI windows extracted from raw traces via SNR analysis,
as well as a global window mode for multi-task learning.
Desynchronization is applied by shifting POI windows by a per-trace random
offset in [0, desync], matching the official ASCAD benchmark protocol.
"""
import os
import logging
import zipfile
from typing import Optional, Tuple
import h5py
import numpy as np
import requests
from tqdm import tqdm
from .constants import (
AES_SBOX,
BYTE_POI_WINDOWS,
DEFAULT_PROFILING_RANGE,
DEFAULT_ATTACK_RANGE,
DESYNC_FILES,
ASCAD_DOWNLOAD_URL,
ASCAD_ZIP_FILENAME,
ASCAD_DB_SUBDIR,
ASCAD_RAW_FILENAME,
GLOBAL_WINDOW_START,
GLOBAL_WINDOW_END,
)
logger = logging.getLogger(__name__)
# Fixed seeds for reproducible desync shifts (different for profiling vs attack)
_PROFILING_DESYNC_SEED = 20260503
_ATTACK_DESYNC_SEED = 20260504
class ASCADDataset:
"""
Manages the ASCAD database: downloading, loading, and preprocessing.
Supports three loading modes:
1. Pre-extracted mode (byte 2 only): Uses ASCAD.h5 / desync variants.
2. Raw trace mode (per-byte): Extracts 700-sample POI windows from
ATMega8515_raw_traces.h5 using SNR-derived windows.
3. Global window mode (multi-task): Extracts a single large window
covering all 16 POI regions for multi-task models.
When desync > 0, per-trace random shifts are applied to POI windows
to simulate clock jitter, matching the official ASCAD benchmark protocol.
"""
def __init__(self, data_dir: str = "/tmp/ascad_data", desync: int = 0) -> None:
"""
Args:
data_dir: Root directory where the ASCAD data is stored.
desync: Desynchronization level (0, 50, or 100).
Raises:
ValueError: If desync is not in {0, 50, 100}.
"""
if desync not in DESYNC_FILES:
raise ValueError(
f"Invalid desync level: {desync}. Must be one of {list(DESYNC_FILES.keys())}."
)
self.data_dir = data_dir
self.desync = desync
self._preextracted_path = os.path.join(
data_dir, ASCAD_DB_SUBDIR, DESYNC_FILES[desync]
)
self._raw_path = os.path.join(data_dir, ASCAD_DB_SUBDIR, ASCAD_RAW_FILENAME)
self._h5_file: Optional[h5py.File] = None
self._raw_h5_file: Optional[h5py.File] = None
# ------------------------------------------------------------------
# Context manager support
# ------------------------------------------------------------------
def __enter__(self) -> "ASCADDataset":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
# ------------------------------------------------------------------
# Download & extract
# ------------------------------------------------------------------
def download(self, force: bool = False) -> None:
"""Download and extract the ASCAD dataset if not already present."""
zip_path = os.path.join(self.data_dir, ASCAD_ZIP_FILENAME)
if os.path.isfile(self._raw_path) and not force:
logger.info("Raw traces already exist at %s", self._raw_path)
return
os.makedirs(self.data_dir, exist_ok=True)
if not os.path.isfile(zip_path) or force:
logger.info("Downloading ASCAD dataset (~4.2 GB) ...")
self._download_file(ASCAD_DOWNLOAD_URL, zip_path)
logger.info("Extracting %s ...", zip_path)
with zipfile.ZipFile(zip_path, "r") as zf:
zf.extractall(self.data_dir)
logger.info("Extraction complete.")
@staticmethod
def _download_file(url: str, dest: str) -> None:
"""Download a file with a progress bar."""
resp = requests.get(url, stream=True, timeout=600)
resp.raise_for_status()
total = int(resp.headers.get("content-length", 0))
with open(dest, "wb") as f, tqdm(
total=total, unit="B", unit_scale=True, desc="Downloading"
) as pbar:
for chunk in resp.iter_content(chunk_size=1 << 20):
f.write(chunk)
pbar.update(len(chunk))
# ------------------------------------------------------------------
# HDF5 access (lazy open)
# ------------------------------------------------------------------
def _open_preextracted(self) -> h5py.File:
"""Open the pre-extracted ASCAD HDF5 file (byte 2 optimized)."""
if self._h5_file is None:
if not os.path.isfile(self._preextracted_path):
raise FileNotFoundError(
f"ASCAD HDF5 not found at {self._preextracted_path}. "
"Call download() first."
)
self._h5_file = h5py.File(self._preextracted_path, "r")
return self._h5_file
def _open_raw(self) -> h5py.File:
"""Open the raw traces HDF5 file."""
if self._raw_h5_file is None:
if not os.path.isfile(self._raw_path):
raise FileNotFoundError(
f"Raw traces not found at {self._raw_path}. "
"Call download() first."
)
self._raw_h5_file = h5py.File(self._raw_path, "r")
return self._raw_h5_file
def close(self) -> None:
"""Close all open HDF5 file handles."""
for attr in ("_h5_file", "_raw_h5_file"):
handle = getattr(self, attr)
if handle is not None:
handle.close()
setattr(self, attr, None)
# ------------------------------------------------------------------
# Desync shift generation
# ------------------------------------------------------------------
def _generate_desync_shifts(
self, num_traces: int, seed: int
) -> np.ndarray:
"""
Generate per-trace random desync shifts.
When desync=0, returns zeros (no shift). Otherwise, generates
uniform random shifts in [0, desync] using a fixed seed for
reproducibility.
Args:
num_traces: Number of traces to generate shifts for.
seed: Random seed for reproducibility.
Returns:
Array of shape (num_traces,) with integer shifts.
"""
if self.desync == 0:
return np.zeros(num_traces, dtype=np.int32)
rng = np.random.default_rng(seed=seed)
return rng.integers(0, self.desync + 1, size=num_traces, dtype=np.int32)
# ------------------------------------------------------------------
# Label computation
# ------------------------------------------------------------------
@staticmethod
def compute_labels(metadata: np.ndarray, target_byte: int) -> np.ndarray:
"""
Compute Sbox(plaintext[byte] XOR key[byte]) labels.
Args:
metadata: Structured numpy array with 'plaintext' and 'key' fields.
target_byte: Index of the target key byte (0-15).
Returns:
Array of uint8 labels, shape (N,).
"""
plaintexts = np.array(
[m["plaintext"][target_byte] for m in metadata], dtype=np.uint8
)
keys = np.array(
[m["key"][target_byte] for m in metadata], dtype=np.uint8
)
return AES_SBOX[plaintexts ^ keys]
# ------------------------------------------------------------------
# Internal loading helpers
# ------------------------------------------------------------------
def _load_raw_window(
self,
trace_range: Tuple[int, int],
window_start: int,
window_end: int,
) -> Tuple[np.ndarray, np.ndarray]:
"""Load a window of raw traces and their metadata (no desync)."""
f = self._open_raw()
start, end = trace_range
traces = np.array(
f["traces"][start:end, window_start:window_end], dtype=np.int8
)
metadata = np.array(f["metadata"][start:end])
logger.info(
"Loaded %d traces, window [%d:%d], shape=%s",
traces.shape[0], window_start, window_end, traces.shape,
)
return traces, metadata
def _load_raw_window_desync(
self,
trace_range: Tuple[int, int],
window_start: int,
window_end: int,
shifts: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Load a window of raw traces with per-trace desync shifts applied.
Reads a wider window [window_start : window_end + max_desync] and
then extracts the correct sub-window for each trace based on its
individual shift value.
Args:
trace_range: (start_idx, end_idx) for trace selection.
window_start: Nominal start of the POI window.
window_end: Nominal end of the POI window.
shifts: Per-trace shift values, shape (num_traces,).
Returns:
Tuple of (traces, metadata).
"""
f = self._open_raw()
start, end = trace_range
window_size = window_end - window_start
max_shift = int(shifts.max()) if len(shifts) > 0 else 0
# Load the extended window covering all possible shifts
traces_wide = np.array(
f["traces"][start:end, window_start:window_end + max_shift],
dtype=np.int8,
)
metadata = np.array(f["metadata"][start:end])
# Apply per-trace shifts via vectorized indexing
num_traces = end - start
row_idx = np.arange(num_traces)[:, None]
col_idx = shifts[:, None] + np.arange(window_size)[None, :]
traces = traces_wide[row_idx, col_idx]
logger.info(
"Loaded %d traces with desync=%d, window [%d:%d], shape=%s",
num_traces, self.desync, window_start, window_end, traces.shape,
)
return traces, metadata
def _load_preextracted(
self,
group: str,
target_byte: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Load from the pre-extracted HDF5 (byte 2, specific desync)."""
f = self._open_preextracted()
traces = np.array(f[group]["traces"], dtype=np.int8)
metadata = np.array(f[group]["metadata"])
labels = self.compute_labels(metadata, target_byte)
return traces, labels, metadata
# ------------------------------------------------------------------
# Public API: per-byte loading (for independent models)
# ------------------------------------------------------------------
def load_profiling(
self, target_byte: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Load profiling (training) data for a single target byte.
For byte 2 at any desync level, uses the pre-extracted ASCAD HDF5
files which already have desync applied. For all other bytes,
extracts from raw traces and applies desync shifts.
Returns:
Tuple of (traces, labels, metadata).
"""
if target_byte == 2:
return self._load_preextracted("Profiling_traces", target_byte)
w_start, w_end = BYTE_POI_WINDOWS[target_byte]
start, end = DEFAULT_PROFILING_RANGE
num_traces = end - start
if self.desync == 0:
traces, metadata = self._load_raw_window(
DEFAULT_PROFILING_RANGE, w_start, w_end
)
else:
shifts = self._generate_desync_shifts(num_traces, _PROFILING_DESYNC_SEED)
traces, metadata = self._load_raw_window_desync(
DEFAULT_PROFILING_RANGE, w_start, w_end, shifts
)
labels = self.compute_labels(metadata, target_byte)
return traces, labels, metadata
def load_attack(
self, target_byte: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Load attack (evaluation) data for a single target byte.
Returns:
Tuple of (traces, labels, metadata).
"""
if target_byte == 2:
return self._load_preextracted("Attack_traces", target_byte)
w_start, w_end = BYTE_POI_WINDOWS[target_byte]
start, end = DEFAULT_ATTACK_RANGE
num_traces = end - start
if self.desync == 0:
traces, metadata = self._load_raw_window(
DEFAULT_ATTACK_RANGE, w_start, w_end
)
else:
shifts = self._generate_desync_shifts(num_traces, _ATTACK_DESYNC_SEED)
traces, metadata = self._load_raw_window_desync(
DEFAULT_ATTACK_RANGE, w_start, w_end, shifts
)
labels = self.compute_labels(metadata, target_byte)
return traces, labels, metadata
# ------------------------------------------------------------------
# Public API: global window loading (for multi-task models)
# ------------------------------------------------------------------
def load_profiling_global(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Load profiling traces using the global window covering all 16 POI regions.
When desync > 0, applies per-trace random shifts to the global window.
Returns:
Tuple of (traces, metadata).
traces: shape (50000, GLOBAL_WINDOW_SIZE), dtype int8.
metadata: structured numpy array.
"""
start, end = DEFAULT_PROFILING_RANGE
num_traces = end - start
if self.desync == 0:
return self._load_raw_window(
DEFAULT_PROFILING_RANGE,
GLOBAL_WINDOW_START,
GLOBAL_WINDOW_END,
)
shifts = self._generate_desync_shifts(num_traces, _PROFILING_DESYNC_SEED)
return self._load_raw_window_desync(
DEFAULT_PROFILING_RANGE,
GLOBAL_WINDOW_START,
GLOBAL_WINDOW_END,
shifts,
)
def load_attack_global(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Load attack traces using the global window.
When desync > 0, applies per-trace random shifts to the global window.
Returns:
Tuple of (traces, metadata).
traces: shape (10000, GLOBAL_WINDOW_SIZE), dtype int8.
metadata: structured numpy array.
"""
start, end = DEFAULT_ATTACK_RANGE
num_traces = end - start
if self.desync == 0:
return self._load_raw_window(
DEFAULT_ATTACK_RANGE,
GLOBAL_WINDOW_START,
GLOBAL_WINDOW_END,
)
shifts = self._generate_desync_shifts(num_traces, _ATTACK_DESYNC_SEED)
return self._load_raw_window_desync(
DEFAULT_ATTACK_RANGE,
GLOBAL_WINDOW_START,
GLOBAL_WINDOW_END,
shifts,
)
def compute_all_labels(
self, metadata: np.ndarray
) -> np.ndarray:
"""
Compute labels for all 16 bytes simultaneously.
Args:
metadata: Structured numpy array with 'plaintext' and 'key' fields.
Returns:
Array of shape (N, 16), dtype uint8, where column i contains
the Sbox labels for byte i.
"""
n = len(metadata)
all_labels = np.empty((n, 16), dtype=np.uint8)
for byte_idx in range(16):
all_labels[:, byte_idx] = self.compute_labels(metadata, byte_idx)
return all_labels
# ------------------------------------------------------------------
# Public API: multi-input loading (for LMIC model)
# ------------------------------------------------------------------
def load_profiling_multi_input(self) -> Tuple[dict, np.ndarray]:
"""
Load profiling traces as 16 separate per-byte POI windows.
Each byte gets its own 700-sample window extracted from the raw
traces. When desync > 0, the same per-trace shift is applied to
all 16 byte windows (simulating uniform clock jitter).
Returns:
Tuple of (traces_dict, metadata).
traces_dict: {"byte_0_input": (50000, 700), ..., "byte_15_input": (50000, 700)}
metadata: structured numpy array.
"""
f = self._open_raw()
start, end = DEFAULT_PROFILING_RANGE
metadata = np.array(f["metadata"][start:end])
num_traces = end - start
# Generate desync shifts (same shift for all bytes within one trace)
shifts = self._generate_desync_shifts(num_traces, _PROFILING_DESYNC_SEED)
traces_dict = {}
for byte_idx in range(16):
w_start, w_end = BYTE_POI_WINDOWS[byte_idx]
window_size = w_end - w_start
if self.desync == 0:
traces = np.array(
f["traces"][start:end, w_start:w_end], dtype=np.int8
)
else:
max_shift = int(shifts.max())
traces_wide = np.array(
f["traces"][start:end, w_start:w_end + max_shift],
dtype=np.int8,
)
# Vectorized per-trace shift extraction
row_idx = np.arange(num_traces)[:, None]
col_idx = shifts[:, None] + np.arange(window_size)[None, :]
traces = traces_wide[row_idx, col_idx]
traces_dict[f"byte_{byte_idx}_input"] = traces
logger.info(
"Loaded byte %d profiling traces: window [%d:%d], desync=%d, shape=%s",
byte_idx, w_start, w_end, self.desync, traces.shape,
)
return traces_dict, metadata
def load_attack_multi_input(self) -> Tuple[dict, np.ndarray]:
"""
Load attack traces as 16 separate per-byte POI windows.
When desync > 0, the same per-trace shift is applied to all 16
byte windows (simulating uniform clock jitter).
Returns:
Tuple of (traces_dict, metadata).
traces_dict: {"byte_0_input": (10000, 700), ..., "byte_15_input": (10000, 700)}
metadata: structured numpy array.
"""
f = self._open_raw()
start, end = DEFAULT_ATTACK_RANGE
metadata = np.array(f["metadata"][start:end])
num_traces = end - start
# Generate desync shifts (same shift for all bytes within one trace)
shifts = self._generate_desync_shifts(num_traces, _ATTACK_DESYNC_SEED)
traces_dict = {}
for byte_idx in range(16):
w_start, w_end = BYTE_POI_WINDOWS[byte_idx]
window_size = w_end - w_start
if self.desync == 0:
traces = np.array(
f["traces"][start:end, w_start:w_end], dtype=np.int8
)
else:
max_shift = int(shifts.max())
traces_wide = np.array(
f["traces"][start:end, w_start:w_end + max_shift],
dtype=np.int8,
)
# Vectorized per-trace shift extraction
row_idx = np.arange(num_traces)[:, None]
col_idx = shifts[:, None] + np.arange(window_size)[None, :]
traces = traces_wide[row_idx, col_idx]
traces_dict[f"byte_{byte_idx}_input"] = traces
logger.info(
"Loaded byte %d attack traces: window [%d:%d], desync=%d, shape=%s",
byte_idx, w_start, w_end, self.desync, traces.shape,
)
return traces_dict, metadata
def get_real_key(self, target_byte: int) -> int:
"""
Retrieve the real key byte value for a given target byte.
The ASCAD dataset uses a fixed key across all traces.
"""
f = self._open_raw()
metadata = f["metadata"][0]
return int(metadata["key"][target_byte])
def get_all_real_keys(self) -> np.ndarray:
"""Retrieve all 16 real key byte values."""
f = self._open_raw()
metadata = f["metadata"][0]
return np.array(metadata["key"], dtype=np.uint8)