File size: 3,526 Bytes
dda3973
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import pandas as pd
import random
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from src.data.transforms import train_transforms, val_transforms

DATASET_ROOT = "/Users/siemoncha/Desktop/CS/datasets/artifact-dataset"

# Manual class mapping based on architecture knowledge
SOURCE_CLASS_MAP = {
    # Class 0 - Real
    "coco": 0,
    "ffhq": 0,
    "lsun": 0,
    "imagenet": 0,
    "landscape": 0,
    "afhq": 0,
    "celebahq": 0,
    "metfaces": 0,

    # Class 1 - GAN
    "stylegan1": 1,
    "stylegan2": 1,
    "stylegan3": 1,
    "pro_gan": 1,
    "big_gan": 1,
    "star_gan": 1,
    "cycle_gan": 1,
    "gansformer": 1,
    "generative_inpainting": 1,
    "lama": 1,
    "mat": 1,
    "sfhq": 1,
    "cips": 1,
    "projected_gan": 1,
    "gau_gan": 1,

    # Class 2 - Diffusion
    "stable_diffusion": 2,
    "ddpm": 2,
    "glide": 2,
    "latent_diffusion": 2,
    "vq_diffusion": 2,
    "denoising_diffusion_gan": 2,
    "diffusion_gan": 2,
    "palette": 2,

    # Class 3 - Other
    "taming_transformer": 3,
    "face_synthetics": 3,
}

CLASS_NAMES = {0: "Real", 1: "GAN", 2: "Diffusion", 3: "Other"}
MAX_PER_CLASS = 10000


class GeneratorDataset(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label


def load_generator_samples():
    class_samples = {0: [], 1: [], 2: [], 3: []}

    for source, cls in SOURCE_CLASS_MAP.items():
        csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv")
        if not os.path.exists(csv_path):
            print(f"Skipping {source} - no metadata.csv")
            continue
        df = pd.read_csv(csv_path)
        for _, row in df.iterrows():
            img_path = os.path.join(DATASET_ROOT, source, row["image_path"])
            class_samples[cls].append((img_path, cls))

    # Balance classes
    for cls in class_samples:
        class_samples[cls] = class_samples[cls][:MAX_PER_CLASS]
        print(f"Class {cls} ({CLASS_NAMES[cls]}): {len(class_samples[cls])} samples")

    all_samples = []
    for cls in class_samples:
        all_samples.extend(class_samples[cls])

    print(f"Total: {len(all_samples)}")
    return all_samples


def get_generator_dataloaders(batch_size=32):
    all_samples = load_generator_samples()

    # Shuffle before splitting
    random.shuffle(all_samples)

    train_size = int(0.75 * len(all_samples))
    val_size = int(0.125 * len(all_samples))
    test_size = len(all_samples) - train_size - val_size

    train_samples = all_samples[:train_size]
    val_samples = all_samples[train_size:train_size + val_size]
    test_samples = all_samples[train_size + val_size:]

    train_set = GeneratorDataset(train_samples, transform=train_transforms)
    val_set = GeneratorDataset(val_samples, transform=val_transforms)
    test_set = GeneratorDataset(test_samples, transform=val_transforms)

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loader