Update app.py
Browse files
app.py
CHANGED
|
@@ -13,18 +13,15 @@ from torch.nn import functional as F
|
|
| 13 |
from torchvision import transforms
|
| 14 |
from torchvision.transforms import functional as TF
|
| 15 |
from tqdm import trange
|
| 16 |
-
from
|
| 17 |
-
import
|
| 18 |
-
from
|
| 19 |
-
import train_latent_diffusion as train
|
| 20 |
from huggingface_hub import hf_hub_url, cached_download
|
| 21 |
import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
|
| 22 |
|
| 23 |
-
# 🖼️ Download the necessary model files
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
|
| 27 |
-
ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
|
| 28 |
|
| 29 |
# 📐 Utility Functions: Math and images, what could go wrong?
|
| 30 |
# These functions help parse prompts and resize/crop images to fit nicely
|
|
@@ -33,11 +30,7 @@ def parse_prompt(prompt, default_weight=3.):
|
|
| 33 |
"""
|
| 34 |
🎯 Parses a prompt into text and weight.
|
| 35 |
"""
|
| 36 |
-
|
| 37 |
-
vals = prompt.rsplit(':', 2)
|
| 38 |
-
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
| 39 |
-
else:
|
| 40 |
-
vals = prompt.rsplit(':', 1)
|
| 41 |
vals = vals + ['', default_weight][len(vals):]
|
| 42 |
return vals[0], float(vals[1])
|
| 43 |
|
|
@@ -49,59 +42,51 @@ def resize_and_center_crop(image, size):
|
|
| 49 |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
| 50 |
return TF.center_crop(image, size[::-1])
|
| 51 |
|
| 52 |
-
|
| 53 |
# 🧠 Model loading: the brain of our operation! 🔥
|
| 54 |
-
# Load all the models: autoencoder, diffusion, and CLOOB
|
| 55 |
|
| 56 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 57 |
print('Using device:', device)
|
| 58 |
print('loading models... 🛠️')
|
| 59 |
|
| 60 |
-
#
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
ae_model.eval().requires_grad_(False).to(device)
|
| 64 |
-
ae_model.load_state_dict(torch.load(ae_model_path))
|
| 65 |
-
n_ch, side_y, side_x = 4, 32, 32
|
| 66 |
-
|
| 67 |
-
# 🌀 Diffusion Model Setup: The artist behind the scenes
|
| 68 |
-
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
|
| 69 |
-
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
|
| 70 |
-
model = model.to(device).eval().requires_grad_(False)
|
| 71 |
|
| 72 |
-
#
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
|
| 77 |
-
cloob.eval().requires_grad_(False).to(device)
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
# 🎨 The key function: Where the magic happens!
|
| 81 |
# This is where we generate images based on text and image prompts
|
| 82 |
|
| 83 |
-
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='
|
| 84 |
"""
|
| 85 |
🖼️ Generates a list of PIL images based on given text and image prompts.
|
| 86 |
"""
|
| 87 |
-
zero_embed = torch.zeros([1,
|
| 88 |
target_embeds, weights = [zero_embed], []
|
| 89 |
|
| 90 |
-
# Parse text prompts
|
| 91 |
for prompt in prompts:
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
| 95 |
|
| 96 |
# Parse image prompts
|
| 97 |
for prompt in images:
|
| 98 |
path, weight = parse_prompt(prompt)
|
| 99 |
-
img = Image.open(
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
target_embeds.append(embed)
|
| 105 |
weights.append(weight)
|
| 106 |
|
| 107 |
# Adjust weights and set seed
|
|
@@ -115,7 +100,7 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
| 115 |
x_in = x.repeat([n_conds, 1, 1, 1])
|
| 116 |
t_in = t.repeat([n_conds])
|
| 117 |
embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
|
| 118 |
-
vs =
|
| 119 |
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
| 120 |
return v
|
| 121 |
|
|
@@ -131,22 +116,19 @@ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method
|
|
| 131 |
|
| 132 |
# 🏃♂️ Generate the output images
|
| 133 |
batch_size = n
|
| 134 |
-
x = torch.randn([n,
|
| 135 |
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
|
| 136 |
pil_ims = []
|
| 137 |
for i in trange(0, n, batch_size):
|
| 138 |
cur_batch_size = min(n - i, batch_size)
|
| 139 |
out_latents = run(x[i:i + cur_batch_size], steps)
|
| 140 |
-
outs =
|
| 141 |
for j, out in enumerate(outs):
|
| 142 |
-
pil_ims.append(
|
| 143 |
|
| 144 |
return pil_ims
|
| 145 |
|
| 146 |
-
|
| 147 |
# 🖌️ Interface: Gradio's brush to paint the UI
|
| 148 |
-
# Gradio is used here to create a user-friendly interface for art generation.
|
| 149 |
-
|
| 150 |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
| 151 |
"""
|
| 152 |
💡 Gradio function to wrap image generation.
|
|
@@ -169,56 +151,12 @@ iface = gr.Interface(
|
|
| 169 |
],
|
| 170 |
outputs=gr.Image(type="pil", label="Generated Image"),
|
| 171 |
examples=[
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
["Abstract Art, in the style of M.C. Escher"],
|
| 176 |
-
['Surrealism, in the style of Salvador Dali'],
|
| 177 |
-
["Romanesque Art, in the style of Leonardo da Vinci"],
|
| 178 |
-
["landscape"],
|
| 179 |
-
["portrait"],
|
| 180 |
-
["sculpture"],
|
| 181 |
-
["photo"],
|
| 182 |
-
["figurative"],
|
| 183 |
-
["illustration"],
|
| 184 |
-
["still life"],
|
| 185 |
-
["cityscape"],
|
| 186 |
-
["marina"],
|
| 187 |
-
["animal painting"],
|
| 188 |
-
["graffiti"],
|
| 189 |
-
["mythological painting"],
|
| 190 |
-
["battle painting"],
|
| 191 |
-
["self-portrait"],
|
| 192 |
-
["Impressionism, oil on canvas"],
|
| 193 |
-
["Katsushika Hokusai, The Dragon of Smoke Escaping from Mount Fuji"],
|
| 194 |
-
["Moon Light Sonata by Basuki Abdullah"],
|
| 195 |
-
["Two Trees by M.C. Escher"],
|
| 196 |
-
["Futurism, in the style of Wassily Kandinsky"],
|
| 197 |
-
["Surrealism, in the style of Edgar Degas"],
|
| 198 |
-
["Expressionism, in the style of Wassily Kandinsky"],
|
| 199 |
-
["Futurism, in the style of Egon Schiele"],
|
| 200 |
-
["Cubism, in the style of Gustav Klimt"],
|
| 201 |
-
["Op Art, in the style of Marc Chagall"],
|
| 202 |
-
["Romanticism, in the style of M.C. Escher"],
|
| 203 |
-
["Futurism, in the style of M.C. Escher"],
|
| 204 |
-
["Mannerism, in the style of Paul Klee"],
|
| 205 |
-
["High Renaissance, in the style of Rembrandt"],
|
| 206 |
-
["Magic Realism, in the style of Gustave Dore"],
|
| 207 |
-
["Realism, in the style of Jean-Michel Basquiat"],
|
| 208 |
-
["Art Nouveau, in the style of Paul Gauguin"],
|
| 209 |
-
["Avant-garde, in the style of Pierre-Auguste Renoir"],
|
| 210 |
-
["Baroque, in the style of Edward Hopper"],
|
| 211 |
-
["Post-Impressionism, in the style of Wassily Kandinsky"],
|
| 212 |
-
["Naturalism, in the style of Rene Magritte"],
|
| 213 |
-
["Constructivism, in the style of Paul Cezanne"],
|
| 214 |
-
["Abstract Expressionism, in the style of Henri Matisse"],
|
| 215 |
-
["Pop Art, in the style of Vincent van Gogh"],
|
| 216 |
-
["Futurism, in the style of Zdzislaw Beksinski"],
|
| 217 |
-
["Aaron Wacker, oil on canvas"]
|
| 218 |
],
|
| 219 |
-
title='
|
| 220 |
-
description="
|
| 221 |
-
article='Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa).'
|
| 222 |
)
|
| 223 |
|
| 224 |
# 🚀 Launch the Gradio interface
|
|
|
|
| 13 |
from torchvision import transforms
|
| 14 |
from torchvision.transforms import functional as TF
|
| 15 |
from tqdm import trange
|
| 16 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 17 |
+
from vqvae import VQVAE2 # Autoencoder replacement
|
| 18 |
+
from diffusion_models import Diffusion # Swapped Diffusion model for DALL·E 2 based model
|
|
|
|
| 19 |
from huggingface_hub import hf_hub_url, cached_download
|
| 20 |
import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
|
| 21 |
|
| 22 |
+
# 🖼️ Download the necessary model files from HuggingFace
|
| 23 |
+
vqvae_model_path = cached_download(hf_hub_url("huggingface/vqvae-2", filename="vqvae_model.ckpt"))
|
| 24 |
+
diffusion_model_path = cached_download(hf_hub_url("huggingface/dalle-2", filename="diffusion_model.ckpt"))
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# 📐 Utility Functions: Math and images, what could go wrong?
|
| 27 |
# These functions help parse prompts and resize/crop images to fit nicely
|
|
|
|
| 30 |
"""
|
| 31 |
🎯 Parses a prompt into text and weight.
|
| 32 |
"""
|
| 33 |
+
vals = prompt.rsplit(':', 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
vals = vals + ['', default_weight][len(vals):]
|
| 35 |
return vals[0], float(vals[1])
|
| 36 |
|
|
|
|
| 42 |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
| 43 |
return TF.center_crop(image, size[::-1])
|
| 44 |
|
|
|
|
| 45 |
# 🧠 Model loading: the brain of our operation! 🔥
|
|
|
|
| 46 |
|
| 47 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 48 |
print('Using device:', device)
|
| 49 |
print('loading models... 🛠️')
|
| 50 |
|
| 51 |
+
# Load CLIP model
|
| 52 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
|
| 53 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
# Load VQ-VAE-2 Autoencoder
|
| 56 |
+
vqvae = VQVAE2()
|
| 57 |
+
vqvae.load_state_dict(torch.load(vqvae_model_path))
|
| 58 |
+
vqvae.eval().requires_grad_(False).to(device)
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
# Load Diffusion Model
|
| 61 |
+
diffusion_model = Diffusion()
|
| 62 |
+
diffusion_model.load_state_dict(torch.load(diffusion_model_path))
|
| 63 |
+
diffusion_model = diffusion_model.to(device).eval().requires_grad_(False)
|
| 64 |
|
| 65 |
# 🎨 The key function: Where the magic happens!
|
| 66 |
# This is where we generate images based on text and image prompts
|
| 67 |
|
| 68 |
+
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='ddim', eta=None):
|
| 69 |
"""
|
| 70 |
🖼️ Generates a list of PIL images based on given text and image prompts.
|
| 71 |
"""
|
| 72 |
+
zero_embed = torch.zeros([1, clip_model.config.projection_dim], device=device)
|
| 73 |
target_embeds, weights = [zero_embed], []
|
| 74 |
|
| 75 |
+
# Parse text prompts and encode with CLIP
|
| 76 |
for prompt in prompts:
|
| 77 |
+
inputs = clip_processor(text=prompt, return_tensors="pt").to(device)
|
| 78 |
+
text_embed = clip_model.get_text_features(**inputs).float()
|
| 79 |
+
target_embeds.append(text_embed)
|
| 80 |
+
weights.append(1.0)
|
| 81 |
|
| 82 |
# Parse image prompts
|
| 83 |
for prompt in images:
|
| 84 |
path, weight = parse_prompt(prompt)
|
| 85 |
+
img = Image.open(path).convert('RGB')
|
| 86 |
+
img = resize_and_center_crop(img, (224, 224))
|
| 87 |
+
inputs = clip_processor(images=img, return_tensors="pt").to(device)
|
| 88 |
+
image_embed = clip_model.get_image_features(**inputs).float()
|
| 89 |
+
target_embeds.append(image_embed)
|
|
|
|
| 90 |
weights.append(weight)
|
| 91 |
|
| 92 |
# Adjust weights and set seed
|
|
|
|
| 100 |
x_in = x.repeat([n_conds, 1, 1, 1])
|
| 101 |
t_in = t.repeat([n_conds])
|
| 102 |
embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
|
| 103 |
+
vs = diffusion_model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]])
|
| 104 |
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
| 105 |
return v
|
| 106 |
|
|
|
|
| 116 |
|
| 117 |
# 🏃♂️ Generate the output images
|
| 118 |
batch_size = n
|
| 119 |
+
x = torch.randn([n, 3, 64, 64], device=device)
|
| 120 |
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
|
| 121 |
pil_ims = []
|
| 122 |
for i in trange(0, n, batch_size):
|
| 123 |
cur_batch_size = min(n - i, batch_size)
|
| 124 |
out_latents = run(x[i:i + cur_batch_size], steps)
|
| 125 |
+
outs = vqvae.decode(out_latents)
|
| 126 |
for j, out in enumerate(outs):
|
| 127 |
+
pil_ims.append(transforms.ToPILImage()(out))
|
| 128 |
|
| 129 |
return pil_ims
|
| 130 |
|
|
|
|
| 131 |
# 🖌️ Interface: Gradio's brush to paint the UI
|
|
|
|
|
|
|
| 132 |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
| 133 |
"""
|
| 134 |
💡 Gradio function to wrap image generation.
|
|
|
|
| 151 |
],
|
| 152 |
outputs=gr.Image(type="pil", label="Generated Image"),
|
| 153 |
examples=[
|
| 154 |
+
["A beautiful sunset over the ocean"],
|
| 155 |
+
["A futuristic cityscape at night"],
|
| 156 |
+
["A surreal dream-like landscape"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
],
|
| 158 |
+
title='CLIP + Diffusion Model Image Generator',
|
| 159 |
+
description="Generate stunning images from text and image prompts using CLIP and a diffusion model.",
|
|
|
|
| 160 |
)
|
| 161 |
|
| 162 |
# 🚀 Launch the Gradio interface
|