Commit
·
161aead
1
Parent(s):
1f308ee
init
Browse files- .gitattributes +1 -0
- README.md +81 -0
- assets/robustness.png +3 -0
- autoencoder.py +134 -0
- config.json +40 -0
- diffusion_pytorch_model.safetensors +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/robustness.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- NextStep
|
| 5 |
+
- Image Tokenizer
|
| 6 |
+
---
|
| 7 |
+
# Improved Image Tokenizer
|
| 8 |
+
|
| 9 |
+
This is an improved image tokenizer of NextStep-1, featuring a fine-tuned decoder with a frozen encoder. The decoder refinement **improves performance** while preserving robust reconstruction quality. We **recommend using this Image Tokenizer** for optimal results with NextStep-1 models.
|
| 10 |
+
|
| 11 |
+
## Usage
|
| 12 |
+
|
| 13 |
+
```py
|
| 14 |
+
import torch
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torchvision.transforms as transforms
|
| 18 |
+
|
| 19 |
+
from autoencoder import AutoencoderKLNextStep
|
| 20 |
+
|
| 21 |
+
device = "cuda"
|
| 22 |
+
dtype = torch.bfloat16
|
| 23 |
+
|
| 24 |
+
model_path = "/path/to/vae_dir"
|
| 25 |
+
vae = AutoencoderKLNextStep.from_pretrained(model_path).to(device=device, dtype=dtype)
|
| 26 |
+
|
| 27 |
+
pil2tensor = transforms.Compose(
|
| 28 |
+
[
|
| 29 |
+
transforms.ToTensor(),
|
| 30 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
| 31 |
+
]
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
image = Image.open("/path/to/image.jpg")
|
| 35 |
+
pixel_values = pil2tensor(image).unsqueeze(0).to(device=device, dtype=dtype)
|
| 36 |
+
|
| 37 |
+
# encode
|
| 38 |
+
latents = vae.encode(pixel_values).latent_dist.sample()
|
| 39 |
+
|
| 40 |
+
# decode
|
| 41 |
+
sampled_images = vae.decode(latents).sample
|
| 42 |
+
sampled_images = sampled_images.detach().cpu().to(torch.float32)
|
| 43 |
+
|
| 44 |
+
def tensor_to_pil(tensor):
|
| 45 |
+
image = tensor.detach().cpu().to(torch.float32)
|
| 46 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 47 |
+
image = image.mul(255).round().to(dtype=torch.uint8)
|
| 48 |
+
image = image.permute(1, 2, 0).numpy()
|
| 49 |
+
return Image.fromarray(image, mode="RGB")
|
| 50 |
+
|
| 51 |
+
rec_image = tensor_to_pil(sampled_images[0])
|
| 52 |
+
rec_image.save("/path/to/output.jpg")
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Evaluation
|
| 56 |
+
|
| 57 |
+
### Reconstruction Performance on ImageNet-1K 256×256
|
| 58 |
+
|
| 59 |
+
| Tokenizer | Latent Shape | PSNR ↑ | SSIM ↑ |
|
| 60 |
+
| ------------------------- | ------------ | --------- | -------- |
|
| 61 |
+
| **Discrete Tokenizers** | | | |
|
| 62 |
+
| SBER-MoVQGAN (270M) | 32×32 | 27.04 | 0.74 |
|
| 63 |
+
| LlamaGen | 32×32 | 24.44 | 0.77 |
|
| 64 |
+
| VAR | 680 | 22.12 | 0.62 |
|
| 65 |
+
| TiTok-S-128 | 128 | 17.52 | 0.44 |
|
| 66 |
+
| Sefltok | 1024 | 26.30 | 0.81 |
|
| 67 |
+
| **Continuous Tokenizers** | | | |
|
| 68 |
+
| Stable Diffusion 1.5 | 32×32×4 | 25.18 | 0.73 |
|
| 69 |
+
| Stable Diffusion XL | 32×32×4 | 26.22 | 0.77 |
|
| 70 |
+
| Stable Diffusion 3 Medium | 32×32×16 | 30.00 | 0.88 |
|
| 71 |
+
| Flux.1-dev | 32×32×16 | 31.64 | 0.91 |
|
| 72 |
+
| **NextStep-1** | **32×32×16** | **30.60** | **0.89** |
|
| 73 |
+
|
| 74 |
+
### Robustness of NextStep-1-f8ch16-Tokenizer
|
| 75 |
+
|
| 76 |
+
Impact of Noise Perturbation on Image Tokenizer Performance. The top panel displays
|
| 77 |
+
quantitative metrics (rFID↓, PSNR↑, and SSIM↑) versus noise intensity. The bottom panel presents qualitative reconstruction examples at noise standard deviations of 0.2 and 0.5.
|
| 78 |
+
|
| 79 |
+
<div align='center'>
|
| 80 |
+
<img src="assets/robustness.png" class="interpolation-image" alt="arch." width="100%" />
|
| 81 |
+
</div>
|
assets/robustness.png
ADDED
|
Git LFS Details
|
autoencoder.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from diffusers import AutoencoderKL
|
| 7 |
+
from diffusers.configuration_utils import register_to_config
|
| 8 |
+
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
| 9 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 10 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AutoencoderKLNextStep(AutoencoderKL):
|
| 14 |
+
@register_to_config
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_channels: int = 3,
|
| 18 |
+
out_channels: int = 3,
|
| 19 |
+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
| 20 |
+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
| 21 |
+
block_out_channels: Tuple[int] = (64,),
|
| 22 |
+
layers_per_block: int = 1,
|
| 23 |
+
act_fn: str = "silu",
|
| 24 |
+
latent_channels: int = 4,
|
| 25 |
+
norm_num_groups: int = 32,
|
| 26 |
+
sample_size: int = 32,
|
| 27 |
+
scaling_factor: float = 0.18215,
|
| 28 |
+
shift_factor: Optional[float] = None,
|
| 29 |
+
latents_mean: Optional[Tuple[float]] = None,
|
| 30 |
+
latents_std: Optional[Tuple[float]] = None,
|
| 31 |
+
force_upcast: bool = True,
|
| 32 |
+
use_quant_conv: bool = True,
|
| 33 |
+
use_post_quant_conv: bool = True,
|
| 34 |
+
mid_block_add_attention: bool = True,
|
| 35 |
+
deterministic: bool = False,
|
| 36 |
+
normalize_latents: bool = False,
|
| 37 |
+
patch_size: Optional[int] = None,
|
| 38 |
+
):
|
| 39 |
+
super().__init__(
|
| 40 |
+
in_channels=in_channels,
|
| 41 |
+
out_channels=out_channels,
|
| 42 |
+
down_block_types=down_block_types,
|
| 43 |
+
up_block_types=up_block_types,
|
| 44 |
+
block_out_channels=block_out_channels,
|
| 45 |
+
layers_per_block=layers_per_block,
|
| 46 |
+
act_fn=act_fn,
|
| 47 |
+
latent_channels=latent_channels,
|
| 48 |
+
norm_num_groups=norm_num_groups,
|
| 49 |
+
sample_size=sample_size,
|
| 50 |
+
scaling_factor=scaling_factor,
|
| 51 |
+
shift_factor=shift_factor,
|
| 52 |
+
latents_mean=latents_mean,
|
| 53 |
+
latents_std=latents_std,
|
| 54 |
+
force_upcast=force_upcast,
|
| 55 |
+
use_quant_conv=use_quant_conv,
|
| 56 |
+
use_post_quant_conv=use_post_quant_conv,
|
| 57 |
+
mid_block_add_attention=mid_block_add_attention,
|
| 58 |
+
)
|
| 59 |
+
self.deterministic = deterministic
|
| 60 |
+
self.normalize_latents = normalize_latents
|
| 61 |
+
self.patch_size = patch_size
|
| 62 |
+
|
| 63 |
+
def patchify(self, x: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
b, c, h, w = x.shape
|
| 65 |
+
p = self.patch_size
|
| 66 |
+
h_, w_ = h // p, w // p
|
| 67 |
+
|
| 68 |
+
x = x.reshape(b, c, h_, p, w_, p)
|
| 69 |
+
x = torch.einsum("bchpwq->bcpqhw", x)
|
| 70 |
+
x = x.reshape(b, c * p ** 2, h_, w_)
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
b, _, h_, w_ = x.shape
|
| 75 |
+
p = self.patch_size
|
| 76 |
+
c = x.shape[1] // (p ** 2)
|
| 77 |
+
|
| 78 |
+
x = x.reshape(b, c, p, p, h_, w_)
|
| 79 |
+
x = torch.einsum("bcpqhw->bchpwq", x)
|
| 80 |
+
x = x.reshape(b, c, h_ * p, w_ * p)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
@apply_forward_hook
|
| 84 |
+
def encode(
|
| 85 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 86 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 87 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 88 |
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
| 89 |
+
h = torch.cat(encoded_slices)
|
| 90 |
+
else:
|
| 91 |
+
h = self._encode(x)
|
| 92 |
+
|
| 93 |
+
mean, logvar = torch.chunk(h, 2, dim=1)
|
| 94 |
+
if self.patch_size is not None:
|
| 95 |
+
mean = self.patchify(mean)
|
| 96 |
+
if self.normalize_latents:
|
| 97 |
+
mean = mean.permute(0, 2, 3, 1)
|
| 98 |
+
mean = F.layer_norm(mean, mean.shape[-1:], eps=1e-6)
|
| 99 |
+
mean = mean.permute(0, 3, 1, 2)
|
| 100 |
+
if self.patch_size is not None:
|
| 101 |
+
mean = self.unpatchify(mean)
|
| 102 |
+
h = torch.cat([mean, logvar], dim=1).contiguous()
|
| 103 |
+
posterior = DiagonalGaussianDistribution(h, deterministic=self.deterministic)
|
| 104 |
+
|
| 105 |
+
if not return_dict:
|
| 106 |
+
return (posterior,)
|
| 107 |
+
|
| 108 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 109 |
+
|
| 110 |
+
def forward(
|
| 111 |
+
self,
|
| 112 |
+
sample: torch.Tensor,
|
| 113 |
+
sample_posterior: bool = False,
|
| 114 |
+
return_dict: bool = True,
|
| 115 |
+
generator: Optional[torch.Generator] = None,
|
| 116 |
+
noise_strength: float = 0.0,
|
| 117 |
+
) -> Union[DecoderOutput, torch.Tensor]:
|
| 118 |
+
x = sample
|
| 119 |
+
posterior = self.encode(x).latent_dist
|
| 120 |
+
if sample_posterior:
|
| 121 |
+
z = posterior.sample(generator=generator)
|
| 122 |
+
else:
|
| 123 |
+
z = posterior.mode()
|
| 124 |
+
if noise_strength > 0.0:
|
| 125 |
+
p = torch.distributions.Uniform(0, noise_strength)
|
| 126 |
+
z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
|
| 127 |
+
z.shape, device=z.device, dtype=z.dtype
|
| 128 |
+
)
|
| 129 |
+
dec = self.decode(z).sample
|
| 130 |
+
|
| 131 |
+
if not return_dict:
|
| 132 |
+
return (dec,)
|
| 133 |
+
|
| 134 |
+
return DecoderOutput(sample=dec)
|
config.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKLNextStep",
|
| 3 |
+
"_diffusers_version": "0.35.0.dev0",
|
| 4 |
+
"act_fn": "silu",
|
| 5 |
+
"block_out_channels": [
|
| 6 |
+
128,
|
| 7 |
+
256,
|
| 8 |
+
512,
|
| 9 |
+
512
|
| 10 |
+
],
|
| 11 |
+
"down_block_types": [
|
| 12 |
+
"DownEncoderBlock2D",
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D"
|
| 16 |
+
],
|
| 17 |
+
"force_upcast": true,
|
| 18 |
+
"in_channels": 3,
|
| 19 |
+
"latent_channels": 16,
|
| 20 |
+
"latents_mean": null,
|
| 21 |
+
"latents_std": null,
|
| 22 |
+
"layers_per_block": 2,
|
| 23 |
+
"mid_block_add_attention": true,
|
| 24 |
+
"norm_num_groups": 32,
|
| 25 |
+
"out_channels": 3,
|
| 26 |
+
"sample_size": 512,
|
| 27 |
+
"scaling_factor": 1,
|
| 28 |
+
"shift_factor": 0,
|
| 29 |
+
"up_block_types": [
|
| 30 |
+
"UpDecoderBlock2D",
|
| 31 |
+
"UpDecoderBlock2D",
|
| 32 |
+
"UpDecoderBlock2D",
|
| 33 |
+
"UpDecoderBlock2D"
|
| 34 |
+
],
|
| 35 |
+
"use_post_quant_conv": false,
|
| 36 |
+
"use_quant_conv": false,
|
| 37 |
+
"deterministic": true,
|
| 38 |
+
"normalize_latents": true,
|
| 39 |
+
"patch_size": 2
|
| 40 |
+
}
|
diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d67ef6afe4ec377d53e99b270cf9a5f346f4c21dfe00732e2043b5b4c42ba394
|
| 3 |
+
size 335306212
|