Spaces:
Build error
Build error
| import torchvision.transforms as standard_transforms | |
| from .SHHA import SHHA | |
| # DeNormalize used to get original images | |
| class DeNormalize(object): | |
| def __init__(self, mean, std): | |
| self.mean = mean | |
| self.std = std | |
| def __call__(self, tensor): | |
| for t, m, s in zip(tensor, self.mean, self.std): | |
| t.mul_(s).add_(m) | |
| return tensor | |
| def loading_data(data_root): | |
| # the pre-proccssing transform | |
| transform = standard_transforms.Compose([ | |
| standard_transforms.ToTensor(), | |
| standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # create the training dataset | |
| train_set = SHHA(data_root, train=True, transform=transform, patch=True, flip=True) | |
| # create the validation dataset | |
| val_set = SHHA(data_root, train=False, transform=transform) | |
| return train_set, val_set | |