class GANInference: def __init__( self, model: Pix2PixLitModule, img_file: str = "/Users/nimud/Downloads/thesis_test2.png", ) -> None: self.img_file = img_file self.model = model def _get_image_from_path(self) -> torch.Tensor: """ gets the tensor from filepath """ image = np.array(Image.open(self.img_file)) # use on inference inference_transform = A.Compose([ A.Resize(width=256, height=256), A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), al_pytorch.ToTensorV2(), ]) inference_img = inference_transform(image=image)['image'].unsqueeze(0) return inference_img def _create_grid(self, result: torch.Tensor) -> np.array: return torchvision.utils.make_grid( [result[0].permute(1, 2, 0).detach()], normalize=True ) def run(self) -> np.array: """ Returns a plottable image """ inference_img = self._get_image_from_path() result = self.model(inference_img) adjusted_result = self._create_grid(result=result) return adjusted_result