Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import DataLoader, default_collate | |
| from torchvision import transforms | |
| from datasets import load_dataset | |
| import torch.utils.data | |
| # ImageNet stats for normalization | |
| IMAGE_MEAN = [0.485, 0.456, 0.406] | |
| IMAGE_STD = [0.229, 0.224, 0.225] | |
| IMAGE_SIZE = 256 | |
| # Transforms for training data (with advanced augmentation) | |
| train_transform = transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| # geometric augmentations | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomVerticalFlip(p=0.5), # Added vertical flip | |
| transforms.RandomRotation(30), # Increased rotation range | |
| # color/appearance augmentations | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # Increased intensity | |
| transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.0)), # Added blur | |
| # final conversion | |
| transforms.ToTensor(), | |
| transforms.Normalize(IMAGE_MEAN, IMAGE_STD) | |
| ]) | |
| # Transforms for validation/test data (no augmentation) | |
| val_test_transform = transforms.Compose([ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(IMAGE_MEAN, IMAGE_STD) | |
| ]) | |
| def apply_transforms(batch, transform_pipeline): | |
| """Applies a transform pipeline to a batch of images and converts labels.""" | |
| batch['image'] = [transform_pipeline(img.convert("RGB")) for img in batch['image']] | |
| # This line is crucial for converting labels to tensors for batching | |
| batch['label'] = torch.tensor(batch['label']) | |
| return batch | |
| def get_dataloaders(batch_size=32, use_prototype=True): | |
| """ | |
| Loads, splits, and prepares the PlantVillage dataset, returning DataLoaders. | |
| NOTE TO TEAM: The dataloaders yield a dictionary. | |
| Access batches using: | |
| batch = next(iter(loader)) | |
| images = batch['image'] | |
| labels = batch['label'] | |
| """ | |
| print("Loading and preparing dataset...") | |
| # Load the full dataset from Hugging Face | |
| full_dataset = load_dataset("DScomp380/plant_village", split='train') | |
| if use_prototype: | |
| # Use 20% of data for prototyping | |
| print(f"Using 20% prototype dataset (approx {len(full_dataset) * 0.2:.0f} images)...") | |
| data_subset = full_dataset.train_test_split(test_size=0.8, seed=42)['train'] | |
| else: | |
| print(f"Using 100% full dataset ({len(full_dataset)} images)...") | |
| data_subset = full_dataset | |
| # 70/15/15 split for train/val/test | |
| train_val_test_split = data_subset.train_test_split(test_size=0.3, seed=42) | |
| train_dataset = train_val_test_split['train'] | |
| val_test_split = train_val_test_split['test'].train_test_split(test_size=0.5, seed=42) | |
| val_dataset = val_test_split['train'] | |
| test_dataset = val_test_split['test'] | |
| print(f"Total images in prototype: {len(data_subset)}") | |
| print(f"Training images: {len(train_dataset)}") | |
| print(f"Validation images: {len(val_dataset)}") | |
| print(f"Test images: {len(test_dataset)}") | |
| print("--------------------") | |
| # Apply the correct transforms to each dataset split | |
| train_dataset.set_transform(lambda batch: apply_transforms(batch, train_transform)) | |
| val_dataset.set_transform(lambda batch: apply_transforms(batch, val_test_transform)) | |
| test_dataset.set_transform(lambda batch: apply_transforms(batch, val_test_transform)) | |
| # Define the collate_fn for batching tensors | |
| collate_fn = default_collate | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| collate_fn=collate_fn | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| collate_fn=collate_fn | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| collate_fn=collate_fn | |
| ) | |
| return train_loader, val_loader, test_loader | |
| if __name__ == "__main__": | |
| print("Running data_pipeline.py as a standalone script...") | |
| # Test the pipeline with a small batch size | |
| train_loader, val_loader, test_loader = get_dataloaders(batch_size=4, use_prototype=True) | |
| print("\n--- Testing Train Loader ---") | |
| # Test train loader | |
| try: | |
| # FIX: Get the batch as a dictionary first | |
| batch = next(iter(train_loader)) | |
| # FIX: Access the data using keys | |
| images = batch['image'] | |
| labels = batch['label'] | |
| print(f"Image batch shape: {images.shape}") | |
| print(f"Label batch shape: {labels.shape}") | |
| # Assert correct shapes | |
| assert images.shape == (4, 3, IMAGE_SIZE, IMAGE_SIZE) | |
| assert labels.shape == (4,) | |
| print("Train loader test PASSED.") | |
| except Exception as e: | |
| print(f"Train loader test FAILED: {e}") | |
| print("\n--- Testing Validation Loader ---") | |
| # Test validation loader | |
| try: | |
| # FIX: Get the batch as a dictionary first | |
| batch = next(iter(val_loader)) | |
| # FIX: Access the data using keys | |
| images = batch['image'] | |
| labels = batch['label'] | |
| print(f"Image batch shape: {images.shape}") | |
| print(f"Label batch shape: {labels.shape}") | |
| # Assert correct shapes | |
| assert images.shape == (4, 3, IMAGE_SIZE, IMAGE_SIZE) | |
| assert labels.shape == (4,) | |
| print("Validation loader test PASSED.") | |
| except Exception as e: | |
| print(f"Validation loader test FAILED: {e}") | |
| print("\nData pipeline script finished.") |