Spaces:
Sleeping
Sleeping
File size: 10,239 Bytes
43124a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
from torchvision import datasets
from pathlib import Path
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Subset
from torchvision import datasets
from torchvision import transforms as T
import numpy as np
import torchvision
from torchvision.datasets import Food101
from torch.utils.data import DataLoader, Dataset
from typing import Dict, Tuple, Any
import random
def get_model_components(
model_name: str,
return_classifier: bool = False,
augmentation_level: str = "default"
) -> Dict[str, Any]:
"""
Retrieves pre-trained model components from torchvision.
This function fetches the appropriate weights and transforms for a given
model. It supports different levels of training data augmentation.
Args:
model_name (str): The name of the model to get components for.
Supported models include "EfficientNet_V2_S" and "EfficientNet_B2".
return_classifier (bool, optional): If True, the model's classifier
head is also returned. Defaults to False.
augmentation_level (str, optional): The level of data augmentation to use
for the training set. Can be "default" or "strong".
Defaults to "default".
Returns:
Dict[str, Any]: A dictionary containing the requested components.
Always includes 'train_transforms' and 'val_transforms'.
Includes 'classifier' if return_classifier is True.
Raises:
ValueError: If model_name or augmentation_level is not supported.
"""
model_registry = {
"EfficientNet_V2_S": (
torchvision.models.efficientnet_v2_s,
torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
),
"EfficientNet_B2": (
torchvision.models.efficientnet_b2,
torchvision.models.EfficientNet_B2_Weights.DEFAULT
)
}
if model_name not in model_registry:
raise ValueError(f"Model '{model_name}' is not supported. "
f"Supported models are: {list(model_registry.keys())}")
# 1. Look up the model and weights classes
model_class, weights_class = model_registry[model_name]
weights = weights_class
val_transforms = weights.transforms()
# 2. Create the training transforms based on the desired level
if augmentation_level == "default":
train_transforms = T.Compose([
T.TrivialAugmentWide(),
val_transforms # val_transforms includes ToTensor and Normalize
])
elif augmentation_level == "strong":
# Note: We don't need to add ToTensor() or Normalize() here because
# they are already included inside the 'val_transforms' pipeline.
train_transforms = T.Compose([
T.RandomResizedCrop(size=val_transforms.crop_size, scale=(0.7, 1.0)),
T.RandomHorizontalFlip(p=0.5),
T.RandAugment(num_ops=2, magnitude=9),
# RandomErasing should be applied to a tensor, so we apply it after
# val_transforms, which handles the PIL -> Tensor conversion.
val_transforms,
T.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')
])
else:
raise ValueError(f"Augmentation level '{augmentation_level}' is not supported. "
f"Choose from 'default' or 'strong'.")
# 3. Prepare the dictionary to be returned
components = {
"train_transforms": train_transforms,
"val_transforms": val_transforms
}
# 4. Optionally, instantiate the model to get the classifier
if return_classifier:
model = model_class(weights=weights)
components["classifier"] = model.classifier
return components
class CustomFood101(Dataset):
"""A PyTorch Dataset for Food101 with conditional downloading and subset support.
This class wraps the torchvision Food101 dataset. It only downloads the data
if the specified directory doesn't already exist. It can also create a
reproducible, shuffled subset of the data for faster experimentation.
Args:
split (str): The dataset split, either "train" or "test".
transform (callable, optional): A function/transform to apply to the images.
data_dir (str, optional): The directory to store the data. Defaults to "data".
subset_fraction (float, optional): The fraction of the dataset to use.
Defaults to 1.0 (using the full dataset).
"""
def __init__(self, split, transform=None, data_dir="data", subset_fraction: float = 0.1):
# Check if the dataset already exists before setting the download flag.
dataset_path = os.path.join(data_dir, "food-101")
should_download = not os.path.isdir(dataset_path)
# 1. Load the full dataset metadata with the conditional flag
self.full_dataset = Food101(root=data_dir, split=split, transform=transform, download=should_download)
self.classes = self.full_dataset.classes
# 2. Create a reproducible subset of indices
if subset_fraction < 1.0:
num_samples = int(len(self.full_dataset) * subset_fraction)
all_indices = list(range(len(self.full_dataset)))
# Shuffle with a fixed seed for reproducibility
random.Random(42).shuffle(all_indices)
self.indices = all_indices[:num_samples]
else:
self.indices = list(range(len(self.full_dataset)))
def __len__(self):
"""Returns the total number of samples in the subset."""
return len(self.indices)
def __getitem__(self, idx):
"""
Fetches the sample for the given subset index and applies the transform.
"""
# Map the subset index to the actual index in the full dataset
original_idx = self.indices[idx]
image, label = self.full_dataset[original_idx]
return image, label
class Food101DataModule(pl.LightningDataModule):
"""A PyTorch Lightning DataModule for the Food101 dataset.
This module encapsulates all data-related logic, including downloading,
processing, and creating DataLoaders for the training, validation, and
test sets. It uses the CustomFood101 dataset internally and allows for
controlling the fraction of data used in the training and validation splits.
Args:
data_dir (str, optional): Root directory for the data. Defaults to "data".
batch_size (int, optional): The batch size for DataLoaders. Defaults to 32.
num_workers (int, optional): Number of workers for data loading. Defaults to 2.
train_transforms (callable, optional): Transformations for the training set.
val_transforms (callable, optional): Transformations for the validation/test set.
subset_fraction (float, optional): The fraction of data to use for training
and validation. Defaults to 1.0.
"""
def __init__(self, data_dir="data", batch_size=32, num_workers=2,
train_transforms=None, val_transforms=None, subset_fraction: float = 0.5):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.subset_fraction = subset_fraction
self.classes = []
def prepare_data(self):
"""Downloads data if needed."""
CustomFood101(split='train', data_dir=self.data_dir)
CustomFood101(split='test', data_dir=self.data_dir)
def setup(self, stage=None):
"""Assigns datasets, passing the subset_fraction."""
if stage == 'fit' or stage is None:
self.train_dataset = CustomFood101(split='train', transform=self.train_transforms,
data_dir=self.data_dir, subset_fraction=self.subset_fraction)
self.val_dataset = CustomFood101(split='test', transform=self.val_transforms,
data_dir=self.data_dir, subset_fraction=self.subset_fraction)
self.classes = self.train_dataset.classes
if stage == 'test' or stage is None:
self.test_dataset = CustomFood101(split='test', transform=self.val_transforms,
data_dir=self.data_dir, subset_fraction=1.0) # Use full test set
if not self.classes:
self.classes = self.test_dataset.classes
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
if __name__ == "__main__":
# Define configuration for the script
DATA_DIR = "data"
MODEL_NAME = "EfficientNet_V2_S"
BATCH_SIZE = 32
print(f"Running data preparation script for model: {MODEL_NAME}")
# 1. Get model-specific transforms
components = get_model_components(MODEL_NAME)
train_transforms = components["train_transforms"]
val_transforms = components["val_transforms"]
# 2. Instantiate the DataModule
datamodule = Food101DataModule(
data_dir=DATA_DIR,
batch_size=BATCH_SIZE,
train_transforms=train_transforms,
val_transforms=val_transforms,
subset_fraction=0.1 # Use a small subset for quick verification
)
# 3. Trigger download and setup
datamodule.prepare_data()
datamodule.setup(stage='fit')
# 4. (Optional) Verification Step
print("\n--- Verifying Dataloader ---")
# Get one batch from the training dataloader
train_dl = datamodule.train_dataloader()
images, labels = next(iter(train_dl))
print(f"Number of classes: {len(datamodule.classes)}")
print(f"Image batch shape: {images.shape}")
print(f"Label batch shape: {labels.shape}")
print("--- Verification Complete ---") |