File size: 2,682 Bytes
26c2a4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import yaml
import os

def get_transforms(cfg):
    """

    DINOv2 expects ImageNet normalization.

    We also add some light augmentation to prevent overfitting.

    """
    img_size = cfg['data']['image_size']
    
    # Training Transforms (with Augmentation)
    train_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(p=0.5), # 50% chance to flip
        transforms.ColorJitter(brightness=0.1, contrast=0.1), # Slight color changes
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], # DINOv2 Expected Mean
            std=[0.229, 0.224, 0.225]   # DINOv2 Expected Std
        )
    ])

    # Validation/Test Transforms (No Augmentation)
    val_transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    
    return train_transform, val_transform

def create_dataloaders(config_path="configs/config.yaml"):
    # Load config
    with open(config_path, 'r') as f:
        cfg = yaml.safe_load(f)

    train_transform, val_transform = get_transforms(cfg)
    data_dir = cfg['data']['train_dir'] # Should be "data/raw"

    # 1. Load the Entire Dataset (REAL + FAKE)
    full_dataset = datasets.ImageFolder(root=data_dir)
    
    # 2. Split: 80% Train, 20% Validation
    total_size = len(full_dataset)
    train_size = int(0.8 * total_size)
    val_size = total_size - train_size
    
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    # Apply specific transforms
    train_dataset.dataset.transform = train_transform
    val_dataset.dataset.transform = val_transform

    # 3. Create Loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=cfg['data']['batch_size'], 
        shuffle=True, 
        num_workers=cfg['data']['num_workers']
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=cfg['data']['batch_size'], 
        shuffle=False, 
        num_workers=cfg['data']['num_workers']
    )

    print(f"✅ Data Ready:")
    print(f"   - Train: {len(train_dataset)} images")
    print(f"   - Val:   {len(val_dataset)} images")
    print(f"   - Classes: {full_dataset.class_to_idx}")
    
    return train_loader, val_loader

if __name__ == "__main__":
    create_dataloaders()