""" Model registry for ASCAD attack architectures. Usage: from src.models import create_model model_wrapper = create_model("mlp") keras_model = model_wrapper.compile(learning_rate=1e-5) Supported architectures: - ``mlp``: Multi-Layer Perceptron (Prouff et al., 2019) - ``cnn``: CNNbest 1-D convolutional network (Prouff et al., 2019) - ``mtan``: Full SNR-MTAN with per-block attention (deprecated) - ``hps``: Hard Parameter Sharing multi-task (Marquet & Oswald, 2024) - ``mtan_lite``: Simplified MTAN with final-block attention (novel) - ``lmic``: Localized Multi-Input CNN (16 per-byte POI inputs) - ``lmic_tsbn``: LMIC with Task-Specific Batch Normalization (Suteu & Serban, 2025) """ from typing import Dict, Type from .base import BaseModel from .mlp import MLPBest from .cnn import CNNBest from .mtan import SNRMTAN from .mtl import HPSModel, MTANLiteModel from .lmic import LMICModel, LMICTSBNModel # Registry mapping model type strings to classes MODEL_REGISTRY: Dict[str, Type[BaseModel]] = { "mlp": MLPBest, "cnn": CNNBest, "mtan": SNRMTAN, "hps": HPSModel, "mtan_lite": MTANLiteModel, "lmic": LMICModel, "lmic_tsbn": LMICTSBNModel, } def create_model(model_type: str, **kwargs) -> BaseModel: """ Factory function to create a model by type name. Args: model_type: One of the registered model types. **kwargs: Additional keyword arguments passed to the model constructor. Returns: An instance of the requested model (not yet compiled). Raises: ValueError: If model_type is not registered. """ if model_type not in MODEL_REGISTRY: raise ValueError( f"Unknown model type '{model_type}'. " f"Available: {list(MODEL_REGISTRY.keys())}" ) return MODEL_REGISTRY[model_type](**kwargs) def register_model(name: str, model_class: Type[BaseModel]) -> None: """ Register a new model class in the registry. Args: name: String identifier for the model type. model_class: The model class (must subclass BaseModel). """ if not issubclass(model_class, BaseModel): raise TypeError(f"{model_class} must be a subclass of BaseModel") MODEL_REGISTRY[name] = model_class __all__ = [ "BaseModel", "MLPBest", "CNNBest", "SNRMTAN", "HPSModel", "MTANLiteModel", "LMICModel", "LMICTSBNModel", "create_model", "register_model", "MODEL_REGISTRY", ]