from torch.utils.data import Dataset import torch import os import numpy as np import cv2 def collate_fn(batch): imgs = [img for img, _ in batch] labels = torch.tensor([label for _, label in batch]) return imgs, labels class ImageDataset(Dataset): def __init__(self,root_path : str,img_size=(256,256)): classes = os.listdir(root_path) self.img_size = img_size self.classes = classes data = [] for idx,class_name in enumerate(classes): class_path = os.path.join(root_path,class_name) files = os.listdir(class_path) for file in files: filepath = os.path.join(class_path,file) data.append({"image_path":filepath,"label":class_name,"id":idx}) self.data = data def __len__(self): return len(self.data) def __getitem__(self,idx): curr = self.data[idx] label = curr['id'] img_path = curr['image_path'] img = cv2.imread(img_path) img = cv2.resize(img,(self.img_size)) img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) img = img.astype(np.float32) / 255.0 return img,label