Spaces:
Sleeping
Sleeping
File size: 4,865 Bytes
d581b00 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, random_split
from src.data.transforms import train_transforms, val_transforms
DATASET_ROOT = "/Users/siemoncha/Desktop/CS/datasets/artifact-dataset"
REAL_SOURCES = ["coco", "ffhq", "lsun", "imagenet", "landscape", "afhq"]
FAKE_SOURCES = ["stable_diffusion", "stylegan2", "ddpm", "glide", "latent_diffusion"]
MAX_PER_CLASS = 15000 # 15k real + 15k fake = 30k total
class ArtiFact(Dataset):
def __init__(self, transform=None):
self.transform = transform
self.samples = []
self._load_metadata()
def _load_metadata(self):
real, fake = [], []
for source in REAL_SOURCES + FAKE_SOURCES:
csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv")
if not os.path.exists(csv_path):
print(f"Skipping {source} - no metadata.csv")
continue
df = pd.read_csv(csv_path)
for _, row in df.iterrows():
img_path = os.path.join(DATASET_ROOT, source, row["image_path"])
if row["target"] == 0:
real.append((img_path, 0))
else:
fake.append((img_path, 1))
# Balance and subsample
real = real[:MAX_PER_CLASS]
fake = fake[:MAX_PER_CLASS]
self.samples = real + fake
print(f"Real: {len(real)} | Fake: {len(fake)} | Total: {len(self.samples)}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, label
class SampleDataset(Dataset):
def __init__(self, samples, transform=None):
self.samples = samples
self.transform = transform
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, label
def get_dataloaders(batch_size=32):
dataset = ArtiFact(transform=train_transforms)
train_size = int(0.75 * len(dataset))
val_size = int(0.125 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])
# Val and test use val_transforms
val_set.dataset.transform = val_transforms
test_set.dataset.transform = val_transforms
from torch.utils.data import DataLoader
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, val_loader, test_loader
def get_cross_dataset_loaders(batch_size=32):
SEEN_FAKE = ["stable_diffusion", "stylegan2", "ddpm"]
UNSEEN_FAKE = ["glide", "latent_diffusion"]
def load_sources(real_sources, fake_sources, max_per_class=10000):
real, fake = [], []
for source in real_sources:
csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv")
if not os.path.exists(csv_path):
continue
df = pd.read_csv(csv_path)
for _, row in df.iterrows():
img_path = os.path.join(DATASET_ROOT, source, row["image_path"])
if row["target"] == 0:
real.append((img_path, 0))
for source in fake_sources:
csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv")
if not os.path.exists(csv_path):
continue
df = pd.read_csv(csv_path)
for _, row in df.iterrows():
img_path = os.path.join(DATASET_ROOT, source, row["image_path"])
if row["target"] != 0:
fake.append((img_path, 1))
real = real[:max_per_class]
fake = fake[:max_per_class]
return real + fake
from torch.utils.data import DataLoader
train_samples = load_sources(REAL_SOURCES, SEEN_FAKE)
test_samples = load_sources(REAL_SOURCES, UNSEEN_FAKE, max_per_class=5000)
print(f"Train samples: {len(train_samples)}")
print(f"Test samples: {len(test_samples)}")
train_set = SampleDataset(train_samples, transform=train_transforms)
test_set = SampleDataset(test_samples, transform=val_transforms)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, test_loader |