import torch from torch.utils.data import DataLoader, default_collate from torchvision import transforms from datasets import load_dataset import torch.utils.data # ImageNet stats for normalization IMAGE_MEAN = [0.485, 0.456, 0.406] IMAGE_STD = [0.229, 0.224, 0.225] IMAGE_SIZE = 256 # Transforms for training data (with advanced augmentation) train_transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), # geometric augmentations transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), # Added vertical flip transforms.RandomRotation(30), # Increased rotation range # color/appearance augmentations transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Increased intensity transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.0)), # Added blur # final conversion transforms.ToTensor(), transforms.Normalize(IMAGE_MEAN, IMAGE_STD) ]) # Transforms for validation/test data (no augmentation) val_test_transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(IMAGE_MEAN, IMAGE_STD) ]) def apply_transforms(batch, transform_pipeline): """Applies a transform pipeline to a batch of images and converts labels.""" batch['image'] = [transform_pipeline(img.convert("RGB")) for img in batch['image']] # This line is crucial for converting labels to tensors for batching batch['label'] = torch.tensor(batch['label']) return batch def get_dataloaders(batch_size=32, use_prototype=True): """ Loads, splits, and prepares the PlantVillage dataset, returning DataLoaders. NOTE TO TEAM: The dataloaders yield a dictionary. Access batches using: batch = next(iter(loader)) images = batch['image'] labels = batch['label'] """ print("Loading and preparing dataset...") # Load the full dataset from Hugging Face full_dataset = load_dataset("DScomp380/plant_village", split='train') if use_prototype: # Use 20% of data for prototyping print(f"Using 20% prototype dataset (approx {len(full_dataset) * 0.2:.0f} images)...") data_subset = full_dataset.train_test_split(test_size=0.8, seed=42)['train'] else: print(f"Using 100% full dataset ({len(full_dataset)} images)...") data_subset = full_dataset # 70/15/15 split for train/val/test train_val_test_split = data_subset.train_test_split(test_size=0.3, seed=42) train_dataset = train_val_test_split['train'] val_test_split = train_val_test_split['test'].train_test_split(test_size=0.5, seed=42) val_dataset = val_test_split['train'] test_dataset = val_test_split['test'] print(f"Total images in prototype: {len(data_subset)}") print(f"Training images: {len(train_dataset)}") print(f"Validation images: {len(val_dataset)}") print(f"Test images: {len(test_dataset)}") print("--------------------") # Apply the correct transforms to each dataset split train_dataset.set_transform(lambda batch: apply_transforms(batch, train_transform)) val_dataset.set_transform(lambda batch: apply_transforms(batch, val_test_transform)) test_dataset.set_transform(lambda batch: apply_transforms(batch, val_test_transform)) # Define the collate_fn for batching tensors collate_fn = default_collate train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn ) return train_loader, val_loader, test_loader if __name__ == "__main__": print("Running data_pipeline.py as a standalone script...") # Test the pipeline with a small batch size train_loader, val_loader, test_loader = get_dataloaders(batch_size=4, use_prototype=True) print("\n--- Testing Train Loader ---") # Test train loader try: # FIX: Get the batch as a dictionary first batch = next(iter(train_loader)) # FIX: Access the data using keys images = batch['image'] labels = batch['label'] print(f"Image batch shape: {images.shape}") print(f"Label batch shape: {labels.shape}") # Assert correct shapes assert images.shape == (4, 3, IMAGE_SIZE, IMAGE_SIZE) assert labels.shape == (4,) print("Train loader test PASSED.") except Exception as e: print(f"Train loader test FAILED: {e}") print("\n--- Testing Validation Loader ---") # Test validation loader try: # FIX: Get the batch as a dictionary first batch = next(iter(val_loader)) # FIX: Access the data using keys images = batch['image'] labels = batch['label'] print(f"Image batch shape: {images.shape}") print(f"Label batch shape: {labels.shape}") # Assert correct shapes assert images.shape == (4, 3, IMAGE_SIZE, IMAGE_SIZE) assert labels.shape == (4,) print("Validation loader test PASSED.") except Exception as e: print(f"Validation loader test FAILED: {e}") print("\nData pipeline script finished.")