Section 43-45: dataset.py per-trace shift fix, orchestrator DELETE FK fix, queue rebuild, clean MLP/CNN resubmission
8c414b1 | """ | |
| ASCAD Dataset Module | |
| ==================== | |
| Handles downloading, loading, and preprocessing of the ASCAD database. | |
| Supports per-byte POI windows extracted from raw traces via SNR analysis, | |
| as well as a global window mode for multi-task learning. | |
| Desynchronization is applied by shifting POI windows by a per-trace random | |
| offset in [0, desync], matching the official ASCAD benchmark protocol. | |
| """ | |
| import os | |
| import logging | |
| import zipfile | |
| from typing import Optional, Tuple | |
| import h5py | |
| import numpy as np | |
| import requests | |
| from tqdm import tqdm | |
| from .constants import ( | |
| AES_SBOX, | |
| BYTE_POI_WINDOWS, | |
| DEFAULT_PROFILING_RANGE, | |
| DEFAULT_ATTACK_RANGE, | |
| DESYNC_FILES, | |
| ASCAD_DOWNLOAD_URL, | |
| ASCAD_ZIP_FILENAME, | |
| ASCAD_DB_SUBDIR, | |
| ASCAD_RAW_FILENAME, | |
| GLOBAL_WINDOW_START, | |
| GLOBAL_WINDOW_END, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Fixed seeds for reproducible desync shifts (different for profiling vs attack) | |
| _PROFILING_DESYNC_SEED = 20260503 | |
| _ATTACK_DESYNC_SEED = 20260504 | |
| class ASCADDataset: | |
| """ | |
| Manages the ASCAD database: downloading, loading, and preprocessing. | |
| Supports three loading modes: | |
| 1. Pre-extracted mode (byte 2 only): Uses ASCAD.h5 / desync variants. | |
| 2. Raw trace mode (per-byte): Extracts 700-sample POI windows from | |
| ATMega8515_raw_traces.h5 using SNR-derived windows. | |
| 3. Global window mode (multi-task): Extracts a single large window | |
| covering all 16 POI regions for multi-task models. | |
| When desync > 0, per-trace random shifts are applied to POI windows | |
| to simulate clock jitter, matching the official ASCAD benchmark protocol. | |
| """ | |
| def __init__(self, data_dir: str = "/tmp/ascad_data", desync: int = 0) -> None: | |
| """ | |
| Args: | |
| data_dir: Root directory where the ASCAD data is stored. | |
| desync: Desynchronization level (0, 50, or 100). | |
| Raises: | |
| ValueError: If desync is not in {0, 50, 100}. | |
| """ | |
| if desync not in DESYNC_FILES: | |
| raise ValueError( | |
| f"Invalid desync level: {desync}. Must be one of {list(DESYNC_FILES.keys())}." | |
| ) | |
| self.data_dir = data_dir | |
| self.desync = desync | |
| self._preextracted_path = os.path.join( | |
| data_dir, ASCAD_DB_SUBDIR, DESYNC_FILES[desync] | |
| ) | |
| self._raw_path = os.path.join(data_dir, ASCAD_DB_SUBDIR, ASCAD_RAW_FILENAME) | |
| self._h5_file: Optional[h5py.File] = None | |
| self._raw_h5_file: Optional[h5py.File] = None | |
| # ------------------------------------------------------------------ | |
| # Context manager support | |
| # ------------------------------------------------------------------ | |
| def __enter__(self) -> "ASCADDataset": | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
| self.close() | |
| # ------------------------------------------------------------------ | |
| # Download & extract | |
| # ------------------------------------------------------------------ | |
| def download(self, force: bool = False) -> None: | |
| """Download and extract the ASCAD dataset if not already present.""" | |
| zip_path = os.path.join(self.data_dir, ASCAD_ZIP_FILENAME) | |
| if os.path.isfile(self._raw_path) and not force: | |
| logger.info("Raw traces already exist at %s", self._raw_path) | |
| return | |
| os.makedirs(self.data_dir, exist_ok=True) | |
| if not os.path.isfile(zip_path) or force: | |
| logger.info("Downloading ASCAD dataset (~4.2 GB) ...") | |
| self._download_file(ASCAD_DOWNLOAD_URL, zip_path) | |
| logger.info("Extracting %s ...", zip_path) | |
| with zipfile.ZipFile(zip_path, "r") as zf: | |
| zf.extractall(self.data_dir) | |
| logger.info("Extraction complete.") | |
| def _download_file(url: str, dest: str) -> None: | |
| """Download a file with a progress bar.""" | |
| resp = requests.get(url, stream=True, timeout=600) | |
| resp.raise_for_status() | |
| total = int(resp.headers.get("content-length", 0)) | |
| with open(dest, "wb") as f, tqdm( | |
| total=total, unit="B", unit_scale=True, desc="Downloading" | |
| ) as pbar: | |
| for chunk in resp.iter_content(chunk_size=1 << 20): | |
| f.write(chunk) | |
| pbar.update(len(chunk)) | |
| # ------------------------------------------------------------------ | |
| # HDF5 access (lazy open) | |
| # ------------------------------------------------------------------ | |
| def _open_preextracted(self) -> h5py.File: | |
| """Open the pre-extracted ASCAD HDF5 file (byte 2 optimized).""" | |
| if self._h5_file is None: | |
| if not os.path.isfile(self._preextracted_path): | |
| raise FileNotFoundError( | |
| f"ASCAD HDF5 not found at {self._preextracted_path}. " | |
| "Call download() first." | |
| ) | |
| self._h5_file = h5py.File(self._preextracted_path, "r") | |
| return self._h5_file | |
| def _open_raw(self) -> h5py.File: | |
| """Open the raw traces HDF5 file.""" | |
| if self._raw_h5_file is None: | |
| if not os.path.isfile(self._raw_path): | |
| raise FileNotFoundError( | |
| f"Raw traces not found at {self._raw_path}. " | |
| "Call download() first." | |
| ) | |
| self._raw_h5_file = h5py.File(self._raw_path, "r") | |
| return self._raw_h5_file | |
| def close(self) -> None: | |
| """Close all open HDF5 file handles.""" | |
| for attr in ("_h5_file", "_raw_h5_file"): | |
| handle = getattr(self, attr) | |
| if handle is not None: | |
| handle.close() | |
| setattr(self, attr, None) | |
| # ------------------------------------------------------------------ | |
| # Desync shift generation | |
| # ------------------------------------------------------------------ | |
| def _generate_desync_shifts( | |
| self, num_traces: int, seed: int | |
| ) -> np.ndarray: | |
| """ | |
| Generate per-trace random desync shifts. | |
| When desync=0, returns zeros (no shift). Otherwise, generates | |
| uniform random shifts in [0, desync] using a fixed seed for | |
| reproducibility. | |
| Args: | |
| num_traces: Number of traces to generate shifts for. | |
| seed: Random seed for reproducibility. | |
| Returns: | |
| Array of shape (num_traces,) with integer shifts. | |
| """ | |
| if self.desync == 0: | |
| return np.zeros(num_traces, dtype=np.int32) | |
| rng = np.random.default_rng(seed=seed) | |
| return rng.integers(0, self.desync + 1, size=num_traces, dtype=np.int32) | |
| # ------------------------------------------------------------------ | |
| # Label computation | |
| # ------------------------------------------------------------------ | |
| def compute_labels(metadata: np.ndarray, target_byte: int) -> np.ndarray: | |
| """ | |
| Compute Sbox(plaintext[byte] XOR key[byte]) labels. | |
| Args: | |
| metadata: Structured numpy array with 'plaintext' and 'key' fields. | |
| target_byte: Index of the target key byte (0-15). | |
| Returns: | |
| Array of uint8 labels, shape (N,). | |
| """ | |
| plaintexts = np.array( | |
| [m["plaintext"][target_byte] for m in metadata], dtype=np.uint8 | |
| ) | |
| keys = np.array( | |
| [m["key"][target_byte] for m in metadata], dtype=np.uint8 | |
| ) | |
| return AES_SBOX[plaintexts ^ keys] | |
| # ------------------------------------------------------------------ | |
| # Internal loading helpers | |
| # ------------------------------------------------------------------ | |
| def _load_raw_window( | |
| self, | |
| trace_range: Tuple[int, int], | |
| window_start: int, | |
| window_end: int, | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """Load a window of raw traces and their metadata (no desync).""" | |
| f = self._open_raw() | |
| start, end = trace_range | |
| traces = np.array( | |
| f["traces"][start:end, window_start:window_end], dtype=np.int8 | |
| ) | |
| metadata = np.array(f["metadata"][start:end]) | |
| logger.info( | |
| "Loaded %d traces, window [%d:%d], shape=%s", | |
| traces.shape[0], window_start, window_end, traces.shape, | |
| ) | |
| return traces, metadata | |
| def _load_raw_window_desync( | |
| self, | |
| trace_range: Tuple[int, int], | |
| window_start: int, | |
| window_end: int, | |
| shifts: np.ndarray, | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Load a window of raw traces with per-trace desync shifts applied. | |
| Reads a wider window [window_start : window_end + max_desync] and | |
| then extracts the correct sub-window for each trace based on its | |
| individual shift value. | |
| Args: | |
| trace_range: (start_idx, end_idx) for trace selection. | |
| window_start: Nominal start of the POI window. | |
| window_end: Nominal end of the POI window. | |
| shifts: Per-trace shift values, shape (num_traces,). | |
| Returns: | |
| Tuple of (traces, metadata). | |
| """ | |
| f = self._open_raw() | |
| start, end = trace_range | |
| window_size = window_end - window_start | |
| max_shift = int(shifts.max()) if len(shifts) > 0 else 0 | |
| # Load the extended window covering all possible shifts | |
| traces_wide = np.array( | |
| f["traces"][start:end, window_start:window_end + max_shift], | |
| dtype=np.int8, | |
| ) | |
| metadata = np.array(f["metadata"][start:end]) | |
| # Apply per-trace shifts via vectorized indexing | |
| num_traces = end - start | |
| row_idx = np.arange(num_traces)[:, None] | |
| col_idx = shifts[:, None] + np.arange(window_size)[None, :] | |
| traces = traces_wide[row_idx, col_idx] | |
| logger.info( | |
| "Loaded %d traces with desync=%d, window [%d:%d], shape=%s", | |
| num_traces, self.desync, window_start, window_end, traces.shape, | |
| ) | |
| return traces, metadata | |
| def _load_preextracted( | |
| self, | |
| group: str, | |
| target_byte: int, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """Load from the pre-extracted HDF5 (byte 2, specific desync).""" | |
| f = self._open_preextracted() | |
| traces = np.array(f[group]["traces"], dtype=np.int8) | |
| metadata = np.array(f[group]["metadata"]) | |
| labels = self.compute_labels(metadata, target_byte) | |
| return traces, labels, metadata | |
| # ------------------------------------------------------------------ | |
| # Public API: per-byte loading (for independent models) | |
| # ------------------------------------------------------------------ | |
| def load_profiling( | |
| self, target_byte: int | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Load profiling (training) data for a single target byte. | |
| For byte 2 at any desync level, uses the pre-extracted ASCAD HDF5 | |
| files which already have desync applied. For all other bytes, | |
| extracts from raw traces and applies desync shifts. | |
| Returns: | |
| Tuple of (traces, labels, metadata). | |
| """ | |
| if target_byte == 2: | |
| return self._load_preextracted("Profiling_traces", target_byte) | |
| w_start, w_end = BYTE_POI_WINDOWS[target_byte] | |
| start, end = DEFAULT_PROFILING_RANGE | |
| num_traces = end - start | |
| if self.desync == 0: | |
| traces, metadata = self._load_raw_window( | |
| DEFAULT_PROFILING_RANGE, w_start, w_end | |
| ) | |
| else: | |
| shifts = self._generate_desync_shifts(num_traces, _PROFILING_DESYNC_SEED) | |
| traces, metadata = self._load_raw_window_desync( | |
| DEFAULT_PROFILING_RANGE, w_start, w_end, shifts | |
| ) | |
| labels = self.compute_labels(metadata, target_byte) | |
| return traces, labels, metadata | |
| def load_attack( | |
| self, target_byte: int | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Load attack (evaluation) data for a single target byte. | |
| Returns: | |
| Tuple of (traces, labels, metadata). | |
| """ | |
| if target_byte == 2: | |
| return self._load_preextracted("Attack_traces", target_byte) | |
| w_start, w_end = BYTE_POI_WINDOWS[target_byte] | |
| start, end = DEFAULT_ATTACK_RANGE | |
| num_traces = end - start | |
| if self.desync == 0: | |
| traces, metadata = self._load_raw_window( | |
| DEFAULT_ATTACK_RANGE, w_start, w_end | |
| ) | |
| else: | |
| shifts = self._generate_desync_shifts(num_traces, _ATTACK_DESYNC_SEED) | |
| traces, metadata = self._load_raw_window_desync( | |
| DEFAULT_ATTACK_RANGE, w_start, w_end, shifts | |
| ) | |
| labels = self.compute_labels(metadata, target_byte) | |
| return traces, labels, metadata | |
| # ------------------------------------------------------------------ | |
| # Public API: global window loading (for multi-task models) | |
| # ------------------------------------------------------------------ | |
| def load_profiling_global(self) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Load profiling traces using the global window covering all 16 POI regions. | |
| When desync > 0, applies per-trace random shifts to the global window. | |
| Returns: | |
| Tuple of (traces, metadata). | |
| traces: shape (50000, GLOBAL_WINDOW_SIZE), dtype int8. | |
| metadata: structured numpy array. | |
| """ | |
| start, end = DEFAULT_PROFILING_RANGE | |
| num_traces = end - start | |
| if self.desync == 0: | |
| return self._load_raw_window( | |
| DEFAULT_PROFILING_RANGE, | |
| GLOBAL_WINDOW_START, | |
| GLOBAL_WINDOW_END, | |
| ) | |
| shifts = self._generate_desync_shifts(num_traces, _PROFILING_DESYNC_SEED) | |
| return self._load_raw_window_desync( | |
| DEFAULT_PROFILING_RANGE, | |
| GLOBAL_WINDOW_START, | |
| GLOBAL_WINDOW_END, | |
| shifts, | |
| ) | |
| def load_attack_global(self) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Load attack traces using the global window. | |
| When desync > 0, applies per-trace random shifts to the global window. | |
| Returns: | |
| Tuple of (traces, metadata). | |
| traces: shape (10000, GLOBAL_WINDOW_SIZE), dtype int8. | |
| metadata: structured numpy array. | |
| """ | |
| start, end = DEFAULT_ATTACK_RANGE | |
| num_traces = end - start | |
| if self.desync == 0: | |
| return self._load_raw_window( | |
| DEFAULT_ATTACK_RANGE, | |
| GLOBAL_WINDOW_START, | |
| GLOBAL_WINDOW_END, | |
| ) | |
| shifts = self._generate_desync_shifts(num_traces, _ATTACK_DESYNC_SEED) | |
| return self._load_raw_window_desync( | |
| DEFAULT_ATTACK_RANGE, | |
| GLOBAL_WINDOW_START, | |
| GLOBAL_WINDOW_END, | |
| shifts, | |
| ) | |
| def compute_all_labels( | |
| self, metadata: np.ndarray | |
| ) -> np.ndarray: | |
| """ | |
| Compute labels for all 16 bytes simultaneously. | |
| Args: | |
| metadata: Structured numpy array with 'plaintext' and 'key' fields. | |
| Returns: | |
| Array of shape (N, 16), dtype uint8, where column i contains | |
| the Sbox labels for byte i. | |
| """ | |
| n = len(metadata) | |
| all_labels = np.empty((n, 16), dtype=np.uint8) | |
| for byte_idx in range(16): | |
| all_labels[:, byte_idx] = self.compute_labels(metadata, byte_idx) | |
| return all_labels | |
| # ------------------------------------------------------------------ | |
| # Public API: multi-input loading (for LMIC model) | |
| # ------------------------------------------------------------------ | |
| def load_profiling_multi_input(self) -> Tuple[dict, np.ndarray]: | |
| """ | |
| Load profiling traces as 16 separate per-byte POI windows. | |
| Each byte gets its own 700-sample window extracted from the raw | |
| traces. When desync > 0, the same per-trace shift is applied to | |
| all 16 byte windows (simulating uniform clock jitter). | |
| Returns: | |
| Tuple of (traces_dict, metadata). | |
| traces_dict: {"byte_0_input": (50000, 700), ..., "byte_15_input": (50000, 700)} | |
| metadata: structured numpy array. | |
| """ | |
| f = self._open_raw() | |
| start, end = DEFAULT_PROFILING_RANGE | |
| metadata = np.array(f["metadata"][start:end]) | |
| num_traces = end - start | |
| # Generate desync shifts (same shift for all bytes within one trace) | |
| shifts = self._generate_desync_shifts(num_traces, _PROFILING_DESYNC_SEED) | |
| traces_dict = {} | |
| for byte_idx in range(16): | |
| w_start, w_end = BYTE_POI_WINDOWS[byte_idx] | |
| window_size = w_end - w_start | |
| if self.desync == 0: | |
| traces = np.array( | |
| f["traces"][start:end, w_start:w_end], dtype=np.int8 | |
| ) | |
| else: | |
| max_shift = int(shifts.max()) | |
| traces_wide = np.array( | |
| f["traces"][start:end, w_start:w_end + max_shift], | |
| dtype=np.int8, | |
| ) | |
| # Vectorized per-trace shift extraction | |
| row_idx = np.arange(num_traces)[:, None] | |
| col_idx = shifts[:, None] + np.arange(window_size)[None, :] | |
| traces = traces_wide[row_idx, col_idx] | |
| traces_dict[f"byte_{byte_idx}_input"] = traces | |
| logger.info( | |
| "Loaded byte %d profiling traces: window [%d:%d], desync=%d, shape=%s", | |
| byte_idx, w_start, w_end, self.desync, traces.shape, | |
| ) | |
| return traces_dict, metadata | |
| def load_attack_multi_input(self) -> Tuple[dict, np.ndarray]: | |
| """ | |
| Load attack traces as 16 separate per-byte POI windows. | |
| When desync > 0, the same per-trace shift is applied to all 16 | |
| byte windows (simulating uniform clock jitter). | |
| Returns: | |
| Tuple of (traces_dict, metadata). | |
| traces_dict: {"byte_0_input": (10000, 700), ..., "byte_15_input": (10000, 700)} | |
| metadata: structured numpy array. | |
| """ | |
| f = self._open_raw() | |
| start, end = DEFAULT_ATTACK_RANGE | |
| metadata = np.array(f["metadata"][start:end]) | |
| num_traces = end - start | |
| # Generate desync shifts (same shift for all bytes within one trace) | |
| shifts = self._generate_desync_shifts(num_traces, _ATTACK_DESYNC_SEED) | |
| traces_dict = {} | |
| for byte_idx in range(16): | |
| w_start, w_end = BYTE_POI_WINDOWS[byte_idx] | |
| window_size = w_end - w_start | |
| if self.desync == 0: | |
| traces = np.array( | |
| f["traces"][start:end, w_start:w_end], dtype=np.int8 | |
| ) | |
| else: | |
| max_shift = int(shifts.max()) | |
| traces_wide = np.array( | |
| f["traces"][start:end, w_start:w_end + max_shift], | |
| dtype=np.int8, | |
| ) | |
| # Vectorized per-trace shift extraction | |
| row_idx = np.arange(num_traces)[:, None] | |
| col_idx = shifts[:, None] + np.arange(window_size)[None, :] | |
| traces = traces_wide[row_idx, col_idx] | |
| traces_dict[f"byte_{byte_idx}_input"] = traces | |
| logger.info( | |
| "Loaded byte %d attack traces: window [%d:%d], desync=%d, shape=%s", | |
| byte_idx, w_start, w_end, self.desync, traces.shape, | |
| ) | |
| return traces_dict, metadata | |
| def get_real_key(self, target_byte: int) -> int: | |
| """ | |
| Retrieve the real key byte value for a given target byte. | |
| The ASCAD dataset uses a fixed key across all traces. | |
| """ | |
| f = self._open_raw() | |
| metadata = f["metadata"][0] | |
| return int(metadata["key"][target_byte]) | |
| def get_all_real_keys(self) -> np.ndarray: | |
| """Retrieve all 16 real key byte values.""" | |
| f = self._open_raw() | |
| metadata = f["metadata"][0] | |
| return np.array(metadata["key"], dtype=np.uint8) | |