|
|
| import gradio as gr |
| from PIL import Image |
|
|
| from tld.denoiser import Denoiser |
| from tld.diffusion import DiffusionGenerator |
|
|
| from diffusers import AutoencoderKL, AutoencoderTiny |
| from tqdm import tqdm |
| import clip |
| import torch |
| import numpy as np |
| import torchvision.utils as vutils |
| import torchvision.transforms as transforms |
| from torch.utils.data import DataLoader, TensorDataset |
| from PIL import Image |
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| to_pil = transforms.ToPILImage() |
|
|
|
|
| |
| vae_scale_factor = 8 |
| img_size = 32 |
| model_dtype = torch.float32 |
|
|
| @torch.no_grad() |
| def encode_text(label, model): |
| text_tokens = clip.tokenize(label, truncate=True).to(device) |
| text_encoding = model.encode_text(text_tokens) |
| return text_encoding.cpu() |
|
|
| def generate_image_from_text(prompt, class_guidance=6, seed=11, num_imgs=1, img_size = 32): |
|
|
| n_iter = 15 |
| nrow = int(np.sqrt(num_imgs)) |
|
|
| cur_prompts = [prompt]*num_imgs |
| labels = encode_text(cur_prompts, clip_model) |
| out, out_latent = diffuser.generate(labels=labels, |
| num_imgs=num_imgs, |
| class_guidance=class_guidance, |
| seed=seed, |
| n_iter=n_iter, |
| exponent=1, |
| scale_factor=8, |
| sharp_f=0, |
| bright_f=0 |
| ) |
|
|
| out = to_pil((vutils.make_grid((out+1)/2, nrow=nrow, padding=4)).float().clip(0, 1)) |
|
|
| out.save(f'{prompt}_cfg:{class_guidance}_seed:{seed}.png') |
|
|
| print("Images Generated and Saved. They will shortly output below.") |
| return out |
|
|
|
|
|
|
| denoiser = Denoiser(image_size=32, noise_embed_dims=256, patch_size=2, |
| embed_dim=768, dropout=0, n_layers=12) |
|
|
|
|
| state_dict = torch.load('state_dict_378000.pth', map_location=torch.device('cpu')) |
|
|
| denoiser = denoiser.to(model_dtype) |
| denoiser.load_state_dict(state_dict) |
| denoiser = denoiser.to(device) |
|
|
| vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", |
| torch_dtype=model_dtype).to(device) |
|
|
| clip_model, preprocess = clip.load("ViT-L/14") |
| clip_model = clip_model.to(device) |
|
|
| diffuser = DiffusionGenerator(denoiser, vae, device, model_dtype) |
|
|
| |
| iface = gr.Interface( |
| fn=generate_image_from_text, |
| inputs=["text", "slider"], |
| outputs="image", |
| title="Text-to-Image Generator", |
| description="Enter a text prompt to generate an image." |
| ) |
|
|
| |
| iface.launch() |
|
|