#!/usr/bin/env python3 """ Train the MorphGuard detector model on real vs morph face images. This is a complete training pipeline that supports: - Multiple model architectures (ResNet, EfficientNet, Vision Transformer) - Transfer learning with pretrained weights - Data augmentation - Learning rate scheduling - Logging and visualization - Early stopping - Checkpointing Usage: python scripts/train_detector.py --data-dir data --epochs 50 --batch-size 32 \ --model efficientnet_b0 --lr 1e-3 --save-path models/morph_detector.pth \ --stats-file training_stats.json --log-metrics """ import os import sys import json import argparse import time import psycopg2 from datetime import datetime from pathlib import Path import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix # Add project root to path for imports sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) try: import config except ImportError: config = None import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR import torchvision from torchvision import datasets, transforms, models import timm # Try to import wandb for logging try: import wandb wandb_available = True except ImportError: wandb_available = False class MorphDataset(Dataset): """Dataset for morph detection with on-the-fly augmentation""" def __init__(self, data_dir, split='train', transform=None): """ Args: data_dir: Path to data directory with train/val/test splits split: One of 'train', 'val', 'test' transform: Optional transform to apply to images """ self.data_dir = data_dir self.split = split self.transform = transform # Get real and morph image paths self.real_dir = os.path.join(data_dir, split, 'real') self.morph_dir = os.path.join(data_dir, split, 'morph') self.real_images = [os.path.join(self.real_dir, f) for f in os.listdir(self.real_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] self.morph_images = [os.path.join(self.morph_dir, f) for f in os.listdir(self.morph_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] # Combine paths and create labels (0 for real, 1 for morph) self.image_paths = self.real_images + self.morph_images self.labels = [0] * len(self.real_images) + [1] * len(self.morph_images) # Shuffle data if split == 'train': indices = torch.randperm(len(self.image_paths)) self.image_paths = [self.image_paths[i] for i in indices] self.labels = [self.labels[i] for i in indices] def __len__(self): return len(self.image_paths) def __getitem__(self, idx): # Load image img_path = self.image_paths[idx] label = self.labels[idx] # Read image img = torchvision.io.read_image(img_path) # Convert to float and normalize img = img.float() / 255.0 # Apply transforms if specified if self.transform: img = self.transform(img) return img, label def get_model(model_name, num_classes=2, pretrained=True): """Get a model by name""" if model_name.startswith('resnet'): # ResNet models if model_name == 'resnet18': model = models.resnet18(pretrained=pretrained) elif model_name == 'resnet34': model = models.resnet34(pretrained=pretrained) elif model_name == 'resnet50': model = models.resnet50(pretrained=pretrained) elif model_name == 'resnet101': model = models.resnet101(pretrained=pretrained) else: raise ValueError(f"Unknown ResNet model: {model_name}") # Replace classifier head in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) elif model_name.startswith('efficientnet'): # EfficientNet models (using timm) model = timm.create_model(model_name, pretrained=pretrained) # Replace classifier head if hasattr(model, 'classifier'): in_features = model.classifier.in_features model.classifier = nn.Linear(in_features, num_classes) elif hasattr(model, 'fc'): in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) else: raise ValueError(f"Could not find classifier head in {model_name}") elif model_name.startswith('vit'): # Vision Transformer models (using timm) model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes) else: # Try to load from timm as fallback try: model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes) except: raise ValueError(f"Unknown model: {model_name}") return model def train_detector(args): """Main training function""" # Set random seeds for reproducibility torch.manual_seed(args.seed) np.random.seed(args.seed) # Device setup: use GPU if available, else CPU cuda_available = torch.cuda.is_available() # Log PyTorch and CUDA versions print(f"PyTorch version: {torch.__version__}, CUDA version: {torch.version.cuda}, GPU available: {cuda_available}") if cuda_available: device = torch.device('cuda') print(f"Using CUDA device: {torch.cuda.get_device_name(0)}") else: device = torch.device('cpu') print("Warning: CUDA not detected or PyTorch built without CUDA support; using CPU.") # Data transforms normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) val_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), normalize, ]) # Dataset directories train_dir = os.path.join(args.data_dir, 'train') val_dir = os.path.join(args.data_dir, 'val') # Prepare datasets train_ds = datasets.ImageFolder(train_dir, transform=train_transforms) val_ds = datasets.ImageFolder(val_dir, transform=val_transforms) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4) # Model: pretrained ResNet18 model = models.resnet18(pretrained=True) model.fc = nn.Linear(model.fc.in_features, 2) model = model.to(device) # Test GPU compatibility and fallback to CPU if unsupported if device.type == 'cuda': try: with torch.no_grad(): dummy = torch.randn(1, 3, 224, 224, device=device) model.eval() model(dummy) model.train() except RuntimeError as e: msg = str(e) if 'no kernel image is available' in msg or 'not compatible' in msg: print(f"Warning: GPU {torch.cuda.get_device_name(0)} not compatible with this PyTorch build; using CPU instead.") device = torch.device('cpu') model = model.to(device) else: raise # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Prepare metrics stats = {'epochs': [], 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []} # Training loop for ep in range(1, args.epochs + 1): model.train() running_loss = 0.0 correct_train = 0 for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) preds = outputs.argmax(dim=1) correct_train += (preds == labels).sum().item() epoch_train_loss = running_loss / len(train_loader.dataset) epoch_train_acc = correct_train / len(train_loader.dataset) # Validation model.eval() val_loss = 0.0 correct = 0 with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() * inputs.size(0) preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() epoch_val_loss = val_loss / len(val_loader.dataset) epoch_val_acc = correct / len(val_loader.dataset) # Log metrics print(f"Epoch {ep}/{args.epochs} | train_loss={epoch_train_loss:.4f} | train_acc={epoch_train_acc:.4f} | val_loss={epoch_val_loss:.4f} | val_acc={epoch_val_acc:.4f}") stats['epochs'].append(ep) stats['train_loss'].append(epoch_train_loss) stats['train_acc'].append(epoch_train_acc) stats['val_loss'].append(epoch_val_loss) stats['val_acc'].append(epoch_val_acc) # Write stats to file try: with open(args.stats_file, 'w') as f: json.dump(stats, f) except Exception as e: print(f"Warning: could not write stats file: {e}") # Insert metrics into TimescaleDB if config: try: conn = psycopg2.connect( dbname=config.DB_NAME, user=config.DB_USER, password=config.DB_PASS, host=config.DB_HOST, port=config.DB_PORT ) cur = conn.cursor() ts = datetime.utcnow() cur.execute( """INSERT INTO training_metrics (time, model_name, epoch, loss, accuracy, val_loss, val_accuracy, learning_rate, batch_size, training_session_id) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (ts, 'morph_detector', ep, epoch_train_loss, epoch_train_acc, epoch_val_loss, epoch_val_acc, args.lr, args.batch_size, args.job_id) ) conn.commit() cur.close() conn.close() except Exception as e: print(f"Warning: could not write to TimescaleDB: {e}") # Save final model os.makedirs(os.path.dirname(args.save_path), exist_ok=True) torch.save(model.state_dict(), args.save_path) print(f"Training complete. Model saved to {args.save_path}") if __name__ == '__main__': parser = argparse.ArgumentParser(description='Train MorphGuard Detector') parser.add_argument('--data-dir', type=str, default='data', help='Root data directory (with train/, val/ subdirs)') parser.add_argument('--epochs', type=int, default=20, help='Number of epochs') parser.add_argument('--batch-size', type=int, default=32, help='Batch size') parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') parser.add_argument('--save-path', type=str, default='models/morph_detector.pth', help='Path to save model') parser.add_argument('--stats-file', type=str, default='training_stats.json', help='JSON file to write metrics') parser.add_argument('--job-id', type=str, default='detector', help='Job ID for metrics in TimescaleDB') parser.add_argument('--metrics-path', type=str, dest='stats_file', help='Alias for --stats-file') parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility') args = parser.parse_args() train_detector(args)