File size: 1,200 Bytes
c6d5483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

class GANInference:
    def __init__(
            self,
            model: Pix2PixLitModule,
            img_file: str = "/Users/nimud/Downloads/thesis_test2.png",
    ) -> None:
        self.img_file = img_file
        self.model = model

    def _get_image_from_path(self) -> torch.Tensor:
        """ gets the tensor from filepath """
        image = np.array(Image.open(self.img_file))
        # use on inference
        inference_transform = A.Compose([
            A.Resize(width=256, height=256),
            A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
            al_pytorch.ToTensorV2(),
        ])
        inference_img = inference_transform(image=image)['image'].unsqueeze(0)
        return inference_img

    def _create_grid(self, result: torch.Tensor) -> np.array:
        return torchvision.utils.make_grid(
            [result[0].permute(1, 2, 0).detach()],
            normalize=True
        )

    def run(self) -> np.array:
        """ Returns a plottable image """
        inference_img = self._get_image_from_path()
        result = self.model(inference_img)
        adjusted_result = self._create_grid(result=result)
        return adjusted_result