dk2430098's picture
Upload folder using huggingface_hub
928b74f verified
"""
training/train_vit.py
----------------------
ViT Branch Training Script β€” ViT-B/16 via PyTorch + timm
STATUS: COMPLETE
Usage:
cd ImageForensics-Detect/
python training/train_vit.py [--epochs 20] [--batch_size 16] [--lr 1e-4]
Training strategy:
- AdamW optimizer with cosine LR schedule and warmup
- Label smoothing (0.1) for better calibration
- Mixed-precision training (if CUDA available)
- Gradient clipping (max norm = 1.0)
Saves:
- Best model weights β†’ models/vit_branch.pth
- Training history β†’ outputs/vit_training_history.json
"""
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(ROOT))
from training.dataset_loader import discover_dataset, split_dataset, make_torch_dataloader
MODELS_DIR = ROOT / "models"
OUTPUTS_DIR = ROOT / "outputs"
MODELS_DIR.mkdir(exist_ok=True)
OUTPUTS_DIR.mkdir(exist_ok=True)
def train(epochs: int = 20, batch_size: int = 16, lr: float = 1e-4):
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import accuracy_score
from branches.vit_branch import _build_vit_model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"\n{'='*55}")
print(" ImageForensics-Detect β€” ViT Branch Training")
print(f"{'='*55}")
print(f" Device: {device} | Epochs: {epochs} | Batch: {batch_size}")
# ── 1. Load Dataset ──────────────────────────────────────────
paths, labels = discover_dataset()
if len(paths) == 0:
print("\n❌ No images found. Populate data/raw/real/ and data/raw/fake/ first.")
sys.exit(1)
splits = split_dataset(paths, labels)
train_loader = make_torch_dataloader(splits["train"][0], splits["train"][1],
batch_size=batch_size, augment=True)
val_loader = make_torch_dataloader(splits["val"][0], splits["val"][1],
batch_size=batch_size, augment=False)
# ── 2. Model & Optimizer ──────────────────────────────────────
model = _build_vit_model().to(device)
# Label smoothing loss for better calibration
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
scaler = torch.cuda.amp.GradScaler() if device == "cuda" else None
best_val_acc = 0.0
model_save = str(MODELS_DIR / "vit_branch.pth")
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
for epoch in range(1, epochs + 1):
# ── Train ──
model.train()
train_losses, train_preds, train_targets = [], [], []
for imgs, lbls in train_loader:
imgs, lbls = imgs.to(device), lbls.to(device)
optimizer.zero_grad()
if scaler:
with torch.cuda.amp.autocast():
logits = model(imgs)
loss = criterion(logits, lbls)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
else:
logits = model(imgs)
loss = criterion(logits, lbls)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_losses.append(loss.item())
train_preds.extend(logits.argmax(dim=1).cpu().numpy())
train_targets.extend(lbls.cpu().numpy())
scheduler.step()
# ── Validate ──
model.eval()
val_losses, val_preds, val_targets = [], [], []
with torch.no_grad():
for imgs, lbls in val_loader:
imgs, lbls = imgs.to(device), lbls.to(device)
logits = model(imgs)
loss = criterion(logits, lbls)
val_losses.append(loss.item())
val_preds.extend(logits.argmax(dim=1).cpu().numpy())
val_targets.extend(lbls.cpu().numpy())
train_acc = accuracy_score(train_targets, train_preds)
val_acc = accuracy_score(val_targets, val_preds)
t_loss = sum(train_losses) / len(train_losses)
v_loss = sum(val_losses) / len(val_losses)
print(f"Epoch [{epoch:02d}/{epochs}] "
f"Train Loss={t_loss:.4f} Acc={train_acc:.4f} | "
f"Val Loss={v_loss:.4f} Acc={val_acc:.4f}")
history["train_loss"].append(t_loss)
history["train_acc"].append(train_acc)
history["val_loss"].append(v_loss)
history["val_acc"].append(val_acc)
# ── Save best ──
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), model_save)
print(f" βœ“ Best model saved (val_acc={val_acc:.4f})")
print(f"\nβœ“ Training complete. Best val accuracy: {best_val_acc:.4f}")
print(f"βœ“ Model saved β†’ {model_save}")
hist_path = OUTPUTS_DIR / "vit_training_history.json"
with open(hist_path, "w") as f:
json.dump(history, f, indent=2)
print(f"βœ“ Training history saved β†’ {hist_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train ViT Branch")
parser.add_argument("--epochs", type=int, default=20, help="Total epochs")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
args = parser.parse_args()
train(epochs=args.epochs, batch_size=args.batch_size, lr=args.lr)