Spaces:
Running
Running
Bachstelze commited on
Commit ·
73f28de
1
Parent(s): 94ac6b0
readd keras models
Browse files- A13/A13_DeepLearning_Report.ipynb +3 -0
- A13/dl_models/__init__.py +1 -0
- A13/dl_models/data_loader.py +108 -0
- A13/dl_models/evaluate.py +68 -0
- A13/dl_models/models.py +121 -0
- A13/dl_models/predict.py +120 -0
- A13/dl_models/saved/A_CNN.keras +3 -0
- A13/dl_models/saved/A_Dense.keras +3 -0
- A13/dl_models/saved/B_CNN.keras +3 -0
- A13/dl_models/saved/B_Dense.keras +3 -0
- A13/dl_models/saved/cv_summary.json +98 -0
- A13/dl_models/saved/training_summary.json +56 -0
- A13/dl_models/train.py +187 -0
A13/A13_DeepLearning_Report.ipynb
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2fc1b5f814669ca0e4161b02a6aa29bede0e249da41695bd08473f7ce8088640
|
| 3 |
+
size 100045
|
A13/dl_models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Deep Learning models for Issue #10 (problems A and B, Dense and CNN)."""
|
A13/dl_models/data_loader.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Load the prepared classification data produced in Issue #9.
|
| 2 |
+
|
| 3 |
+
The directory ``A13/classification_problems/prepared_data`` contains:
|
| 4 |
+
|
| 5 |
+
* ``{P}_{M}_train_X.npy`` original train features
|
| 6 |
+
* ``{P}_{M}_train_y.npy`` original train labels
|
| 7 |
+
* ``{P}_{M}_train_aug_X.npy`` augmented train features (incl. originals)
|
| 8 |
+
* ``{P}_{M}_train_aug_y.npy`` augmented train labels
|
| 9 |
+
* ``{P}_{M}_test_X.npy`` held-out test features
|
| 10 |
+
* ``{P}_{M}_test_y.npy`` held-out test labels
|
| 11 |
+
* ``{P}_{M}_*_filenames.npy`` the source clip name (used to keep all
|
| 12 |
+
augmentations of one clip in the same CV fold)
|
| 13 |
+
|
| 14 |
+
with ``P in {A, B}`` and ``M in {Dense, CNN}``.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
# Resolve the prepared_data directory relative to this file so that the
|
| 25 |
+
# package works no matter from where the notebook / script is launched.
|
| 26 |
+
_THIS_DIR = Path(__file__).resolve().parent
|
| 27 |
+
DATA_DIR = (_THIS_DIR.parent / "classification_problems" / "prepared_data").resolve()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class Dataset:
|
| 32 |
+
"""Container holding all arrays for one (problem, model) combination."""
|
| 33 |
+
|
| 34 |
+
problem: str # "A" or "B"
|
| 35 |
+
model_kind: str # "Dense" or "CNN"
|
| 36 |
+
X_train: np.ndarray # original (un-augmented) train features
|
| 37 |
+
y_train: np.ndarray
|
| 38 |
+
X_train_aug: np.ndarray # augmented train features (used for fitting)
|
| 39 |
+
y_train_aug: np.ndarray
|
| 40 |
+
train_groups: np.ndarray # source-clip id per augmented train sample
|
| 41 |
+
X_test: np.ndarray
|
| 42 |
+
y_test: np.ndarray
|
| 43 |
+
test_filenames: np.ndarray
|
| 44 |
+
|
| 45 |
+
@property
|
| 46 |
+
def input_shape(self) -> tuple[int, ...]:
|
| 47 |
+
return self.X_train_aug.shape[1:]
|
| 48 |
+
|
| 49 |
+
def summary(self) -> str:
|
| 50 |
+
return (
|
| 51 |
+
f"Problem {self.problem} / {self.model_kind}: "
|
| 52 |
+
f"train_aug={self.X_train_aug.shape}, "
|
| 53 |
+
f"test={self.X_test.shape}, "
|
| 54 |
+
f"pos_train={int(self.y_train_aug.sum())}/{len(self.y_train_aug)}, "
|
| 55 |
+
f"pos_test={int(self.y_test.sum())}/{len(self.y_test)}"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _load(name: str) -> np.ndarray:
|
| 60 |
+
path = DATA_DIR / f"{name}.npy"
|
| 61 |
+
return np.load(path, allow_pickle=True)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Augmentation suffixes appended to source-clip filenames in the prepared data.
|
| 65 |
+
# The CV must group all augmented copies of one source clip together, so we
|
| 66 |
+
# strip these suffixes to recover the original clip id (e.g. ``A1_mirror`` -> ``A1``).
|
| 67 |
+
_AUG_SUFFIXES = ("_mirror", "_rotate_pos", "_rotate_neg", "_stretch")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _source_clip_ids(filenames: np.ndarray) -> np.ndarray:
|
| 71 |
+
out = np.empty(len(filenames), dtype=object)
|
| 72 |
+
for i, name in enumerate(filenames):
|
| 73 |
+
s = str(name)
|
| 74 |
+
for suf in _AUG_SUFFIXES:
|
| 75 |
+
if s.endswith(suf):
|
| 76 |
+
s = s[: -len(suf)]
|
| 77 |
+
break
|
| 78 |
+
out[i] = s
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def load_dataset(problem: str, model_kind: str) -> Dataset:
|
| 83 |
+
"""Load arrays for problem ``A``/``B`` and ``Dense``/``CNN``."""
|
| 84 |
+
|
| 85 |
+
if problem not in {"A", "B"}:
|
| 86 |
+
raise ValueError(f"problem must be 'A' or 'B', got {problem!r}")
|
| 87 |
+
if model_kind not in {"Dense", "CNN"}:
|
| 88 |
+
raise ValueError(f"model_kind must be 'Dense' or 'CNN', got {model_kind!r}")
|
| 89 |
+
|
| 90 |
+
prefix = f"{problem}_{model_kind}"
|
| 91 |
+
return Dataset(
|
| 92 |
+
problem=problem,
|
| 93 |
+
model_kind=model_kind,
|
| 94 |
+
X_train=_load(f"{prefix}_train_X").astype("float32"),
|
| 95 |
+
y_train=_load(f"{prefix}_train_y").astype("int32"),
|
| 96 |
+
X_train_aug=_load(f"{prefix}_train_aug_X").astype("float32"),
|
| 97 |
+
y_train_aug=_load(f"{prefix}_train_aug_y").astype("int32"),
|
| 98 |
+
train_groups=_source_clip_ids(_load(f"{prefix}_train_aug_filenames")),
|
| 99 |
+
X_test=_load(f"{prefix}_test_X").astype("float32"),
|
| 100 |
+
y_test=_load(f"{prefix}_test_y").astype("int32"),
|
| 101 |
+
test_filenames=_load(f"{prefix}_test_filenames"),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def load_all() -> dict[tuple[str, str], Dataset]:
|
| 106 |
+
"""Convenience helper returning the four datasets keyed by (problem, kind)."""
|
| 107 |
+
|
| 108 |
+
return {(p, m): load_dataset(p, m) for p in ("A", "B") for m in ("Dense", "CNN")}
|
A13/dl_models/evaluate.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Evaluation helpers (confusion matrix, metrics tables, plots)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Iterable
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
from sklearn.metrics import (
|
| 10 |
+
accuracy_score,
|
| 11 |
+
confusion_matrix,
|
| 12 |
+
precision_score,
|
| 13 |
+
recall_score,
|
| 14 |
+
roc_auc_score,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
METRIC_KEYS = ["tp", "fp", "tn", "fn", "accuracy", "precision", "recall", "auc"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def predict_proba(model: tf.keras.Model, X: np.ndarray) -> np.ndarray:
|
| 22 |
+
return model.predict(X, verbose=0).reshape(-1)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def metrics_from_predictions(
|
| 26 |
+
y_true: np.ndarray, y_proba: np.ndarray, threshold: float = 0.5
|
| 27 |
+
) -> dict[str, float]:
|
| 28 |
+
y_pred = (y_proba >= threshold).astype(int)
|
| 29 |
+
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
|
| 30 |
+
tn, fp, fn, tp = cm.ravel()
|
| 31 |
+
auc = float("nan")
|
| 32 |
+
if len(np.unique(y_true)) > 1:
|
| 33 |
+
auc = roc_auc_score(y_true, y_proba)
|
| 34 |
+
return {
|
| 35 |
+
"tp": int(tp), "fp": int(fp), "tn": int(tn), "fn": int(fn),
|
| 36 |
+
"accuracy": accuracy_score(y_true, y_pred),
|
| 37 |
+
"precision": precision_score(y_true, y_pred, zero_division=0),
|
| 38 |
+
"recall": recall_score(y_true, y_pred, zero_division=0),
|
| 39 |
+
"auc": auc,
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def confusion(y_true: np.ndarray, y_proba: np.ndarray, threshold: float = 0.5) -> np.ndarray:
|
| 44 |
+
return confusion_matrix(y_true, (y_proba >= threshold).astype(int), labels=[0, 1])
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def plot_confusion(cm: np.ndarray, title: str, ax=None):
|
| 48 |
+
import matplotlib.pyplot as plt
|
| 49 |
+
|
| 50 |
+
if ax is None:
|
| 51 |
+
_, ax = plt.subplots(figsize=(3, 3))
|
| 52 |
+
ax.imshow(cm, cmap="Blues")
|
| 53 |
+
ax.set_xticks([0, 1]); ax.set_yticks([0, 1])
|
| 54 |
+
ax.set_xticklabels(["bad", "good"]); ax.set_yticklabels(["bad", "good"])
|
| 55 |
+
ax.set_xlabel("predicted"); ax.set_ylabel("true")
|
| 56 |
+
ax.set_title(title)
|
| 57 |
+
for i in range(2):
|
| 58 |
+
for j in range(2):
|
| 59 |
+
ax.text(j, i, int(cm[i, j]), ha="center", va="center",
|
| 60 |
+
color="white" if cm[i, j] > cm.max() / 2 else "black")
|
| 61 |
+
return ax
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def metrics_table(rows: Iterable[dict], index: Iterable[str]):
|
| 65 |
+
import pandas as pd
|
| 66 |
+
df = pd.DataFrame(list(rows), index=list(index))
|
| 67 |
+
cols = [c for c in METRIC_KEYS if c in df.columns]
|
| 68 |
+
return df[cols].round(4)
|
A13/dl_models/models.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model factories for Issue #10.
|
| 2 |
+
|
| 3 |
+
Two architectures are provided per problem:
|
| 4 |
+
|
| 5 |
+
* :func:`build_dense` -- multi-layer perceptron over the flattened sequence.
|
| 6 |
+
* :func:`build_cnn` -- small Conv2D-over-(time, joint) network. The default
|
| 7 |
+
hyper-parameters were chosen so that the CNN has at most ~20 % of the
|
| 8 |
+
parameters of the Dense baseline (verified by :func:`assert_param_budget`).
|
| 9 |
+
|
| 10 |
+
All models output a single sigmoid logit (good=1 / bad=0) and are compiled
|
| 11 |
+
with ``binary_crossentropy`` plus the metrics required by issue #10:
|
| 12 |
+
True/False Positives & Negatives, AUC, BinaryAccuracy, Precision, Recall.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from typing import Sequence
|
| 18 |
+
|
| 19 |
+
import tensorflow as tf
|
| 20 |
+
from tensorflow.keras import layers, models, regularizers
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# --------------------------------------------------------------------------- #
|
| 24 |
+
# Metrics & compile helper #
|
| 25 |
+
# --------------------------------------------------------------------------- #
|
| 26 |
+
def make_metrics() -> list[tf.keras.metrics.Metric]:
|
| 27 |
+
return [
|
| 28 |
+
tf.keras.metrics.TruePositives(name="tp"),
|
| 29 |
+
tf.keras.metrics.FalsePositives(name="fp"),
|
| 30 |
+
tf.keras.metrics.TrueNegatives(name="tn"),
|
| 31 |
+
tf.keras.metrics.FalseNegatives(name="fn"),
|
| 32 |
+
tf.keras.metrics.BinaryAccuracy(name="accuracy"),
|
| 33 |
+
tf.keras.metrics.Precision(name="precision"),
|
| 34 |
+
tf.keras.metrics.Recall(name="recall"),
|
| 35 |
+
tf.keras.metrics.AUC(name="auc"),
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def compile_model(model: tf.keras.Model, learning_rate: float = 1e-3) -> tf.keras.Model:
|
| 40 |
+
model.compile(
|
| 41 |
+
optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
|
| 42 |
+
loss="binary_crossentropy",
|
| 43 |
+
metrics=make_metrics(),
|
| 44 |
+
)
|
| 45 |
+
return model
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# --------------------------------------------------------------------------- #
|
| 49 |
+
# Architectures #
|
| 50 |
+
# --------------------------------------------------------------------------- #
|
| 51 |
+
def build_dense(
|
| 52 |
+
input_dim: int,
|
| 53 |
+
hidden_units: Sequence[int] = (128, 64, 32),
|
| 54 |
+
dropout: float = 0.3,
|
| 55 |
+
l2: float = 1e-4,
|
| 56 |
+
learning_rate: float = 1e-3,
|
| 57 |
+
name: str = "dense",
|
| 58 |
+
) -> tf.keras.Model:
|
| 59 |
+
"""MLP for flattened sequences (Dense approach)."""
|
| 60 |
+
|
| 61 |
+
reg = regularizers.l2(l2) if l2 else None
|
| 62 |
+
inputs = layers.Input(shape=(input_dim,), name="features")
|
| 63 |
+
x = layers.BatchNormalization()(inputs)
|
| 64 |
+
for i, units in enumerate(hidden_units):
|
| 65 |
+
x = layers.Dense(units, activation="relu", kernel_regularizer=reg, name=f"fc{i+1}")(x)
|
| 66 |
+
if dropout:
|
| 67 |
+
x = layers.Dropout(dropout)(x)
|
| 68 |
+
output = layers.Dense(1, activation="sigmoid", name="prob")(x)
|
| 69 |
+
return compile_model(models.Model(inputs, output, name=name), learning_rate)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def build_cnn(
|
| 73 |
+
input_shape: tuple[int, int, int],
|
| 74 |
+
filters: Sequence[int] = (8, 16),
|
| 75 |
+
kernel_size: tuple[int, int] = (3, 3),
|
| 76 |
+
dense_units: int = 16,
|
| 77 |
+
dropout: float = 0.3,
|
| 78 |
+
l2: float = 1e-4,
|
| 79 |
+
learning_rate: float = 1e-3,
|
| 80 |
+
name: str = "cnn",
|
| 81 |
+
) -> tf.keras.Model:
|
| 82 |
+
"""Compact 2D CNN over (time, joint, coordinate) tensors.
|
| 83 |
+
|
| 84 |
+
The default ``filters`` and ``dense_units`` produce <20 % of the Dense
|
| 85 |
+
baseline's parameters for both problem A and problem B.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
reg = regularizers.l2(l2) if l2 else None
|
| 89 |
+
inputs = layers.Input(shape=input_shape, name="sequence")
|
| 90 |
+
x = layers.BatchNormalization()(inputs)
|
| 91 |
+
for i, f in enumerate(filters):
|
| 92 |
+
x = layers.Conv2D(
|
| 93 |
+
f, kernel_size=kernel_size, padding="same", activation="relu",
|
| 94 |
+
kernel_regularizer=reg, name=f"conv{i+1}",
|
| 95 |
+
)(x)
|
| 96 |
+
# only pool on the time axis; joint axis is small (13).
|
| 97 |
+
x = layers.MaxPool2D(pool_size=(2, 1), name=f"pool{i+1}")(x)
|
| 98 |
+
x = layers.GlobalAveragePooling2D(name="gap")(x)
|
| 99 |
+
if dense_units:
|
| 100 |
+
x = layers.Dense(dense_units, activation="relu", kernel_regularizer=reg, name="fc")(x)
|
| 101 |
+
if dropout:
|
| 102 |
+
x = layers.Dropout(dropout)(x)
|
| 103 |
+
output = layers.Dense(1, activation="sigmoid", name="prob")(x)
|
| 104 |
+
return compile_model(models.Model(inputs, output, name=name), learning_rate)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# --------------------------------------------------------------------------- #
|
| 108 |
+
# Parameter budget #
|
| 109 |
+
# --------------------------------------------------------------------------- #
|
| 110 |
+
def count_params(model: tf.keras.Model) -> int:
|
| 111 |
+
return int(model.count_params())
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def assert_param_budget(dense: tf.keras.Model, cnn: tf.keras.Model, ratio: float = 0.20) -> None:
|
| 115 |
+
"""Raise if the CNN exceeds ``ratio`` × Dense parameter count."""
|
| 116 |
+
d, c = count_params(dense), count_params(cnn)
|
| 117 |
+
if c > ratio * d:
|
| 118 |
+
raise AssertionError(
|
| 119 |
+
f"CNN has {c} parameters which exceeds {ratio:.0%} of Dense's {d} "
|
| 120 |
+
f"({c / d:.1%}). Reduce CNN filters/dense_units."
|
| 121 |
+
)
|
A13/dl_models/predict.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference helpers and a small CLI so the model is easy to (re)use.
|
| 2 |
+
|
| 3 |
+
Examples
|
| 4 |
+
--------
|
| 5 |
+
Train + save all four models::
|
| 6 |
+
|
| 7 |
+
python -m A13.dl_models.predict train --out A13/dl_models/saved
|
| 8 |
+
|
| 9 |
+
Predict on a NumPy array of features::
|
| 10 |
+
|
| 11 |
+
python -m A13.dl_models.predict run --model A13/dl_models/saved/A_Dense.keras \\
|
| 12 |
+
--X my_features.npy
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import tensorflow as tf
|
| 23 |
+
|
| 24 |
+
from .data_loader import load_all, load_dataset
|
| 25 |
+
from .models import build_dense, build_cnn, count_params, assert_param_budget
|
| 26 |
+
from .train import train_final
|
| 27 |
+
from .evaluate import predict_proba, metrics_from_predictions
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
SAVED_DIR = Path(__file__).resolve().parent / "saved"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _builders(dataset):
|
| 34 |
+
if dataset.model_kind == "Dense":
|
| 35 |
+
return lambda: build_dense(input_dim=dataset.input_shape[0], name=f"{dataset.problem}_Dense")
|
| 36 |
+
return lambda: build_cnn(input_shape=dataset.input_shape, name=f"{dataset.problem}_CNN")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def train_all(out_dir: Path = SAVED_DIR, epochs: int = 120, verbose: int = 1) -> dict:
|
| 40 |
+
"""Train Dense + CNN for both problems and save them.
|
| 41 |
+
|
| 42 |
+
Also asserts the CNN parameter budget (<= 20% of Dense) per problem.
|
| 43 |
+
"""
|
| 44 |
+
out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
|
| 45 |
+
datasets = load_all()
|
| 46 |
+
summary: dict[str, dict] = {}
|
| 47 |
+
|
| 48 |
+
# --- parameter budget check ------------------------------------------------
|
| 49 |
+
for problem in ("A", "B"):
|
| 50 |
+
d = build_dense(input_dim=datasets[(problem, "Dense")].input_shape[0])
|
| 51 |
+
c = build_cnn(input_shape=datasets[(problem, "CNN")].input_shape)
|
| 52 |
+
assert_param_budget(d, c, ratio=0.20)
|
| 53 |
+
summary[f"{problem}_param_counts"] = {
|
| 54 |
+
"dense": count_params(d), "cnn": count_params(c),
|
| 55 |
+
"ratio": count_params(c) / count_params(d),
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# --- train + save ----------------------------------------------------------
|
| 59 |
+
for (problem, kind), dataset in datasets.items():
|
| 60 |
+
if verbose:
|
| 61 |
+
print(f"== training {problem} / {kind} == {dataset.summary()}")
|
| 62 |
+
result = train_final(
|
| 63 |
+
dataset, _builders(dataset), epochs=epochs, verbose=verbose,
|
| 64 |
+
save_path=out_dir / f"{problem}_{kind}.keras",
|
| 65 |
+
)
|
| 66 |
+
summary[f"{problem}_{kind}_test_metrics"] = result.test_metrics
|
| 67 |
+
|
| 68 |
+
(out_dir / "training_summary.json").write_text(json.dumps(summary, indent=2))
|
| 69 |
+
return summary
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def predict(model_path: Path | str, X: np.ndarray, threshold: float = 0.5):
|
| 73 |
+
model = tf.keras.models.load_model(model_path)
|
| 74 |
+
proba = predict_proba(model, X)
|
| 75 |
+
return proba, (proba >= threshold).astype(int)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def evaluate_saved(model_path: Path | str, problem: str, model_kind: str) -> dict:
|
| 79 |
+
"""Re-evaluate a saved model on the official held-out test set."""
|
| 80 |
+
ds = load_dataset(problem, model_kind)
|
| 81 |
+
proba, _ = predict(model_path, ds.X_test)
|
| 82 |
+
return metrics_from_predictions(ds.y_test, proba)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# --------------------------------------------------------------------------- #
|
| 86 |
+
# CLI #
|
| 87 |
+
# --------------------------------------------------------------------------- #
|
| 88 |
+
def _cli() -> None:
|
| 89 |
+
parser = argparse.ArgumentParser(description="Train / use Issue #10 models.")
|
| 90 |
+
sub = parser.add_subparsers(dest="cmd", required=True)
|
| 91 |
+
|
| 92 |
+
p_train = sub.add_parser("train", help="Train all four models.")
|
| 93 |
+
p_train.add_argument("--out", default=str(SAVED_DIR))
|
| 94 |
+
p_train.add_argument("--epochs", type=int, default=120)
|
| 95 |
+
|
| 96 |
+
p_eval = sub.add_parser("eval", help="Evaluate a saved model on its test set.")
|
| 97 |
+
p_eval.add_argument("--model", required=True)
|
| 98 |
+
p_eval.add_argument("--problem", required=True, choices=["A", "B"])
|
| 99 |
+
p_eval.add_argument("--kind", required=True, choices=["Dense", "CNN"])
|
| 100 |
+
|
| 101 |
+
p_run = sub.add_parser("run", help="Run inference on a .npy feature array.")
|
| 102 |
+
p_run.add_argument("--model", required=True)
|
| 103 |
+
p_run.add_argument("--X", required=True)
|
| 104 |
+
p_run.add_argument("--threshold", type=float, default=0.5)
|
| 105 |
+
|
| 106 |
+
args = parser.parse_args()
|
| 107 |
+
if args.cmd == "train":
|
| 108 |
+
summary = train_all(Path(args.out), epochs=args.epochs)
|
| 109 |
+
print(json.dumps(summary, indent=2))
|
| 110 |
+
elif args.cmd == "eval":
|
| 111 |
+
print(json.dumps(evaluate_saved(args.model, args.problem, args.kind), indent=2))
|
| 112 |
+
elif args.cmd == "run":
|
| 113 |
+
X = np.load(args.X)
|
| 114 |
+
proba, pred = predict(args.model, X, threshold=args.threshold)
|
| 115 |
+
for p, q in zip(proba, pred):
|
| 116 |
+
print(f"{p:.4f}\t{int(q)}")
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
_cli()
|
A13/dl_models/saved/A_CNN.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f2b94a87774a05ecd05cacf4bc3c04d3636ed9853d1467c1926bc86f6a362e6d
|
| 3 |
+
size 71378
|
A13/dl_models/saved/A_Dense.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58e5873c14fb6f1b30c3df7e1f3a14b60700b1c66fd34fc1463d2d0513caf683
|
| 3 |
+
size 785808
|
A13/dl_models/saved/B_CNN.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5f4cc63783b17d5c406ba06c5bc5de7b06dd2254b43a8f7037080c715552eae
|
| 3 |
+
size 70482
|
A13/dl_models/saved/B_Dense.keras
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:01835e1f426ddd3b160a3162d601f93d969ba7789edd2323c7d502e6fdb1c809
|
| 3 |
+
size 581971
|
A13/dl_models/saved/cv_summary.json
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"A_Dense": {
|
| 3 |
+
"mean": {
|
| 4 |
+
"accuracy": 1.0,
|
| 5 |
+
"auc": 1.0,
|
| 6 |
+
"fn": 0.0,
|
| 7 |
+
"fp": 0.0,
|
| 8 |
+
"loss": 0.029864558018743992,
|
| 9 |
+
"precision": 1.0,
|
| 10 |
+
"recall": 1.0,
|
| 11 |
+
"tn": 17.0,
|
| 12 |
+
"tp": 28.5
|
| 13 |
+
},
|
| 14 |
+
"std": {
|
| 15 |
+
"accuracy": 0.0,
|
| 16 |
+
"auc": 0.0,
|
| 17 |
+
"fn": 0.0,
|
| 18 |
+
"fp": 0.0,
|
| 19 |
+
"loss": 0.0003843123911400311,
|
| 20 |
+
"precision": 0.0,
|
| 21 |
+
"recall": 0.0,
|
| 22 |
+
"tn": 1.0,
|
| 23 |
+
"tp": 0.9219544457292888
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"A_CNN": {
|
| 27 |
+
"mean": {
|
| 28 |
+
"accuracy": 0.9274879336357117,
|
| 29 |
+
"auc": 0.9561430215835571,
|
| 30 |
+
"fn": 3.3,
|
| 31 |
+
"fp": 0.0,
|
| 32 |
+
"loss": 0.19682206511497496,
|
| 33 |
+
"precision": 1.0,
|
| 34 |
+
"recall": 0.8842802405357361,
|
| 35 |
+
"tn": 17.0,
|
| 36 |
+
"tp": 25.2
|
| 37 |
+
},
|
| 38 |
+
"std": {
|
| 39 |
+
"accuracy": 0.01719623343302091,
|
| 40 |
+
"auc": 0.025368922288748794,
|
| 41 |
+
"fn": 0.7810249675906654,
|
| 42 |
+
"fp": 0.0,
|
| 43 |
+
"loss": 0.028594990244631205,
|
| 44 |
+
"precision": 0.0,
|
| 45 |
+
"recall": 0.026408634136469537,
|
| 46 |
+
"tn": 1.0,
|
| 47 |
+
"tp": 1.0770329614269007
|
| 48 |
+
}
|
| 49 |
+
},
|
| 50 |
+
"B_Dense": {
|
| 51 |
+
"mean": {
|
| 52 |
+
"accuracy": 0.997826087474823,
|
| 53 |
+
"auc": 1.0,
|
| 54 |
+
"fn": 0.1,
|
| 55 |
+
"fp": 0.0,
|
| 56 |
+
"loss": 0.03609825950115919,
|
| 57 |
+
"precision": 1.0,
|
| 58 |
+
"recall": 0.9965517222881317,
|
| 59 |
+
"tn": 17.0,
|
| 60 |
+
"tp": 28.4
|
| 61 |
+
},
|
| 62 |
+
"std": {
|
| 63 |
+
"accuracy": 0.006521737575531006,
|
| 64 |
+
"auc": 0.0,
|
| 65 |
+
"fn": 0.30000000000000004,
|
| 66 |
+
"fp": 0.0,
|
| 67 |
+
"loss": 0.014322467489723906,
|
| 68 |
+
"precision": 0.0,
|
| 69 |
+
"recall": 0.010344833135604858,
|
| 70 |
+
"tn": 1.0,
|
| 71 |
+
"tp": 0.9165151389911681
|
| 72 |
+
}
|
| 73 |
+
},
|
| 74 |
+
"B_CNN": {
|
| 75 |
+
"mean": {
|
| 76 |
+
"accuracy": 0.9274879336357117,
|
| 77 |
+
"auc": 0.9576378405094147,
|
| 78 |
+
"fn": 3.3,
|
| 79 |
+
"fp": 0.0,
|
| 80 |
+
"loss": 0.19666245728731155,
|
| 81 |
+
"precision": 1.0,
|
| 82 |
+
"recall": 0.8842802405357361,
|
| 83 |
+
"tn": 17.0,
|
| 84 |
+
"tp": 25.2
|
| 85 |
+
},
|
| 86 |
+
"std": {
|
| 87 |
+
"accuracy": 0.01719623343302091,
|
| 88 |
+
"auc": 0.02392501041855642,
|
| 89 |
+
"fn": 0.7810249675906654,
|
| 90 |
+
"fp": 0.0,
|
| 91 |
+
"loss": 0.027519047326329205,
|
| 92 |
+
"precision": 0.0,
|
| 93 |
+
"recall": 0.026408634136469537,
|
| 94 |
+
"tn": 1.0,
|
| 95 |
+
"tp": 1.0770329614269007
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
}
|
A13/dl_models/saved/training_summary.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"A_param_counts": {
|
| 3 |
+
"dense": 61977,
|
| 4 |
+
"cnn": 1693,
|
| 5 |
+
"ratio": 0.02731658518482663
|
| 6 |
+
},
|
| 7 |
+
"B_param_counts": {
|
| 8 |
+
"dense": 44817,
|
| 9 |
+
"cnn": 1617,
|
| 10 |
+
"ratio": 0.03608005890621862
|
| 11 |
+
},
|
| 12 |
+
"A_Dense_test_metrics": {
|
| 13 |
+
"accuracy": 0.9130434989929199,
|
| 14 |
+
"auc": 0.9365079402923584,
|
| 15 |
+
"fn": 1.0,
|
| 16 |
+
"fp": 1.0,
|
| 17 |
+
"loss": 0.5623016953468323,
|
| 18 |
+
"precision": 0.9285714030265808,
|
| 19 |
+
"recall": 0.9285714030265808,
|
| 20 |
+
"tn": 8.0,
|
| 21 |
+
"tp": 13.0
|
| 22 |
+
},
|
| 23 |
+
"A_CNN_test_metrics": {
|
| 24 |
+
"accuracy": 0.95652174949646,
|
| 25 |
+
"auc": 0.964285671710968,
|
| 26 |
+
"fn": 1.0,
|
| 27 |
+
"fp": 0.0,
|
| 28 |
+
"loss": 0.15785124897956848,
|
| 29 |
+
"precision": 1.0,
|
| 30 |
+
"recall": 0.9285714030265808,
|
| 31 |
+
"tn": 9.0,
|
| 32 |
+
"tp": 13.0
|
| 33 |
+
},
|
| 34 |
+
"B_Dense_test_metrics": {
|
| 35 |
+
"accuracy": 0.9130434989929199,
|
| 36 |
+
"auc": 0.9365079402923584,
|
| 37 |
+
"fn": 1.0,
|
| 38 |
+
"fp": 1.0,
|
| 39 |
+
"loss": 0.6157440543174744,
|
| 40 |
+
"precision": 0.9285714030265808,
|
| 41 |
+
"recall": 0.9285714030265808,
|
| 42 |
+
"tn": 8.0,
|
| 43 |
+
"tp": 13.0
|
| 44 |
+
},
|
| 45 |
+
"B_CNN_test_metrics": {
|
| 46 |
+
"accuracy": 0.95652174949646,
|
| 47 |
+
"auc": 0.9722222685813904,
|
| 48 |
+
"fn": 1.0,
|
| 49 |
+
"fp": 0.0,
|
| 50 |
+
"loss": 0.15502068400382996,
|
| 51 |
+
"precision": 1.0,
|
| 52 |
+
"recall": 0.9285714030265808,
|
| 53 |
+
"tn": 9.0,
|
| 54 |
+
"tp": 13.0
|
| 55 |
+
}
|
| 56 |
+
}
|
A13/dl_models/train.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training and cross-validation utilities for Issue #10.
|
| 2 |
+
|
| 3 |
+
Supports:
|
| 4 |
+
|
| 5 |
+
* a single train/test fit (``train_final``)
|
| 6 |
+
* 10-fold *grouped* cross-validation that keeps all augmentations of the same
|
| 7 |
+
original clip in the same fold (``cross_validate``)
|
| 8 |
+
* small grid search over a few hyper-parameter combinations (``grid_search``)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from itertools import product
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Callable, Iterable
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import tensorflow as tf
|
| 20 |
+
from sklearn.model_selection import GroupKFold
|
| 21 |
+
|
| 22 |
+
from .data_loader import Dataset
|
| 23 |
+
from . import models as _models
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# --------------------------------------------------------------------------- #
|
| 27 |
+
# Common training helpers #
|
| 28 |
+
# --------------------------------------------------------------------------- #
|
| 29 |
+
def _callbacks(patience: int = 15) -> list[tf.keras.callbacks.Callback]:
|
| 30 |
+
return [
|
| 31 |
+
tf.keras.callbacks.EarlyStopping(
|
| 32 |
+
monitor="val_loss", patience=patience, restore_best_weights=True
|
| 33 |
+
),
|
| 34 |
+
tf.keras.callbacks.ReduceLROnPlateau(
|
| 35 |
+
monitor="val_loss", factor=0.5, patience=max(3, patience // 3), min_lr=1e-5
|
| 36 |
+
),
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def class_weight(y: np.ndarray) -> dict[int, float]:
|
| 41 |
+
pos = float(y.sum())
|
| 42 |
+
neg = float(len(y) - pos)
|
| 43 |
+
if pos == 0 or neg == 0:
|
| 44 |
+
return {0: 1.0, 1: 1.0}
|
| 45 |
+
total = pos + neg
|
| 46 |
+
return {0: total / (2 * neg), 1: total / (2 * pos)}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _evaluate(model: tf.keras.Model, X: np.ndarray, y: np.ndarray) -> dict[str, float]:
|
| 50 |
+
out = model.evaluate(X, y, verbose=0, return_dict=True)
|
| 51 |
+
return {k: float(v) for k, v in out.items()}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# --------------------------------------------------------------------------- #
|
| 55 |
+
# Grouped cross-validation #
|
| 56 |
+
# --------------------------------------------------------------------------- #
|
| 57 |
+
@dataclass
|
| 58 |
+
class CVResult:
|
| 59 |
+
fold_metrics: list[dict[str, float]]
|
| 60 |
+
mean: dict[str, float]
|
| 61 |
+
std: dict[str, float]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def cross_validate(
|
| 65 |
+
dataset: Dataset,
|
| 66 |
+
build_fn: Callable[[], tf.keras.Model],
|
| 67 |
+
n_splits: int = 10,
|
| 68 |
+
epochs: int = 80,
|
| 69 |
+
batch_size: int = 32,
|
| 70 |
+
use_class_weight: bool = True,
|
| 71 |
+
verbose: int = 0,
|
| 72 |
+
) -> CVResult:
|
| 73 |
+
"""Run grouped K-fold CV on ``dataset.X_train_aug``.
|
| 74 |
+
|
| 75 |
+
Splits use ``dataset.train_groups`` so all augmented copies of one
|
| 76 |
+
original clip stay in the same fold, as required by issue #10.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
X, y, groups = dataset.X_train_aug, dataset.y_train_aug, dataset.train_groups
|
| 80 |
+
n_splits = min(n_splits, len(np.unique(groups)))
|
| 81 |
+
gkf = GroupKFold(n_splits=n_splits)
|
| 82 |
+
|
| 83 |
+
fold_metrics: list[dict[str, float]] = []
|
| 84 |
+
for fold, (train_idx, val_idx) in enumerate(gkf.split(X, y, groups), start=1):
|
| 85 |
+
tf.keras.backend.clear_session()
|
| 86 |
+
model = build_fn()
|
| 87 |
+
cw = class_weight(y[train_idx]) if use_class_weight else None
|
| 88 |
+
model.fit(
|
| 89 |
+
X[train_idx], y[train_idx],
|
| 90 |
+
validation_data=(X[val_idx], y[val_idx]),
|
| 91 |
+
epochs=epochs, batch_size=batch_size,
|
| 92 |
+
callbacks=_callbacks(),
|
| 93 |
+
class_weight=cw,
|
| 94 |
+
verbose=verbose,
|
| 95 |
+
)
|
| 96 |
+
m = _evaluate(model, X[val_idx], y[val_idx])
|
| 97 |
+
m["fold"] = fold
|
| 98 |
+
fold_metrics.append(m)
|
| 99 |
+
if verbose:
|
| 100 |
+
print(f" fold {fold:2d}: auc={m['auc']:.3f} acc={m['accuracy']:.3f}")
|
| 101 |
+
|
| 102 |
+
keys = [k for k in fold_metrics[0] if k != "fold"]
|
| 103 |
+
mean = {k: float(np.mean([f[k] for f in fold_metrics])) for k in keys}
|
| 104 |
+
std = {k: float(np.std([f[k] for f in fold_metrics])) for k in keys}
|
| 105 |
+
return CVResult(fold_metrics=fold_metrics, mean=mean, std=std)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# --------------------------------------------------------------------------- #
|
| 109 |
+
# Final fit on all augmented training data #
|
| 110 |
+
# --------------------------------------------------------------------------- #
|
| 111 |
+
@dataclass
|
| 112 |
+
class TrainResult:
|
| 113 |
+
model: tf.keras.Model
|
| 114 |
+
history: dict[str, list[float]]
|
| 115 |
+
test_metrics: dict[str, float]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def train_final(
|
| 119 |
+
dataset: Dataset,
|
| 120 |
+
build_fn: Callable[[], tf.keras.Model],
|
| 121 |
+
epochs: int = 120,
|
| 122 |
+
batch_size: int = 32,
|
| 123 |
+
val_fraction: float = 0.15,
|
| 124 |
+
use_class_weight: bool = True,
|
| 125 |
+
verbose: int = 0,
|
| 126 |
+
save_path: Path | str | None = None,
|
| 127 |
+
) -> TrainResult:
|
| 128 |
+
X, y, groups = dataset.X_train_aug, dataset.y_train_aug, dataset.train_groups
|
| 129 |
+
|
| 130 |
+
# Hold out a single grouped validation split for early stopping.
|
| 131 |
+
n_groups = len(np.unique(groups))
|
| 132 |
+
n_val = max(1, int(round(n_groups * val_fraction)))
|
| 133 |
+
gkf = GroupKFold(n_splits=max(2, n_groups // n_val))
|
| 134 |
+
train_idx, val_idx = next(iter(gkf.split(X, y, groups)))
|
| 135 |
+
|
| 136 |
+
tf.keras.backend.clear_session()
|
| 137 |
+
model = build_fn()
|
| 138 |
+
cw = class_weight(y[train_idx]) if use_class_weight else None
|
| 139 |
+
history = model.fit(
|
| 140 |
+
X[train_idx], y[train_idx],
|
| 141 |
+
validation_data=(X[val_idx], y[val_idx]),
|
| 142 |
+
epochs=epochs, batch_size=batch_size,
|
| 143 |
+
callbacks=_callbacks(),
|
| 144 |
+
class_weight=cw,
|
| 145 |
+
verbose=verbose,
|
| 146 |
+
)
|
| 147 |
+
test_metrics = _evaluate(model, dataset.X_test, dataset.y_test)
|
| 148 |
+
if save_path is not None:
|
| 149 |
+
save_path = Path(save_path)
|
| 150 |
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
| 151 |
+
model.save(save_path)
|
| 152 |
+
return TrainResult(model=model, history={k: list(map(float, v)) for k, v in history.history.items()},
|
| 153 |
+
test_metrics=test_metrics)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# --------------------------------------------------------------------------- #
|
| 157 |
+
# Tiny grid search #
|
| 158 |
+
# --------------------------------------------------------------------------- #
|
| 159 |
+
def grid_search(
|
| 160 |
+
dataset: Dataset,
|
| 161 |
+
build_fn_factory: Callable[..., Callable[[], tf.keras.Model]],
|
| 162 |
+
grid: dict[str, Iterable],
|
| 163 |
+
n_splits: int = 5,
|
| 164 |
+
epochs: int = 60,
|
| 165 |
+
batch_size: int = 32,
|
| 166 |
+
verbose: int = 0,
|
| 167 |
+
) -> list[dict]:
|
| 168 |
+
"""Simple grid search using grouped CV.
|
| 169 |
+
|
| 170 |
+
``build_fn_factory(**hp)`` must return a zero-arg builder of a fresh model.
|
| 171 |
+
Returns a list of dicts sorted by mean validation AUC (best first).
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
keys = list(grid.keys())
|
| 175 |
+
results = []
|
| 176 |
+
for combo in product(*[grid[k] for k in keys]):
|
| 177 |
+
hp = dict(zip(keys, combo))
|
| 178 |
+
if verbose:
|
| 179 |
+
print(f"-> {hp}")
|
| 180 |
+
cv = cross_validate(
|
| 181 |
+
dataset, build_fn_factory(**hp),
|
| 182 |
+
n_splits=n_splits, epochs=epochs, batch_size=batch_size,
|
| 183 |
+
verbose=0,
|
| 184 |
+
)
|
| 185 |
+
results.append({"hp": hp, "mean": cv.mean, "std": cv.std})
|
| 186 |
+
results.sort(key=lambda r: r["mean"]["auc"], reverse=True)
|
| 187 |
+
return results
|