Spaces:
Runtime error
Runtime error
Update to v3
Browse files- app.py +126 -52
- requirements.txt +3 -3
app.py
CHANGED
|
@@ -1,23 +1,26 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import open_clip
|
| 3 |
import torch
|
|
|
|
|
|
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
from open_clip import tokenizer
|
| 6 |
-
from
|
| 7 |
-
from einops import rearrange
|
| 8 |
-
from huggingface_hub import hf_hub_download
|
| 9 |
-
from modules import DenoiseUNet
|
| 10 |
from arroz import Diffuzz, PriorModel
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
model_repo = "
|
| 13 |
-
model_file = "
|
| 14 |
-
prior_file = "
|
|
|
|
| 15 |
|
| 16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
-
device_text = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
|
| 18 |
|
| 19 |
batch_size = 4
|
| 20 |
-
latent_shape = (64, 64)
|
|
|
|
| 21 |
|
| 22 |
generator_timesteps = 12
|
| 23 |
generator_cfg = 5
|
|
@@ -98,61 +101,135 @@ def sample(model, c, x=None, negative_embeddings=None, mask=None, T=12, size=(32
|
|
| 98 |
|
| 99 |
# Model loading
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
|
|
|
| 103 |
|
|
|
|
| 104 |
clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
|
| 105 |
clip_model = clip_model.to(device).half().eval().requires_grad_(False)
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
img = vqmodel.model.decode(z)
|
| 117 |
-
img = (img.clamp(-1., 1.) + 1) * 0.5
|
| 118 |
-
return img
|
| 119 |
-
|
| 120 |
-
model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
|
| 121 |
-
model = DenoiseUNet(num_labels=8192, c_clip=1024, c_hidden=1280, down_levels=[1, 2, 8, 32], up_levels=[32, 8, 2, 1])
|
| 122 |
-
model = model.to(device).half()
|
| 123 |
-
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 124 |
-
model.eval().requires_grad_()
|
| 125 |
|
| 126 |
prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
|
| 127 |
prior = PriorModel().to(device).half()
|
| 128 |
prior.load_state_dict(torch.load(prior_path, map_location=device))
|
| 129 |
prior.eval().requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
diffuzz = Diffuzz(device=device)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
# -----
|
| 133 |
|
| 134 |
def infer(prompt, negative_prompt):
|
| 135 |
-
|
| 136 |
-
negative_text = tokenizer.tokenize([negative_prompt] * batch_size).to(device)
|
| 137 |
with torch.inference_mode():
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
|
|
|
| 144 |
timesteps=prior_timesteps, cfg=prior_cfg, sampler=prior_sampler
|
| 145 |
)[-1]
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
css = """
|
| 158 |
.gradio-container {
|
|
@@ -304,9 +381,6 @@ with block:
|
|
| 304 |
Paella Demo
|
| 305 |
</h1>
|
| 306 |
</div>
|
| 307 |
-
<p>
|
| 308 |
-
Running on <b>{device_text}</b>
|
| 309 |
-
</p>
|
| 310 |
<p style="margin-bottom: 10px; font-size: 94%">
|
| 311 |
Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.
|
| 312 |
</p>
|
|
@@ -321,7 +395,7 @@ with block:
|
|
| 321 |
label="Enter your prompt",
|
| 322 |
show_label=False,
|
| 323 |
max_lines=1,
|
| 324 |
-
placeholder="
|
| 325 |
elem_id="prompt-text-input",
|
| 326 |
).style(
|
| 327 |
border=(True, False, True, True),
|
|
@@ -332,7 +406,7 @@ with block:
|
|
| 332 |
label="Enter your negative prompt",
|
| 333 |
show_label=False,
|
| 334 |
max_lines=1,
|
| 335 |
-
placeholder="
|
| 336 |
elem_id="negative-prompt-text-input",
|
| 337 |
).style(
|
| 338 |
border=(True, False, True, True),
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
import torch
|
| 3 |
+
import open_clip
|
| 4 |
+
import torchvision
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
from PIL import Image
|
| 7 |
from open_clip import tokenizer
|
| 8 |
+
from Paella.utils.modules import Paella
|
|
|
|
|
|
|
|
|
|
| 9 |
from arroz import Diffuzz, PriorModel
|
| 10 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
| 11 |
+
from Paella.src.vqgan import VQModel
|
| 12 |
+
from Paella.utils.alter_attention import replace_attention_layers
|
| 13 |
|
| 14 |
+
model_repo = "dome272/Paella"
|
| 15 |
+
model_file = "paella_v3.pt"
|
| 16 |
+
prior_file = "prior_v1.pt"
|
| 17 |
+
vqgan_file = "vqgan_f4.pt"
|
| 18 |
|
| 19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 20 |
|
| 21 |
batch_size = 4
|
| 22 |
+
latent_shape = (batch_size, 64, 64) # latent shape of the generated image, we are using an f4 vqgan and thus sampling 64x64 will result in 256x256
|
| 23 |
+
prior_timesteps, prior_cfg, prior_sampler, clip_embedding_shape = 60, 3.0, "ddpm", (batch_size, 1024)
|
| 24 |
|
| 25 |
generator_timesteps = 12
|
| 26 |
generator_cfg = 5
|
|
|
|
| 101 |
|
| 102 |
# Model loading
|
| 103 |
|
| 104 |
+
# Load T5 on CPU
|
| 105 |
+
t5_tokenizer = AutoTokenizer.from_pretrained("google/byt5-xl")
|
| 106 |
+
t5_model = T5EncoderModel.from_pretrained("google/byt5-xl")
|
| 107 |
|
| 108 |
+
# Load other models on GPU
|
| 109 |
clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
|
| 110 |
clip_model = clip_model.to(device).half().eval().requires_grad_(False)
|
| 111 |
|
| 112 |
+
clip_preprocess = torchvision.transforms.Compose([
|
| 113 |
+
torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
|
| 114 |
+
torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
|
| 115 |
+
])
|
| 116 |
+
|
| 117 |
+
vqgan_path = hf_hub_download(repo_id=model_repo, filename=vqgan_file)
|
| 118 |
+
vqmodel = VQModel().to(device)
|
| 119 |
+
vqmodel.load_state_dict(torch.load(vqgan_path, map_location=device))
|
| 120 |
+
vqmodel.eval().requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
|
| 123 |
prior = PriorModel().to(device).half()
|
| 124 |
prior.load_state_dict(torch.load(prior_path, map_location=device))
|
| 125 |
prior.eval().requires_grad_(False)
|
| 126 |
+
|
| 127 |
+
model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
|
| 128 |
+
model = Paella(byt5_embd=2560)
|
| 129 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 130 |
+
model.eval().requires_grad_().half()
|
| 131 |
+
replace_attention_layers(model)
|
| 132 |
+
model.to(device)
|
| 133 |
+
|
| 134 |
diffuzz = Diffuzz(device=device)
|
| 135 |
|
| 136 |
+
@torch.inference_mode()
|
| 137 |
+
def decode(img_seq):
|
| 138 |
+
return vqmodel.decode_indices(img_seq)
|
| 139 |
+
|
| 140 |
+
@torch.inference_mode()
|
| 141 |
+
def embed_t5(text, t5_tokenizer, t5_model, final_device="cuda"):
|
| 142 |
+
device = t5_model.device
|
| 143 |
+
t5_tokens = t5_tokenizer(text, padding="longest", return_tensors="pt", max_length=768, truncation=True).input_ids.to(device)
|
| 144 |
+
t5_embeddings = t5_model(input_ids=t5_tokens).last_hidden_state.to(final_device)
|
| 145 |
+
return t5_embeddings
|
| 146 |
+
|
| 147 |
+
@torch.inference_mode()
|
| 148 |
+
def sample(model, model_inputs, latent_shape,
|
| 149 |
+
unconditional_inputs=None, init_x=None, steps=12, renoise_steps=None,
|
| 150 |
+
temperature = (0.7, 0.3), cfg=(8.0, 8.0),
|
| 151 |
+
mode = 'multinomial', # 'quant', 'multinomial', 'argmax'
|
| 152 |
+
t_start=1.0, t_end=0.0,
|
| 153 |
+
sampling_conditional_steps=None, sampling_quant_steps=None, attn_weights=None
|
| 154 |
+
):
|
| 155 |
+
device = unconditional_inputs["byt5"].device
|
| 156 |
+
if sampling_conditional_steps is None:
|
| 157 |
+
sampling_conditional_steps = steps
|
| 158 |
+
if sampling_quant_steps is None:
|
| 159 |
+
sampling_quant_steps = steps
|
| 160 |
+
if renoise_steps is None:
|
| 161 |
+
renoise_steps = steps-1
|
| 162 |
+
if unconditional_inputs is None:
|
| 163 |
+
unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
|
| 164 |
+
|
| 165 |
+
init_noise = torch.randint(0, model.num_labels, size=latent_shape, device=device)
|
| 166 |
+
if init_x != None:
|
| 167 |
+
sampled = init_x
|
| 168 |
+
else:
|
| 169 |
+
sampled = init_noise.clone()
|
| 170 |
+
t_list = torch.linspace(t_start, t_end, steps+1)
|
| 171 |
+
temperatures = torch.linspace(temperature[0], temperature[1], steps)
|
| 172 |
+
cfgs = torch.linspace(cfg[0], cfg[1], steps)
|
| 173 |
+
for i, tv in enumerate(t_list[:steps]):
|
| 174 |
+
if i >= sampling_quant_steps:
|
| 175 |
+
mode = "quant"
|
| 176 |
+
t = torch.ones(latent_shape[0], device=device) * tv
|
| 177 |
+
|
| 178 |
+
logits = model(sampled, t, **model_inputs, attn_weights=attn_weights)
|
| 179 |
+
if cfg is not None and i < sampling_conditional_steps:
|
| 180 |
+
logits = logits * cfgs[i] + model(sampled, t, **unconditional_inputs) * (1-cfgs[i])
|
| 181 |
+
scores = logits.div(temperatures[i]).softmax(dim=1)
|
| 182 |
+
|
| 183 |
+
if mode == 'argmax':
|
| 184 |
+
sampled = logits.argmax(dim=1)
|
| 185 |
+
elif mode == 'multinomial':
|
| 186 |
+
sampled = scores.permute(0, 2, 3, 1).reshape(-1, logits.size(1))
|
| 187 |
+
sampled = torch.multinomial(sampled, 1)[:, 0].view(logits.size(0), *logits.shape[2:])
|
| 188 |
+
elif mode == 'quant':
|
| 189 |
+
sampled = scores.permute(0, 2, 3, 1) @ vqmodel.vquantizer.codebook.weight.data
|
| 190 |
+
sampled = vqmodel.vquantizer.forward(sampled, dim=-1)[-1]
|
| 191 |
+
else:
|
| 192 |
+
raise Exception(f"Mode '{mode}' not supported, use: 'quant', 'multinomial' or 'argmax'")
|
| 193 |
+
|
| 194 |
+
if i < renoise_steps:
|
| 195 |
+
t_next = torch.ones(latent_shape[0], device=device) * t_list[i+1]
|
| 196 |
+
sampled = model.add_noise(sampled, t_next, random_x=init_noise)[0]
|
| 197 |
+
return sampled
|
| 198 |
+
|
| 199 |
# -----
|
| 200 |
|
| 201 |
def infer(prompt, negative_prompt):
|
| 202 |
+
text = tokenizer.tokenize([prompt] * latent_shape[0]).to(device)
|
|
|
|
| 203 |
with torch.inference_mode():
|
| 204 |
+
if negative_prompt:
|
| 205 |
+
clip_text_tokens_uncond = tokenizer.tokenize([negative_prompt] * len(text)).to(device)
|
| 206 |
+
t5_embeddings_uncond = embed_t5([negative_prompt] * len(text), t5_tokenizer, t5_model)
|
| 207 |
+
else:
|
| 208 |
+
clip_text_tokens_uncond = tokenizer.tokenize([""] * len(text)).to(device)
|
| 209 |
+
t5_embeddings_uncond = embed_t5([""] * len(text), t5_tokenizer, t5_model)
|
| 210 |
+
|
| 211 |
+
t5_embeddings = embed_t5([prompt] * latent_shape[0], t5_tokenizer, t5_model)
|
| 212 |
+
clip_text_embeddings = clip_model.encode_text(text)
|
| 213 |
+
clip_text_embeddings_uncond = clip_model.encode_text(clip_text_tokens_uncond)
|
| 214 |
|
| 215 |
+
with torch.autocast(device_type="cuda"):
|
| 216 |
+
clip_image_embeddings = diffuzz.sample(
|
| 217 |
+
prior, {'c': clip_text_embeddings}, clip_embedding_shape,
|
| 218 |
timesteps=prior_timesteps, cfg=prior_cfg, sampler=prior_sampler
|
| 219 |
)[-1]
|
| 220 |
+
|
| 221 |
+
attn_weights = torch.ones((t5_embeddings.shape[1]))
|
| 222 |
+
attn_weights[-4:] = 0.4 # reweigh attention weights for image embeddings --> less influence
|
| 223 |
+
attn_weights[:-4] = 1.2 # reweigh attention weights for the rest --> more influence
|
| 224 |
+
attn_weights = attn_weights.to(device)
|
| 225 |
+
|
| 226 |
+
sampled_tokens = sample(model,
|
| 227 |
+
model_inputs={'byt5': t5_embeddings, 'clip': clip_text_embeddings, 'clip_image': clip_image_embeddings}, unconditional_inputs={'byt5': t5_embeddings_uncond, 'clip': clip_text_embeddings_uncond, 'clip_image': None},
|
| 228 |
+
temperature=(1.2, 0.2), cfg=(8,8), steps=32, renoise_steps=26, latent_shape=latent_shape, t_start=1.0, t_end=0.0,
|
| 229 |
+
mode="multinomial", sampling_conditional_steps=20, attn_weights=attn_weights)
|
| 230 |
+
|
| 231 |
+
sampled = decode(sampled_tokens)
|
| 232 |
+
return to_pil(sampled.clamp(0, 1))
|
| 233 |
|
| 234 |
css = """
|
| 235 |
.gradio-container {
|
|
|
|
| 381 |
Paella Demo
|
| 382 |
</h1>
|
| 383 |
</div>
|
|
|
|
|
|
|
|
|
|
| 384 |
<p style="margin-bottom: 10px; font-size: 94%">
|
| 385 |
Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.
|
| 386 |
</p>
|
|
|
|
| 395 |
label="Enter your prompt",
|
| 396 |
show_label=False,
|
| 397 |
max_lines=1,
|
| 398 |
+
placeholder="an image of a shiba inu, donning a spacesuit and helmet, traversing the uncharted terrain of a distant, extraterrestrial world, as a symbol of the intrepid spirit of exploration and the unrelenting curiosity that drives humanity to push beyond the bounds of the known",
|
| 399 |
elem_id="prompt-text-input",
|
| 400 |
).style(
|
| 401 |
border=(True, False, True, True),
|
|
|
|
| 406 |
label="Enter your negative prompt",
|
| 407 |
show_label=False,
|
| 408 |
max_lines=1,
|
| 409 |
+
placeholder="low quality, low resolution, bad image, blurry, blur",
|
| 410 |
elem_id="negative-prompt-text-input",
|
| 411 |
).style(
|
| 412 |
border=(True, False, True, True),
|
requirements.txt
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
torch
|
| 2 |
open_clip_torch
|
| 3 |
-
einops
|
| 4 |
Pillow
|
| 5 |
huggingface_hub
|
| 6 |
-
git+https://github.com/
|
| 7 |
-
git+https://github.com/
|
|
|
|
|
|
| 1 |
torch
|
| 2 |
open_clip_torch
|
|
|
|
| 3 |
Pillow
|
| 4 |
huggingface_hub
|
| 5 |
+
git+https://github.com/pabloppp/pytorch-tools
|
| 6 |
+
git+https://github.com/pabloppp/Arroz-Con-Cosas
|
| 7 |
+
git+https://github.com/fbcotter/pytorch_wavelets
|