| """ |
| training/train_vit.py |
| ---------------------- |
| ViT Branch Training Script β ViT-B/16 via PyTorch + timm |
| STATUS: COMPLETE |
| |
| Usage: |
| cd ImageForensics-Detect/ |
| python training/train_vit.py [--epochs 20] [--batch_size 16] [--lr 1e-4] |
| |
| Training strategy: |
| - AdamW optimizer with cosine LR schedule and warmup |
| - Label smoothing (0.1) for better calibration |
| - Mixed-precision training (if CUDA available) |
| - Gradient clipping (max norm = 1.0) |
| |
| Saves: |
| - Best model weights β models/vit_branch.pth |
| - Training history β outputs/vit_training_history.json |
| """ |
|
|
| import argparse |
| import json |
| import sys |
| from pathlib import Path |
|
|
| ROOT = Path(__file__).parent.parent |
| sys.path.insert(0, str(ROOT)) |
|
|
| from training.dataset_loader import discover_dataset, split_dataset, make_torch_dataloader |
|
|
| MODELS_DIR = ROOT / "models" |
| OUTPUTS_DIR = ROOT / "outputs" |
| MODELS_DIR.mkdir(exist_ok=True) |
| OUTPUTS_DIR.mkdir(exist_ok=True) |
|
|
|
|
| def train(epochs: int = 20, batch_size: int = 16, lr: float = 1e-4): |
| import torch |
| import torch.nn as nn |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from sklearn.metrics import accuracy_score |
| from branches.vit_branch import _build_vit_model |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"\n{'='*55}") |
| print(" ImageForensics-Detect β ViT Branch Training") |
| print(f"{'='*55}") |
| print(f" Device: {device} | Epochs: {epochs} | Batch: {batch_size}") |
|
|
| |
| paths, labels = discover_dataset() |
| if len(paths) == 0: |
| print("\nβ No images found. Populate data/raw/real/ and data/raw/fake/ first.") |
| sys.exit(1) |
|
|
| splits = split_dataset(paths, labels) |
| train_loader = make_torch_dataloader(splits["train"][0], splits["train"][1], |
| batch_size=batch_size, augment=True) |
| val_loader = make_torch_dataloader(splits["val"][0], splits["val"][1], |
| batch_size=batch_size, augment=False) |
|
|
| |
| model = _build_vit_model().to(device) |
|
|
| |
| criterion = nn.CrossEntropyLoss(label_smoothing=0.1) |
|
|
| optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-2) |
| scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6) |
|
|
| scaler = torch.cuda.amp.GradScaler() if device == "cuda" else None |
|
|
| best_val_acc = 0.0 |
| model_save = str(MODELS_DIR / "vit_branch.pth") |
| history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []} |
|
|
| for epoch in range(1, epochs + 1): |
| |
| model.train() |
| train_losses, train_preds, train_targets = [], [], [] |
|
|
| for imgs, lbls in train_loader: |
| imgs, lbls = imgs.to(device), lbls.to(device) |
| optimizer.zero_grad() |
|
|
| if scaler: |
| with torch.cuda.amp.autocast(): |
| logits = model(imgs) |
| loss = criterion(logits, lbls) |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| logits = model(imgs) |
| loss = criterion(logits, lbls) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| train_losses.append(loss.item()) |
| train_preds.extend(logits.argmax(dim=1).cpu().numpy()) |
| train_targets.extend(lbls.cpu().numpy()) |
|
|
| scheduler.step() |
|
|
| |
| model.eval() |
| val_losses, val_preds, val_targets = [], [], [] |
|
|
| with torch.no_grad(): |
| for imgs, lbls in val_loader: |
| imgs, lbls = imgs.to(device), lbls.to(device) |
| logits = model(imgs) |
| loss = criterion(logits, lbls) |
| val_losses.append(loss.item()) |
| val_preds.extend(logits.argmax(dim=1).cpu().numpy()) |
| val_targets.extend(lbls.cpu().numpy()) |
|
|
| train_acc = accuracy_score(train_targets, train_preds) |
| val_acc = accuracy_score(val_targets, val_preds) |
| t_loss = sum(train_losses) / len(train_losses) |
| v_loss = sum(val_losses) / len(val_losses) |
|
|
| print(f"Epoch [{epoch:02d}/{epochs}] " |
| f"Train Loss={t_loss:.4f} Acc={train_acc:.4f} | " |
| f"Val Loss={v_loss:.4f} Acc={val_acc:.4f}") |
|
|
| history["train_loss"].append(t_loss) |
| history["train_acc"].append(train_acc) |
| history["val_loss"].append(v_loss) |
| history["val_acc"].append(val_acc) |
|
|
| |
| if val_acc > best_val_acc: |
| best_val_acc = val_acc |
| torch.save(model.state_dict(), model_save) |
| print(f" β Best model saved (val_acc={val_acc:.4f})") |
|
|
| print(f"\nβ Training complete. Best val accuracy: {best_val_acc:.4f}") |
| print(f"β Model saved β {model_save}") |
|
|
| hist_path = OUTPUTS_DIR / "vit_training_history.json" |
| with open(hist_path, "w") as f: |
| json.dump(history, f, indent=2) |
| print(f"β Training history saved β {hist_path}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Train ViT Branch") |
| parser.add_argument("--epochs", type=int, default=20, help="Total epochs") |
| parser.add_argument("--batch_size", type=int, default=16, help="Batch size") |
| parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") |
| args = parser.parse_args() |
| train(epochs=args.epochs, batch_size=args.batch_size, lr=args.lr) |
|
|