Spaces:
Sleeping
Sleeping
File size: 3,526 Bytes
dda3973 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import os
import pandas as pd
import random
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from src.data.transforms import train_transforms, val_transforms
DATASET_ROOT = "/Users/siemoncha/Desktop/CS/datasets/artifact-dataset"
# Manual class mapping based on architecture knowledge
SOURCE_CLASS_MAP = {
# Class 0 - Real
"coco": 0,
"ffhq": 0,
"lsun": 0,
"imagenet": 0,
"landscape": 0,
"afhq": 0,
"celebahq": 0,
"metfaces": 0,
# Class 1 - GAN
"stylegan1": 1,
"stylegan2": 1,
"stylegan3": 1,
"pro_gan": 1,
"big_gan": 1,
"star_gan": 1,
"cycle_gan": 1,
"gansformer": 1,
"generative_inpainting": 1,
"lama": 1,
"mat": 1,
"sfhq": 1,
"cips": 1,
"projected_gan": 1,
"gau_gan": 1,
# Class 2 - Diffusion
"stable_diffusion": 2,
"ddpm": 2,
"glide": 2,
"latent_diffusion": 2,
"vq_diffusion": 2,
"denoising_diffusion_gan": 2,
"diffusion_gan": 2,
"palette": 2,
# Class 3 - Other
"taming_transformer": 3,
"face_synthetics": 3,
}
CLASS_NAMES = {0: "Real", 1: "GAN", 2: "Diffusion", 3: "Other"}
MAX_PER_CLASS = 10000
class GeneratorDataset(Dataset):
def __init__(self, samples, transform=None):
self.samples = samples
self.transform = transform
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, label
def load_generator_samples():
class_samples = {0: [], 1: [], 2: [], 3: []}
for source, cls in SOURCE_CLASS_MAP.items():
csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv")
if not os.path.exists(csv_path):
print(f"Skipping {source} - no metadata.csv")
continue
df = pd.read_csv(csv_path)
for _, row in df.iterrows():
img_path = os.path.join(DATASET_ROOT, source, row["image_path"])
class_samples[cls].append((img_path, cls))
# Balance classes
for cls in class_samples:
class_samples[cls] = class_samples[cls][:MAX_PER_CLASS]
print(f"Class {cls} ({CLASS_NAMES[cls]}): {len(class_samples[cls])} samples")
all_samples = []
for cls in class_samples:
all_samples.extend(class_samples[cls])
print(f"Total: {len(all_samples)}")
return all_samples
def get_generator_dataloaders(batch_size=32):
all_samples = load_generator_samples()
# Shuffle before splitting
random.shuffle(all_samples)
train_size = int(0.75 * len(all_samples))
val_size = int(0.125 * len(all_samples))
test_size = len(all_samples) - train_size - val_size
train_samples = all_samples[:train_size]
val_samples = all_samples[train_size:train_size + val_size]
test_samples = all_samples[train_size + val_size:]
train_set = GeneratorDataset(train_samples, transform=train_transforms)
val_set = GeneratorDataset(val_samples, transform=val_transforms)
test_set = GeneratorDataset(test_samples, transform=val_transforms)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, val_loader, test_loader |