File size: 4,865 Bytes
d581b00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, random_split
from src.data.transforms import train_transforms, val_transforms

DATASET_ROOT = "/Users/siemoncha/Desktop/CS/datasets/artifact-dataset"

REAL_SOURCES = ["coco", "ffhq", "lsun", "imagenet", "landscape", "afhq"]
FAKE_SOURCES = ["stable_diffusion", "stylegan2", "ddpm", "glide", "latent_diffusion"]

MAX_PER_CLASS = 15000  # 15k real + 15k fake = 30k total


class ArtiFact(Dataset):
    def __init__(self, transform=None):
        self.transform = transform
        self.samples = []
        self._load_metadata()

    def _load_metadata(self):
        real, fake = [], []

        for source in REAL_SOURCES + FAKE_SOURCES:
            csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv")
            if not os.path.exists(csv_path):
                print(f"Skipping {source} - no metadata.csv")
                continue
            df = pd.read_csv(csv_path)
            for _, row in df.iterrows():
                img_path = os.path.join(DATASET_ROOT, source, row["image_path"])
                if row["target"] == 0:
                    real.append((img_path, 0))
                else:
                    fake.append((img_path, 1))

        # Balance and subsample
        real = real[:MAX_PER_CLASS]
        fake = fake[:MAX_PER_CLASS]
        self.samples = real + fake

        print(f"Real: {len(real)} | Fake: {len(fake)} | Total: {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

class SampleDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label


def get_dataloaders(batch_size=32):
    dataset = ArtiFact(transform=train_transforms)

    train_size = int(0.75 * len(dataset))
    val_size = int(0.125 * len(dataset))
    test_size = len(dataset) - train_size - val_size

    train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

    # Val and test use val_transforms
    val_set.dataset.transform = val_transforms
    test_set.dataset.transform = val_transforms

    from torch.utils.data import DataLoader
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loader


def get_cross_dataset_loaders(batch_size=32):
    SEEN_FAKE = ["stable_diffusion", "stylegan2", "ddpm"]
    UNSEEN_FAKE = ["glide", "latent_diffusion"]

    def load_sources(real_sources, fake_sources, max_per_class=10000):
        real, fake = [], []
        for source in real_sources:
            csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv")
            if not os.path.exists(csv_path):
                continue
            df = pd.read_csv(csv_path)
            for _, row in df.iterrows():
                img_path = os.path.join(DATASET_ROOT, source, row["image_path"])
                if row["target"] == 0:
                    real.append((img_path, 0))
        for source in fake_sources:
            csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv")
            if not os.path.exists(csv_path):
                continue
            df = pd.read_csv(csv_path)
            for _, row in df.iterrows():
                img_path = os.path.join(DATASET_ROOT, source, row["image_path"])
                if row["target"] != 0:
                    fake.append((img_path, 1))
        real = real[:max_per_class]
        fake = fake[:max_per_class]
        return real + fake

    from torch.utils.data import DataLoader

    train_samples = load_sources(REAL_SOURCES, SEEN_FAKE)
    test_samples = load_sources(REAL_SOURCES, UNSEEN_FAKE, max_per_class=5000)

    print(f"Train samples: {len(train_samples)}")
    print(f"Test samples: {len(test_samples)}")

    train_set = SampleDataset(train_samples, transform=train_transforms)
    test_set = SampleDataset(test_samples, transform=val_transforms)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, test_loader