TIDE-II / utils /plots.py
pgatoula's picture
Minor corrections
b79a585
import numpy as np
from PIL import Image
def visualize_from_latent_space(latent_dim, input_shape, vae, output_path, epoch="final", num_items=10,):
image_size, _, img_channels = input_shape
figure = np.zeros((image_size * num_items, image_size * num_items, img_channels))
scale = 1.0
grid_x = np.linspace(-scale, scale, num_items)
grid_y = np.linspace(-scale, scale, num_items)[::-1]
np.random.seed(42)
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
random_z = np.random.normal(0, 1, (1, latent_dim))
x_decoded = vae.decoder.predict(random_z)
image = x_decoded[0].reshape(input_shape)
figure[i * image_size: (i + 1) * image_size, j * image_size: (j + 1) * image_size, ] = image
print(f'Saving collage in {output_path}/decoding-noise-ep{epoch}.png')
figure = (figure * 255).astype('uint8')
if img_channels == 1:
figure = np.squeeze(figure, axis=-1)
figure = Image.fromarray(figure)
figure.save(f"{output_path}/decoding-noise-ep{epoch}.jpg")