File size: 2,167 Bytes
6107278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from config import config
from torchvision import transforms
import cv2 as cv
import torchvision.transforms.functional as TF
import random


class JointTransformMethod:
    def __call__(self, img, label):
        img = transforms.ToPILImage()(img).convert('L')
        label = transforms.ToPILImage()(label).convert('L')

        if random.random() > 0.5:
            img = TF.hflip(img)
            label = TF.hflip(label)

        if random.random() > 0.2:  # 80%的概率做裁切
            i, j, h, w = transforms.RandomResizedCrop.get_params(
                img, scale=(0.8, 1.0), ratio=(0.9, 1.1))
            img = TF.resized_crop(img, i, j, h, w, (config.image_size, config.image_size))
            label = TF.resized_crop(label, i, j, h, w, (config.image_size, config.image_size))
        else:
            img = TF.resize(img, (config.image_size, config.image_size))
            label = TF.resize(label, (config.image_size, config.image_size))

        img = TF.to_tensor(img)
        label = TF.to_tensor(label)

        img = (img - 0.5) / 0.5
        label = (label - 0.5) / 0.5

        return img, label


class TestTransformMethod:
    def __call__(self, img):
        img = cv.resize(img, (config.image_size, config.image_size))
        if len(img.shape) == 2:
            img = img[:, :, None]  # H,W,1

        img = transforms.ToTensor()(img)
        img = (img - 0.5) / 0.5
        return img


myDiTTransform = {
    'trainTransform': JointTransformMethod(),
    'testTransform': TestTransformMethod()
}


class myTransformMethod():
    def __call__(self, img):
        img = cv.resize(img, (config.image_size, config.image_size))
        if img.shape[-1] == 3:  # HWC
            img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
        return img


myTransform = {
    'trainTransform': transforms.Compose([
        myTransformMethod(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
    'testTransform': transforms.Compose([
        myTransformMethod(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),

}