Spaces:
Sleeping
Sleeping
File size: 4,801 Bytes
5762bbf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from utils.config import load_config
config = load_config()
batch_size = config["batch_size"]
num_workers = config["num_workers"]
mean_nm = config["normalize_mean"]
std_nm = config["normalize_std"]
execute_remotely = config.get("execute_remotely", False)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# set dataset to clearml dataset if executing remotely or load from huggingface otherwise
if config["execute_remotely"]:
from clearml import Dataset as ClearMLDataset
clearml_dataset = ClearMLDataset.get(dataset_id="0c3de7af2d98482dacf41633a0587845")
dataset_path = clearml_dataset.get_local_copy()
dataset = load_dataset(dataset_path)
else:
dataset = load_dataset("DScomp380/plant_village", cache_dir="./data_cache")
#split dataset into train(70%), and 30% remaining for val and test
splits = dataset["train"].train_test_split(test_size=0.30, seed=42)
train_split = splits["train"] #training set
remaining = splits["test"]
#split remaining 30% into val(15%) and test(15%)
val_test = remaining.train_test_split(test_size=0.5, seed=42)
val_split = val_test["train"] #validation set
test_split = val_test["test"] #test set
preprocess_transform = transforms.Compose([
# resize images to 224x224, convert to tensor, and normalize
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=mean_nm, std=std_nm)
])
def preprocess_batch(batch):
batch["pixel_values"] = [preprocess_transform(img) for img in batch["image"]]
return batch
if execute_remotely:
def train_transform_batch(batch):
batch["pixel_values"] = [preprocess_transform(img) for img in batch["image"]]
return batch
train_split = train_split.with_transform(train_transform_batch)
val_split = val_split.with_transform(train_transform_batch)
test_split = test_split.with_transform(train_transform_batch)
else:
train_split = train_split.map(
preprocess_batch,
batched=True,
batch_size=100,
remove_columns=["image"],
cache_file_name="./data_cache/train_preprocessed.arrow"
)
val_split = val_split.map(
preprocess_batch,
batched=True,
batch_size=100,
remove_columns=["image"],
cache_file_name="./data_cache/val_preprocessed.arrow"
)
test_split = test_split.map(
preprocess_batch,
batched=True,
batch_size=100,
remove_columns=["image"],
cache_file_name="./data_cache/test_preprocessed.arrow"
)
train_split.set_format(type="torch", columns=["pixel_values", "label"])
val_split.set_format(type="torch", columns=["pixel_values", "label"])
test_split.set_format(type="torch", columns=["pixel_values", "label"])
# augmentations
train_augment = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.3),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
], p=0.3),
])
def train_collate_fn(batch):
pixel_values = [item["pixel_values"] for item in batch]
labels = [item["label"] for item in batch]
augmented = [train_augment(img) for img in pixel_values] # apply augmentation while training
return {
"pixel_values": torch.stack(augmented),
"labels": torch.tensor(labels)
}
def val_collate_fn(batch):
return {
"pixel_values": torch.stack([item["pixel_values"] for item in batch]),
"labels": torch.tensor([item["label"] for item in batch])
}
# create DataLoaders for train, val, and test sets
train_loader = DataLoader(
train_split,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True if num_workers > 0 else False,
collate_fn=train_collate_fn
)
val_loader = DataLoader(
val_split,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True if num_workers > 0 else False,
collate_fn=val_collate_fn
)
test_loader = DataLoader(
test_split,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True if num_workers > 0 else False,
collate_fn=val_collate_fn
)
if __name__ == "__main__":
print(f"Device: {device}")
print(f"Train samples: {len(train_split)}")
print(f"Val samples: {len(val_split)}")
print(f"Test samples: {len(test_split)}")
print(f"Batches per epoch: {len(train_loader)}")
|