|
|
import os |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
|
|
|
from PIL import Image |
|
|
|
|
|
from model.vae import VAE |
|
|
from model import tidev2 |
|
|
|
|
|
|
|
|
def init_vae_model(model_name, latent_dim, input_shape): |
|
|
if model_name == 'tidev2': |
|
|
vae_model = VAE(tidev2.ConvNeXtEncoderTiny(latent_dim=latent_dim), |
|
|
tidev2.ConvNeXtDecoderTiny(latent_dim=latent_dim, image_dims=input_shape[:2], out_channels=input_shape[-1]) |
|
|
) |
|
|
vae_model.build((None, *input_shape)) |
|
|
return vae_model |
|
|
|
|
|
|
|
|
def load_weights(vae, weights_path): |
|
|
print("Loading weights from {}".format(weights_path)) |
|
|
if "ckpt-" in weights_path: |
|
|
ckpt = tf.train.Checkpoint(vae=vae) |
|
|
ckpt.restore(weights_path).expect_partial() |
|
|
return vae |
|
|
if ".TF" in weights_path: |
|
|
vae.load_weights(weights_path, by_name=True) |
|
|
return vae |
|
|
|
|
|
|
|
|
def get_noise_seeded(noise_shape, seed=0): |
|
|
np.random.seed(seed) |
|
|
random_z = np.random.normal(0, 1, noise_shape) |
|
|
return random_z |
|
|
|
|
|
def decode_noise(trained_vae, noise, return_list=False): |
|
|
print("Generating synthetic images ...") |
|
|
pred = trained_vae.decoder.predict(noise, batch_size=1) |
|
|
|
|
|
pred *= 255.0 |
|
|
|
|
|
if return_list: |
|
|
return [img for img in pred] |
|
|
return pred |
|
|
|
|
|
|
|
|
def save_images(save_folder, images, seed=None): |
|
|
print(f"Saving synthetic images into {save_folder}") |
|
|
if isinstance(images, list): |
|
|
for i, image in enumerate(images): |
|
|
image = image.astype(np.uint8) |
|
|
if image.shape[-1] == 1: |
|
|
image = np.squeeze(image, axis=-1) |
|
|
save_filename = f"image-{i}.jpg" if seed is None else f"image-{seed}.jpg" |
|
|
Image.fromarray(image).save(os.path.join(save_folder, save_filename)) |