Spaces:
Runtime error
Runtime error
File size: 1,200 Bytes
c6d5483 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
class GANInference:
def __init__(
self,
model: Pix2PixLitModule,
img_file: str = "/Users/nimud/Downloads/thesis_test2.png",
) -> None:
self.img_file = img_file
self.model = model
def _get_image_from_path(self) -> torch.Tensor:
""" gets the tensor from filepath """
image = np.array(Image.open(self.img_file))
# use on inference
inference_transform = A.Compose([
A.Resize(width=256, height=256),
A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
al_pytorch.ToTensorV2(),
])
inference_img = inference_transform(image=image)['image'].unsqueeze(0)
return inference_img
def _create_grid(self, result: torch.Tensor) -> np.array:
return torchvision.utils.make_grid(
[result[0].permute(1, 2, 0).detach()],
normalize=True
)
def run(self) -> np.array:
""" Returns a plottable image """
inference_img = self._get_image_from_path()
result = self.model(inference_img)
adjusted_result = self._create_grid(result=result)
return adjusted_result
|