| """ |
| training/train_cnn.py |
| ---------------------- |
| CNN Branch Training Script β EfficientNet-B0 |
| STATUS: COMPLETE |
| |
| Usage: |
| cd ImageForensics-Detect/ |
| python training/train_cnn.py [--epochs 30] [--batch_size 32] [--lr 1e-4] |
| |
| Training strategy: |
| Phase 1 (epochs 1β10) : Only train classification head (backbone frozen) |
| Phase 2 (epochs 11β30) : Unfreeze top 20% of backbone layers + fine-tune |
| |
| Saves: |
| - Best model weights β models/cnn_branch.h5 |
| - Training history β outputs/cnn_training_history.json |
| """ |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| |
| ROOT = Path(__file__).parent.parent |
| sys.path.insert(0, str(ROOT)) |
|
|
| from training.dataset_loader import discover_dataset, split_dataset, make_tf_dataset |
|
|
| MODELS_DIR = ROOT / "models" |
| OUTPUTS_DIR = ROOT / "outputs" |
| MODELS_DIR.mkdir(exist_ok=True) |
| OUTPUTS_DIR.mkdir(exist_ok=True) |
|
|
|
|
| def build_callbacks(model_save_path: str): |
| """Build training callbacks: ModelCheckpoint, EarlyStopping, ReduceLROnPlateau.""" |
| import tensorflow as tf |
|
|
| return [ |
| tf.keras.callbacks.ModelCheckpoint( |
| filepath=model_save_path, |
| monitor="val_accuracy", |
| save_best_only=True, |
| save_weights_only=True, |
| verbose=1, |
| ), |
| tf.keras.callbacks.EarlyStopping( |
| monitor="val_loss", |
| patience=7, |
| restore_best_weights=True, |
| verbose=1, |
| ), |
| tf.keras.callbacks.ReduceLROnPlateau( |
| monitor="val_loss", |
| factor=0.5, |
| patience=3, |
| min_lr=1e-7, |
| verbose=1, |
| ), |
| ] |
|
|
|
|
| def train(epochs: int = 30, batch_size: int = 32, lr: float = 1e-4): |
| import tensorflow as tf |
| from branches.cnn_branch import _build_model |
|
|
| print(f"\n{'='*55}") |
| print(" ImageForensics-Detect β CNN Branch Training") |
| print(f"{'='*55}") |
| print(f" Epochs: {epochs} | Batch: {batch_size} | LR: {lr}") |
|
|
| |
| paths, labels = discover_dataset() |
| if len(paths) == 0: |
| print("\nβ No images found in data/raw/real/ and data/raw/fake/") |
| print(" Please populate the dataset first. See README.md β Dataset Setup.") |
| sys.exit(1) |
|
|
| splits = split_dataset(paths, labels) |
|
|
| train_ds = make_tf_dataset(splits["train"][0], splits["train"][1], |
| batch_size=batch_size, augment=True, shuffle=True) |
| val_ds = make_tf_dataset(splits["val"][0], splits["val"][1], |
| batch_size=batch_size, augment=False, shuffle=False) |
|
|
| print(f"\n Train: {len(splits['train'][0])} | " |
| f"Val: {len(splits['val'][0])} | Test: {len(splits['test'][0])}") |
|
|
| |
| model = _build_model() |
| model.summary(line_length=80) |
|
|
| model_save = str(MODELS_DIR / "cnn_branch.h5") |
| callbacks = build_callbacks(model_save) |
|
|
| |
| print("\n[Phase 1] Training classification head only (backbone frozen)...") |
| for layer in model.get_layer("efficientnetb0").layers: |
| layer.trainable = False |
|
|
| model.compile( |
| optimizer=tf.keras.optimizers.Adam(lr), |
| loss="binary_crossentropy", |
| metrics=["accuracy"], |
| ) |
| phase1_epochs = min(10, epochs // 3) |
| history1 = model.fit( |
| train_ds, validation_data=val_ds, |
| epochs=phase1_epochs, callbacks=callbacks, verbose=1, |
| ) |
|
|
| |
| if epochs > phase1_epochs: |
| print("\n[Phase 2] Fine-tuning top 20% of EfficientNet layers...") |
| base = model.get_layer("efficientnetb0") |
| n_unfreeze = max(1, int(0.20 * len(base.layers))) |
| for layer in base.layers[-n_unfreeze:]: |
| layer.trainable = True |
|
|
| model.compile( |
| optimizer=tf.keras.optimizers.Adam(lr * 0.1), |
| loss="binary_crossentropy", |
| metrics=["accuracy"], |
| ) |
| history2 = model.fit( |
| train_ds, validation_data=val_ds, |
| epochs=epochs - phase1_epochs, |
| callbacks=callbacks, verbose=1, |
| ) |
| else: |
| history2 = None |
|
|
| |
| model.save_weights(model_save) |
| print(f"\nβ Best model weights saved β {model_save}") |
|
|
| |
| history = {k: v + (history2.history.get(k, []) if history2 else []) |
| for k, v in history1.history.items()} |
| hist_path = OUTPUTS_DIR / "cnn_training_history.json" |
| with open(hist_path, "w") as f: |
| json.dump({k: [float(x) for x in v] for k, v in history.items()}, f, indent=2) |
| print(f"β Training history saved β {hist_path}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Train CNN Branch") |
| parser.add_argument("--epochs", type=int, default=30, help="Total epochs") |
| parser.add_argument("--batch_size", type=int, default=32, help="Batch size") |
| parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") |
| args = parser.parse_args() |
| train(epochs=args.epochs, batch_size=args.batch_size, lr=args.lr) |
|
|