Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
-
import os
|
| 2 |
import spaces
|
| 3 |
-
from dataclasses import dataclass
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import torch
|
|
@@ -22,6 +21,7 @@ from torch import Tensor, nn
|
|
| 22 |
from transformers import CLIPTextModel, CLIPTokenizer
|
| 23 |
from transformers import T5EncoderModel, T5Tokenizer
|
| 24 |
from safetensors.torch import load_file
|
|
|
|
| 25 |
# from optimum.quanto import freeze, qfloat8, quantize
|
| 26 |
|
| 27 |
|
|
@@ -216,18 +216,27 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
|
| 216 |
q, k = apply_rope(q, k, pe)
|
| 217 |
|
| 218 |
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 219 |
-
x = rearrange(x, "B H L D -> B L (H D)")
|
|
|
|
| 220 |
|
| 221 |
return x
|
| 222 |
|
| 223 |
|
| 224 |
-
def rope(pos
|
| 225 |
-
assert dim % 2 == 0
|
| 226 |
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
| 227 |
-
omega = 1.0 / (theta**scale)
|
| 228 |
-
|
| 229 |
-
out = torch.
|
| 230 |
-
out =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
return out.float()
|
| 232 |
|
| 233 |
|
|
@@ -267,9 +276,12 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
|
| 267 |
"""
|
| 268 |
t = time_factor * t
|
| 269 |
half = dim // 2
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
)
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
args = t[:, None].float() * freqs[None]
|
| 275 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
@@ -327,7 +339,10 @@ class SelfAttention(nn.Module):
|
|
| 327 |
|
| 328 |
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
| 329 |
qkv = self.qkv(x)
|
| 330 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
|
|
|
|
|
|
|
|
| 331 |
q, k = self.norm(q, k, v)
|
| 332 |
x = attention(q, k, v, pe=pe)
|
| 333 |
x = self.proj(x)
|
|
@@ -394,14 +409,20 @@ class DoubleStreamBlock(nn.Module):
|
|
| 394 |
img_modulated = self.img_norm1(img)
|
| 395 |
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 396 |
img_qkv = self.img_attn.qkv(img_modulated)
|
| 397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
| 399 |
|
| 400 |
# prepare txt for attention
|
| 401 |
txt_modulated = self.txt_norm1(txt)
|
| 402 |
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 403 |
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
| 404 |
-
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
|
|
|
|
|
| 405 |
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 406 |
|
| 407 |
# run actual attention
|
|
@@ -460,7 +481,9 @@ class SingleStreamBlock(nn.Module):
|
|
| 460 |
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
| 461 |
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
| 462 |
|
| 463 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
|
|
|
|
|
| 464 |
q, k = self.norm(q, k, v)
|
| 465 |
|
| 466 |
# compute attention
|
|
@@ -677,9 +700,7 @@ def denoise(
|
|
| 677 |
timesteps=t_vec,
|
| 678 |
guidance=guidance_vec,
|
| 679 |
)
|
| 680 |
-
|
| 681 |
img = img + (t_prev - t_curr) * pred
|
| 682 |
-
|
| 683 |
return img
|
| 684 |
|
| 685 |
|
|
@@ -723,7 +744,7 @@ from safetensors.torch import load_file
|
|
| 723 |
|
| 724 |
sd = load_file(hf_hub_download(repo_id="lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4.safetensors"))
|
| 725 |
sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
|
| 726 |
-
model = Flux().to(dtype=torch.
|
| 727 |
result = model.load_state_dict(sd)
|
| 728 |
print(result)
|
| 729 |
|
|
@@ -731,7 +752,7 @@ print(result)
|
|
| 731 |
# result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
|
| 732 |
|
| 733 |
@spaces.GPU
|
| 734 |
-
@torch.
|
| 735 |
def generate_image(
|
| 736 |
prompt, width, height, guidance, seed,
|
| 737 |
do_img2img, init_image, image2image_strength, resize_img,
|
|
@@ -742,7 +763,7 @@ def generate_image(
|
|
| 742 |
|
| 743 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 744 |
torch_device = torch.device(device)
|
| 745 |
-
|
| 746 |
global model
|
| 747 |
model = model.to(torch_device)
|
| 748 |
|
|
@@ -761,7 +782,7 @@ def generate_image(
|
|
| 761 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 762 |
x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
|
| 763 |
|
| 764 |
-
num_steps =
|
| 765 |
timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
|
| 766 |
|
| 767 |
if do_img2img and init_image is not None:
|
|
@@ -770,13 +791,16 @@ def generate_image(
|
|
| 770 |
timesteps = timesteps[t_idx:]
|
| 771 |
x = t * x + (1.0 - t) * init_image.to(x.dtype)
|
| 772 |
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
|
|
|
|
|
|
|
|
|
| 780 |
|
| 781 |
x = x.clamp(-1, 1)
|
| 782 |
x = rearrange(x[0], "c h w -> h w c")
|
|
|
|
| 1 |
+
# import os
|
| 2 |
import spaces
|
|
|
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
import torch
|
|
|
|
| 21 |
from transformers import CLIPTextModel, CLIPTokenizer
|
| 22 |
from transformers import T5EncoderModel, T5Tokenizer
|
| 23 |
from safetensors.torch import load_file
|
| 24 |
+
# from torch.profiler import profile, record_function, ProfilerActivity
|
| 25 |
# from optimum.quanto import freeze, qfloat8, quantize
|
| 26 |
|
| 27 |
|
|
|
|
| 216 |
q, k = apply_rope(q, k, pe)
|
| 217 |
|
| 218 |
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 219 |
+
# x = rearrange(x, "B H L D -> B L (H D)")
|
| 220 |
+
x = x.permute(0, 2, 1, 3).contiguous().reshape(x.size(0), x.size(2), -1)
|
| 221 |
|
| 222 |
return x
|
| 223 |
|
| 224 |
|
| 225 |
+
def rope(pos, dim, theta):
|
|
|
|
| 226 |
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
| 227 |
+
omega = 1.0 / (theta ** scale)
|
| 228 |
+
|
| 229 |
+
# out = torch.einsum("...n,d->...nd", pos, omega)
|
| 230 |
+
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
|
| 231 |
+
|
| 232 |
+
cos_out = torch.cos(out)
|
| 233 |
+
sin_out = torch.sin(out)
|
| 234 |
+
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
| 235 |
+
|
| 236 |
+
# out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
| 237 |
+
b, n, d, _ = out.shape
|
| 238 |
+
out = out.view(b, n, d, 2, 2)
|
| 239 |
+
|
| 240 |
return out.float()
|
| 241 |
|
| 242 |
|
|
|
|
| 276 |
"""
|
| 277 |
t = time_factor * t
|
| 278 |
half = dim // 2
|
| 279 |
+
|
| 280 |
+
# Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
|
| 281 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
| 282 |
+
|
| 283 |
+
# Block CUDA steam, but consistent with official codes:
|
| 284 |
+
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
| 285 |
|
| 286 |
args = t[:, None].float() * freqs[None]
|
| 287 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
|
|
| 339 |
|
| 340 |
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
| 341 |
qkv = self.qkv(x)
|
| 342 |
+
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 343 |
+
B, L, _ = qkv.shape
|
| 344 |
+
qkv = qkv.view(B, L, 3, self.num_heads, -1)
|
| 345 |
+
q, k, v = qkv.permute(2, 0, 3, 1, 4)
|
| 346 |
q, k = self.norm(q, k, v)
|
| 347 |
x = attention(q, k, v, pe=pe)
|
| 348 |
x = self.proj(x)
|
|
|
|
| 409 |
img_modulated = self.img_norm1(img)
|
| 410 |
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 411 |
img_qkv = self.img_attn.qkv(img_modulated)
|
| 412 |
+
# img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 413 |
+
B, L, _ = img_qkv.shape
|
| 414 |
+
H = self.num_heads
|
| 415 |
+
D = img_qkv.shape[-1] // (3 * H)
|
| 416 |
+
img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
|
| 417 |
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
| 418 |
|
| 419 |
# prepare txt for attention
|
| 420 |
txt_modulated = self.txt_norm1(txt)
|
| 421 |
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 422 |
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
| 423 |
+
# txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 424 |
+
B, L, _ = txt_qkv.shape
|
| 425 |
+
txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
|
| 426 |
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 427 |
|
| 428 |
# run actual attention
|
|
|
|
| 481 |
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
| 482 |
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
| 483 |
|
| 484 |
+
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 485 |
+
qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
|
| 486 |
+
q, k, v = qkv.permute(2, 0, 3, 1, 4)
|
| 487 |
q, k = self.norm(q, k, v)
|
| 488 |
|
| 489 |
# compute attention
|
|
|
|
| 700 |
timesteps=t_vec,
|
| 701 |
guidance=guidance_vec,
|
| 702 |
)
|
|
|
|
| 703 |
img = img + (t_prev - t_curr) * pred
|
|
|
|
| 704 |
return img
|
| 705 |
|
| 706 |
|
|
|
|
| 744 |
|
| 745 |
sd = load_file(hf_hub_download(repo_id="lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4.safetensors"))
|
| 746 |
sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
|
| 747 |
+
model = Flux().to(dtype=torch.bfloat16, device="cuda")
|
| 748 |
result = model.load_state_dict(sd)
|
| 749 |
print(result)
|
| 750 |
|
|
|
|
| 752 |
# result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
|
| 753 |
|
| 754 |
@spaces.GPU
|
| 755 |
+
@torch.no_grad()
|
| 756 |
def generate_image(
|
| 757 |
prompt, width, height, guidance, seed,
|
| 758 |
do_img2img, init_image, image2image_strength, resize_img,
|
|
|
|
| 763 |
|
| 764 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 765 |
torch_device = torch.device(device)
|
| 766 |
+
|
| 767 |
global model
|
| 768 |
model = model.to(torch_device)
|
| 769 |
|
|
|
|
| 782 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 783 |
x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
|
| 784 |
|
| 785 |
+
num_steps = 20
|
| 786 |
timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
|
| 787 |
|
| 788 |
if do_img2img and init_image is not None:
|
|
|
|
| 791 |
timesteps = timesteps[t_idx:]
|
| 792 |
x = t * x + (1.0 - t) * init_image.to(x.dtype)
|
| 793 |
|
| 794 |
+
inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
|
| 795 |
+
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
|
| 796 |
+
|
| 797 |
+
# with profile(activities=[ProfilerActivity.CPU],record_shapes=True,profile_memory=True) as prof:
|
| 798 |
+
# print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
|
| 799 |
+
|
| 800 |
+
x = unpack(x.float(), height, width)
|
| 801 |
+
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
| 802 |
+
x = x = (x / ae.config.scaling_factor) + ae.config.shift_factor
|
| 803 |
+
x = ae.decode(x).sample
|
| 804 |
|
| 805 |
x = x.clamp(-1, 1)
|
| 806 |
x = rearrange(x[0], "c h w -> h w c")
|