Spaces:
Runtime error
Runtime error
| class GANInference: | |
| def __init__( | |
| self, | |
| model: Pix2PixLitModule, | |
| img_file: str = "/Users/nimud/Downloads/thesis_test2.png", | |
| ) -> None: | |
| self.img_file = img_file | |
| self.model = model | |
| def _get_image_from_path(self) -> torch.Tensor: | |
| """ gets the tensor from filepath """ | |
| image = np.array(Image.open(self.img_file)) | |
| # use on inference | |
| inference_transform = A.Compose([ | |
| A.Resize(width=256, height=256), | |
| A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0), | |
| al_pytorch.ToTensorV2(), | |
| ]) | |
| inference_img = inference_transform(image=image)['image'].unsqueeze(0) | |
| return inference_img | |
| def _create_grid(self, result: torch.Tensor) -> np.array: | |
| return torchvision.utils.make_grid( | |
| [result[0].permute(1, 2, 0).detach()], | |
| normalize=True | |
| ) | |
| def run(self) -> np.array: | |
| """ Returns a plottable image """ | |
| inference_img = self._get_image_from_path() | |
| result = self.model(inference_img) | |
| adjusted_result = self._create_grid(result=result) | |
| return adjusted_result | |