Team1 / data_pipeline.py
michaela299
Restore app files
361cbfe
import torch
from torch.utils.data import DataLoader, default_collate
from torchvision import transforms
from datasets import load_dataset
import torch.utils.data
# ImageNet stats for normalization
IMAGE_MEAN = [0.485, 0.456, 0.406]
IMAGE_STD = [0.229, 0.224, 0.225]
IMAGE_SIZE = 256
# Transforms for training data (with advanced augmentation)
train_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
# geometric augmentations
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5), # Added vertical flip
transforms.RandomRotation(30), # Increased rotation range
# color/appearance augmentations
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Increased intensity
transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.0)), # Added blur
# final conversion
transforms.ToTensor(),
transforms.Normalize(IMAGE_MEAN, IMAGE_STD)
])
# Transforms for validation/test data (no augmentation)
val_test_transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(IMAGE_MEAN, IMAGE_STD)
])
def apply_transforms(batch, transform_pipeline):
"""Applies a transform pipeline to a batch of images and converts labels."""
batch['image'] = [transform_pipeline(img.convert("RGB")) for img in batch['image']]
# This line is crucial for converting labels to tensors for batching
batch['label'] = torch.tensor(batch['label'])
return batch
def get_dataloaders(batch_size=32, use_prototype=True):
"""
Loads, splits, and prepares the PlantVillage dataset, returning DataLoaders.
NOTE TO TEAM: The dataloaders yield a dictionary.
Access batches using:
batch = next(iter(loader))
images = batch['image']
labels = batch['label']
"""
print("Loading and preparing dataset...")
# Load the full dataset from Hugging Face
full_dataset = load_dataset("DScomp380/plant_village", split='train')
if use_prototype:
# Use 20% of data for prototyping
print(f"Using 20% prototype dataset (approx {len(full_dataset) * 0.2:.0f} images)...")
data_subset = full_dataset.train_test_split(test_size=0.8, seed=42)['train']
else:
print(f"Using 100% full dataset ({len(full_dataset)} images)...")
data_subset = full_dataset
# 70/15/15 split for train/val/test
train_val_test_split = data_subset.train_test_split(test_size=0.3, seed=42)
train_dataset = train_val_test_split['train']
val_test_split = train_val_test_split['test'].train_test_split(test_size=0.5, seed=42)
val_dataset = val_test_split['train']
test_dataset = val_test_split['test']
print(f"Total images in prototype: {len(data_subset)}")
print(f"Training images: {len(train_dataset)}")
print(f"Validation images: {len(val_dataset)}")
print(f"Test images: {len(test_dataset)}")
print("--------------------")
# Apply the correct transforms to each dataset split
train_dataset.set_transform(lambda batch: apply_transforms(batch, train_transform))
val_dataset.set_transform(lambda batch: apply_transforms(batch, val_test_transform))
test_dataset.set_transform(lambda batch: apply_transforms(batch, val_test_transform))
# Define the collate_fn for batching tensors
collate_fn = default_collate
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
return train_loader, val_loader, test_loader
if __name__ == "__main__":
print("Running data_pipeline.py as a standalone script...")
# Test the pipeline with a small batch size
train_loader, val_loader, test_loader = get_dataloaders(batch_size=4, use_prototype=True)
print("\n--- Testing Train Loader ---")
# Test train loader
try:
# FIX: Get the batch as a dictionary first
batch = next(iter(train_loader))
# FIX: Access the data using keys
images = batch['image']
labels = batch['label']
print(f"Image batch shape: {images.shape}")
print(f"Label batch shape: {labels.shape}")
# Assert correct shapes
assert images.shape == (4, 3, IMAGE_SIZE, IMAGE_SIZE)
assert labels.shape == (4,)
print("Train loader test PASSED.")
except Exception as e:
print(f"Train loader test FAILED: {e}")
print("\n--- Testing Validation Loader ---")
# Test validation loader
try:
# FIX: Get the batch as a dictionary first
batch = next(iter(val_loader))
# FIX: Access the data using keys
images = batch['image']
labels = batch['label']
print(f"Image batch shape: {images.shape}")
print(f"Label batch shape: {labels.shape}")
# Assert correct shapes
assert images.shape == (4, 3, IMAGE_SIZE, IMAGE_SIZE)
assert labels.shape == (4,)
print("Validation loader test PASSED.")
except Exception as e:
print(f"Validation loader test FAILED: {e}")
print("\nData pipeline script finished.")