Spaces:
Sleeping
Sleeping
File size: 5,879 Bytes
6276d4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | """
PyTorch Dataset and DataLoader factory for Saudi date fruit images.
Handles:
- Loading images from CSV manifests (train.csv, val.csv, test.csv)
- Albumentations augmentation pipelines (train vs val/test)
- DataLoader creation with proper config
"""
from pathlib import Path
import albumentations as A
import cv2
import numpy as np
import pandas as pd
import torch
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
from src.utils import load_config
# ImageNet normalization stats (used with pretrained models)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
def get_train_transforms(config: dict) -> A.Compose:
"""Build training augmentation pipeline."""
aug = config["augmentation"]
size = config["data"]["image_size"]
return A.Compose([
A.RandomResizedCrop(size, size, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
A.HorizontalFlip(p=aug["horizontal_flip"]),
A.VerticalFlip(p=aug["vertical_flip"]),
A.Rotate(limit=aug["rotation_limit"], p=0.5),
A.ColorJitter(
brightness=aug["color_jitter_brightness"],
contrast=aug["color_jitter_contrast"],
saturation=aug["color_jitter_saturation"],
hue=aug["color_jitter_hue"],
p=0.5,
),
A.GaussNoise(var_limit=aug["gaussian_noise_var_limit"], p=0.3),
A.GaussianBlur(blur_limit=(3, 5), p=0.1),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
def get_val_transforms(config: dict) -> A.Compose:
"""Build validation/test transform pipeline (no augmentation)."""
size = config["data"]["image_size"]
return A.Compose([
A.Resize(size + 32, size + 32), # Resize slightly larger
A.CenterCrop(size, size), # Then center crop
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
class DateFruitDataset(Dataset):
"""
PyTorch Dataset for Saudi date fruit images.
Args:
csv_path: Path to the CSV manifest (train.csv, val.csv, or test.csv)
transform: Albumentations transform pipeline
"""
def __init__(self, csv_path: str, transform: A.Compose | None = None):
self.df = pd.read_csv(csv_path)
self.transform = transform
# Verify at least some images exist
sample_path = Path(self.df.iloc[0]["image_path"])
if not sample_path.exists():
raise FileNotFoundError(
f"Image not found: {sample_path}\n"
"Make sure the dataset is extracted to data/raw/"
)
def __len__(self) -> int:
return len(self.df)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, int, str]:
"""
Returns:
image: Tensor of shape (3, H, W) normalized
label: Integer class index
variety: String variety name
"""
row = self.df.iloc[idx]
image_path = row["image_path"]
label = int(row["label_idx"])
variety = row["variety"]
# Load image with OpenCV (Albumentations uses numpy/cv2)
image = cv2.imread(image_path)
if image is None:
raise RuntimeError(f"Failed to load image: {image_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Apply transforms
if self.transform:
transformed = self.transform(image=image)
image = transformed["image"]
else:
# Fallback: just convert to tensor
image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0
return image, label, variety
@property
def class_names(self) -> list[str]:
"""Return sorted list of variety names."""
return sorted(self.df["variety"].unique().tolist())
@property
def num_classes(self) -> int:
"""Return number of unique classes."""
return self.df["label_idx"].nunique()
@property
def class_counts(self) -> dict[str, int]:
"""Return dict of {variety: count}."""
return dict(self.df["variety"].value_counts().sort_index())
def create_dataloaders(
config: dict | None = None,
) -> tuple[DataLoader, DataLoader, DataLoader, list[str]]:
"""
Create train, val, and test DataLoaders from CSV manifests.
Args:
config: Configuration dict. If None, loads from default.yaml.
Returns:
train_loader, val_loader, test_loader, class_names
"""
if config is None:
config = load_config()
# Build transform pipelines
train_transform = get_train_transforms(config)
val_transform = get_val_transforms(config)
# Create datasets
train_dataset = DateFruitDataset("data/train.csv", transform=train_transform)
val_dataset = DateFruitDataset("data/val.csv", transform=val_transform)
test_dataset = DateFruitDataset("data/test.csv", transform=val_transform)
# Create DataLoaders
batch_size = config["data"]["batch_size"]
num_workers = config["data"]["num_workers"]
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
drop_last=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
class_names = train_dataset.class_names
print(f"DataLoaders ready: train={len(train_dataset)}, val={len(val_dataset)}, test={len(test_dataset)}")
print(f"Classes ({len(class_names)}): {class_names}")
return train_loader, val_loader, test_loader, class_names
|