""" training/train_cnn.py ---------------------- CNN Branch Training Script — EfficientNet-B0 STATUS: COMPLETE Usage: cd ImageForensics-Detect/ python training/train_cnn.py [--epochs 30] [--batch_size 32] [--lr 1e-4] Training strategy: Phase 1 (epochs 1–10) : Only train classification head (backbone frozen) Phase 2 (epochs 11–30) : Unfreeze top 20% of backbone layers + fine-tune Saves: - Best model weights → models/cnn_branch.h5 - Training history → outputs/cnn_training_history.json """ import argparse import json import sys from pathlib import Path # ── Add project root to path ────────────────────────────────────── ROOT = Path(__file__).parent.parent sys.path.insert(0, str(ROOT)) from training.dataset_loader import discover_dataset, split_dataset, make_tf_dataset MODELS_DIR = ROOT / "models" OUTPUTS_DIR = ROOT / "outputs" MODELS_DIR.mkdir(exist_ok=True) OUTPUTS_DIR.mkdir(exist_ok=True) def build_callbacks(model_save_path: str): """Build training callbacks: ModelCheckpoint, EarlyStopping, ReduceLROnPlateau.""" import tensorflow as tf return [ tf.keras.callbacks.ModelCheckpoint( filepath=model_save_path, monitor="val_accuracy", save_best_only=True, save_weights_only=True, verbose=1, ), tf.keras.callbacks.EarlyStopping( monitor="val_loss", patience=7, restore_best_weights=True, verbose=1, ), tf.keras.callbacks.ReduceLROnPlateau( monitor="val_loss", factor=0.5, patience=3, min_lr=1e-7, verbose=1, ), ] def train(epochs: int = 30, batch_size: int = 32, lr: float = 1e-4): import tensorflow as tf from branches.cnn_branch import _build_model print(f"\n{'='*55}") print(" ImageForensics-Detect — CNN Branch Training") print(f"{'='*55}") print(f" Epochs: {epochs} | Batch: {batch_size} | LR: {lr}") # ── 1. Load Dataset ────────────────────────────────────────── paths, labels = discover_dataset() if len(paths) == 0: print("\n❌ No images found in data/raw/real/ and data/raw/fake/") print(" Please populate the dataset first. See README.md → Dataset Setup.") sys.exit(1) splits = split_dataset(paths, labels) train_ds = make_tf_dataset(splits["train"][0], splits["train"][1], batch_size=batch_size, augment=True, shuffle=True) val_ds = make_tf_dataset(splits["val"][0], splits["val"][1], batch_size=batch_size, augment=False, shuffle=False) print(f"\n Train: {len(splits['train'][0])} | " f"Val: {len(splits['val'][0])} | Test: {len(splits['test'][0])}") # ── 2. Build Model ──────────────────────────────────────────── model = _build_model() model.summary(line_length=80) model_save = str(MODELS_DIR / "cnn_branch.h5") callbacks = build_callbacks(model_save) # ── 3. Phase 1: Head-only training ─────────────────────────── print("\n[Phase 1] Training classification head only (backbone frozen)...") for layer in model.get_layer("efficientnetb0").layers: layer.trainable = False model.compile( optimizer=tf.keras.optimizers.Adam(lr), loss="binary_crossentropy", metrics=["accuracy"], ) phase1_epochs = min(10, epochs // 3) history1 = model.fit( train_ds, validation_data=val_ds, epochs=phase1_epochs, callbacks=callbacks, verbose=1, ) # ── 4. Phase 2: Fine-tune top 20% of backbone ───────────────── if epochs > phase1_epochs: print("\n[Phase 2] Fine-tuning top 20% of EfficientNet layers...") base = model.get_layer("efficientnetb0") n_unfreeze = max(1, int(0.20 * len(base.layers))) for layer in base.layers[-n_unfreeze:]: layer.trainable = True model.compile( optimizer=tf.keras.optimizers.Adam(lr * 0.1), # Lower LR for fine-tune loss="binary_crossentropy", metrics=["accuracy"], ) history2 = model.fit( train_ds, validation_data=val_ds, epochs=epochs - phase1_epochs, callbacks=callbacks, verbose=1, ) else: history2 = None # ── 5. Save weights ─────────────────────────────────────────── model.save_weights(model_save) print(f"\n✓ Best model weights saved → {model_save}") # ── 6. Save history ─────────────────────────────────────────── history = {k: v + (history2.history.get(k, []) if history2 else []) for k, v in history1.history.items()} hist_path = OUTPUTS_DIR / "cnn_training_history.json" with open(hist_path, "w") as f: json.dump({k: [float(x) for x in v] for k, v in history.items()}, f, indent=2) print(f"✓ Training history saved → {hist_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train CNN Branch") parser.add_argument("--epochs", type=int, default=30, help="Total epochs") parser.add_argument("--batch_size", type=int, default=32, help="Batch size") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") args = parser.parse_args() train(epochs=args.epochs, batch_size=args.batch_size, lr=args.lr)