dk2430098's picture
Upload folder using huggingface_hub
928b74f verified
"""
training/train_cnn.py
----------------------
CNN Branch Training Script β€” EfficientNet-B0
STATUS: COMPLETE
Usage:
cd ImageForensics-Detect/
python training/train_cnn.py [--epochs 30] [--batch_size 32] [--lr 1e-4]
Training strategy:
Phase 1 (epochs 1–10) : Only train classification head (backbone frozen)
Phase 2 (epochs 11–30) : Unfreeze top 20% of backbone layers + fine-tune
Saves:
- Best model weights β†’ models/cnn_branch.h5
- Training history β†’ outputs/cnn_training_history.json
"""
import argparse
import json
import sys
from pathlib import Path
# ── Add project root to path ──────────────────────────────────────
ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(ROOT))
from training.dataset_loader import discover_dataset, split_dataset, make_tf_dataset
MODELS_DIR = ROOT / "models"
OUTPUTS_DIR = ROOT / "outputs"
MODELS_DIR.mkdir(exist_ok=True)
OUTPUTS_DIR.mkdir(exist_ok=True)
def build_callbacks(model_save_path: str):
"""Build training callbacks: ModelCheckpoint, EarlyStopping, ReduceLROnPlateau."""
import tensorflow as tf
return [
tf.keras.callbacks.ModelCheckpoint(
filepath=model_save_path,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
verbose=1,
),
tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
patience=7,
restore_best_weights=True,
verbose=1,
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=3,
min_lr=1e-7,
verbose=1,
),
]
def train(epochs: int = 30, batch_size: int = 32, lr: float = 1e-4):
import tensorflow as tf
from branches.cnn_branch import _build_model
print(f"\n{'='*55}")
print(" ImageForensics-Detect β€” CNN Branch Training")
print(f"{'='*55}")
print(f" Epochs: {epochs} | Batch: {batch_size} | LR: {lr}")
# ── 1. Load Dataset ──────────────────────────────────────────
paths, labels = discover_dataset()
if len(paths) == 0:
print("\n❌ No images found in data/raw/real/ and data/raw/fake/")
print(" Please populate the dataset first. See README.md β†’ Dataset Setup.")
sys.exit(1)
splits = split_dataset(paths, labels)
train_ds = make_tf_dataset(splits["train"][0], splits["train"][1],
batch_size=batch_size, augment=True, shuffle=True)
val_ds = make_tf_dataset(splits["val"][0], splits["val"][1],
batch_size=batch_size, augment=False, shuffle=False)
print(f"\n Train: {len(splits['train'][0])} | "
f"Val: {len(splits['val'][0])} | Test: {len(splits['test'][0])}")
# ── 2. Build Model ────────────────────────────────────────────
model = _build_model()
model.summary(line_length=80)
model_save = str(MODELS_DIR / "cnn_branch.h5")
callbacks = build_callbacks(model_save)
# ── 3. Phase 1: Head-only training ───────────────────────────
print("\n[Phase 1] Training classification head only (backbone frozen)...")
for layer in model.get_layer("efficientnetb0").layers:
layer.trainable = False
model.compile(
optimizer=tf.keras.optimizers.Adam(lr),
loss="binary_crossentropy",
metrics=["accuracy"],
)
phase1_epochs = min(10, epochs // 3)
history1 = model.fit(
train_ds, validation_data=val_ds,
epochs=phase1_epochs, callbacks=callbacks, verbose=1,
)
# ── 4. Phase 2: Fine-tune top 20% of backbone ─────────────────
if epochs > phase1_epochs:
print("\n[Phase 2] Fine-tuning top 20% of EfficientNet layers...")
base = model.get_layer("efficientnetb0")
n_unfreeze = max(1, int(0.20 * len(base.layers)))
for layer in base.layers[-n_unfreeze:]:
layer.trainable = True
model.compile(
optimizer=tf.keras.optimizers.Adam(lr * 0.1), # Lower LR for fine-tune
loss="binary_crossentropy",
metrics=["accuracy"],
)
history2 = model.fit(
train_ds, validation_data=val_ds,
epochs=epochs - phase1_epochs,
callbacks=callbacks, verbose=1,
)
else:
history2 = None
# ── 5. Save weights ───────────────────────────────────────────
model.save_weights(model_save)
print(f"\nβœ“ Best model weights saved β†’ {model_save}")
# ── 6. Save history ───────────────────────────────────────────
history = {k: v + (history2.history.get(k, []) if history2 else [])
for k, v in history1.history.items()}
hist_path = OUTPUTS_DIR / "cnn_training_history.json"
with open(hist_path, "w") as f:
json.dump({k: [float(x) for x in v] for k, v in history.items()}, f, indent=2)
print(f"βœ“ Training history saved β†’ {hist_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train CNN Branch")
parser.add_argument("--epochs", type=int, default=30, help="Total epochs")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
args = parser.parse_args()
train(epochs=args.epochs, batch_size=args.batch_size, lr=args.lr)