MorphGuard / scripts /train_detector.py
juanquy's picture
Initial clean commit of modular MorphGuard
2978bba
Raw
History Blame Contribute Delete
12.3 kB
#!/usr/bin/env python3
"""
Train the MorphGuard detector model on real vs morph face images.
This is a complete training pipeline that supports:
- Multiple model architectures (ResNet, EfficientNet, Vision Transformer)
- Transfer learning with pretrained weights
- Data augmentation
- Learning rate scheduling
- Logging and visualization
- Early stopping
- Checkpointing
Usage:
python scripts/train_detector.py --data-dir data --epochs 50 --batch-size 32 \
--model efficientnet_b0 --lr 1e-3 --save-path models/morph_detector.pth \
--stats-file training_stats.json --log-metrics
"""
import os
import sys
import json
import argparse
import time
import psycopg2
from datetime import datetime
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
# Add project root to path for imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
try:
import config
except ImportError:
config = None
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
import torchvision
from torchvision import datasets, transforms, models
import timm
# Try to import wandb for logging
try:
import wandb
wandb_available = True
except ImportError:
wandb_available = False
class MorphDataset(Dataset):
"""Dataset for morph detection with on-the-fly augmentation"""
def __init__(self, data_dir, split='train', transform=None):
"""
Args:
data_dir: Path to data directory with train/val/test splits
split: One of 'train', 'val', 'test'
transform: Optional transform to apply to images
"""
self.data_dir = data_dir
self.split = split
self.transform = transform
# Get real and morph image paths
self.real_dir = os.path.join(data_dir, split, 'real')
self.morph_dir = os.path.join(data_dir, split, 'morph')
self.real_images = [os.path.join(self.real_dir, f) for f in os.listdir(self.real_dir)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
self.morph_images = [os.path.join(self.morph_dir, f) for f in os.listdir(self.morph_dir)
if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
# Combine paths and create labels (0 for real, 1 for morph)
self.image_paths = self.real_images + self.morph_images
self.labels = [0] * len(self.real_images) + [1] * len(self.morph_images)
# Shuffle data
if split == 'train':
indices = torch.randperm(len(self.image_paths))
self.image_paths = [self.image_paths[i] for i in indices]
self.labels = [self.labels[i] for i in indices]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# Load image
img_path = self.image_paths[idx]
label = self.labels[idx]
# Read image
img = torchvision.io.read_image(img_path)
# Convert to float and normalize
img = img.float() / 255.0
# Apply transforms if specified
if self.transform:
img = self.transform(img)
return img, label
def get_model(model_name, num_classes=2, pretrained=True):
"""Get a model by name"""
if model_name.startswith('resnet'):
# ResNet models
if model_name == 'resnet18':
model = models.resnet18(pretrained=pretrained)
elif model_name == 'resnet34':
model = models.resnet34(pretrained=pretrained)
elif model_name == 'resnet50':
model = models.resnet50(pretrained=pretrained)
elif model_name == 'resnet101':
model = models.resnet101(pretrained=pretrained)
else:
raise ValueError(f"Unknown ResNet model: {model_name}")
# Replace classifier head
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
elif model_name.startswith('efficientnet'):
# EfficientNet models (using timm)
model = timm.create_model(model_name, pretrained=pretrained)
# Replace classifier head
if hasattr(model, 'classifier'):
in_features = model.classifier.in_features
model.classifier = nn.Linear(in_features, num_classes)
elif hasattr(model, 'fc'):
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
else:
raise ValueError(f"Could not find classifier head in {model_name}")
elif model_name.startswith('vit'):
# Vision Transformer models (using timm)
model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
else:
# Try to load from timm as fallback
try:
model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
except:
raise ValueError(f"Unknown model: {model_name}")
return model
def train_detector(args):
"""Main training function"""
# Set random seeds for reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# Device setup: use GPU if available, else CPU
cuda_available = torch.cuda.is_available()
# Log PyTorch and CUDA versions
print(f"PyTorch version: {torch.__version__}, CUDA version: {torch.version.cuda}, GPU available: {cuda_available}")
if cuda_available:
device = torch.device('cuda')
print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
else:
device = torch.device('cpu')
print("Warning: CUDA not detected or PyTorch built without CUDA support; using CPU.")
# Data transforms
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
val_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize,
])
# Dataset directories
train_dir = os.path.join(args.data_dir, 'train')
val_dir = os.path.join(args.data_dir, 'val')
# Prepare datasets
train_ds = datasets.ImageFolder(train_dir, transform=train_transforms)
val_ds = datasets.ImageFolder(val_dir, transform=val_transforms)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4)
# Model: pretrained ResNet18
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device)
# Test GPU compatibility and fallback to CPU if unsupported
if device.type == 'cuda':
try:
with torch.no_grad():
dummy = torch.randn(1, 3, 224, 224, device=device)
model.eval()
model(dummy)
model.train()
except RuntimeError as e:
msg = str(e)
if 'no kernel image is available' in msg or 'not compatible' in msg:
print(f"Warning: GPU {torch.cuda.get_device_name(0)} not compatible with this PyTorch build; using CPU instead.")
device = torch.device('cpu')
model = model.to(device)
else:
raise
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# Prepare metrics
stats = {'epochs': [], 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
# Training loop
for ep in range(1, args.epochs + 1):
model.train()
running_loss = 0.0
correct_train = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
preds = outputs.argmax(dim=1)
correct_train += (preds == labels).sum().item()
epoch_train_loss = running_loss / len(train_loader.dataset)
epoch_train_acc = correct_train / len(train_loader.dataset)
# Validation
model.eval()
val_loss = 0.0
correct = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
epoch_val_loss = val_loss / len(val_loader.dataset)
epoch_val_acc = correct / len(val_loader.dataset)
# Log metrics
print(f"Epoch {ep}/{args.epochs} | train_loss={epoch_train_loss:.4f} | train_acc={epoch_train_acc:.4f} | val_loss={epoch_val_loss:.4f} | val_acc={epoch_val_acc:.4f}")
stats['epochs'].append(ep)
stats['train_loss'].append(epoch_train_loss)
stats['train_acc'].append(epoch_train_acc)
stats['val_loss'].append(epoch_val_loss)
stats['val_acc'].append(epoch_val_acc)
# Write stats to file
try:
with open(args.stats_file, 'w') as f:
json.dump(stats, f)
except Exception as e:
print(f"Warning: could not write stats file: {e}")
# Insert metrics into TimescaleDB
if config:
try:
conn = psycopg2.connect(
dbname=config.DB_NAME,
user=config.DB_USER,
password=config.DB_PASS,
host=config.DB_HOST,
port=config.DB_PORT
)
cur = conn.cursor()
ts = datetime.utcnow()
cur.execute(
"""INSERT INTO training_metrics
(time, model_name, epoch, loss, accuracy, val_loss, val_accuracy, learning_rate, batch_size, training_session_id)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
(ts, 'morph_detector', ep, epoch_train_loss, epoch_train_acc, epoch_val_loss, epoch_val_acc, args.lr, args.batch_size, args.job_id)
)
conn.commit()
cur.close()
conn.close()
except Exception as e:
print(f"Warning: could not write to TimescaleDB: {e}")
# Save final model
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
torch.save(model.state_dict(), args.save_path)
print(f"Training complete. Model saved to {args.save_path}")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train MorphGuard Detector')
parser.add_argument('--data-dir', type=str, default='data', help='Root data directory (with train/, val/ subdirs)')
parser.add_argument('--epochs', type=int, default=20, help='Number of epochs')
parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
parser.add_argument('--save-path', type=str, default='models/morph_detector.pth', help='Path to save model')
parser.add_argument('--stats-file', type=str, default='training_stats.json', help='JSON file to write metrics')
parser.add_argument('--job-id', type=str, default='detector', help='Job ID for metrics in TimescaleDB')
parser.add_argument('--metrics-path', type=str, dest='stats_file', help='Alias for --stats-file')
parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
args = parser.parse_args()
train_detector(args)