Spaces:
Build error
Build error
| """ | |
| Data loading and preprocessing utilities | |
| """ | |
| import os | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| from torchvision import transforms | |
| from PIL import Image | |
| from pathlib import Path | |
| import numpy as np | |
| from sklearn.model_selection import train_test_split | |
| from typing import Tuple, List | |
| import config | |
| class HerbalPlantsDataset(Dataset): | |
| """Custom Dataset for Indonesian Herbal Plants""" | |
| def __init__(self, image_paths: List[str], labels: List[int], transform=None): | |
| self.image_paths = image_paths | |
| self.labels = labels | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.image_paths) | |
| def __getitem__(self, idx): | |
| image = Image.open(self.image_paths[idx]).convert('RGB') | |
| label = self.labels[idx] | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| def get_transforms(is_training: bool = True) -> transforms.Compose: | |
| """Get data transforms for training or validation/test""" | |
| if is_training: | |
| return transforms.Compose([ | |
| transforms.Resize((config.IMAGE_SIZE + 32, config.IMAGE_SIZE + 32)), | |
| transforms.RandomCrop(config.IMAGE_SIZE), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomVerticalFlip(p=0.3), | |
| transforms.RandomRotation(degrees=15), | |
| transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), | |
| transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| transforms.RandomErasing(p=0.2), | |
| ]) | |
| else: | |
| return transforms.Compose([ | |
| transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def load_dataset() -> Tuple[List[str], List[int], List[str]]: | |
| """Load all image paths and labels from dataset directory""" | |
| image_paths = [] | |
| labels = [] | |
| class_names = sorted([d.name for d in config.DATA_DIR.iterdir() if d.is_dir()]) | |
| print(f"Found {len(class_names)} classes:") | |
| for idx, class_name in enumerate(class_names): | |
| class_dir = config.DATA_DIR / class_name | |
| class_images = list(class_dir.glob("*.[jJ][pP][gG]")) + \ | |
| list(class_dir.glob("*.[jJ][pP][eE][gG]")) + \ | |
| list(class_dir.glob("*.[pP][nN][gG]")) | |
| print(f" [{idx:2d}] {class_name}: {len(class_images)} images") | |
| for img_path in class_images: | |
| image_paths.append(str(img_path)) | |
| labels.append(idx) | |
| # Update global class names | |
| config.CLASS_NAMES = class_names | |
| return image_paths, labels, class_names | |
| def create_data_loaders() -> Tuple[DataLoader, DataLoader, DataLoader, List[str]]: | |
| """Create train, validation, and test data loaders""" | |
| # Load dataset | |
| image_paths, labels, class_names = load_dataset() | |
| # Split data | |
| X_train, X_temp, y_train, y_temp = train_test_split( | |
| image_paths, labels, | |
| test_size=(config.VAL_SPLIT + config.TEST_SPLIT), | |
| stratify=labels, | |
| random_state=config.RANDOM_SEED | |
| ) | |
| X_val, X_test, y_val, y_test = train_test_split( | |
| X_temp, y_temp, | |
| test_size=config.TEST_SPLIT / (config.VAL_SPLIT + config.TEST_SPLIT), | |
| stratify=y_temp, | |
| random_state=config.RANDOM_SEED | |
| ) | |
| print(f"\nDataset splits:") | |
| print(f" Train: {len(X_train)} images") | |
| print(f" Val: {len(X_val)} images") | |
| print(f" Test: {len(X_test)} images") | |
| # Create datasets | |
| train_dataset = HerbalPlantsDataset(X_train, y_train, transform=get_transforms(is_training=True)) | |
| val_dataset = HerbalPlantsDataset(X_val, y_val, transform=get_transforms(is_training=False)) | |
| test_dataset = HerbalPlantsDataset(X_test, y_test, transform=get_transforms(is_training=False)) | |
| # Create data loaders | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=True, | |
| num_workers=config.NUM_WORKERS, | |
| pin_memory=True | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=False, | |
| num_workers=config.NUM_WORKERS, | |
| pin_memory=True | |
| ) | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=config.BATCH_SIZE, | |
| shuffle=False, | |
| num_workers=config.NUM_WORKERS, | |
| pin_memory=True | |
| ) | |
| return train_loader, val_loader, test_loader, class_names | |
| if __name__ == "__main__": | |
| # Test data loading | |
| train_loader, val_loader, test_loader, class_names = create_data_loaders() | |
| # Get a batch | |
| images, labels = next(iter(train_loader)) | |
| print(f"\nBatch shape: {images.shape}") | |
| print(f"Labels shape: {labels.shape}") | |