from torchvision import datasets from pathlib import Path import os import pytorch_lightning as pl from torch.utils.data import DataLoader, Subset from torchvision import datasets from torchvision import transforms as T import numpy as np import torchvision from torchvision.datasets import Food101 from torch.utils.data import DataLoader, Dataset from typing import Dict, Tuple, Any import random def get_model_components( model_name: str, return_classifier: bool = False, augmentation_level: str = "default" ) -> Dict[str, Any]: """ Retrieves pre-trained model components from torchvision. This function fetches the appropriate weights and transforms for a given model. It supports different levels of training data augmentation. Args: model_name (str): The name of the model to get components for. Supported models include "EfficientNet_V2_S" and "EfficientNet_B2". return_classifier (bool, optional): If True, the model's classifier head is also returned. Defaults to False. augmentation_level (str, optional): The level of data augmentation to use for the training set. Can be "default" or "strong". Defaults to "default". Returns: Dict[str, Any]: A dictionary containing the requested components. Always includes 'train_transforms' and 'val_transforms'. Includes 'classifier' if return_classifier is True. Raises: ValueError: If model_name or augmentation_level is not supported. """ model_registry = { "EfficientNet_V2_S": ( torchvision.models.efficientnet_v2_s, torchvision.models.EfficientNet_V2_S_Weights.DEFAULT ), "EfficientNet_B2": ( torchvision.models.efficientnet_b2, torchvision.models.EfficientNet_B2_Weights.DEFAULT ) } if model_name not in model_registry: raise ValueError(f"Model '{model_name}' is not supported. " f"Supported models are: {list(model_registry.keys())}") # 1. Look up the model and weights classes model_class, weights_class = model_registry[model_name] weights = weights_class val_transforms = weights.transforms() # 2. Create the training transforms based on the desired level if augmentation_level == "default": train_transforms = T.Compose([ T.TrivialAugmentWide(), val_transforms # val_transforms includes ToTensor and Normalize ]) elif augmentation_level == "strong": # Note: We don't need to add ToTensor() or Normalize() here because # they are already included inside the 'val_transforms' pipeline. train_transforms = T.Compose([ T.RandomResizedCrop(size=val_transforms.crop_size, scale=(0.7, 1.0)), T.RandomHorizontalFlip(p=0.5), T.RandAugment(num_ops=2, magnitude=9), # RandomErasing should be applied to a tensor, so we apply it after # val_transforms, which handles the PIL -> Tensor conversion. val_transforms, T.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random') ]) else: raise ValueError(f"Augmentation level '{augmentation_level}' is not supported. " f"Choose from 'default' or 'strong'.") # 3. Prepare the dictionary to be returned components = { "train_transforms": train_transforms, "val_transforms": val_transforms } # 4. Optionally, instantiate the model to get the classifier if return_classifier: model = model_class(weights=weights) components["classifier"] = model.classifier return components class CustomFood101(Dataset): """A PyTorch Dataset for Food101 with conditional downloading and subset support. This class wraps the torchvision Food101 dataset. It only downloads the data if the specified directory doesn't already exist. It can also create a reproducible, shuffled subset of the data for faster experimentation. Args: split (str): The dataset split, either "train" or "test". transform (callable, optional): A function/transform to apply to the images. data_dir (str, optional): The directory to store the data. Defaults to "data". subset_fraction (float, optional): The fraction of the dataset to use. Defaults to 1.0 (using the full dataset). """ def __init__(self, split, transform=None, data_dir="data", subset_fraction: float = 0.1): # Check if the dataset already exists before setting the download flag. dataset_path = os.path.join(data_dir, "food-101") should_download = not os.path.isdir(dataset_path) # 1. Load the full dataset metadata with the conditional flag self.full_dataset = Food101(root=data_dir, split=split, transform=transform, download=should_download) self.classes = self.full_dataset.classes # 2. Create a reproducible subset of indices if subset_fraction < 1.0: num_samples = int(len(self.full_dataset) * subset_fraction) all_indices = list(range(len(self.full_dataset))) # Shuffle with a fixed seed for reproducibility random.Random(42).shuffle(all_indices) self.indices = all_indices[:num_samples] else: self.indices = list(range(len(self.full_dataset))) def __len__(self): """Returns the total number of samples in the subset.""" return len(self.indices) def __getitem__(self, idx): """ Fetches the sample for the given subset index and applies the transform. """ # Map the subset index to the actual index in the full dataset original_idx = self.indices[idx] image, label = self.full_dataset[original_idx] return image, label class Food101DataModule(pl.LightningDataModule): """A PyTorch Lightning DataModule for the Food101 dataset. This module encapsulates all data-related logic, including downloading, processing, and creating DataLoaders for the training, validation, and test sets. It uses the CustomFood101 dataset internally and allows for controlling the fraction of data used in the training and validation splits. Args: data_dir (str, optional): Root directory for the data. Defaults to "data". batch_size (int, optional): The batch size for DataLoaders. Defaults to 32. num_workers (int, optional): Number of workers for data loading. Defaults to 2. train_transforms (callable, optional): Transformations for the training set. val_transforms (callable, optional): Transformations for the validation/test set. subset_fraction (float, optional): The fraction of data to use for training and validation. Defaults to 1.0. """ def __init__(self, data_dir="data", batch_size=32, num_workers=2, train_transforms=None, val_transforms=None, subset_fraction: float = 0.5): super().__init__() self.data_dir = data_dir self.batch_size = batch_size self.num_workers = num_workers self.train_transforms = train_transforms self.val_transforms = val_transforms self.subset_fraction = subset_fraction self.classes = [] def prepare_data(self): """Downloads data if needed.""" CustomFood101(split='train', data_dir=self.data_dir) CustomFood101(split='test', data_dir=self.data_dir) def setup(self, stage=None): """Assigns datasets, passing the subset_fraction.""" if stage == 'fit' or stage is None: self.train_dataset = CustomFood101(split='train', transform=self.train_transforms, data_dir=self.data_dir, subset_fraction=self.subset_fraction) self.val_dataset = CustomFood101(split='test', transform=self.val_transforms, data_dir=self.data_dir, subset_fraction=self.subset_fraction) self.classes = self.train_dataset.classes if stage == 'test' or stage is None: self.test_dataset = CustomFood101(split='test', transform=self.val_transforms, data_dir=self.data_dir, subset_fraction=1.0) # Use full test set if not self.classes: self.classes = self.test_dataset.classes def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) def test_dataloader(self): return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) if __name__ == "__main__": # Define configuration for the script DATA_DIR = "data" MODEL_NAME = "EfficientNet_V2_S" BATCH_SIZE = 32 print(f"Running data preparation script for model: {MODEL_NAME}") # 1. Get model-specific transforms components = get_model_components(MODEL_NAME) train_transforms = components["train_transforms"] val_transforms = components["val_transforms"] # 2. Instantiate the DataModule datamodule = Food101DataModule( data_dir=DATA_DIR, batch_size=BATCH_SIZE, train_transforms=train_transforms, val_transforms=val_transforms, subset_fraction=0.1 # Use a small subset for quick verification ) # 3. Trigger download and setup datamodule.prepare_data() datamodule.setup(stage='fit') # 4. (Optional) Verification Step print("\n--- Verifying Dataloader ---") # Get one batch from the training dataloader train_dl = datamodule.train_dataloader() images, labels = next(iter(train_dl)) print(f"Number of classes: {len(datamodule.classes)}") print(f"Image batch shape: {images.shape}") print(f"Label batch shape: {labels.shape}") print("--- Verification Complete ---")