File size: 1,262 Bytes
c02d17f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import random
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.transforms import RandomCrop, InterpolationMode


class CustomRandomResize(nn.Module):

    def __init__(self, scale=(0.5, 2.0), interpolation=InterpolationMode.BILINEAR):
        super().__init__()
        self.min_scale, self.max_scale = min(scale), max(scale)
        self.interpolation = interpolation

    def forward(self, img):
        if isinstance(img, torch.Tensor):
            height, width = img.shape[:2]
        else:
            width, height = img.size
        scale = random.uniform(self.min_scale, self.max_scale)
        new_size = [int(height * scale), int(width * scale)]
        img = F.resize(img, new_size, self.interpolation)

        return img


class CustomRandomCrop(RandomCrop):
    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            PIL Image or Tensor: Cropped image.
        """

        width, height = F.get_image_size(img)
        tar_h, tar_w = self.size

        tar_h = min(tar_h, height)
        tar_w = min(tar_w, width)
        i, j, h, w = self.get_params(img, (tar_h, tar_w))

        return F.crop(img, i, j, h, w)