from torch.utils.data import Subset,Dataset import torch import os import numpy as np import cv2 def collate_fn(batch): imgs = [img for img, _ in batch] labels = torch.tensor([label for _, label in batch]) return imgs, labels class ImageDataset(Dataset): def __init__(self,root_path : str,img_size=(256,256)): classes = os.listdir(root_path) self.img_size = img_size self.classes = classes data = [] for idx,class_name in enumerate(classes): class_path = os.path.join(root_path,class_name) files = os.listdir(class_path) for file in files: filepath = os.path.join(class_path,file) data.append({"image_path":filepath,"label":class_name,"id":idx}) self.data = data def __len__(self): return len(self.data) def __getitem__(self,idx): curr = self.data[idx] label = curr['id'] img_path = curr['image_path'] img = cv2.imread(img_path) img = cv2.resize(img,(self.img_size)) img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) img = img.astype(np.float32) / 255.0 return img,label def simple_augment(img): if np.random.rand() > 0.5: img = cv2.flip(img, 1) angle = np.random.uniform(-15, 15) h, w = img.shape[:2] M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0) img = cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) return img class AugmentedSubset(Subset): def __init__(self, subset, augment_fn=None): super().__init__(subset.dataset, subset.indices) self.augment_fn = augment_fn def __getitem__(self, idx): img, label = super().__getitem__(idx) if self.augment_fn: img = self.augment_fn(img) return img, label