Spaces:
Sleeping
Sleeping
File size: 2,682 Bytes
26c2a4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import yaml
import os
def get_transforms(cfg):
"""
DINOv2 expects ImageNet normalization.
We also add some light augmentation to prevent overfitting.
"""
img_size = cfg['data']['image_size']
# Training Transforms (with Augmentation)
train_transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.RandomHorizontalFlip(p=0.5), # 50% chance to flip
transforms.ColorJitter(brightness=0.1, contrast=0.1), # Slight color changes
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], # DINOv2 Expected Mean
std=[0.229, 0.224, 0.225] # DINOv2 Expected Std
)
])
# Validation/Test Transforms (No Augmentation)
val_transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return train_transform, val_transform
def create_dataloaders(config_path="configs/config.yaml"):
# Load config
with open(config_path, 'r') as f:
cfg = yaml.safe_load(f)
train_transform, val_transform = get_transforms(cfg)
data_dir = cfg['data']['train_dir'] # Should be "data/raw"
# 1. Load the Entire Dataset (REAL + FAKE)
full_dataset = datasets.ImageFolder(root=data_dir)
# 2. Split: 80% Train, 20% Validation
total_size = len(full_dataset)
train_size = int(0.8 * total_size)
val_size = total_size - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
# Apply specific transforms
train_dataset.dataset.transform = train_transform
val_dataset.dataset.transform = val_transform
# 3. Create Loaders
train_loader = DataLoader(
train_dataset,
batch_size=cfg['data']['batch_size'],
shuffle=True,
num_workers=cfg['data']['num_workers']
)
val_loader = DataLoader(
val_dataset,
batch_size=cfg['data']['batch_size'],
shuffle=False,
num_workers=cfg['data']['num_workers']
)
print(f"✅ Data Ready:")
print(f" - Train: {len(train_dataset)} images")
print(f" - Val: {len(val_dataset)} images")
print(f" - Classes: {full_dataset.class_to_idx}")
return train_loader, val_loader
if __name__ == "__main__":
create_dataloaders() |