""" training/train_vit.py ---------------------- ViT Branch Training Script — ViT-B/16 via PyTorch + timm STATUS: COMPLETE Usage: cd ImageForensics-Detect/ python training/train_vit.py [--epochs 20] [--batch_size 16] [--lr 1e-4] Training strategy: - AdamW optimizer with cosine LR schedule and warmup - Label smoothing (0.1) for better calibration - Mixed-precision training (if CUDA available) - Gradient clipping (max norm = 1.0) Saves: - Best model weights → models/vit_branch.pth - Training history → outputs/vit_training_history.json """ import argparse import json import sys from pathlib import Path ROOT = Path(__file__).parent.parent sys.path.insert(0, str(ROOT)) from training.dataset_loader import discover_dataset, split_dataset, make_torch_dataloader MODELS_DIR = ROOT / "models" OUTPUTS_DIR = ROOT / "outputs" MODELS_DIR.mkdir(exist_ok=True) OUTPUTS_DIR.mkdir(exist_ok=True) def train(epochs: int = 20, batch_size: int = 16, lr: float = 1e-4): import torch import torch.nn as nn from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR from sklearn.metrics import accuracy_score from branches.vit_branch import _build_vit_model device = "cuda" if torch.cuda.is_available() else "cpu" print(f"\n{'='*55}") print(" ImageForensics-Detect — ViT Branch Training") print(f"{'='*55}") print(f" Device: {device} | Epochs: {epochs} | Batch: {batch_size}") # ── 1. Load Dataset ────────────────────────────────────────── paths, labels = discover_dataset() if len(paths) == 0: print("\n❌ No images found. Populate data/raw/real/ and data/raw/fake/ first.") sys.exit(1) splits = split_dataset(paths, labels) train_loader = make_torch_dataloader(splits["train"][0], splits["train"][1], batch_size=batch_size, augment=True) val_loader = make_torch_dataloader(splits["val"][0], splits["val"][1], batch_size=batch_size, augment=False) # ── 2. Model & Optimizer ────────────────────────────────────── model = _build_vit_model().to(device) # Label smoothing loss for better calibration criterion = nn.CrossEntropyLoss(label_smoothing=0.1) optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-2) scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6) scaler = torch.cuda.amp.GradScaler() if device == "cuda" else None best_val_acc = 0.0 model_save = str(MODELS_DIR / "vit_branch.pth") history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []} for epoch in range(1, epochs + 1): # ── Train ── model.train() train_losses, train_preds, train_targets = [], [], [] for imgs, lbls in train_loader: imgs, lbls = imgs.to(device), lbls.to(device) optimizer.zero_grad() if scaler: with torch.cuda.amp.autocast(): logits = model(imgs) loss = criterion(logits, lbls) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() else: logits = model(imgs) loss = criterion(logits, lbls) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() train_losses.append(loss.item()) train_preds.extend(logits.argmax(dim=1).cpu().numpy()) train_targets.extend(lbls.cpu().numpy()) scheduler.step() # ── Validate ── model.eval() val_losses, val_preds, val_targets = [], [], [] with torch.no_grad(): for imgs, lbls in val_loader: imgs, lbls = imgs.to(device), lbls.to(device) logits = model(imgs) loss = criterion(logits, lbls) val_losses.append(loss.item()) val_preds.extend(logits.argmax(dim=1).cpu().numpy()) val_targets.extend(lbls.cpu().numpy()) train_acc = accuracy_score(train_targets, train_preds) val_acc = accuracy_score(val_targets, val_preds) t_loss = sum(train_losses) / len(train_losses) v_loss = sum(val_losses) / len(val_losses) print(f"Epoch [{epoch:02d}/{epochs}] " f"Train Loss={t_loss:.4f} Acc={train_acc:.4f} | " f"Val Loss={v_loss:.4f} Acc={val_acc:.4f}") history["train_loss"].append(t_loss) history["train_acc"].append(train_acc) history["val_loss"].append(v_loss) history["val_acc"].append(val_acc) # ── Save best ── if val_acc > best_val_acc: best_val_acc = val_acc torch.save(model.state_dict(), model_save) print(f" ✓ Best model saved (val_acc={val_acc:.4f})") print(f"\n✓ Training complete. Best val accuracy: {best_val_acc:.4f}") print(f"✓ Model saved → {model_save}") hist_path = OUTPUTS_DIR / "vit_training_history.json" with open(hist_path, "w") as f: json.dump(history, f, indent=2) print(f"✓ Training history saved → {hist_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train ViT Branch") parser.add_argument("--epochs", type=int, default=20, help="Total epochs") parser.add_argument("--batch_size", type=int, default=16, help="Batch size") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") args = parser.parse_args() train(epochs=args.epochs, batch_size=args.batch_size, lr=args.lr)