Spaces:
Running
on
Zero
Running
on
Zero
fix
Browse files- app.py +31 -30
- vae/config.json +1 -2
- vae/nextstep_ae.py +39 -72
app.py
CHANGED
|
@@ -3,9 +3,8 @@ import numpy as np
|
|
| 3 |
import random
|
| 4 |
import spaces
|
| 5 |
from PIL import Image
|
| 6 |
-
|
| 7 |
-
# import spaces #[uncomment to use ZeroGPU]
|
| 8 |
import torch
|
|
|
|
| 9 |
|
| 10 |
from transformers import AutoTokenizer, AutoModel
|
| 11 |
from models.gen_pipeline import NextStepPipeline
|
|
@@ -15,8 +14,15 @@ HF_HUB = "stepfun-ai/NextStep-1-Large"
|
|
| 15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
|
| 17 |
tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True)
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
MAX_SEED = np.iinfo(np.int16).max
|
| 22 |
MAX_IMAGE_SIZE = 512
|
|
@@ -30,8 +36,6 @@ def infer(
|
|
| 30 |
seed=0,
|
| 31 |
width=512,
|
| 32 |
height=512,
|
| 33 |
-
#text_cfg=7.5,
|
| 34 |
-
#img_cfg=1.0,
|
| 35 |
num_inference_steps=28,
|
| 36 |
positive_prompt=DEFAULT_POSITIVE_PROMPT,
|
| 37 |
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
|
@@ -40,21 +44,23 @@ def infer(
|
|
| 40 |
if prompt in [None, ""]:
|
| 41 |
gr.Warning("⚠️ Please enter a prompt!")
|
| 42 |
return None
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
| 58 |
|
| 59 |
return image[0]
|
| 60 |
|
|
@@ -114,7 +120,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 114 |
step=1,
|
| 115 |
value=28,
|
| 116 |
)
|
| 117 |
-
|
| 118 |
with gr.Row():
|
| 119 |
width = gr.Slider(
|
| 120 |
label="Width",
|
|
@@ -132,9 +138,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 132 |
)
|
| 133 |
|
| 134 |
with gr.Row():
|
| 135 |
-
result_1 = gr.Image(label="Result 1", show_label=False, container=True, height=
|
| 136 |
-
|
| 137 |
-
#gr.Examples(examples=examples, inputs=[prompt, ref])
|
| 138 |
|
| 139 |
def show_result():
|
| 140 |
return gr.update(visible=True)
|
|
@@ -147,15 +151,13 @@ with gr.Blocks(css=css) as demo:
|
|
| 147 |
seed,
|
| 148 |
width,
|
| 149 |
height,
|
| 150 |
-
#text_cfg,
|
| 151 |
-
#img_cfg,
|
| 152 |
num_inference_steps,
|
| 153 |
positive_prompt,
|
| 154 |
negative_prompt,
|
| 155 |
],
|
| 156 |
outputs=[result_1],
|
| 157 |
)
|
| 158 |
-
|
| 159 |
cancel_button.click(
|
| 160 |
fn=None,
|
| 161 |
inputs=None,
|
|
@@ -169,6 +171,5 @@ with gr.Blocks(css=css) as demo:
|
|
| 169 |
outputs=[result_1],
|
| 170 |
)
|
| 171 |
|
| 172 |
-
|
| 173 |
if __name__ == "__main__":
|
| 174 |
-
demo.launch()
|
|
|
|
| 3 |
import random
|
| 4 |
import spaces
|
| 5 |
from PIL import Image
|
|
|
|
|
|
|
| 6 |
import torch
|
| 7 |
+
from torch.amp import autocast
|
| 8 |
|
| 9 |
from transformers import AutoTokenizer, AutoModel
|
| 10 |
from models.gen_pipeline import NextStepPipeline
|
|
|
|
| 14 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
|
| 16 |
tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True)
|
| 17 |
+
|
| 18 |
+
model = AutoModel.from_pretrained(
|
| 19 |
+
HF_HUB,
|
| 20 |
+
local_files_only=False,
|
| 21 |
+
trust_remote_code=True,
|
| 22 |
+
torch_dtype=torch.bfloat16,
|
| 23 |
+
).to(device)
|
| 24 |
+
|
| 25 |
+
pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)
|
| 26 |
|
| 27 |
MAX_SEED = np.iinfo(np.int16).max
|
| 28 |
MAX_IMAGE_SIZE = 512
|
|
|
|
| 36 |
seed=0,
|
| 37 |
width=512,
|
| 38 |
height=512,
|
|
|
|
|
|
|
| 39 |
num_inference_steps=28,
|
| 40 |
positive_prompt=DEFAULT_POSITIVE_PROMPT,
|
| 41 |
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
|
|
|
| 44 |
if prompt in [None, ""]:
|
| 45 |
gr.Warning("⚠️ Please enter a prompt!")
|
| 46 |
return None
|
| 47 |
+
|
| 48 |
+
with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
|
| 49 |
+
image = pipeline.generate_image(
|
| 50 |
+
prompt,
|
| 51 |
+
hw=(height, width),
|
| 52 |
+
num_images_per_caption=1,
|
| 53 |
+
positive_prompt=positive_prompt,
|
| 54 |
+
negative_prompt=negative_prompt,
|
| 55 |
+
cfg=7.5,
|
| 56 |
+
cfg_img=1.0,
|
| 57 |
+
cfg_schedule="constant",
|
| 58 |
+
use_norm=False,
|
| 59 |
+
num_sampling_steps=num_inference_steps,
|
| 60 |
+
timesteps_shift=1.0,
|
| 61 |
+
seed=seed,
|
| 62 |
+
progress=True,
|
| 63 |
+
)
|
| 64 |
|
| 65 |
return image[0]
|
| 66 |
|
|
|
|
| 120 |
step=1,
|
| 121 |
value=28,
|
| 122 |
)
|
| 123 |
+
|
| 124 |
with gr.Row():
|
| 125 |
width = gr.Slider(
|
| 126 |
label="Width",
|
|
|
|
| 138 |
)
|
| 139 |
|
| 140 |
with gr.Row():
|
| 141 |
+
result_1 = gr.Image(label="Result 1", show_label=False, container=True, height=MAX_IMAGE_SIZE, visible=False)
|
|
|
|
|
|
|
| 142 |
|
| 143 |
def show_result():
|
| 144 |
return gr.update(visible=True)
|
|
|
|
| 151 |
seed,
|
| 152 |
width,
|
| 153 |
height,
|
|
|
|
|
|
|
| 154 |
num_inference_steps,
|
| 155 |
positive_prompt,
|
| 156 |
negative_prompt,
|
| 157 |
],
|
| 158 |
outputs=[result_1],
|
| 159 |
)
|
| 160 |
+
|
| 161 |
cancel_button.click(
|
| 162 |
fn=None,
|
| 163 |
inputs=None,
|
|
|
|
| 171 |
outputs=[result_1],
|
| 172 |
)
|
| 173 |
|
|
|
|
| 174 |
if __name__ == "__main__":
|
| 175 |
+
demo.launch()
|
vae/config.json
CHANGED
|
@@ -9,7 +9,6 @@
|
|
| 9 |
"shift_factor": 0,
|
| 10 |
"scaling_factor": 1,
|
| 11 |
"deterministic": true,
|
| 12 |
-
"
|
| 13 |
-
"norm_level": "channel",
|
| 14 |
"psz": 1
|
| 15 |
}
|
|
|
|
| 9 |
"shift_factor": 0,
|
| 10 |
"scaling_factor": 1,
|
| 11 |
"deterministic": true,
|
| 12 |
+
"encoder_norm": true,
|
|
|
|
| 13 |
"psz": 1
|
| 14 |
}
|
vae/nextstep_ae.py
CHANGED
|
@@ -2,7 +2,6 @@ import os
|
|
| 2 |
import json
|
| 3 |
import inspect
|
| 4 |
from dataclasses import dataclass, field, asdict
|
| 5 |
-
from typing import Literal
|
| 6 |
from loguru import logger
|
| 7 |
from omegaconf import OmegaConf
|
| 8 |
from tabulate import tabulate
|
|
@@ -10,6 +9,7 @@ from einops import rearrange
|
|
| 10 |
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
|
|
|
| 13 |
from torch import Tensor
|
| 14 |
from torch.utils.checkpoint import checkpoint
|
| 15 |
|
|
@@ -17,7 +17,7 @@ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDis
|
|
| 17 |
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 18 |
|
| 19 |
from utils.misc import LargeInt
|
| 20 |
-
from utils.model_utils import
|
| 21 |
from utils.compile_utils import smart_compile
|
| 22 |
|
| 23 |
|
|
@@ -33,8 +33,7 @@ class AutoEncoderParams:
|
|
| 33 |
scaling_factor: float = 0.3611
|
| 34 |
shift_factor: float = 0.1159
|
| 35 |
deterministic: bool = False
|
| 36 |
-
|
| 37 |
-
norm_level: Literal["latent", "channel"] = "latent"
|
| 38 |
psz: int | None = None
|
| 39 |
|
| 40 |
|
|
@@ -306,6 +305,14 @@ class Decoder(nn.Module):
|
|
| 306 |
return h
|
| 307 |
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
class AutoencoderKL(nn.Module):
|
| 310 |
def __init__(self, params: AutoEncoderParams):
|
| 311 |
super().__init__()
|
|
@@ -333,19 +340,8 @@ class AutoencoderKL(nn.Module):
|
|
| 333 |
z_channels=params.z_channels,
|
| 334 |
)
|
| 335 |
|
|
|
|
| 336 |
self.psz = params.psz
|
| 337 |
-
# if self.psz is not None:
|
| 338 |
-
# logger.warning("psz has been deprecated, this is only used for hack's vae")
|
| 339 |
-
|
| 340 |
-
if params.norm_fn is None:
|
| 341 |
-
self.norm_fn = identity
|
| 342 |
-
elif params.norm_fn == "layer_norm":
|
| 343 |
-
self.norm_fn = layer_norm
|
| 344 |
-
elif params.norm_fn == "rms_norm":
|
| 345 |
-
self.norm_fn = rms_norm
|
| 346 |
-
else:
|
| 347 |
-
raise ValueError(f"Invalid norm_fn: {params.norm_fn}")
|
| 348 |
-
self.norm_level = params.norm_level
|
| 349 |
|
| 350 |
self.apply(self._init_weights)
|
| 351 |
|
|
@@ -420,18 +416,17 @@ class AutoencoderKL(nn.Module):
|
|
| 420 |
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
| 421 |
moments = self.encoder(x)
|
| 422 |
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
moments = torch.cat([mean, logvar], dim=1).contiguous()
|
| 435 |
|
| 436 |
posterior = DiagonalGaussianDistribution(moments, deterministic=self.params.deterministic)
|
| 437 |
|
|
@@ -448,14 +443,7 @@ class AutoencoderKL(nn.Module):
|
|
| 448 |
|
| 449 |
return DecoderOutput(sample=dec)
|
| 450 |
|
| 451 |
-
def forward(
|
| 452 |
-
self,
|
| 453 |
-
input,
|
| 454 |
-
sample_posterior=True,
|
| 455 |
-
noise_strength=0.0,
|
| 456 |
-
interpolative_noise=False,
|
| 457 |
-
t_dist: Literal["uniform", "logitnormal"] = "logitnormal",
|
| 458 |
-
):
|
| 459 |
posterior = self.encode(input).latent_dist
|
| 460 |
z = posterior.sample() if sample_posterior else posterior.mode()
|
| 461 |
if noise_strength > 0.0:
|
|
@@ -463,46 +451,25 @@ class AutoencoderKL(nn.Module):
|
|
| 463 |
z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
|
| 464 |
z.shape, device=z.device, dtype=z.dtype
|
| 465 |
)
|
| 466 |
-
if interpolative_noise:
|
| 467 |
-
z = self.patchify(z)
|
| 468 |
-
bsz, c, h, w = z.shape
|
| 469 |
-
z = z.permute(0, 2, 3, 1) # [bsz, h, w, c]
|
| 470 |
-
z = z.reshape(-1, c) # [bsz * h * w, c]
|
| 471 |
-
|
| 472 |
-
if t_dist == "logitnormal":
|
| 473 |
-
u = torch.normal(mean=0.0, std=1.0, size=(z.shape[0],))
|
| 474 |
-
t = (1 / (1 + torch.exp(-u))).to(z)
|
| 475 |
-
elif t_dist == "uniform":
|
| 476 |
-
t = torch.randn((z.shape[0],)).to(z)
|
| 477 |
-
else:
|
| 478 |
-
raise ValueError(f"Invalid t_dist: {t_dist}")
|
| 479 |
-
|
| 480 |
-
noise = torch.randn_like(z)
|
| 481 |
-
z = expand_t(t, z) * z + (1 - expand_t(t, z)) * noise
|
| 482 |
-
|
| 483 |
-
z = z.reshape(bsz, h, w, c).permute(0, 3, 1, 2)
|
| 484 |
-
z = self.unpatchify(z)
|
| 485 |
-
|
| 486 |
dec = self.decode(z).sample
|
| 487 |
return dec, posterior
|
| 488 |
|
| 489 |
@classmethod
|
| 490 |
-
def from_pretrained(cls,
|
| 491 |
-
config_path =
|
| 492 |
-
ckpt_path =
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
kwargs = config
|
| 506 |
|
| 507 |
# Filter out kwargs that are not in AutoEncoderParams
|
| 508 |
# This ensures we only pass parameters that the model can accept
|
|
|
|
| 2 |
import json
|
| 3 |
import inspect
|
| 4 |
from dataclasses import dataclass, field, asdict
|
|
|
|
| 5 |
from loguru import logger
|
| 6 |
from omegaconf import OmegaConf
|
| 7 |
from tabulate import tabulate
|
|
|
|
| 9 |
|
| 10 |
import torch
|
| 11 |
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
from torch import Tensor
|
| 14 |
from torch.utils.checkpoint import checkpoint
|
| 15 |
|
|
|
|
| 17 |
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 18 |
|
| 19 |
from utils.misc import LargeInt
|
| 20 |
+
from utils.model_utils import randn_tensor
|
| 21 |
from utils.compile_utils import smart_compile
|
| 22 |
|
| 23 |
|
|
|
|
| 33 |
scaling_factor: float = 0.3611
|
| 34 |
shift_factor: float = 0.1159
|
| 35 |
deterministic: bool = False
|
| 36 |
+
encoder_norm: bool = False
|
|
|
|
| 37 |
psz: int | None = None
|
| 38 |
|
| 39 |
|
|
|
|
| 305 |
return h
|
| 306 |
|
| 307 |
|
| 308 |
+
def layer_norm_2d(input: torch.Tensor, normalized_shape: torch.Size, eps: float = 1e-6) -> torch.Tensor:
|
| 309 |
+
# input.shape = (bsz, c, h, w)
|
| 310 |
+
_input = input.permute(0, 2, 3, 1)
|
| 311 |
+
_input = F.layer_norm(_input, normalized_shape, None, None, eps)
|
| 312 |
+
_input = _input.permute(0, 3, 1, 2)
|
| 313 |
+
return _input
|
| 314 |
+
|
| 315 |
+
|
| 316 |
class AutoencoderKL(nn.Module):
|
| 317 |
def __init__(self, params: AutoEncoderParams):
|
| 318 |
super().__init__()
|
|
|
|
| 340 |
z_channels=params.z_channels,
|
| 341 |
)
|
| 342 |
|
| 343 |
+
self.encoder_norm = params.encoder_norm
|
| 344 |
self.psz = params.psz
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
self.apply(self._init_weights)
|
| 347 |
|
|
|
|
| 416 |
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
| 417 |
moments = self.encoder(x)
|
| 418 |
|
| 419 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
| 420 |
+
if self.psz is not None:
|
| 421 |
+
mean = self.patchify(mean)
|
| 422 |
+
|
| 423 |
+
if self.encoder_norm:
|
| 424 |
+
mean = layer_norm_2d(mean, mean.size()[-1:])
|
| 425 |
+
|
| 426 |
+
if self.psz is not None:
|
| 427 |
+
mean = self.unpatchify(mean)
|
| 428 |
+
|
| 429 |
+
moments = torch.cat([mean, logvar], dim=1).contiguous()
|
|
|
|
| 430 |
|
| 431 |
posterior = DiagonalGaussianDistribution(moments, deterministic=self.params.deterministic)
|
| 432 |
|
|
|
|
| 443 |
|
| 444 |
return DecoderOutput(sample=dec)
|
| 445 |
|
| 446 |
+
def forward(self, input, sample_posterior=True, noise_strength=0.0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
posterior = self.encode(input).latent_dist
|
| 448 |
z = posterior.sample() if sample_posterior else posterior.mode()
|
| 449 |
if noise_strength > 0.0:
|
|
|
|
| 451 |
z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
|
| 452 |
z.shape, device=z.device, dtype=z.dtype
|
| 453 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
dec = self.decode(z).sample
|
| 455 |
return dec, posterior
|
| 456 |
|
| 457 |
@classmethod
|
| 458 |
+
def from_pretrained(cls, model_path, **kwargs):
|
| 459 |
+
config_path = os.path.join(model_path, "config.json")
|
| 460 |
+
ckpt_path = os.path.join(model_path, "checkpoint.pt")
|
| 461 |
+
|
| 462 |
+
if not os.path.isdir(model_path) or not os.path.isfile(config_path) or not os.path.isfile(ckpt_path):
|
| 463 |
+
raise ValueError(
|
| 464 |
+
f"Invalid model path: {model_path}. The path should contain both config.json and checkpoint.pt files."
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
| 468 |
+
|
| 469 |
+
with open(config_path, "r") as f:
|
| 470 |
+
config: dict = json.load(f)
|
| 471 |
+
config.update(kwargs)
|
| 472 |
+
kwargs = config
|
|
|
|
| 473 |
|
| 474 |
# Filter out kwargs that are not in AutoEncoderParams
|
| 475 |
# This ensures we only pass parameters that the model can accept
|