"""Arch-Building-Image-Classification — Model Construction & Inference Module. This module provides the architecture definition, custom layer implementations, and inference utilities for the EfficientNetV2-S-based fine-grained visual classification (FGIC) model trained on the World Architectural Buildings dataset. Custom layers (GeMPooling, FocalLoss, DiscriminativeAdamW) are registered via ``@register_keras_serializable`` so that ``tf.keras.models.load_model`` can deserialize them without an explicit ``custom_objects`` dict — simply importing this module is sufficient. Usage — Clean load (no ProtectAI flag, recommended): >>> from build_model import ArchBuildingClassifier >>> clf = ArchBuildingClassifier.build() >>> clf.load_weights('fine_tuning_swa.weights.h5') >>> preds = clf.predict(image_array) Usage — Load from .keras (flagged by ProtectAI but functionally correct): >>> import build_model # registers custom classes >>> import tensorflow as tf >>> model = tf.keras.models.load_model('fine_tuning_swa.keras') Usage — Inference with preprocessing: >>> from build_model import ArchBuildingClassifier >>> clf = ArchBuildingClassifier.from_weights('fine_tuning_swa.weights.h5') >>> label, confidence, top3 = clf.predict(image_pil_or_array) References: - GeM Pooling: Radenovic et al., CVPR 2018 - Focal Loss: Lin et al., ICCV 2017 - DiscriminativeAdamW: Howard & Ruder, ACL 2018 (selective fine-tuning) - Random Erasing: Zhong et al., AAAI 2020 - SWA: Izmailov et al., UAI 2018 License: - Code: MIT - Model weights: Apache-2.0 - Dataset: CC-BY-4.0 """ from __future__ import annotations import os from typing import Dict, List, Optional, Tuple, Union import numpy as np import tensorflow as tf from tensorflow.keras.applications import EfficientNetV2S try: from tensorflow.keras.applications.efficientnet_v2 import preprocess_input except (ImportError, ModuleNotFoundError): from tensorflow.keras.applications.efficientnet import preprocess_input from tensorflow.keras.layers import ( BatchNormalization, Conv2D, Dense, Dropout, Layer, MaxPooling2D, ) from tensorflow.keras.layers import Input # --------------------------------------------------------------------------- # Compatibility shim — tf.keras.saving is not exposed in all TF/Keras setups. # --------------------------------------------------------------------------- try: from tensorflow.keras.saving import register_keras_serializable except (ImportError, AttributeError): try: from keras.saving import register_keras_serializable except (ImportError, AttributeError): def register_keras_serializable(package: Optional[str] = None): """No-op fallback when Keras saving API is unavailable.""" def decorator(cls): return cls return decorator __all__ = [ "ArchBuildingClassifier", "GeMPooling", "FocalLoss", "DiscriminativeAdamW", "CUSTOM_OBJECTS", "LABELS", "build_model", ] # --------------------------------------------------------------------------- # Module-level constants # --------------------------------------------------------------------------- LABELS: List[str] = [ "barn", "bridge", "castle", "mosque", "skyscraper", "stadium", "temple", "windmill", ] INPUT_SHAPE: Tuple[int, int, int] = (320, 320, 3) NUM_CLASSES: int = len(LABELS) PACKAGE: str = "ArchClassifier" # =========================================================================== # Custom Layers # =========================================================================== @register_keras_serializable(package=PACKAGE) class GeMPooling(Layer): """Generalized Mean Pooling layer for fine-grained visual recognition. Replaces standard Global Average Pooling with a learnable generalized mean that better preserves discriminative spatial features. The pooling parameter ``p`` is trainable: ``p -> 1`` reduces to average pooling, ``p -> inf`` approaches max pooling. Args: p: Initial value for the pooling power parameter (default: 3.0). eps: Small constant for numerical stability when clamping inputs (default: 1e-6). **kwargs: Standard Keras layer keyword arguments (name, trainable, etc.). Reference: Radenovic, F., Tolias, G., & Chum, O. (2018). Fine-tuning CNN Image Retrieval with No Human Annotation. IEEE TPAMI. """ def __init__(self, p: float = 3.0, eps: float = 1e-6, **kwargs): super().__init__(**kwargs) self.p_init = p self.eps = eps def build(self, input_shape): self.p = self.add_weight( name="gem_p", shape=(), initializer=tf.keras.initializers.Constant(self.p_init), trainable=True, dtype=tf.float32, ) super().build(input_shape) def call(self, x: tf.Tensor) -> tf.Tensor: x = tf.maximum(x, self.eps) x = tf.pow(x, self.p) x = tf.reduce_mean(x, axis=[1, 2], keepdims=False) x = tf.pow(x, 1.0 / self.p) return x def get_config(self) -> dict: config = super().get_config() config.update({"p": self.p_init, "eps": self.eps}) return config @register_keras_serializable(package=PACKAGE) class FocalLoss(tf.keras.losses.Loss): """Focal Loss for class imbalance and hard-example mining. Down-weights well-classified examples via ``(1 - p)^gamma``, focusing gradient updates on difficult samples. Combined with optional label smoothing to prevent overconfidence. Args: gamma: Focusing parameter; higher values increase down-weighting of easy examples (default: 2.0, per Lin et al.). alpha: Optional per-class weighting factor. If None, no class weighting is applied. label_smoothing: Smoothing factor in [0, 1) to soft-target labels (default: 0.0). **kwargs: Standard Keras loss keyword arguments. Reference: Lin, T.-Y., Goyal, P., Girshick, R., He, K., & Dollar, P. (2017). Focal Loss for Dense Object Detection. ICCV 2017. """ def __init__( self, gamma: float = 2.0, alpha: Optional[float] = None, label_smoothing: float = 0.0, **kwargs, ): super().__init__(**kwargs) self.gamma = gamma self.alpha = alpha self.label_smoothing = label_smoothing def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor: y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7) if self.label_smoothing > 0: num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32) y_true = y_true * (1.0 - self.label_smoothing) + ( self.label_smoothing / num_classes ) ce = -y_true * tf.math.log(y_pred) weight = tf.pow(1.0 - y_pred, self.gamma) fl = weight * ce if self.alpha is not None: alpha_t = y_true * self.alpha fl = alpha_t * fl return tf.reduce_mean(tf.reduce_sum(fl, axis=-1)) def get_config(self) -> dict: config = super().get_config() config.update( { "gamma": self.gamma, "alpha": self.alpha, "label_smoothing": self.label_smoothing, } ) return config @register_keras_serializable(package=PACKAGE) class DiscriminativeAdamW(tf.keras.optimizers.AdamW): """AdamW with per-variable learning rate scaling for selective fine-tuning. Overrides ``update_step`` to scale the learning rate per-variable based on layer name patterns within the backbone network. Unlike gradient scaling (which is scale-invariant in Adam), LR scaling produces truly discriminative updates — block6 variables receive 10x smaller updates than head variables. Args: lr_multipliers: Mapping from layer-name substrings to LR scale factors. e.g. ``{'block6': 0.1}`` applies 0.1x learning rate to all block6 variables. backbone_layer_idx: Index of the backbone model within the Functional model container (default: 0). **kwargs: Standard AdamW keyword arguments (learning_rate, weight_decay, etc.). Note: LR scaling is applied inside ``update_step`` by multiplying ``learning_rate * mult`` before calling the parent AdamW update. A variable cache is built via ``_build_var_cache(model)`` to map ``id(variable) -> multiplier``. Reference: Howard, J., & Ruder, S. (2018). Universal Language Model Fine-tuning for Text Classification. ACL 2018. """ def __init__( self, lr_multipliers: Optional[Dict[str, float]] = None, backbone_layer_idx: int = 0, **kwargs, ): super().__init__(**kwargs) self.lr_multipliers = lr_multipliers or {} self.backbone_layer_idx = backbone_layer_idx self._var_mult_cache: Dict[int, float] = {} def _build_var_cache(self, model: tf.keras.Model) -> None: """Build the variable-to-multiplier cache from the model's backbone.""" self._var_mult_cache = {} base_model = next((l for l in model.layers if isinstance(l, tf.keras.Model)), None) if base_model is None: base_model = model.layers[self.backbone_layer_idx] for layer in base_model.layers: mult = 1.0 for pattern, m in self.lr_multipliers.items(): if pattern in layer.name: mult = m break for var in layer.trainable_variables: self._var_mult_cache[id(var)] = mult def _get_multiplier(self, var: tf.Variable) -> float: return self._var_mult_cache.get(id(var), 1.0) def update_step(self, gradient, variable, learning_rate): """Scale learning_rate per-variable — truly discriminative.""" mult = self._get_multiplier(variable) effective_lr = learning_rate * mult return super().update_step(gradient, variable, effective_lr) def get_config(self) -> dict: config = super().get_config() config.update( { "lr_multipliers": self.lr_multipliers, "backbone_layer_idx": self.backbone_layer_idx, } ) return config # --------------------------------------------------------------------------- # Custom objects registry (for explicit load_model custom_objects dict) # --------------------------------------------------------------------------- CUSTOM_OBJECTS: Dict[str, type] = { "GeMPooling": GeMPooling, "FocalLoss": FocalLoss, "DiscriminativeAdamW": DiscriminativeAdamW, } # =========================================================================== # Model Wrapper Class # =========================================================================== class ArchBuildingClassifier: """High-level wrapper for the Arch-Building-Image-Classification model. Encapsulates architecture construction, weight loading from multiple formats, and single/batch inference with EfficientNetV2-S preprocessing. The underlying architecture is a Functional model: EfficientNetV2-S (frozen, training=False) -> Conv2D(256) -> BN -> MaxPool -> GeMPooling(p=3.0) -> Dense(256) -> BN -> Dropout(0.4) -> Dense(8, softmax, dtype=float32) Attributes: labels: List of class label strings (alphabetical order). input_shape: Expected input tensor shape (H, W, C). num_classes: Number of output classes. Example: >>> clf = ArchBuildingClassifier.from_weights('model.weights.h5') >>> label, conf, top3 = clf.predict(image) >>> print(f"Predicted: {label} ({conf:.1%})") """ labels: List[str] = LABELS input_shape: Tuple[int, int, int] = INPUT_SHAPE num_classes: int = NUM_CLASSES def __init__(self, model: Optional[tf.keras.Model] = None): self._model = model # ------------------------------------------------------------------ # Construction # ------------------------------------------------------------------ @classmethod def build( cls, input_shape: Optional[Tuple[int, int, int]] = None, num_classes: Optional[int] = None, ) -> "ArchBuildingClassifier": """Construct the model architecture from scratch. Creates a Functional model with EfficientNetV2-S backbone (ImageNet weights, frozen) and a custom classification head featuring GeM pooling. The output Dense layer uses dtype=float32 for mixed precision stability. Args: input_shape: Input tensor shape (default: (320, 320, 3)). num_classes: Number of output classes (default: 8). Returns: An ArchBuildingClassifier instance with an untrained model. """ input_shape = input_shape or cls.input_shape num_classes = num_classes or cls.num_classes base_model = EfficientNetV2S( weights="imagenet", include_top=False, include_preprocessing=True, input_shape=input_shape, ) base_model.trainable = False inputs = Input(shape=input_shape) x = base_model(inputs, training=False) x = Conv2D(256, (3, 3), activation="relu", padding="same")(x) x = BatchNormalization()(x) x = MaxPooling2D(pool_size=(2, 2))(x) x = GeMPooling(p=3.0, name="gem_pooling")(x) x = Dense(256, activation="relu")(x) x = BatchNormalization()(x) x = Dropout(0.4)(x) outputs = Dense(num_classes, activation="softmax", dtype="float32")(x) model = tf.keras.Model(inputs, outputs) return cls(model) @classmethod def from_keras(cls, path: str) -> "ArchBuildingClassifier": """Load from a .keras checkpoint file. Requires that custom classes are registered (importing this module is sufficient) or passed via ``CUSTOM_OBJECTS``. Args: path: Path to the .keras file. Returns: An ArchBuildingClassifier with loaded weights and architecture. """ model = tf.keras.models.load_model( path, custom_objects=CUSTOM_OBJECTS, compile=False ) return cls(model) @classmethod def from_weights(cls, weights_path: str) -> "ArchBuildingClassifier": """Reconstruct architecture and load weights from .weights.h5. This is the recommended loading path for production inference — the .weights.h5 format does not carry custom class references and is not flagged by ProtectAI Guardian (PAIT-KERAS-301). Args: weights_path: Path to the .weights.h5 file. Returns: An ArchBuildingClassifier with loaded weights. """ clf = cls.build() clf._model.load_weights(weights_path) return clf # ------------------------------------------------------------------ # Loading # ------------------------------------------------------------------ def load_weights(self, weights_path: str) -> None: """Load weights into the existing model. Args: weights_path: Path to the .weights.h5 file. """ if self._model is None: raise RuntimeError("Model not initialized. Call build() first.") self._model.load_weights(weights_path) # ------------------------------------------------------------------ # Inference # ------------------------------------------------------------------ def _preprocess(self, image: Union[np.ndarray, "Image.Image"]) -> np.ndarray: """Resize and apply EfficientNetV2-S preprocessing to a single image. Args: image: PIL Image or numpy array (H, W, C) in uint8 range. Returns: Preprocessed batch of shape (1, 320, 320, 3) as float32. """ if hasattr(image, "resize"): # PIL Image image = image.convert("RGB").resize( (self.input_shape[1], self.input_shape[0]) ) image = np.array(image, dtype=np.float32) elif image.shape[:2] != self.input_shape[:2]: image = tf.image.resize(image, self.input_shape[:2]).numpy() if image.ndim == 3: image = np.expand_dims(image, axis=0) image = preprocess_input(image) return image def predict( self, image: Union[np.ndarray, "Image.Image"], top_k: int = 3, ) -> Tuple[str, float, List[Tuple[str, float]]]: """Run inference on a single image. Args: image: PIL Image or numpy array (H, W, C) in uint8 range. top_k: Number of top predictions to return. Returns: Tuple of (predicted_label, confidence, top_k_list) where top_k_list is a list of (label, probability) pairs. """ if self._model is None: raise RuntimeError("Model not initialized. Call build() first.") x = self._preprocess(image) probs = self._model.predict(x, verbose=0)[0] idx = int(np.argmax(probs)) label = self.labels[idx] confidence = float(probs[idx]) top_indices = np.argsort(probs)[::-1][:top_k] top_k_list = [(self.labels[i], float(probs[i])) for i in top_indices] return label, confidence, top_k_list def predict_batch( self, images: List[Union[np.ndarray, "Image.Image"]], ) -> List[Tuple[str, float]]: """Run batch inference on multiple images. Args: images: List of PIL Images or numpy arrays. Returns: List of (label, confidence) tuples. """ if self._model is None: raise RuntimeError("Model not initialized. Call build() first.") batch = np.vstack([self._preprocess(img) for img in images]) probs = self._model.predict(batch, verbose=0) results = [] for row in probs: idx = int(np.argmax(row)) results.append((self.labels[idx], float(row[idx]))) return results # ------------------------------------------------------------------ # Utilities # ------------------------------------------------------------------ @property def keras_model(self) -> tf.keras.Model: """Return the underlying tf.keras.Model instance.""" if self._model is None: raise RuntimeError("Model not initialized. Call build() first.") return self._model @property def parameters(self) -> int: """Total number of model parameters.""" return self.keras_model.count_params() def summary(self) -> None: """Print the model architecture summary.""" self.keras_model.summary() # =========================================================================== # Backward-compatible convenience function # =========================================================================== def build_model( input_shape: Tuple[int, int, int] = INPUT_SHAPE, num_classes: int = NUM_CLASSES, ) -> tf.keras.Model: """Construct the architecture and return a raw tf.keras.Model. This is a backward-compatible thin wrapper around ``ArchBuildingClassifier.build()``. New code should prefer using the class directly for access to ``predict()``, ``from_weights()``, and other utilities. Args: input_shape: Input tensor shape (default: (320, 320, 3)). num_classes: Number of output classes (default: 8). Returns: A compiled but untrained tf.keras.Model instance. """ return ArchBuildingClassifier.build( input_shape=input_shape, num_classes=num_classes ).keras_model # =========================================================================== # CLI entry point # =========================================================================== if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description="Arch-Building-Image-Classification model loader" ) parser.add_argument( "--weights", type=str, default="fine_tuning_swa.weights.h5", help="Path to .weights.h5 file (default: fine_tuning_swa.weights.h5)", ) parser.add_argument( "--keras", type=str, default=None, help="Path to .keras file (alternative to --weights)", ) args = parser.parse_args() if args.keras: clf = ArchBuildingClassifier.from_keras(args.keras) print(f"Loaded from .keras: {args.keras}") else: clf = ArchBuildingClassifier.from_weights(args.weights) print(f"Loaded from weights: {args.weights}") print(f" Parameters: {clf.parameters:,}") print(f" Input shape: {clf.input_shape}") print(f" Classes: {clf.num_classes} ({', '.join(clf.labels)})") print(" Status: Ready for inference.")