""" PyTorch Dataset and DataLoader factory for Saudi date fruit images. Handles: - Loading images from CSV manifests (train.csv, val.csv, test.csv) - Albumentations augmentation pipelines (train vs val/test) - DataLoader creation with proper config """ from pathlib import Path import albumentations as A import cv2 import numpy as np import pandas as pd import torch from albumentations.pytorch import ToTensorV2 from torch.utils.data import DataLoader, Dataset from src.utils import load_config # ImageNet normalization stats (used with pretrained models) IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] def get_train_transforms(config: dict) -> A.Compose: """Build training augmentation pipeline.""" aug = config["augmentation"] size = config["data"]["image_size"] return A.Compose([ A.RandomResizedCrop(size, size, scale=(0.8, 1.0), ratio=(0.9, 1.1)), A.HorizontalFlip(p=aug["horizontal_flip"]), A.VerticalFlip(p=aug["vertical_flip"]), A.Rotate(limit=aug["rotation_limit"], p=0.5), A.ColorJitter( brightness=aug["color_jitter_brightness"], contrast=aug["color_jitter_contrast"], saturation=aug["color_jitter_saturation"], hue=aug["color_jitter_hue"], p=0.5, ), A.GaussNoise(var_limit=aug["gaussian_noise_var_limit"], p=0.3), A.GaussianBlur(blur_limit=(3, 5), p=0.1), A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ToTensorV2(), ]) def get_val_transforms(config: dict) -> A.Compose: """Build validation/test transform pipeline (no augmentation).""" size = config["data"]["image_size"] return A.Compose([ A.Resize(size + 32, size + 32), # Resize slightly larger A.CenterCrop(size, size), # Then center crop A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ToTensorV2(), ]) class DateFruitDataset(Dataset): """ PyTorch Dataset for Saudi date fruit images. Args: csv_path: Path to the CSV manifest (train.csv, val.csv, or test.csv) transform: Albumentations transform pipeline """ def __init__(self, csv_path: str, transform: A.Compose | None = None): self.df = pd.read_csv(csv_path) self.transform = transform # Verify at least some images exist sample_path = Path(self.df.iloc[0]["image_path"]) if not sample_path.exists(): raise FileNotFoundError( f"Image not found: {sample_path}\n" "Make sure the dataset is extracted to data/raw/" ) def __len__(self) -> int: return len(self.df) def __getitem__(self, idx: int) -> tuple[torch.Tensor, int, str]: """ Returns: image: Tensor of shape (3, H, W) normalized label: Integer class index variety: String variety name """ row = self.df.iloc[idx] image_path = row["image_path"] label = int(row["label_idx"]) variety = row["variety"] # Load image with OpenCV (Albumentations uses numpy/cv2) image = cv2.imread(image_path) if image is None: raise RuntimeError(f"Failed to load image: {image_path}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Apply transforms if self.transform: transformed = self.transform(image=image) image = transformed["image"] else: # Fallback: just convert to tensor image = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255.0 return image, label, variety @property def class_names(self) -> list[str]: """Return sorted list of variety names.""" return sorted(self.df["variety"].unique().tolist()) @property def num_classes(self) -> int: """Return number of unique classes.""" return self.df["label_idx"].nunique() @property def class_counts(self) -> dict[str, int]: """Return dict of {variety: count}.""" return dict(self.df["variety"].value_counts().sort_index()) def create_dataloaders( config: dict | None = None, ) -> tuple[DataLoader, DataLoader, DataLoader, list[str]]: """ Create train, val, and test DataLoaders from CSV manifests. Args: config: Configuration dict. If None, loads from default.yaml. Returns: train_loader, val_loader, test_loader, class_names """ if config is None: config = load_config() # Build transform pipelines train_transform = get_train_transforms(config) val_transform = get_val_transforms(config) # Create datasets train_dataset = DateFruitDataset("data/train.csv", transform=train_transform) val_dataset = DateFruitDataset("data/val.csv", transform=val_transform) test_dataset = DateFruitDataset("data/test.csv", transform=val_transform) # Create DataLoaders batch_size = config["data"]["batch_size"] num_workers = config["data"]["num_workers"] train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True, ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, ) class_names = train_dataset.class_names print(f"DataLoaders ready: train={len(train_dataset)}, val={len(val_dataset)}, test={len(test_dataset)}") print(f"Classes ({len(class_names)}): {class_names}") return train_loader, val_loader, test_loader, class_names