from __future__ import annotations import tensorflow as tf from tensorflow.keras import Sequential, layers from tensorflow.keras.applications import MobileNetV2 from tensorflow.keras.optimizers import Adam from .config import IMAGE_SIZE, SEED def compile_model(model: tf.keras.Model, learning_rate: float) -> None: model.compile( optimizer=Adam(learning_rate=learning_rate), loss="binary_crossentropy", metrics=[ "accuracy", tf.keras.metrics.Precision(name="precision"), tf.keras.metrics.Recall(name="recall"), tf.keras.metrics.AUC(name="auc"), ], ) def build_model() -> tf.keras.Model: tf.keras.utils.set_random_seed(SEED) base_model = MobileNetV2( input_shape=(*IMAGE_SIZE, 3), include_top=False, weights="imagenet", ) base_model.trainable = False model = Sequential( [ base_model, layers.GlobalAveragePooling2D(), layers.Dropout(0.3), layers.Dense(128, activation="relu"), layers.Dropout(0.2), layers.Dense(1, activation="sigmoid"), ] ) compile_model(model, learning_rate=1e-4) return model def unfreeze_for_fine_tuning(model: tf.keras.Model, trainable_layers: int = 30) -> None: base_model = model.layers[0] base_model.trainable = True for layer in base_model.layers[:-trainable_layers]: layer.trainable = False for layer in base_model.layers[-trainable_layers:]: if isinstance(layer, tf.keras.layers.BatchNormalization): layer.trainable = False compile_model(model, learning_rate=1e-5)