import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets import torchvision.transforms as transforms from torch.utils.data import DataLoader from torch.optim.lr_scheduler import StepLR import albumentations as A from albumentations.pytorch import ToTensorV2 import numpy as np from PIL import Image # Import the new V2 Architecture Modules from core.v2_architecture import MultiModalDeepfakeSystemV2, CompoundLoss from core.diffusion_latent import DiffusionErrorLoop class MultiModalDataset(datasets.DatasetFolder): def __init__(self, root): # Only look for .pt files super().__init__(root, loader=torch.load, extensions=('.pt',)) def __getitem__(self, index): path, _ = self.samples[index] # data is a dict: {spatial_tensor, freq_tensor, latent_tensor, stat_tensor, label} data = self.loader(path) return data def validate(model, val_loader, device): model.eval() correct = 0 total = 0 with torch.no_grad(): for batch in val_loader: spatial = batch["spatial_tensor"].to(device) freq = batch["freq_tensor"].to(device) latent = batch["latent_tensor"].to(device) stat = batch["stat_tensor"].to(device) labels = batch["label"].to(device) # Forward Main Architecture outputs = model(spatial, freq, latent, stat) # Threshold logit at 0.0 (equivalent to prob > 0.5) predicted = (outputs.squeeze() > 0.0).long() total += labels.size(0) correct += (predicted == labels.squeeze()).sum().item() return 100 * correct / total def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using:", device) # Enable NVIDIA CuDNN Auto-Tuner: Drastically speeds up convolution math on fixed-size images if device.type == 'cuda': torch.backends.cudnn.benchmark = True train_data = MultiModalDataset("dataset/processed_train") val_data = MultiModalDataset("dataset/processed_val") print("Class mapping:", train_data.class_to_idx) # VAE Error loop and Multi-Modal model are heavy, lowering batch size from 8 to 4 batch_size = 4 accumulation_steps = 8 train_loader = DataLoader( train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True ) val_loader = DataLoader( val_data, batch_size=batch_size, num_workers=4, pin_memory=True ) print("Initializing Multi-Modal Deepfake System V2...") model = MultiModalDeepfakeSystemV2().to(device) # Calculate class weights to handle dataset imbalance # targets may not be immediately available depending on DatasetFolder processing # but self.targets is populated in DatasetFolder labels_list = train_data.targets class_counts = np.bincount(labels_list) if len(class_counts) > 1: # pos_weight = negative_samples / positive_samples (class 0 / class 1) pos_weight = torch.tensor([class_counts[0] / class_counts[1]], device=device, dtype=torch.float) else: pos_weight = None # We use our new Compound Loss (BCE + Contrastive) criterion = CompoundLoss(lambda_weight=0.35) if pos_weight is not None: criterion.bce_loss = nn.BCEWithLogitsLoss(pos_weight=pos_weight) optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4) scheduler = StepLR(optimizer, step_size=3, gamma=0.1) epochs = 10 best_val_acc = 0.0 patience = 5 patience_counter = 0 scaler = torch.amp.GradScaler('cuda') try: for epoch in range(epochs): print(f"Starting Epoch {epoch+1}...") model.train() total_loss = 0 optimizer.zero_grad() for i, batch in enumerate(train_loader): if i == 0 and epoch == 0: print("SUCCESS: Grabbed the first batch of pre-computed tensors. Processing...") spatial = batch["spatial_tensor"].to(device) freq = batch["freq_tensor"].to(device) latent = batch["latent_tensor"].to(device) stat = batch["stat_tensor"].to(device) labels = batch["label"].to(device) with torch.amp.autocast('cuda'): # Forward Main Architecture, returning features for Contrastive Loss logits, features = model(spatial, freq, latent, stat, return_features=True) # Compute Compound Loss loss, loss_bce, loss_contrastive = criterion(logits, labels, features) loss = loss / accumulation_steps scaler.scale(loss).backward() if (i + 1) % accumulation_steps == 0 or (i + 1) == len(train_loader): scaler.step(optimizer) scaler.update() optimizer.zero_grad() total_loss += loss.item() * accumulation_steps # Print periodic batch updates if i % 10 == 0: print(f"Batch {i}/{len(train_loader)} - BCE: {loss_bce.item():.4f}, Contrastive: {loss_contrastive.item():.4f}") avg_loss = total_loss / len(train_loader) val_acc = validate(model, val_loader, device) print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Val Accuracy: {val_acc:.2f}%") if val_acc > best_val_acc: best_val_acc = val_acc torch.save(model.state_dict(), "model_best.pth") print("--> Best model checkpoint completely secured (model_best.pth).") patience_counter = 0 else: patience_counter += 1 if patience_counter >= patience: print(f"Early stopping triggered after {epoch+1} epochs.") break scheduler.step() except KeyboardInterrupt: print("\n[!] Training halted manually by user. The highest accuracy checkpoint is completely saved as 'model_best.pth'!") torch.save(model.state_dict(), "model.pth") print("\nTraining procedure officially terminated. 'model.pth' securely written to disk!") if __name__ == "__main__": main()