Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Train the MorphGuard detector model on real vs morph face images. | |
| This is a complete training pipeline that supports: | |
| - Multiple model architectures (ResNet, EfficientNet, Vision Transformer) | |
| - Transfer learning with pretrained weights | |
| - Data augmentation | |
| - Learning rate scheduling | |
| - Logging and visualization | |
| - Early stopping | |
| - Checkpointing | |
| Usage: | |
| python scripts/train_detector.py --data-dir data --epochs 50 --batch-size 32 \ | |
| --model efficientnet_b0 --lr 1e-3 --save-path models/morph_detector.pth \ | |
| --stats-file training_stats.json --log-metrics | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import argparse | |
| import time | |
| import psycopg2 | |
| from datetime import datetime | |
| from pathlib import Path | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix | |
| # Add project root to path for imports | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| try: | |
| import config | |
| except ImportError: | |
| config = None | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR | |
| import torchvision | |
| from torchvision import datasets, transforms, models | |
| import timm | |
| # Try to import wandb for logging | |
| try: | |
| import wandb | |
| wandb_available = True | |
| except ImportError: | |
| wandb_available = False | |
| class MorphDataset(Dataset): | |
| """Dataset for morph detection with on-the-fly augmentation""" | |
| def __init__(self, data_dir, split='train', transform=None): | |
| """ | |
| Args: | |
| data_dir: Path to data directory with train/val/test splits | |
| split: One of 'train', 'val', 'test' | |
| transform: Optional transform to apply to images | |
| """ | |
| self.data_dir = data_dir | |
| self.split = split | |
| self.transform = transform | |
| # Get real and morph image paths | |
| self.real_dir = os.path.join(data_dir, split, 'real') | |
| self.morph_dir = os.path.join(data_dir, split, 'morph') | |
| self.real_images = [os.path.join(self.real_dir, f) for f in os.listdir(self.real_dir) | |
| if f.lower().endswith(('.jpg', '.jpeg', '.png'))] | |
| self.morph_images = [os.path.join(self.morph_dir, f) for f in os.listdir(self.morph_dir) | |
| if f.lower().endswith(('.jpg', '.jpeg', '.png'))] | |
| # Combine paths and create labels (0 for real, 1 for morph) | |
| self.image_paths = self.real_images + self.morph_images | |
| self.labels = [0] * len(self.real_images) + [1] * len(self.morph_images) | |
| # Shuffle data | |
| if split == 'train': | |
| indices = torch.randperm(len(self.image_paths)) | |
| self.image_paths = [self.image_paths[i] for i in indices] | |
| self.labels = [self.labels[i] for i in indices] | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| # Load image | |
| img_path = self.image_paths[idx] | |
| label = self.labels[idx] | |
| # Read image | |
| img = torchvision.io.read_image(img_path) | |
| # Convert to float and normalize | |
| img = img.float() / 255.0 | |
| # Apply transforms if specified | |
| if self.transform: | |
| img = self.transform(img) | |
| return img, label | |
| def get_model(model_name, num_classes=2, pretrained=True): | |
| """Get a model by name""" | |
| if model_name.startswith('resnet'): | |
| # ResNet models | |
| if model_name == 'resnet18': | |
| model = models.resnet18(pretrained=pretrained) | |
| elif model_name == 'resnet34': | |
| model = models.resnet34(pretrained=pretrained) | |
| elif model_name == 'resnet50': | |
| model = models.resnet50(pretrained=pretrained) | |
| elif model_name == 'resnet101': | |
| model = models.resnet101(pretrained=pretrained) | |
| else: | |
| raise ValueError(f"Unknown ResNet model: {model_name}") | |
| # Replace classifier head | |
| in_features = model.fc.in_features | |
| model.fc = nn.Linear(in_features, num_classes) | |
| elif model_name.startswith('efficientnet'): | |
| # EfficientNet models (using timm) | |
| model = timm.create_model(model_name, pretrained=pretrained) | |
| # Replace classifier head | |
| if hasattr(model, 'classifier'): | |
| in_features = model.classifier.in_features | |
| model.classifier = nn.Linear(in_features, num_classes) | |
| elif hasattr(model, 'fc'): | |
| in_features = model.fc.in_features | |
| model.fc = nn.Linear(in_features, num_classes) | |
| else: | |
| raise ValueError(f"Could not find classifier head in {model_name}") | |
| elif model_name.startswith('vit'): | |
| # Vision Transformer models (using timm) | |
| model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes) | |
| else: | |
| # Try to load from timm as fallback | |
| try: | |
| model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes) | |
| except: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| return model | |
| def train_detector(args): | |
| """Main training function""" | |
| # Set random seeds for reproducibility | |
| torch.manual_seed(args.seed) | |
| np.random.seed(args.seed) | |
| # Device setup: use GPU if available, else CPU | |
| cuda_available = torch.cuda.is_available() | |
| # Log PyTorch and CUDA versions | |
| print(f"PyTorch version: {torch.__version__}, CUDA version: {torch.version.cuda}, GPU available: {cuda_available}") | |
| if cuda_available: | |
| device = torch.device('cuda') | |
| print(f"Using CUDA device: {torch.cuda.get_device_name(0)}") | |
| else: | |
| device = torch.device('cpu') | |
| print("Warning: CUDA not detected or PyTorch built without CUDA support; using CPU.") | |
| # Data transforms | |
| normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| train_transforms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.ToTensor(), | |
| normalize, | |
| ]) | |
| val_transforms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| normalize, | |
| ]) | |
| # Dataset directories | |
| train_dir = os.path.join(args.data_dir, 'train') | |
| val_dir = os.path.join(args.data_dir, 'val') | |
| # Prepare datasets | |
| train_ds = datasets.ImageFolder(train_dir, transform=train_transforms) | |
| val_ds = datasets.ImageFolder(val_dir, transform=val_transforms) | |
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=4) | |
| val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4) | |
| # Model: pretrained ResNet18 | |
| model = models.resnet18(pretrained=True) | |
| model.fc = nn.Linear(model.fc.in_features, 2) | |
| model = model.to(device) | |
| # Test GPU compatibility and fallback to CPU if unsupported | |
| if device.type == 'cuda': | |
| try: | |
| with torch.no_grad(): | |
| dummy = torch.randn(1, 3, 224, 224, device=device) | |
| model.eval() | |
| model(dummy) | |
| model.train() | |
| except RuntimeError as e: | |
| msg = str(e) | |
| if 'no kernel image is available' in msg or 'not compatible' in msg: | |
| print(f"Warning: GPU {torch.cuda.get_device_name(0)} not compatible with this PyTorch build; using CPU instead.") | |
| device = torch.device('cpu') | |
| model = model.to(device) | |
| else: | |
| raise | |
| # Loss and optimizer | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) | |
| # Prepare metrics | |
| stats = {'epochs': [], 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []} | |
| # Training loop | |
| for ep in range(1, args.epochs + 1): | |
| model.train() | |
| running_loss = 0.0 | |
| correct_train = 0 | |
| for inputs, labels in train_loader: | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * inputs.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct_train += (preds == labels).sum().item() | |
| epoch_train_loss = running_loss / len(train_loader.dataset) | |
| epoch_train_acc = correct_train / len(train_loader.dataset) | |
| # Validation | |
| model.eval() | |
| val_loss = 0.0 | |
| correct = 0 | |
| with torch.no_grad(): | |
| for inputs, labels in val_loader: | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| val_loss += loss.item() * inputs.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| epoch_val_loss = val_loss / len(val_loader.dataset) | |
| epoch_val_acc = correct / len(val_loader.dataset) | |
| # Log metrics | |
| print(f"Epoch {ep}/{args.epochs} | train_loss={epoch_train_loss:.4f} | train_acc={epoch_train_acc:.4f} | val_loss={epoch_val_loss:.4f} | val_acc={epoch_val_acc:.4f}") | |
| stats['epochs'].append(ep) | |
| stats['train_loss'].append(epoch_train_loss) | |
| stats['train_acc'].append(epoch_train_acc) | |
| stats['val_loss'].append(epoch_val_loss) | |
| stats['val_acc'].append(epoch_val_acc) | |
| # Write stats to file | |
| try: | |
| with open(args.stats_file, 'w') as f: | |
| json.dump(stats, f) | |
| except Exception as e: | |
| print(f"Warning: could not write stats file: {e}") | |
| # Insert metrics into TimescaleDB | |
| if config: | |
| try: | |
| conn = psycopg2.connect( | |
| dbname=config.DB_NAME, | |
| user=config.DB_USER, | |
| password=config.DB_PASS, | |
| host=config.DB_HOST, | |
| port=config.DB_PORT | |
| ) | |
| cur = conn.cursor() | |
| ts = datetime.utcnow() | |
| cur.execute( | |
| """INSERT INTO training_metrics | |
| (time, model_name, epoch, loss, accuracy, val_loss, val_accuracy, learning_rate, batch_size, training_session_id) | |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", | |
| (ts, 'morph_detector', ep, epoch_train_loss, epoch_train_acc, epoch_val_loss, epoch_val_acc, args.lr, args.batch_size, args.job_id) | |
| ) | |
| conn.commit() | |
| cur.close() | |
| conn.close() | |
| except Exception as e: | |
| print(f"Warning: could not write to TimescaleDB: {e}") | |
| # Save final model | |
| os.makedirs(os.path.dirname(args.save_path), exist_ok=True) | |
| torch.save(model.state_dict(), args.save_path) | |
| print(f"Training complete. Model saved to {args.save_path}") | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Train MorphGuard Detector') | |
| parser.add_argument('--data-dir', type=str, default='data', help='Root data directory (with train/, val/ subdirs)') | |
| parser.add_argument('--epochs', type=int, default=20, help='Number of epochs') | |
| parser.add_argument('--batch-size', type=int, default=32, help='Batch size') | |
| parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') | |
| parser.add_argument('--save-path', type=str, default='models/morph_detector.pth', help='Path to save model') | |
| parser.add_argument('--stats-file', type=str, default='training_stats.json', help='JSON file to write metrics') | |
| parser.add_argument('--job-id', type=str, default='detector', help='Job ID for metrics in TimescaleDB') | |
| parser.add_argument('--metrics-path', type=str, dest='stats_file', help='Alias for --stats-file') | |
| parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility') | |
| args = parser.parse_args() | |
| train_detector(args) |