import os import pandas as pd from PIL import Image from torch.utils.data import Dataset, random_split from src.data.transforms import train_transforms, val_transforms DATASET_ROOT = "/Users/siemoncha/Desktop/CS/datasets/artifact-dataset" REAL_SOURCES = ["coco", "ffhq", "lsun", "imagenet", "landscape", "afhq"] FAKE_SOURCES = ["stable_diffusion", "stylegan2", "ddpm", "glide", "latent_diffusion"] MAX_PER_CLASS = 15000 # 15k real + 15k fake = 30k total class ArtiFact(Dataset): def __init__(self, transform=None): self.transform = transform self.samples = [] self._load_metadata() def _load_metadata(self): real, fake = [], [] for source in REAL_SOURCES + FAKE_SOURCES: csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv") if not os.path.exists(csv_path): print(f"Skipping {source} - no metadata.csv") continue df = pd.read_csv(csv_path) for _, row in df.iterrows(): img_path = os.path.join(DATASET_ROOT, source, row["image_path"]) if row["target"] == 0: real.append((img_path, 0)) else: fake.append((img_path, 1)) # Balance and subsample real = real[:MAX_PER_CLASS] fake = fake[:MAX_PER_CLASS] self.samples = real + fake print(f"Real: {len(real)} | Fake: {len(fake)} | Total: {len(self.samples)}") def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, label = self.samples[idx] image = Image.open(img_path).convert("RGB") if self.transform: image = self.transform(image) return image, label class SampleDataset(Dataset): def __init__(self, samples, transform=None): self.samples = samples self.transform = transform def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, label = self.samples[idx] image = Image.open(img_path).convert("RGB") if self.transform: image = self.transform(image) return image, label def get_dataloaders(batch_size=32): dataset = ArtiFact(transform=train_transforms) train_size = int(0.75 * len(dataset)) val_size = int(0.125 * len(dataset)) test_size = len(dataset) - train_size - val_size train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size]) # Val and test use val_transforms val_set.dataset.transform = val_transforms test_set.dataset.transform = val_transforms from torch.utils.data import DataLoader train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2) val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2) test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2) return train_loader, val_loader, test_loader def get_cross_dataset_loaders(batch_size=32): SEEN_FAKE = ["stable_diffusion", "stylegan2", "ddpm"] UNSEEN_FAKE = ["glide", "latent_diffusion"] def load_sources(real_sources, fake_sources, max_per_class=10000): real, fake = [], [] for source in real_sources: csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv") if not os.path.exists(csv_path): continue df = pd.read_csv(csv_path) for _, row in df.iterrows(): img_path = os.path.join(DATASET_ROOT, source, row["image_path"]) if row["target"] == 0: real.append((img_path, 0)) for source in fake_sources: csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv") if not os.path.exists(csv_path): continue df = pd.read_csv(csv_path) for _, row in df.iterrows(): img_path = os.path.join(DATASET_ROOT, source, row["image_path"]) if row["target"] != 0: fake.append((img_path, 1)) real = real[:max_per_class] fake = fake[:max_per_class] return real + fake from torch.utils.data import DataLoader train_samples = load_sources(REAL_SOURCES, SEEN_FAKE) test_samples = load_sources(REAL_SOURCES, UNSEEN_FAKE, max_per_class=5000) print(f"Train samples: {len(train_samples)}") print(f"Test samples: {len(test_samples)}") train_set = SampleDataset(train_samples, transform=train_transforms) test_set = SampleDataset(test_samples, transform=val_transforms) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2) test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2) return train_loader, test_loader