| import torch |
| import gradio as gr |
| import numpy as np |
| import nltk |
| nltk.download('wordnet') |
| nltk.download('omw-1.4') |
| from PIL import Image |
| from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample, |
| save_as_images, display_in_terminal) |
| initial_archi = 'biggan-deep-128' |
| initial_class = 'dog' |
|
|
| gan_model = BigGAN.from_pretrained(initial_archi) |
|
|
| def generate_images (initial_archi, initial_class, batch_size): |
| truncation = 0.4 |
| class_vector = one_hot_from_names(initial_class, batch_size=batch_size) |
| noise_vector = truncated_noise_sample(truncation=truncation, batch_size=batch_size) |
|
|
| |
| noise_vector = torch.from_numpy(noise_vector) |
| class_vector = torch.from_numpy(class_vector) |
|
|
| |
| |
| |
| |
|
|
| |
| with torch.no_grad(): |
| output = gan_model(noise_vector, class_vector, truncation) |
|
|
| |
| output = output.to('cpu') |
| save_as_images(output) |
| return output |
| |
| def convert_to_images(obj): |
| """ Convert an output tensor from BigGAN in a list of images. |
| Params: |
| obj: tensor or numpy array of shape (batch_size, channels, height, width) |
| Output: |
| list of Pillow Images of size (height, width) |
| """ |
| try: |
| import PIL |
| except ImportError: |
| raise ImportError("Please install Pillow to use images: pip install Pillow") |
|
|
| if not isinstance(obj, np.ndarray): |
| obj = obj.detach().numpy() |
|
|
| obj = obj.transpose((0, 2, 3, 1)) |
| obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255) |
|
|
| img = [] |
| for i, out in enumerate(obj): |
| out_array = np.asarray(np.uint8(out), dtype=np.uint8) |
| img.append(PIL.Image.fromarray(out_array)) |
| return img |
| |
| def inference(initial_archi, initial_class): |
| output = generate_images (initial_archi, initial_class, 1) |
| PIL_output = convert_to_images(output) |
| return PIL_output[0] |
| |
|
|
|
|
| title = "BigGAN" |
| description = "BigGAN using various architecture models to generate images." |
| article="Coming soon" |
|
|
| examples = [ |
| ["biggan-deep-128", "dog"], |
| ["biggan-deep-256", 'dog'], |
| ["biggan-deep-512", 'dog'] |
| ] |
|
|
| gr.Interface(inference, |
| inputs=[gr.inputs.Dropdown(["biggan-deep-128", "biggan-deep-256", "biggan-deep-512"]), "text"], |
| outputs= [gr.outputs.Image(type="pil",label="output")], |
| examples=examples, |
| title=title, |
| description=description, |
| article=article).launch( debug=True) |