ImageTrust-AI / src /data /loader.py
SiemonCha's picture
initial deployment
d581b00
import os
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, random_split
from src.data.transforms import train_transforms, val_transforms
DATASET_ROOT = "/Users/siemoncha/Desktop/CS/datasets/artifact-dataset"
REAL_SOURCES = ["coco", "ffhq", "lsun", "imagenet", "landscape", "afhq"]
FAKE_SOURCES = ["stable_diffusion", "stylegan2", "ddpm", "glide", "latent_diffusion"]
MAX_PER_CLASS = 15000 # 15k real + 15k fake = 30k total
class ArtiFact(Dataset):
def __init__(self, transform=None):
self.transform = transform
self.samples = []
self._load_metadata()
def _load_metadata(self):
real, fake = [], []
for source in REAL_SOURCES + FAKE_SOURCES:
csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv")
if not os.path.exists(csv_path):
print(f"Skipping {source} - no metadata.csv")
continue
df = pd.read_csv(csv_path)
for _, row in df.iterrows():
img_path = os.path.join(DATASET_ROOT, source, row["image_path"])
if row["target"] == 0:
real.append((img_path, 0))
else:
fake.append((img_path, 1))
# Balance and subsample
real = real[:MAX_PER_CLASS]
fake = fake[:MAX_PER_CLASS]
self.samples = real + fake
print(f"Real: {len(real)} | Fake: {len(fake)} | Total: {len(self.samples)}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, label
class SampleDataset(Dataset):
def __init__(self, samples, transform=None):
self.samples = samples
self.transform = transform
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, label
def get_dataloaders(batch_size=32):
dataset = ArtiFact(transform=train_transforms)
train_size = int(0.75 * len(dataset))
val_size = int(0.125 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])
# Val and test use val_transforms
val_set.dataset.transform = val_transforms
test_set.dataset.transform = val_transforms
from torch.utils.data import DataLoader
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, val_loader, test_loader
def get_cross_dataset_loaders(batch_size=32):
SEEN_FAKE = ["stable_diffusion", "stylegan2", "ddpm"]
UNSEEN_FAKE = ["glide", "latent_diffusion"]
def load_sources(real_sources, fake_sources, max_per_class=10000):
real, fake = [], []
for source in real_sources:
csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv")
if not os.path.exists(csv_path):
continue
df = pd.read_csv(csv_path)
for _, row in df.iterrows():
img_path = os.path.join(DATASET_ROOT, source, row["image_path"])
if row["target"] == 0:
real.append((img_path, 0))
for source in fake_sources:
csv_path = os.path.join(DATASET_ROOT, source, "metadata.csv")
if not os.path.exists(csv_path):
continue
df = pd.read_csv(csv_path)
for _, row in df.iterrows():
img_path = os.path.join(DATASET_ROOT, source, row["image_path"])
if row["target"] != 0:
fake.append((img_path, 1))
real = real[:max_per_class]
fake = fake[:max_per_class]
return real + fake
from torch.utils.data import DataLoader
train_samples = load_sources(REAL_SOURCES, SEEN_FAKE)
test_samples = load_sources(REAL_SOURCES, UNSEEN_FAKE, max_per_class=5000)
print(f"Train samples: {len(train_samples)}")
print(f"Test samples: {len(test_samples)}")
train_set = SampleDataset(train_samples, transform=train_transforms)
test_set = SampleDataset(test_samples, transform=val_transforms)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
return train_loader, test_loader