Bachstelze
readd keras models
73f28de
"""Model factories for Issue #10.
Two architectures are provided per problem:
* :func:`build_dense` -- multi-layer perceptron over the flattened sequence.
* :func:`build_cnn` -- small Conv2D-over-(time, joint) network. The default
hyper-parameters were chosen so that the CNN has at most ~20 % of the
parameters of the Dense baseline (verified by :func:`assert_param_budget`).
All models output a single sigmoid logit (good=1 / bad=0) and are compiled
with ``binary_crossentropy`` plus the metrics required by issue #10:
True/False Positives & Negatives, AUC, BinaryAccuracy, Precision, Recall.
"""
from __future__ import annotations
from typing import Sequence
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
# --------------------------------------------------------------------------- #
# Metrics & compile helper #
# --------------------------------------------------------------------------- #
def make_metrics() -> list[tf.keras.metrics.Metric]:
return [
tf.keras.metrics.TruePositives(name="tp"),
tf.keras.metrics.FalsePositives(name="fp"),
tf.keras.metrics.TrueNegatives(name="tn"),
tf.keras.metrics.FalseNegatives(name="fn"),
tf.keras.metrics.BinaryAccuracy(name="accuracy"),
tf.keras.metrics.Precision(name="precision"),
tf.keras.metrics.Recall(name="recall"),
tf.keras.metrics.AUC(name="auc"),
]
def compile_model(model: tf.keras.Model, learning_rate: float = 1e-3) -> tf.keras.Model:
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
loss="binary_crossentropy",
metrics=make_metrics(),
)
return model
# --------------------------------------------------------------------------- #
# Architectures #
# --------------------------------------------------------------------------- #
def build_dense(
input_dim: int,
hidden_units: Sequence[int] = (128, 64, 32),
dropout: float = 0.3,
l2: float = 1e-4,
learning_rate: float = 1e-3,
name: str = "dense",
) -> tf.keras.Model:
"""MLP for flattened sequences (Dense approach)."""
reg = regularizers.l2(l2) if l2 else None
inputs = layers.Input(shape=(input_dim,), name="features")
x = layers.BatchNormalization()(inputs)
for i, units in enumerate(hidden_units):
x = layers.Dense(units, activation="relu", kernel_regularizer=reg, name=f"fc{i+1}")(x)
if dropout:
x = layers.Dropout(dropout)(x)
output = layers.Dense(1, activation="sigmoid", name="prob")(x)
return compile_model(models.Model(inputs, output, name=name), learning_rate)
def build_cnn(
input_shape: tuple[int, int, int],
filters: Sequence[int] = (8, 16),
kernel_size: tuple[int, int] = (3, 3),
dense_units: int = 16,
dropout: float = 0.3,
l2: float = 1e-4,
learning_rate: float = 1e-3,
name: str = "cnn",
) -> tf.keras.Model:
"""Compact 2D CNN over (time, joint, coordinate) tensors.
The default ``filters`` and ``dense_units`` produce <20 % of the Dense
baseline's parameters for both problem A and problem B.
"""
reg = regularizers.l2(l2) if l2 else None
inputs = layers.Input(shape=input_shape, name="sequence")
x = layers.BatchNormalization()(inputs)
for i, f in enumerate(filters):
x = layers.Conv2D(
f, kernel_size=kernel_size, padding="same", activation="relu",
kernel_regularizer=reg, name=f"conv{i+1}",
)(x)
# only pool on the time axis; joint axis is small (13).
x = layers.MaxPool2D(pool_size=(2, 1), name=f"pool{i+1}")(x)
x = layers.GlobalAveragePooling2D(name="gap")(x)
if dense_units:
x = layers.Dense(dense_units, activation="relu", kernel_regularizer=reg, name="fc")(x)
if dropout:
x = layers.Dropout(dropout)(x)
output = layers.Dense(1, activation="sigmoid", name="prob")(x)
return compile_model(models.Model(inputs, output, name=name), learning_rate)
# --------------------------------------------------------------------------- #
# Parameter budget #
# --------------------------------------------------------------------------- #
def count_params(model: tf.keras.Model) -> int:
return int(model.count_params())
def assert_param_budget(dense: tf.keras.Model, cnn: tf.keras.Model, ratio: float = 0.20) -> None:
"""Raise if the CNN exceeds ``ratio`` × Dense parameter count."""
d, c = count_params(dense), count_params(cnn)
if c > ratio * d:
raise AssertionError(
f"CNN has {c} parameters which exceeds {ratio:.0%} of Dense's {d} "
f"({c / d:.1%}). Reduce CNN filters/dense_units."
)