File size: 1,829 Bytes
5e96bc9
c3d45c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e96bc9
 
 
c3d45c0
5e96bc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from torch.utils.data import Subset,Dataset
import torch
import os
import numpy as np
import cv2

def collate_fn(batch):
    imgs = [img for img, _ in batch]
    labels = torch.tensor([label for _, label in batch])
    return imgs, labels


class ImageDataset(Dataset):
    def __init__(self,root_path : str,img_size=(256,256)):
        classes = os.listdir(root_path)
        self.img_size = img_size
        self.classes = classes
        data = []
        for idx,class_name in enumerate(classes):
            class_path = os.path.join(root_path,class_name)
            files = os.listdir(class_path)
            for file in files:
                filepath = os.path.join(class_path,file)
                data.append({"image_path":filepath,"label":class_name,"id":idx})
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        curr = self.data[idx]
        label = curr['id']
        img_path = curr['image_path']
        img = cv2.imread(img_path)
        img = cv2.resize(img,(self.img_size))
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32) / 255.0
        return img,label

def simple_augment(img):
    if np.random.rand() > 0.5:
        img = cv2.flip(img, 1)

    angle = np.random.uniform(-15, 15)
    h, w = img.shape[:2]
    M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0)
    img = cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)

    return img


class AugmentedSubset(Subset):
    def __init__(self, subset, augment_fn=None):
        super().__init__(subset.dataset, subset.indices)
        self.augment_fn = augment_fn

    def __getitem__(self, idx):
        img, label = super().__getitem__(idx)
        if self.augment_fn:
            img = self.augment_fn(img)
        return img, label