Sketch2ColourDemo / app /scratch.py
Nikhil Mudhalwadkar
added other files
c6d5483
raw
history blame
1.2 kB
class GANInference:
def __init__(
self,
model: Pix2PixLitModule,
img_file: str = "/Users/nimud/Downloads/thesis_test2.png",
) -> None:
self.img_file = img_file
self.model = model
def _get_image_from_path(self) -> torch.Tensor:
""" gets the tensor from filepath """
image = np.array(Image.open(self.img_file))
# use on inference
inference_transform = A.Compose([
A.Resize(width=256, height=256),
A.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5], max_pixel_value=255.0),
al_pytorch.ToTensorV2(),
])
inference_img = inference_transform(image=image)['image'].unsqueeze(0)
return inference_img
def _create_grid(self, result: torch.Tensor) -> np.array:
return torchvision.utils.make_grid(
[result[0].permute(1, 2, 0).detach()],
normalize=True
)
def run(self) -> np.array:
""" Returns a plottable image """
inference_img = self._get_image_from_path()
result = self.model(inference_img)
adjusted_result = self._create_grid(result=result)
return adjusted_result