import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader, random_split import yaml import os def get_transforms(cfg): """ DINOv2 expects ImageNet normalization. We also add some light augmentation to prevent overfitting. """ img_size = cfg['data']['image_size'] # Training Transforms (with Augmentation) train_transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.RandomHorizontalFlip(p=0.5), # 50% chance to flip transforms.ColorJitter(brightness=0.1, contrast=0.1), # Slight color changes transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], # DINOv2 Expected Mean std=[0.229, 0.224, 0.225] # DINOv2 Expected Std ) ]) # Validation/Test Transforms (No Augmentation) val_transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) return train_transform, val_transform def create_dataloaders(config_path="configs/config.yaml"): # Load config with open(config_path, 'r') as f: cfg = yaml.safe_load(f) train_transform, val_transform = get_transforms(cfg) data_dir = cfg['data']['train_dir'] # Should be "data/raw" # 1. Load the Entire Dataset (REAL + FAKE) full_dataset = datasets.ImageFolder(root=data_dir) # 2. Split: 80% Train, 20% Validation total_size = len(full_dataset) train_size = int(0.8 * total_size) val_size = total_size - train_size train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) # Apply specific transforms train_dataset.dataset.transform = train_transform val_dataset.dataset.transform = val_transform # 3. Create Loaders train_loader = DataLoader( train_dataset, batch_size=cfg['data']['batch_size'], shuffle=True, num_workers=cfg['data']['num_workers'] ) val_loader = DataLoader( val_dataset, batch_size=cfg['data']['batch_size'], shuffle=False, num_workers=cfg['data']['num_workers'] ) print(f"✅ Data Ready:") print(f" - Train: {len(train_dataset)} images") print(f" - Val: {len(val_dataset)} images") print(f" - Classes: {full_dataset.class_to_idx}") return train_loader, val_loader if __name__ == "__main__": create_dataloaders()