budijuarto's picture
Deploy Indonesian Herbal Plants Classifier
fa49101 verified
"""
Data loading and preprocessing utilities
"""
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
from typing import Tuple, List
import config
class HerbalPlantsDataset(Dataset):
"""Custom Dataset for Indonesian Herbal Plants"""
def __init__(self, image_paths: List[str], labels: List[int], transform=None):
self.image_paths = image_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
def get_transforms(is_training: bool = True) -> transforms.Compose:
"""Get data transforms for training or validation/test"""
if is_training:
return transforms.Compose([
transforms.Resize((config.IMAGE_SIZE + 32, config.IMAGE_SIZE + 32)),
transforms.RandomCrop(config.IMAGE_SIZE),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.3),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.RandomErasing(p=0.2),
])
else:
return transforms.Compose([
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def load_dataset() -> Tuple[List[str], List[int], List[str]]:
"""Load all image paths and labels from dataset directory"""
image_paths = []
labels = []
class_names = sorted([d.name for d in config.DATA_DIR.iterdir() if d.is_dir()])
print(f"Found {len(class_names)} classes:")
for idx, class_name in enumerate(class_names):
class_dir = config.DATA_DIR / class_name
class_images = list(class_dir.glob("*.[jJ][pP][gG]")) + \
list(class_dir.glob("*.[jJ][pP][eE][gG]")) + \
list(class_dir.glob("*.[pP][nN][gG]"))
print(f" [{idx:2d}] {class_name}: {len(class_images)} images")
for img_path in class_images:
image_paths.append(str(img_path))
labels.append(idx)
# Update global class names
config.CLASS_NAMES = class_names
return image_paths, labels, class_names
def create_data_loaders() -> Tuple[DataLoader, DataLoader, DataLoader, List[str]]:
"""Create train, validation, and test data loaders"""
# Load dataset
image_paths, labels, class_names = load_dataset()
# Split data
X_train, X_temp, y_train, y_temp = train_test_split(
image_paths, labels,
test_size=(config.VAL_SPLIT + config.TEST_SPLIT),
stratify=labels,
random_state=config.RANDOM_SEED
)
X_val, X_test, y_val, y_test = train_test_split(
X_temp, y_temp,
test_size=config.TEST_SPLIT / (config.VAL_SPLIT + config.TEST_SPLIT),
stratify=y_temp,
random_state=config.RANDOM_SEED
)
print(f"\nDataset splits:")
print(f" Train: {len(X_train)} images")
print(f" Val: {len(X_val)} images")
print(f" Test: {len(X_test)} images")
# Create datasets
train_dataset = HerbalPlantsDataset(X_train, y_train, transform=get_transforms(is_training=True))
val_dataset = HerbalPlantsDataset(X_val, y_val, transform=get_transforms(is_training=False))
test_dataset = HerbalPlantsDataset(X_test, y_test, transform=get_transforms(is_training=False))
# Create data loaders
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=config.BATCH_SIZE,
shuffle=False,
num_workers=config.NUM_WORKERS,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=config.BATCH_SIZE,
shuffle=False,
num_workers=config.NUM_WORKERS,
pin_memory=True
)
return train_loader, val_loader, test_loader, class_names
if __name__ == "__main__":
# Test data loading
train_loader, val_loader, test_loader, class_names = create_data_loaders()
# Get a batch
images, labels = next(iter(train_loader))
print(f"\nBatch shape: {images.shape}")
print(f"Labels shape: {labels.shape}")