Spaces:
Running
Running
| """Model factories for Issue #10. | |
| Two architectures are provided per problem: | |
| * :func:`build_dense` -- multi-layer perceptron over the flattened sequence. | |
| * :func:`build_cnn` -- small Conv2D-over-(time, joint) network. The default | |
| hyper-parameters were chosen so that the CNN has at most ~20 % of the | |
| parameters of the Dense baseline (verified by :func:`assert_param_budget`). | |
| All models output a single sigmoid logit (good=1 / bad=0) and are compiled | |
| with ``binary_crossentropy`` plus the metrics required by issue #10: | |
| True/False Positives & Negatives, AUC, BinaryAccuracy, Precision, Recall. | |
| """ | |
| from __future__ import annotations | |
| from typing import Sequence | |
| import tensorflow as tf | |
| from tensorflow.keras import layers, models, regularizers | |
| # --------------------------------------------------------------------------- # | |
| # Metrics & compile helper # | |
| # --------------------------------------------------------------------------- # | |
| def make_metrics() -> list[tf.keras.metrics.Metric]: | |
| return [ | |
| tf.keras.metrics.TruePositives(name="tp"), | |
| tf.keras.metrics.FalsePositives(name="fp"), | |
| tf.keras.metrics.TrueNegatives(name="tn"), | |
| tf.keras.metrics.FalseNegatives(name="fn"), | |
| tf.keras.metrics.BinaryAccuracy(name="accuracy"), | |
| tf.keras.metrics.Precision(name="precision"), | |
| tf.keras.metrics.Recall(name="recall"), | |
| tf.keras.metrics.AUC(name="auc"), | |
| ] | |
| def compile_model(model: tf.keras.Model, learning_rate: float = 1e-3) -> tf.keras.Model: | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), | |
| loss="binary_crossentropy", | |
| metrics=make_metrics(), | |
| ) | |
| return model | |
| # --------------------------------------------------------------------------- # | |
| # Architectures # | |
| # --------------------------------------------------------------------------- # | |
| def build_dense( | |
| input_dim: int, | |
| hidden_units: Sequence[int] = (128, 64, 32), | |
| dropout: float = 0.3, | |
| l2: float = 1e-4, | |
| learning_rate: float = 1e-3, | |
| name: str = "dense", | |
| ) -> tf.keras.Model: | |
| """MLP for flattened sequences (Dense approach).""" | |
| reg = regularizers.l2(l2) if l2 else None | |
| inputs = layers.Input(shape=(input_dim,), name="features") | |
| x = layers.BatchNormalization()(inputs) | |
| for i, units in enumerate(hidden_units): | |
| x = layers.Dense(units, activation="relu", kernel_regularizer=reg, name=f"fc{i+1}")(x) | |
| if dropout: | |
| x = layers.Dropout(dropout)(x) | |
| output = layers.Dense(1, activation="sigmoid", name="prob")(x) | |
| return compile_model(models.Model(inputs, output, name=name), learning_rate) | |
| def build_cnn( | |
| input_shape: tuple[int, int, int], | |
| filters: Sequence[int] = (8, 16), | |
| kernel_size: tuple[int, int] = (3, 3), | |
| dense_units: int = 16, | |
| dropout: float = 0.3, | |
| l2: float = 1e-4, | |
| learning_rate: float = 1e-3, | |
| name: str = "cnn", | |
| ) -> tf.keras.Model: | |
| """Compact 2D CNN over (time, joint, coordinate) tensors. | |
| The default ``filters`` and ``dense_units`` produce <20 % of the Dense | |
| baseline's parameters for both problem A and problem B. | |
| """ | |
| reg = regularizers.l2(l2) if l2 else None | |
| inputs = layers.Input(shape=input_shape, name="sequence") | |
| x = layers.BatchNormalization()(inputs) | |
| for i, f in enumerate(filters): | |
| x = layers.Conv2D( | |
| f, kernel_size=kernel_size, padding="same", activation="relu", | |
| kernel_regularizer=reg, name=f"conv{i+1}", | |
| )(x) | |
| # only pool on the time axis; joint axis is small (13). | |
| x = layers.MaxPool2D(pool_size=(2, 1), name=f"pool{i+1}")(x) | |
| x = layers.GlobalAveragePooling2D(name="gap")(x) | |
| if dense_units: | |
| x = layers.Dense(dense_units, activation="relu", kernel_regularizer=reg, name="fc")(x) | |
| if dropout: | |
| x = layers.Dropout(dropout)(x) | |
| output = layers.Dense(1, activation="sigmoid", name="prob")(x) | |
| return compile_model(models.Model(inputs, output, name=name), learning_rate) | |
| # --------------------------------------------------------------------------- # | |
| # Parameter budget # | |
| # --------------------------------------------------------------------------- # | |
| def count_params(model: tf.keras.Model) -> int: | |
| return int(model.count_params()) | |
| def assert_param_budget(dense: tf.keras.Model, cnn: tf.keras.Model, ratio: float = 0.20) -> None: | |
| """Raise if the CNN exceeds ``ratio`` × Dense parameter count.""" | |
| d, c = count_params(dense), count_params(cnn) | |
| if c > ratio * d: | |
| raise AssertionError( | |
| f"CNN has {c} parameters which exceeds {ratio:.0%} of Dense's {d} " | |
| f"({c / d:.1%}). Reduce CNN filters/dense_units." | |
| ) | |