TEAM_7_GAP / data_prep.py
fatimaxa's picture
Upload 3 files
5762bbf verified
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from utils.config import load_config
config = load_config()
batch_size = config["batch_size"]
num_workers = config["num_workers"]
mean_nm = config["normalize_mean"]
std_nm = config["normalize_std"]
execute_remotely = config.get("execute_remotely", False)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# set dataset to clearml dataset if executing remotely or load from huggingface otherwise
if config["execute_remotely"]:
from clearml import Dataset as ClearMLDataset
clearml_dataset = ClearMLDataset.get(dataset_id="0c3de7af2d98482dacf41633a0587845")
dataset_path = clearml_dataset.get_local_copy()
dataset = load_dataset(dataset_path)
else:
dataset = load_dataset("DScomp380/plant_village", cache_dir="./data_cache")
#split dataset into train(70%), and 30% remaining for val and test
splits = dataset["train"].train_test_split(test_size=0.30, seed=42)
train_split = splits["train"] #training set
remaining = splits["test"]
#split remaining 30% into val(15%) and test(15%)
val_test = remaining.train_test_split(test_size=0.5, seed=42)
val_split = val_test["train"] #validation set
test_split = val_test["test"] #test set
preprocess_transform = transforms.Compose([
# resize images to 224x224, convert to tensor, and normalize
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=mean_nm, std=std_nm)
])
def preprocess_batch(batch):
batch["pixel_values"] = [preprocess_transform(img) for img in batch["image"]]
return batch
if execute_remotely:
def train_transform_batch(batch):
batch["pixel_values"] = [preprocess_transform(img) for img in batch["image"]]
return batch
train_split = train_split.with_transform(train_transform_batch)
val_split = val_split.with_transform(train_transform_batch)
test_split = test_split.with_transform(train_transform_batch)
else:
train_split = train_split.map(
preprocess_batch,
batched=True,
batch_size=100,
remove_columns=["image"],
cache_file_name="./data_cache/train_preprocessed.arrow"
)
val_split = val_split.map(
preprocess_batch,
batched=True,
batch_size=100,
remove_columns=["image"],
cache_file_name="./data_cache/val_preprocessed.arrow"
)
test_split = test_split.map(
preprocess_batch,
batched=True,
batch_size=100,
remove_columns=["image"],
cache_file_name="./data_cache/test_preprocessed.arrow"
)
train_split.set_format(type="torch", columns=["pixel_values", "label"])
val_split.set_format(type="torch", columns=["pixel_values", "label"])
test_split.set_format(type="torch", columns=["pixel_values", "label"])
# augmentations
train_augment = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.3),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomApply([
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
], p=0.3),
])
def train_collate_fn(batch):
pixel_values = [item["pixel_values"] for item in batch]
labels = [item["label"] for item in batch]
augmented = [train_augment(img) for img in pixel_values] # apply augmentation while training
return {
"pixel_values": torch.stack(augmented),
"labels": torch.tensor(labels)
}
def val_collate_fn(batch):
return {
"pixel_values": torch.stack([item["pixel_values"] for item in batch]),
"labels": torch.tensor([item["label"] for item in batch])
}
# create DataLoaders for train, val, and test sets
train_loader = DataLoader(
train_split,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True if num_workers > 0 else False,
collate_fn=train_collate_fn
)
val_loader = DataLoader(
val_split,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True if num_workers > 0 else False,
collate_fn=val_collate_fn
)
test_loader = DataLoader(
test_split,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True if num_workers > 0 else False,
collate_fn=val_collate_fn
)
if __name__ == "__main__":
print(f"Device: {device}")
print(f"Train samples: {len(train_split)}")
print(f"Val samples: {len(val_split)}")
print(f"Test samples: {len(test_split)}")
print(f"Batches per epoch: {len(train_loader)}")