unionpoint's picture
Upload folder using huggingface_hub
5d2fa0b verified
import json
from pathlib import Path
import numpy as np
import torch
import torchvision.transforms.v2 as T
from PIL import Image
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
class PlantDiseaseDataset(Dataset):
def __init__(self, root_dir, label_map=None, transform=None):
"""
Args:
root_dir (str): Path to the root directory of the split.
label_map (dict, optional): A dictionary mapping disease names to integers.
Crucial for consistency across splits.
transform (callable, optional): PyTorch transforms.
"""
self.root_dir = Path(root_dir)
self.transform = transform
self.image_paths = []
self.labels = []
self.plant_labels = []
self.plants = [
"apple",
"banana",
"bean",
"bell pepper",
"blueberry",
"basil",
"broccoli",
"cabbage",
"cauliflower",
"celery",
"cherry",
"citrus",
"coffee",
"corn",
"cucumber",
"garlic",
"ginger",
"grape",
"lettuce",
"maple",
"peach",
"plum",
"potato",
"raspberry",
"rice",
"soybean",
"squash",
"strawberry",
"tobacco",
"tomato",
"wheat",
"zucchini",
]
self.plants.sort(key=len, reverse=True)
if not self.root_dir.exists():
return
if label_map is None:
self.disease_to_idx = self._build_label_map()
else:
self.disease_to_idx = label_map
for folder_name in sorted([d for d in self.root_dir.iterdir() if d.is_dir()]):
disease, plant = self._split_plant_disease(folder_name)
if disease not in self.disease_to_idx:
print(
f"WARNING: Skipping '{folder_name.name}': Disease '{disease}' not found in label_map"
)
continue
disease_idx = self.disease_to_idx[disease]
for img_path in folder_name.glob("**/*"):
if img_path.is_file() and img_path.suffix.lower() in [
".jpg",
".jpeg",
".png",
".webp",
]:
self.image_paths.append(str(img_path))
self.labels.append(disease_idx)
self.plant_labels.append(plant)
self.classes = list(self.disease_to_idx.keys())
def _build_label_map(self):
all_diseases = set()
for folder in sorted([d for d in self.root_dir.iterdir() if d.is_dir()]):
folder_name = folder.name.lower()
for plant in self.plants:
if folder_name.startswith(plant):
disease_name = folder_name[len(plant) :].strip()
all_diseases.add(disease_name)
break
return {disease: i for i, disease in enumerate(sorted(list(all_diseases)))}
def _split_plant_disease(self, folder):
for plant in self.plants:
folder_name = folder.name.lower()
if folder_name.startswith(plant):
disease = folder_name[len(plant) :].strip()
return disease, plant
return None, None
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
label = self.labels[idx]
try:
image = Image.open(img_path).convert("RGB")
except Exception as e:
print(f"Error loading {img_path}: {e}")
return None, None
if self.transform:
image = self.transform(image)
return image, label
def get_transforms(image_size=384, is_train=True):
if is_train:
return T.Compose(
[
T.RandomResizedCrop(image_size, scale=(0.7, 1.0), antialias=True),
T.RandomHorizontalFlip(),
T.TrivialAugmentWide(),
T.ToImage(),
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
else:
return T.Compose(
[
T.Resize((image_size, image_size), antialias=True),
T.ToImage(),
T.ToDtype(torch.float32, scale=True),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def build_weighted_sampler(dataset, max_weight=10.0):
if not hasattr(dataset, "labels") or not hasattr(dataset, "plant_labels"):
raise ValueError("Dataset must have 'labels' and 'plant_labels'")
if len(dataset) == 0:
return None
disease = torch.tensor(dataset.labels, dtype=torch.long)
_, plant_indices = np.unique(dataset.plant_labels, return_inverse=True)
plant = torch.tensor(plant_indices, dtype=torch.long)
disease_counts = torch.bincount(disease)
pairs = torch.stack([disease, plant], dim=1)
_, group_id = torch.unique(pairs, return_inverse=True, dim=0)
group_counts = torch.bincount(group_id)
d_count = disease_counts[disease]
g_count = group_counts[group_id]
weights = 1.0 / torch.sqrt(d_count.float() * g_count.float())
if max_weight:
weights = torch.clamp(weights, max=max_weight)
return WeightedRandomSampler(
weights=weights,
num_samples=len(dataset),
replacement=True,
)
def get_dataloaders(config):
train_dir = Path(config.data.train_dir)
val_dir = Path(config.data.val_dir)
train_dir.mkdir(parents=True, exist_ok=True)
val_dir.mkdir(parents=True, exist_ok=True)
# load existing if dont exist dataset wil build from training automatically
label_map_path = train_dir.parent / "label_map.json"
if label_map_path.exists():
with open(label_map_path) as f:
label_map = json.load(f)
else:
label_map = None
# create datasets with consistent label_map
train_dataset = PlantDiseaseDataset(
config.data.train_dir,
label_map=label_map,
transform=get_transforms(config.data.image_size, is_train=True),
)
val_dataset = PlantDiseaseDataset(
config.data.val_dir,
label_map=train_dataset.disease_to_idx,
transform=get_transforms(config.data.image_size, is_train=False),
)
# save label_map for future
with open(label_map_path, "w") as f:
json.dump(train_dataset.disease_to_idx, f, indent=2)
if len(train_dataset) == 0:
print("Warning: No train data found. Dataloader might fail.")
num_classes = len(train_dataset.classes) if len(train_dataset) > 0 else 0
train_sampler = None
if config.data.weighted_sampling:
train_sampler = build_weighted_sampler(
train_dataset, max_weight=config.data.max_weight
)
train_loader = DataLoader(
train_dataset,
batch_size=config.data.batch_size,
sampler=train_sampler,
shuffle=(train_sampler is None),
num_workers=config.data.num_workers,
pin_memory=config.data.pin_memory,
drop_last=True if len(train_dataset) > config.data.batch_size else False,
)
val_loader = DataLoader(
val_dataset,
batch_size=config.data.batch_size,
shuffle=False,
num_workers=config.data.num_workers,
pin_memory=config.data.pin_memory,
)
return train_loader, val_loader, num_classes