| | import gradio as gr |
| | import dnnlib |
| | import numpy as np |
| | import PIL.Image |
| | import torch |
| |
|
| | import legacy |
| | import pickle |
| |
|
| | import torchvision.transforms as transforms |
| | from PIL import Image |
| | import os |
| |
|
| | network_pkl_a = 'stylegan3-r-afhqv2-512x512.pkl' |
| | network_pkl_d = 'network-snapshot.pkl' |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print("Device = ",device) |
| |
|
| | with open(network_pkl_d, 'rb') as f: |
| | G_d = pickle.load(f)['G_ema'].to(device) |
| |
|
| | with open(network_pkl_a, 'rb') as f: |
| | G_a = pickle.load(f)['G_ema'].to(device) |
| |
|
| | cl_text = os.getenv('SHOW_TEXT') |
| |
|
| |
|
| | |
| | def gen_image(text): |
| | text = text.strip().lower() |
| | if text==cl_text: |
| | return gen_image_helper(G_d) |
| | else: |
| | return gen_image_helper(G_a) |
| |
|
| | def gen_image_helper(model): |
| | z = torch.randn([1, model.z_dim]).to(device) |
| | c = None |
| | img = model(z, c) |
| | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
| | |
| | |
| | image=PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB') |
| | transform = transforms.Resize((image.height * 2, image.width * 2), interpolation=transforms.InterpolationMode.BILINEAR) |
| | upscaled_image = transform(image) |
| | return upscaled_image |
| | |
| |
|
| | demo = gr.Interface( |
| | fn=gen_image, |
| | inputs=gr.Textbox(lines=2, placeholder="Prompt here..."), |
| | outputs=gr.Image(type="pil"), |
| | title="Text to Image Generator", |
| | description="Enter any text to generate an image of an animal" |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|