Spaces:
Sleeping
Sleeping
| import abc | |
| from typing import List, Optional, Tuple | |
| from torch import nn, Tensor | |
| from PIL import Image | |
| class BaseDataAugmentation(nn.Module): | |
| def __init__(self): | |
| super(BaseDataAugmentation, self).__init__() | |
| def forward( | |
| self, | |
| image: Image.Image, | |
| depth: Image.Image, | |
| gt: Optional[Image.Image] = None, | |
| ranking_gt: Optional[Image.Image] = None, | |
| multi_gts: Optional[List[Image.Image]] = None, | |
| is_transform: bool = True, # is augmented? | |
| is_debug: bool = False, | |
| ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: | |
| """ | |
| Usual case: | |
| If gt is provided, return [image, depth, gt] | |
| Otherwise, return [image, depth] | |
| When ranking_gt is provided, gt will be ignored | |
| Return [image, depth, ranking_gt] | |
| For debugging: | |
| Return [image, depth, gt|ranking_gt, unnormalized, Optional[ranking_gts]] | |
| """ | |
| pass | |