File size: 2,398 Bytes
283a882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Abstract base class for all ASCAD attack models.

Provides a common interface for building, compiling, and describing models
so that the trainer and evaluation modules can work with any architecture
without knowing its internals.
"""

from abc import ABC, abstractmethod
from typing import Dict, Any

import tensorflow as tf


class BaseModel(ABC):
    """Abstract base class for ASCAD side-channel attack models."""

    def __init__(self, input_shape: tuple, num_classes: int = 256) -> None:
        """
        Args:
            input_shape: Shape of a single input sample (excluding batch dim).
            num_classes: Number of output classes (256 for AES S-Box).
        """
        self.input_shape = input_shape
        self.num_classes = num_classes
        self._model: tf.keras.Model = None

    @abstractmethod
    def build(self) -> tf.keras.Model:
        """
        Construct and return the Keras model.

        Subclasses must implement this to define the architecture.
        The returned model should NOT be compiled yet.
        """
        raise NotImplementedError

    def compile(self, learning_rate: float = 1e-5) -> tf.keras.Model:
        """
        Build (if needed) and compile the model with standard SCA settings.

        Args:
            learning_rate: Learning rate for the RMSprop optimizer.

        Returns:
            The compiled Keras model.
        """
        if self._model is None:
            self._model = self.build()

        self._model.compile(
            optimizer=tf.keras.optimizers.RMSprop(learning_rate=learning_rate),
            loss="categorical_crossentropy",
            metrics=["accuracy"],
        )
        return self._model

    @property
    def model(self) -> tf.keras.Model:
        """Access the underlying Keras model (builds if needed)."""
        if self._model is None:
            self._model = self.build()
        return self._model

    @abstractmethod
    def get_config(self) -> Dict[str, Any]:
        """
        Return a dictionary of architecture hyperparameters.

        Used for logging, reproducibility, and results.json metadata.
        """
        raise NotImplementedError

    def summary(self) -> None:
        """Print the model summary."""
        self.model.summary()

    @property
    def name(self) -> str:
        """Human-readable model name."""
        return self.__class__.__name__