0xgr3y's picture
Upload build_model.py with huggingface_hub
b98c929 verified
Raw
History Blame Contribute Delete
21.2 kB
"""Arch-Building-Image-Classification — Model Construction & Inference Module.
This module provides the architecture definition, custom layer implementations,
and inference utilities for the EfficientNetV2-S-based fine-grained visual
classification (FGIC) model trained on the World Architectural Buildings dataset.
Custom layers (GeMPooling, FocalLoss, DiscriminativeAdamW) are registered via
``@register_keras_serializable`` so that ``tf.keras.models.load_model`` can
deserialize them without an explicit ``custom_objects`` dict — simply importing
this module is sufficient.
Usage — Clean load (no ProtectAI flag, recommended):
>>> from build_model import ArchBuildingClassifier
>>> clf = ArchBuildingClassifier.build()
>>> clf.load_weights('fine_tuning_swa.weights.h5')
>>> preds = clf.predict(image_array)
Usage — Load from .keras (flagged by ProtectAI but functionally correct):
>>> import build_model # registers custom classes
>>> import tensorflow as tf
>>> model = tf.keras.models.load_model('fine_tuning_swa.keras')
Usage — Inference with preprocessing:
>>> from build_model import ArchBuildingClassifier
>>> clf = ArchBuildingClassifier.from_weights('fine_tuning_swa.weights.h5')
>>> label, confidence, top3 = clf.predict(image_pil_or_array)
References:
- GeM Pooling: Radenovic et al., CVPR 2018
- Focal Loss: Lin et al., ICCV 2017
- DiscriminativeAdamW: Howard & Ruder, ACL 2018 (selective fine-tuning)
- Random Erasing: Zhong et al., AAAI 2020
- SWA: Izmailov et al., UAI 2018
License:
- Code: MIT
- Model weights: Apache-2.0
- Dataset: CC-BY-4.0
"""
from __future__ import annotations
import os
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetV2S
try:
from tensorflow.keras.applications.efficientnet_v2 import preprocess_input
except (ImportError, ModuleNotFoundError):
from tensorflow.keras.applications.efficientnet import preprocess_input
from tensorflow.keras.layers import (
BatchNormalization,
Conv2D,
Dense,
Dropout,
Layer,
MaxPooling2D,
)
from tensorflow.keras.layers import Input
# ---------------------------------------------------------------------------
# Compatibility shim — tf.keras.saving is not exposed in all TF/Keras setups.
# ---------------------------------------------------------------------------
try:
from tensorflow.keras.saving import register_keras_serializable
except (ImportError, AttributeError):
try:
from keras.saving import register_keras_serializable
except (ImportError, AttributeError):
def register_keras_serializable(package: Optional[str] = None):
"""No-op fallback when Keras saving API is unavailable."""
def decorator(cls):
return cls
return decorator
__all__ = [
"ArchBuildingClassifier",
"GeMPooling",
"FocalLoss",
"DiscriminativeAdamW",
"CUSTOM_OBJECTS",
"LABELS",
"build_model",
]
# ---------------------------------------------------------------------------
# Module-level constants
# ---------------------------------------------------------------------------
LABELS: List[str] = [
"barn",
"bridge",
"castle",
"mosque",
"skyscraper",
"stadium",
"temple",
"windmill",
]
INPUT_SHAPE: Tuple[int, int, int] = (320, 320, 3)
NUM_CLASSES: int = len(LABELS)
PACKAGE: str = "ArchClassifier"
# ===========================================================================
# Custom Layers
# ===========================================================================
@register_keras_serializable(package=PACKAGE)
class GeMPooling(Layer):
"""Generalized Mean Pooling layer for fine-grained visual recognition.
Replaces standard Global Average Pooling with a learnable generalized
mean that better preserves discriminative spatial features. The pooling
parameter ``p`` is trainable: ``p -> 1`` reduces to average pooling,
``p -> inf`` approaches max pooling.
Args:
p: Initial value for the pooling power parameter (default: 3.0).
eps: Small constant for numerical stability when clamping inputs
(default: 1e-6).
**kwargs: Standard Keras layer keyword arguments (name, trainable, etc.).
Reference:
Radenovic, F., Tolias, G., & Chum, O. (2018). Fine-tuning CNN
Image Retrieval with No Human Annotation. IEEE TPAMI.
"""
def __init__(self, p: float = 3.0, eps: float = 1e-6, **kwargs):
super().__init__(**kwargs)
self.p_init = p
self.eps = eps
def build(self, input_shape):
self.p = self.add_weight(
name="gem_p",
shape=(),
initializer=tf.keras.initializers.Constant(self.p_init),
trainable=True,
dtype=tf.float32,
)
super().build(input_shape)
def call(self, x: tf.Tensor) -> tf.Tensor:
x = tf.maximum(x, self.eps)
x = tf.pow(x, self.p)
x = tf.reduce_mean(x, axis=[1, 2], keepdims=False)
x = tf.pow(x, 1.0 / self.p)
return x
def get_config(self) -> dict:
config = super().get_config()
config.update({"p": self.p_init, "eps": self.eps})
return config
@register_keras_serializable(package=PACKAGE)
class FocalLoss(tf.keras.losses.Loss):
"""Focal Loss for class imbalance and hard-example mining.
Down-weights well-classified examples via ``(1 - p)^gamma``, focusing
gradient updates on difficult samples. Combined with optional label
smoothing to prevent overconfidence.
Args:
gamma: Focusing parameter; higher values increase down-weighting
of easy examples (default: 2.0, per Lin et al.).
alpha: Optional per-class weighting factor. If None, no class
weighting is applied.
label_smoothing: Smoothing factor in [0, 1) to soft-target labels
(default: 0.0).
**kwargs: Standard Keras loss keyword arguments.
Reference:
Lin, T.-Y., Goyal, P., Girshick, R., He, K., & Dollar, P. (2017).
Focal Loss for Dense Object Detection. ICCV 2017.
"""
def __init__(
self,
gamma: float = 2.0,
alpha: Optional[float] = None,
label_smoothing: float = 0.0,
**kwargs,
):
super().__init__(**kwargs)
self.gamma = gamma
self.alpha = alpha
self.label_smoothing = label_smoothing
def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
if self.label_smoothing > 0:
num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32)
y_true = y_true * (1.0 - self.label_smoothing) + (
self.label_smoothing / num_classes
)
ce = -y_true * tf.math.log(y_pred)
weight = tf.pow(1.0 - y_pred, self.gamma)
fl = weight * ce
if self.alpha is not None:
alpha_t = y_true * self.alpha
fl = alpha_t * fl
return tf.reduce_mean(tf.reduce_sum(fl, axis=-1))
def get_config(self) -> dict:
config = super().get_config()
config.update(
{
"gamma": self.gamma,
"alpha": self.alpha,
"label_smoothing": self.label_smoothing,
}
)
return config
@register_keras_serializable(package=PACKAGE)
class DiscriminativeAdamW(tf.keras.optimizers.AdamW):
"""AdamW with per-variable learning rate scaling for selective fine-tuning.
Overrides ``update_step`` to scale the learning rate per-variable based
on layer name patterns within the backbone network. Unlike gradient
scaling (which is scale-invariant in Adam), LR scaling produces truly
discriminative updates — block6 variables receive 10x smaller updates
than head variables.
Args:
lr_multipliers: Mapping from layer-name substrings to LR scale
factors. e.g. ``{'block6': 0.1}`` applies 0.1x learning rate
to all block6 variables.
backbone_layer_idx: Index of the backbone model within the
Functional model container (default: 0).
**kwargs: Standard AdamW keyword arguments (learning_rate,
weight_decay, etc.).
Note:
LR scaling is applied inside ``update_step`` by multiplying
``learning_rate * mult`` before calling the parent AdamW update.
A variable cache is built via ``_build_var_cache(model)`` to map
``id(variable) -> multiplier``.
Reference:
Howard, J., & Ruder, S. (2018). Universal Language Model
Fine-tuning for Text Classification. ACL 2018.
"""
def __init__(
self,
lr_multipliers: Optional[Dict[str, float]] = None,
backbone_layer_idx: int = 0,
**kwargs,
):
super().__init__(**kwargs)
self.lr_multipliers = lr_multipliers or {}
self.backbone_layer_idx = backbone_layer_idx
self._var_mult_cache: Dict[int, float] = {}
def _build_var_cache(self, model: tf.keras.Model) -> None:
"""Build the variable-to-multiplier cache from the model's backbone."""
self._var_mult_cache = {}
base_model = next((l for l in model.layers if isinstance(l, tf.keras.Model)), None)
if base_model is None:
base_model = model.layers[self.backbone_layer_idx]
for layer in base_model.layers:
mult = 1.0
for pattern, m in self.lr_multipliers.items():
if pattern in layer.name:
mult = m
break
for var in layer.trainable_variables:
self._var_mult_cache[id(var)] = mult
def _get_multiplier(self, var: tf.Variable) -> float:
return self._var_mult_cache.get(id(var), 1.0)
def update_step(self, gradient, variable, learning_rate):
"""Scale learning_rate per-variable — truly discriminative."""
mult = self._get_multiplier(variable)
effective_lr = learning_rate * mult
return super().update_step(gradient, variable, effective_lr)
def get_config(self) -> dict:
config = super().get_config()
config.update(
{
"lr_multipliers": self.lr_multipliers,
"backbone_layer_idx": self.backbone_layer_idx,
}
)
return config
# ---------------------------------------------------------------------------
# Custom objects registry (for explicit load_model custom_objects dict)
# ---------------------------------------------------------------------------
CUSTOM_OBJECTS: Dict[str, type] = {
"GeMPooling": GeMPooling,
"FocalLoss": FocalLoss,
"DiscriminativeAdamW": DiscriminativeAdamW,
}
# ===========================================================================
# Model Wrapper Class
# ===========================================================================
class ArchBuildingClassifier:
"""High-level wrapper for the Arch-Building-Image-Classification model.
Encapsulates architecture construction, weight loading from multiple
formats, and single/batch inference with EfficientNetV2-S preprocessing.
The underlying architecture is a Functional model:
EfficientNetV2-S (frozen, training=False) -> Conv2D(256) -> BN -> MaxPool ->
GeMPooling(p=3.0) -> Dense(256) -> BN -> Dropout(0.4) ->
Dense(8, softmax, dtype=float32)
Attributes:
labels: List of class label strings (alphabetical order).
input_shape: Expected input tensor shape (H, W, C).
num_classes: Number of output classes.
Example:
>>> clf = ArchBuildingClassifier.from_weights('model.weights.h5')
>>> label, conf, top3 = clf.predict(image)
>>> print(f"Predicted: {label} ({conf:.1%})")
"""
labels: List[str] = LABELS
input_shape: Tuple[int, int, int] = INPUT_SHAPE
num_classes: int = NUM_CLASSES
def __init__(self, model: Optional[tf.keras.Model] = None):
self._model = model
# ------------------------------------------------------------------
# Construction
# ------------------------------------------------------------------
@classmethod
def build(
cls,
input_shape: Optional[Tuple[int, int, int]] = None,
num_classes: Optional[int] = None,
) -> "ArchBuildingClassifier":
"""Construct the model architecture from scratch.
Creates a Functional model with EfficientNetV2-S backbone (ImageNet
weights, frozen) and a custom classification head featuring GeM
pooling. The output Dense layer uses dtype=float32 for mixed
precision stability.
Args:
input_shape: Input tensor shape (default: (320, 320, 3)).
num_classes: Number of output classes (default: 8).
Returns:
An ArchBuildingClassifier instance with an untrained model.
"""
input_shape = input_shape or cls.input_shape
num_classes = num_classes or cls.num_classes
base_model = EfficientNetV2S(
weights="imagenet",
include_top=False,
include_preprocessing=True,
input_shape=input_shape,
)
base_model.trainable = False
inputs = Input(shape=input_shape)
x = base_model(inputs, training=False)
x = Conv2D(256, (3, 3), activation="relu", padding="same")(x)
x = BatchNormalization()(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = GeMPooling(p=3.0, name="gem_pooling")(x)
x = Dense(256, activation="relu")(x)
x = BatchNormalization()(x)
x = Dropout(0.4)(x)
outputs = Dense(num_classes, activation="softmax", dtype="float32")(x)
model = tf.keras.Model(inputs, outputs)
return cls(model)
@classmethod
def from_keras(cls, path: str) -> "ArchBuildingClassifier":
"""Load from a .keras checkpoint file.
Requires that custom classes are registered (importing this module
is sufficient) or passed via ``CUSTOM_OBJECTS``.
Args:
path: Path to the .keras file.
Returns:
An ArchBuildingClassifier with loaded weights and architecture.
"""
model = tf.keras.models.load_model(
path, custom_objects=CUSTOM_OBJECTS, compile=False
)
return cls(model)
@classmethod
def from_weights(cls, weights_path: str) -> "ArchBuildingClassifier":
"""Reconstruct architecture and load weights from .weights.h5.
This is the recommended loading path for production inference —
the .weights.h5 format does not carry custom class references and
is not flagged by ProtectAI Guardian (PAIT-KERAS-301).
Args:
weights_path: Path to the .weights.h5 file.
Returns:
An ArchBuildingClassifier with loaded weights.
"""
clf = cls.build()
clf._model.load_weights(weights_path)
return clf
# ------------------------------------------------------------------
# Loading
# ------------------------------------------------------------------
def load_weights(self, weights_path: str) -> None:
"""Load weights into the existing model.
Args:
weights_path: Path to the .weights.h5 file.
"""
if self._model is None:
raise RuntimeError("Model not initialized. Call build() first.")
self._model.load_weights(weights_path)
# ------------------------------------------------------------------
# Inference
# ------------------------------------------------------------------
def _preprocess(self, image: Union[np.ndarray, "Image.Image"]) -> np.ndarray:
"""Resize and apply EfficientNetV2-S preprocessing to a single image.
Args:
image: PIL Image or numpy array (H, W, C) in uint8 range.
Returns:
Preprocessed batch of shape (1, 320, 320, 3) as float32.
"""
if hasattr(image, "resize"): # PIL Image
image = image.convert("RGB").resize(
(self.input_shape[1], self.input_shape[0])
)
image = np.array(image, dtype=np.float32)
elif image.shape[:2] != self.input_shape[:2]:
image = tf.image.resize(image, self.input_shape[:2]).numpy()
if image.ndim == 3:
image = np.expand_dims(image, axis=0)
image = preprocess_input(image)
return image
def predict(
self,
image: Union[np.ndarray, "Image.Image"],
top_k: int = 3,
) -> Tuple[str, float, List[Tuple[str, float]]]:
"""Run inference on a single image.
Args:
image: PIL Image or numpy array (H, W, C) in uint8 range.
top_k: Number of top predictions to return.
Returns:
Tuple of (predicted_label, confidence, top_k_list) where
top_k_list is a list of (label, probability) pairs.
"""
if self._model is None:
raise RuntimeError("Model not initialized. Call build() first.")
x = self._preprocess(image)
probs = self._model.predict(x, verbose=0)[0]
idx = int(np.argmax(probs))
label = self.labels[idx]
confidence = float(probs[idx])
top_indices = np.argsort(probs)[::-1][:top_k]
top_k_list = [(self.labels[i], float(probs[i])) for i in top_indices]
return label, confidence, top_k_list
def predict_batch(
self,
images: List[Union[np.ndarray, "Image.Image"]],
) -> List[Tuple[str, float]]:
"""Run batch inference on multiple images.
Args:
images: List of PIL Images or numpy arrays.
Returns:
List of (label, confidence) tuples.
"""
if self._model is None:
raise RuntimeError("Model not initialized. Call build() first.")
batch = np.vstack([self._preprocess(img) for img in images])
probs = self._model.predict(batch, verbose=0)
results = []
for row in probs:
idx = int(np.argmax(row))
results.append((self.labels[idx], float(row[idx])))
return results
# ------------------------------------------------------------------
# Utilities
# ------------------------------------------------------------------
@property
def keras_model(self) -> tf.keras.Model:
"""Return the underlying tf.keras.Model instance."""
if self._model is None:
raise RuntimeError("Model not initialized. Call build() first.")
return self._model
@property
def parameters(self) -> int:
"""Total number of model parameters."""
return self.keras_model.count_params()
def summary(self) -> None:
"""Print the model architecture summary."""
self.keras_model.summary()
# ===========================================================================
# Backward-compatible convenience function
# ===========================================================================
def build_model(
input_shape: Tuple[int, int, int] = INPUT_SHAPE,
num_classes: int = NUM_CLASSES,
) -> tf.keras.Model:
"""Construct the architecture and return a raw tf.keras.Model.
This is a backward-compatible thin wrapper around
``ArchBuildingClassifier.build()``. New code should prefer using
the class directly for access to ``predict()``, ``from_weights()``,
and other utilities.
Args:
input_shape: Input tensor shape (default: (320, 320, 3)).
num_classes: Number of output classes (default: 8).
Returns:
A compiled but untrained tf.keras.Model instance.
"""
return ArchBuildingClassifier.build(
input_shape=input_shape, num_classes=num_classes
).keras_model
# ===========================================================================
# CLI entry point
# ===========================================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Arch-Building-Image-Classification model loader"
)
parser.add_argument(
"--weights",
type=str,
default="fine_tuning_swa.weights.h5",
help="Path to .weights.h5 file (default: fine_tuning_swa.weights.h5)",
)
parser.add_argument(
"--keras",
type=str,
default=None,
help="Path to .keras file (alternative to --weights)",
)
args = parser.parse_args()
if args.keras:
clf = ArchBuildingClassifier.from_keras(args.keras)
print(f"Loaded from .keras: {args.keras}")
else:
clf = ArchBuildingClassifier.from_weights(args.weights)
print(f"Loaded from weights: {args.weights}")
print(f" Parameters: {clf.parameters:,}")
print(f" Input shape: {clf.input_shape}")
print(f" Classes: {clf.num_classes} ({', '.join(clf.labels)})")
print(" Status: Ready for inference.")