""" ASCAD Dataset Module ==================== Handles downloading, loading, and preprocessing of the ASCAD database. Supports per-byte POI windows extracted from raw traces via SNR analysis, as well as a global window mode for multi-task learning. Desynchronization is applied by shifting POI windows by a per-trace random offset in [0, desync], matching the official ASCAD benchmark protocol. """ import os import logging import zipfile from typing import Optional, Tuple import h5py import numpy as np import requests from tqdm import tqdm from .constants import ( AES_SBOX, BYTE_POI_WINDOWS, DEFAULT_PROFILING_RANGE, DEFAULT_ATTACK_RANGE, DESYNC_FILES, ASCAD_DOWNLOAD_URL, ASCAD_ZIP_FILENAME, ASCAD_DB_SUBDIR, ASCAD_RAW_FILENAME, GLOBAL_WINDOW_START, GLOBAL_WINDOW_END, ) logger = logging.getLogger(__name__) # Fixed seeds for reproducible desync shifts (different for profiling vs attack) _PROFILING_DESYNC_SEED = 20260503 _ATTACK_DESYNC_SEED = 20260504 class ASCADDataset: """ Manages the ASCAD database: downloading, loading, and preprocessing. Supports three loading modes: 1. Pre-extracted mode (byte 2 only): Uses ASCAD.h5 / desync variants. 2. Raw trace mode (per-byte): Extracts 700-sample POI windows from ATMega8515_raw_traces.h5 using SNR-derived windows. 3. Global window mode (multi-task): Extracts a single large window covering all 16 POI regions for multi-task models. When desync > 0, per-trace random shifts are applied to POI windows to simulate clock jitter, matching the official ASCAD benchmark protocol. """ def __init__(self, data_dir: str = "/tmp/ascad_data", desync: int = 0) -> None: """ Args: data_dir: Root directory where the ASCAD data is stored. desync: Desynchronization level (0, 50, or 100). Raises: ValueError: If desync is not in {0, 50, 100}. """ if desync not in DESYNC_FILES: raise ValueError( f"Invalid desync level: {desync}. Must be one of {list(DESYNC_FILES.keys())}." ) self.data_dir = data_dir self.desync = desync self._preextracted_path = os.path.join( data_dir, ASCAD_DB_SUBDIR, DESYNC_FILES[desync] ) self._raw_path = os.path.join(data_dir, ASCAD_DB_SUBDIR, ASCAD_RAW_FILENAME) self._h5_file: Optional[h5py.File] = None self._raw_h5_file: Optional[h5py.File] = None # ------------------------------------------------------------------ # Context manager support # ------------------------------------------------------------------ def __enter__(self) -> "ASCADDataset": return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() # ------------------------------------------------------------------ # Download & extract # ------------------------------------------------------------------ def download(self, force: bool = False) -> None: """Download and extract the ASCAD dataset if not already present.""" zip_path = os.path.join(self.data_dir, ASCAD_ZIP_FILENAME) if os.path.isfile(self._raw_path) and not force: logger.info("Raw traces already exist at %s", self._raw_path) return os.makedirs(self.data_dir, exist_ok=True) if not os.path.isfile(zip_path) or force: logger.info("Downloading ASCAD dataset (~4.2 GB) ...") self._download_file(ASCAD_DOWNLOAD_URL, zip_path) logger.info("Extracting %s ...", zip_path) with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(self.data_dir) logger.info("Extraction complete.") @staticmethod def _download_file(url: str, dest: str) -> None: """Download a file with a progress bar.""" resp = requests.get(url, stream=True, timeout=600) resp.raise_for_status() total = int(resp.headers.get("content-length", 0)) with open(dest, "wb") as f, tqdm( total=total, unit="B", unit_scale=True, desc="Downloading" ) as pbar: for chunk in resp.iter_content(chunk_size=1 << 20): f.write(chunk) pbar.update(len(chunk)) # ------------------------------------------------------------------ # HDF5 access (lazy open) # ------------------------------------------------------------------ def _open_preextracted(self) -> h5py.File: """Open the pre-extracted ASCAD HDF5 file (byte 2 optimized).""" if self._h5_file is None: if not os.path.isfile(self._preextracted_path): raise FileNotFoundError( f"ASCAD HDF5 not found at {self._preextracted_path}. " "Call download() first." ) self._h5_file = h5py.File(self._preextracted_path, "r") return self._h5_file def _open_raw(self) -> h5py.File: """Open the raw traces HDF5 file.""" if self._raw_h5_file is None: if not os.path.isfile(self._raw_path): raise FileNotFoundError( f"Raw traces not found at {self._raw_path}. " "Call download() first." ) self._raw_h5_file = h5py.File(self._raw_path, "r") return self._raw_h5_file def close(self) -> None: """Close all open HDF5 file handles.""" for attr in ("_h5_file", "_raw_h5_file"): handle = getattr(self, attr) if handle is not None: handle.close() setattr(self, attr, None) # ------------------------------------------------------------------ # Desync shift generation # ------------------------------------------------------------------ def _generate_desync_shifts( self, num_traces: int, seed: int ) -> np.ndarray: """ Generate per-trace random desync shifts. When desync=0, returns zeros (no shift). Otherwise, generates uniform random shifts in [0, desync] using a fixed seed for reproducibility. Args: num_traces: Number of traces to generate shifts for. seed: Random seed for reproducibility. Returns: Array of shape (num_traces,) with integer shifts. """ if self.desync == 0: return np.zeros(num_traces, dtype=np.int32) rng = np.random.default_rng(seed=seed) return rng.integers(0, self.desync + 1, size=num_traces, dtype=np.int32) # ------------------------------------------------------------------ # Label computation # ------------------------------------------------------------------ @staticmethod def compute_labels(metadata: np.ndarray, target_byte: int) -> np.ndarray: """ Compute Sbox(plaintext[byte] XOR key[byte]) labels. Args: metadata: Structured numpy array with 'plaintext' and 'key' fields. target_byte: Index of the target key byte (0-15). Returns: Array of uint8 labels, shape (N,). """ plaintexts = np.array( [m["plaintext"][target_byte] for m in metadata], dtype=np.uint8 ) keys = np.array( [m["key"][target_byte] for m in metadata], dtype=np.uint8 ) return AES_SBOX[plaintexts ^ keys] # ------------------------------------------------------------------ # Internal loading helpers # ------------------------------------------------------------------ def _load_raw_window( self, trace_range: Tuple[int, int], window_start: int, window_end: int, ) -> Tuple[np.ndarray, np.ndarray]: """Load a window of raw traces and their metadata (no desync).""" f = self._open_raw() start, end = trace_range traces = np.array( f["traces"][start:end, window_start:window_end], dtype=np.int8 ) metadata = np.array(f["metadata"][start:end]) logger.info( "Loaded %d traces, window [%d:%d], shape=%s", traces.shape[0], window_start, window_end, traces.shape, ) return traces, metadata def _load_raw_window_desync( self, trace_range: Tuple[int, int], window_start: int, window_end: int, shifts: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: """ Load a window of raw traces with per-trace desync shifts applied. Reads a wider window [window_start : window_end + max_desync] and then extracts the correct sub-window for each trace based on its individual shift value. Args: trace_range: (start_idx, end_idx) for trace selection. window_start: Nominal start of the POI window. window_end: Nominal end of the POI window. shifts: Per-trace shift values, shape (num_traces,). Returns: Tuple of (traces, metadata). """ f = self._open_raw() start, end = trace_range window_size = window_end - window_start max_shift = int(shifts.max()) if len(shifts) > 0 else 0 # Load the extended window covering all possible shifts traces_wide = np.array( f["traces"][start:end, window_start:window_end + max_shift], dtype=np.int8, ) metadata = np.array(f["metadata"][start:end]) # Apply per-trace shifts via vectorized indexing num_traces = end - start row_idx = np.arange(num_traces)[:, None] col_idx = shifts[:, None] + np.arange(window_size)[None, :] traces = traces_wide[row_idx, col_idx] logger.info( "Loaded %d traces with desync=%d, window [%d:%d], shape=%s", num_traces, self.desync, window_start, window_end, traces.shape, ) return traces, metadata def _load_preextracted( self, group: str, target_byte: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Load from the pre-extracted HDF5 (byte 2, specific desync).""" f = self._open_preextracted() traces = np.array(f[group]["traces"], dtype=np.int8) metadata = np.array(f[group]["metadata"]) labels = self.compute_labels(metadata, target_byte) return traces, labels, metadata # ------------------------------------------------------------------ # Public API: per-byte loading (for independent models) # ------------------------------------------------------------------ def load_profiling( self, target_byte: int ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Load profiling (training) data for a single target byte. For byte 2 at any desync level, uses the pre-extracted ASCAD HDF5 files which already have desync applied. For all other bytes, extracts from raw traces and applies desync shifts. Returns: Tuple of (traces, labels, metadata). """ if target_byte == 2: return self._load_preextracted("Profiling_traces", target_byte) w_start, w_end = BYTE_POI_WINDOWS[target_byte] start, end = DEFAULT_PROFILING_RANGE num_traces = end - start if self.desync == 0: traces, metadata = self._load_raw_window( DEFAULT_PROFILING_RANGE, w_start, w_end ) else: shifts = self._generate_desync_shifts(num_traces, _PROFILING_DESYNC_SEED) traces, metadata = self._load_raw_window_desync( DEFAULT_PROFILING_RANGE, w_start, w_end, shifts ) labels = self.compute_labels(metadata, target_byte) return traces, labels, metadata def load_attack( self, target_byte: int ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Load attack (evaluation) data for a single target byte. Returns: Tuple of (traces, labels, metadata). """ if target_byte == 2: return self._load_preextracted("Attack_traces", target_byte) w_start, w_end = BYTE_POI_WINDOWS[target_byte] start, end = DEFAULT_ATTACK_RANGE num_traces = end - start if self.desync == 0: traces, metadata = self._load_raw_window( DEFAULT_ATTACK_RANGE, w_start, w_end ) else: shifts = self._generate_desync_shifts(num_traces, _ATTACK_DESYNC_SEED) traces, metadata = self._load_raw_window_desync( DEFAULT_ATTACK_RANGE, w_start, w_end, shifts ) labels = self.compute_labels(metadata, target_byte) return traces, labels, metadata # ------------------------------------------------------------------ # Public API: global window loading (for multi-task models) # ------------------------------------------------------------------ def load_profiling_global(self) -> Tuple[np.ndarray, np.ndarray]: """ Load profiling traces using the global window covering all 16 POI regions. When desync > 0, applies per-trace random shifts to the global window. Returns: Tuple of (traces, metadata). traces: shape (50000, GLOBAL_WINDOW_SIZE), dtype int8. metadata: structured numpy array. """ start, end = DEFAULT_PROFILING_RANGE num_traces = end - start if self.desync == 0: return self._load_raw_window( DEFAULT_PROFILING_RANGE, GLOBAL_WINDOW_START, GLOBAL_WINDOW_END, ) shifts = self._generate_desync_shifts(num_traces, _PROFILING_DESYNC_SEED) return self._load_raw_window_desync( DEFAULT_PROFILING_RANGE, GLOBAL_WINDOW_START, GLOBAL_WINDOW_END, shifts, ) def load_attack_global(self) -> Tuple[np.ndarray, np.ndarray]: """ Load attack traces using the global window. When desync > 0, applies per-trace random shifts to the global window. Returns: Tuple of (traces, metadata). traces: shape (10000, GLOBAL_WINDOW_SIZE), dtype int8. metadata: structured numpy array. """ start, end = DEFAULT_ATTACK_RANGE num_traces = end - start if self.desync == 0: return self._load_raw_window( DEFAULT_ATTACK_RANGE, GLOBAL_WINDOW_START, GLOBAL_WINDOW_END, ) shifts = self._generate_desync_shifts(num_traces, _ATTACK_DESYNC_SEED) return self._load_raw_window_desync( DEFAULT_ATTACK_RANGE, GLOBAL_WINDOW_START, GLOBAL_WINDOW_END, shifts, ) def compute_all_labels( self, metadata: np.ndarray ) -> np.ndarray: """ Compute labels for all 16 bytes simultaneously. Args: metadata: Structured numpy array with 'plaintext' and 'key' fields. Returns: Array of shape (N, 16), dtype uint8, where column i contains the Sbox labels for byte i. """ n = len(metadata) all_labels = np.empty((n, 16), dtype=np.uint8) for byte_idx in range(16): all_labels[:, byte_idx] = self.compute_labels(metadata, byte_idx) return all_labels # ------------------------------------------------------------------ # Public API: multi-input loading (for LMIC model) # ------------------------------------------------------------------ def load_profiling_multi_input(self) -> Tuple[dict, np.ndarray]: """ Load profiling traces as 16 separate per-byte POI windows. Each byte gets its own 700-sample window extracted from the raw traces. When desync > 0, the same per-trace shift is applied to all 16 byte windows (simulating uniform clock jitter). Returns: Tuple of (traces_dict, metadata). traces_dict: {"byte_0_input": (50000, 700), ..., "byte_15_input": (50000, 700)} metadata: structured numpy array. """ f = self._open_raw() start, end = DEFAULT_PROFILING_RANGE metadata = np.array(f["metadata"][start:end]) num_traces = end - start # Generate desync shifts (same shift for all bytes within one trace) shifts = self._generate_desync_shifts(num_traces, _PROFILING_DESYNC_SEED) traces_dict = {} for byte_idx in range(16): w_start, w_end = BYTE_POI_WINDOWS[byte_idx] window_size = w_end - w_start if self.desync == 0: traces = np.array( f["traces"][start:end, w_start:w_end], dtype=np.int8 ) else: max_shift = int(shifts.max()) traces_wide = np.array( f["traces"][start:end, w_start:w_end + max_shift], dtype=np.int8, ) # Vectorized per-trace shift extraction row_idx = np.arange(num_traces)[:, None] col_idx = shifts[:, None] + np.arange(window_size)[None, :] traces = traces_wide[row_idx, col_idx] traces_dict[f"byte_{byte_idx}_input"] = traces logger.info( "Loaded byte %d profiling traces: window [%d:%d], desync=%d, shape=%s", byte_idx, w_start, w_end, self.desync, traces.shape, ) return traces_dict, metadata def load_attack_multi_input(self) -> Tuple[dict, np.ndarray]: """ Load attack traces as 16 separate per-byte POI windows. When desync > 0, the same per-trace shift is applied to all 16 byte windows (simulating uniform clock jitter). Returns: Tuple of (traces_dict, metadata). traces_dict: {"byte_0_input": (10000, 700), ..., "byte_15_input": (10000, 700)} metadata: structured numpy array. """ f = self._open_raw() start, end = DEFAULT_ATTACK_RANGE metadata = np.array(f["metadata"][start:end]) num_traces = end - start # Generate desync shifts (same shift for all bytes within one trace) shifts = self._generate_desync_shifts(num_traces, _ATTACK_DESYNC_SEED) traces_dict = {} for byte_idx in range(16): w_start, w_end = BYTE_POI_WINDOWS[byte_idx] window_size = w_end - w_start if self.desync == 0: traces = np.array( f["traces"][start:end, w_start:w_end], dtype=np.int8 ) else: max_shift = int(shifts.max()) traces_wide = np.array( f["traces"][start:end, w_start:w_end + max_shift], dtype=np.int8, ) # Vectorized per-trace shift extraction row_idx = np.arange(num_traces)[:, None] col_idx = shifts[:, None] + np.arange(window_size)[None, :] traces = traces_wide[row_idx, col_idx] traces_dict[f"byte_{byte_idx}_input"] = traces logger.info( "Loaded byte %d attack traces: window [%d:%d], desync=%d, shape=%s", byte_idx, w_start, w_end, self.desync, traces.shape, ) return traces_dict, metadata def get_real_key(self, target_byte: int) -> int: """ Retrieve the real key byte value for a given target byte. The ASCAD dataset uses a fixed key across all traces. """ f = self._open_raw() metadata = f["metadata"][0] return int(metadata["key"][target_byte]) def get_all_real_keys(self) -> np.ndarray: """Retrieve all 16 real key byte values.""" f = self._open_raw() metadata = f["metadata"][0] return np.array(metadata["key"], dtype=np.uint8)