File size: 1,829 Bytes
5e96bc9 c3d45c0 5e96bc9 c3d45c0 5e96bc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from torch.utils.data import Subset,Dataset
import torch
import os
import numpy as np
import cv2
def collate_fn(batch):
imgs = [img for img, _ in batch]
labels = torch.tensor([label for _, label in batch])
return imgs, labels
class ImageDataset(Dataset):
def __init__(self,root_path : str,img_size=(256,256)):
classes = os.listdir(root_path)
self.img_size = img_size
self.classes = classes
data = []
for idx,class_name in enumerate(classes):
class_path = os.path.join(root_path,class_name)
files = os.listdir(class_path)
for file in files:
filepath = os.path.join(class_path,file)
data.append({"image_path":filepath,"label":class_name,"id":idx})
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self,idx):
curr = self.data[idx]
label = curr['id']
img_path = curr['image_path']
img = cv2.imread(img_path)
img = cv2.resize(img,(self.img_size))
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.0
return img,label
def simple_augment(img):
if np.random.rand() > 0.5:
img = cv2.flip(img, 1)
angle = np.random.uniform(-15, 15)
h, w = img.shape[:2]
M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0)
img = cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)
return img
class AugmentedSubset(Subset):
def __init__(self, subset, augment_fn=None):
super().__init__(subset.dataset, subset.indices)
self.augment_fn = augment_fn
def __getitem__(self, idx):
img, label = super().__getitem__(idx)
if self.augment_fn:
img = self.augment_fn(img)
return img, label
|