File size: 6,064 Bytes
928b74f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | """
training/train_vit.py
----------------------
ViT Branch Training Script β ViT-B/16 via PyTorch + timm
STATUS: COMPLETE
Usage:
cd ImageForensics-Detect/
python training/train_vit.py [--epochs 20] [--batch_size 16] [--lr 1e-4]
Training strategy:
- AdamW optimizer with cosine LR schedule and warmup
- Label smoothing (0.1) for better calibration
- Mixed-precision training (if CUDA available)
- Gradient clipping (max norm = 1.0)
Saves:
- Best model weights β models/vit_branch.pth
- Training history β outputs/vit_training_history.json
"""
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(ROOT))
from training.dataset_loader import discover_dataset, split_dataset, make_torch_dataloader
MODELS_DIR = ROOT / "models"
OUTPUTS_DIR = ROOT / "outputs"
MODELS_DIR.mkdir(exist_ok=True)
OUTPUTS_DIR.mkdir(exist_ok=True)
def train(epochs: int = 20, batch_size: int = 16, lr: float = 1e-4):
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import accuracy_score
from branches.vit_branch import _build_vit_model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\n{'='*55}")
print(" ImageForensics-Detect β ViT Branch Training")
print(f"{'='*55}")
print(f" Device: {device} | Epochs: {epochs} | Batch: {batch_size}")
# ββ 1. Load Dataset ββββββββββββββββββββββββββββββββββββββββββ
paths, labels = discover_dataset()
if len(paths) == 0:
print("\nβ No images found. Populate data/raw/real/ and data/raw/fake/ first.")
sys.exit(1)
splits = split_dataset(paths, labels)
train_loader = make_torch_dataloader(splits["train"][0], splits["train"][1],
batch_size=batch_size, augment=True)
val_loader = make_torch_dataloader(splits["val"][0], splits["val"][1],
batch_size=batch_size, augment=False)
# ββ 2. Model & Optimizer ββββββββββββββββββββββββββββββββββββββ
model = _build_vit_model().to(device)
# Label smoothing loss for better calibration
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
scaler = torch.cuda.amp.GradScaler() if device == "cuda" else None
best_val_acc = 0.0
model_save = str(MODELS_DIR / "vit_branch.pth")
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
for epoch in range(1, epochs + 1):
# ββ Train ββ
model.train()
train_losses, train_preds, train_targets = [], [], []
for imgs, lbls in train_loader:
imgs, lbls = imgs.to(device), lbls.to(device)
optimizer.zero_grad()
if scaler:
with torch.cuda.amp.autocast():
logits = model(imgs)
loss = criterion(logits, lbls)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
else:
logits = model(imgs)
loss = criterion(logits, lbls)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_losses.append(loss.item())
train_preds.extend(logits.argmax(dim=1).cpu().numpy())
train_targets.extend(lbls.cpu().numpy())
scheduler.step()
# ββ Validate ββ
model.eval()
val_losses, val_preds, val_targets = [], [], []
with torch.no_grad():
for imgs, lbls in val_loader:
imgs, lbls = imgs.to(device), lbls.to(device)
logits = model(imgs)
loss = criterion(logits, lbls)
val_losses.append(loss.item())
val_preds.extend(logits.argmax(dim=1).cpu().numpy())
val_targets.extend(lbls.cpu().numpy())
train_acc = accuracy_score(train_targets, train_preds)
val_acc = accuracy_score(val_targets, val_preds)
t_loss = sum(train_losses) / len(train_losses)
v_loss = sum(val_losses) / len(val_losses)
print(f"Epoch [{epoch:02d}/{epochs}] "
f"Train Loss={t_loss:.4f} Acc={train_acc:.4f} | "
f"Val Loss={v_loss:.4f} Acc={val_acc:.4f}")
history["train_loss"].append(t_loss)
history["train_acc"].append(train_acc)
history["val_loss"].append(v_loss)
history["val_acc"].append(val_acc)
# ββ Save best ββ
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), model_save)
print(f" β Best model saved (val_acc={val_acc:.4f})")
print(f"\nβ Training complete. Best val accuracy: {best_val_acc:.4f}")
print(f"β Model saved β {model_save}")
hist_path = OUTPUTS_DIR / "vit_training_history.json"
with open(hist_path, "w") as f:
json.dump(history, f, indent=2)
print(f"β Training history saved β {hist_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train ViT Branch")
parser.add_argument("--epochs", type=int, default=20, help="Total epochs")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
args = parser.parse_args()
train(epochs=args.epochs, batch_size=args.batch_size, lr=args.lr)
|