DeBoneDiT / code /transform.py
diaoquesang's picture
Upload 15 files
6107278 verified
Raw
History Blame Contribute Delete
2.17 kB
from config import config
from torchvision import transforms
import cv2 as cv
import torchvision.transforms.functional as TF
import random
class JointTransformMethod:
def __call__(self, img, label):
img = transforms.ToPILImage()(img).convert('L')
label = transforms.ToPILImage()(label).convert('L')
if random.random() > 0.5:
img = TF.hflip(img)
label = TF.hflip(label)
if random.random() > 0.2: # 80%็š„ๆฆ‚็އๅš่ฃๅˆ‡
i, j, h, w = transforms.RandomResizedCrop.get_params(
img, scale=(0.8, 1.0), ratio=(0.9, 1.1))
img = TF.resized_crop(img, i, j, h, w, (config.image_size, config.image_size))
label = TF.resized_crop(label, i, j, h, w, (config.image_size, config.image_size))
else:
img = TF.resize(img, (config.image_size, config.image_size))
label = TF.resize(label, (config.image_size, config.image_size))
img = TF.to_tensor(img)
label = TF.to_tensor(label)
img = (img - 0.5) / 0.5
label = (label - 0.5) / 0.5
return img, label
class TestTransformMethod:
def __call__(self, img):
img = cv.resize(img, (config.image_size, config.image_size))
if len(img.shape) == 2:
img = img[:, :, None] # H,W,1
img = transforms.ToTensor()(img)
img = (img - 0.5) / 0.5
return img
myDiTTransform = {
'trainTransform': JointTransformMethod(),
'testTransform': TestTransformMethod()
}
class myTransformMethod():
def __call__(self, img):
img = cv.resize(img, (config.image_size, config.image_size))
if img.shape[-1] == 3: # HWC
img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
return img
myTransform = {
'trainTransform': transforms.Compose([
myTransformMethod(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
]),
'testTransform': transforms.Compose([
myTransformMethod(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
]),
}