lavinal712 commited on
Commit
161aead
·
1 Parent(s): 1f308ee
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/robustness.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - NextStep
5
+ - Image Tokenizer
6
+ ---
7
+ # Improved Image Tokenizer
8
+
9
+ This is an improved image tokenizer of NextStep-1, featuring a fine-tuned decoder with a frozen encoder. The decoder refinement **improves performance** while preserving robust reconstruction quality. We **recommend using this Image Tokenizer** for optimal results with NextStep-1 models.
10
+
11
+ ## Usage
12
+
13
+ ```py
14
+ import torch
15
+ from PIL import Image
16
+ import numpy as np
17
+ import torchvision.transforms as transforms
18
+
19
+ from autoencoder import AutoencoderKLNextStep
20
+
21
+ device = "cuda"
22
+ dtype = torch.bfloat16
23
+
24
+ model_path = "/path/to/vae_dir"
25
+ vae = AutoencoderKLNextStep.from_pretrained(model_path).to(device=device, dtype=dtype)
26
+
27
+ pil2tensor = transforms.Compose(
28
+ [
29
+ transforms.ToTensor(),
30
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
31
+ ]
32
+ )
33
+
34
+ image = Image.open("/path/to/image.jpg")
35
+ pixel_values = pil2tensor(image).unsqueeze(0).to(device=device, dtype=dtype)
36
+
37
+ # encode
38
+ latents = vae.encode(pixel_values).latent_dist.sample()
39
+
40
+ # decode
41
+ sampled_images = vae.decode(latents).sample
42
+ sampled_images = sampled_images.detach().cpu().to(torch.float32)
43
+
44
+ def tensor_to_pil(tensor):
45
+ image = tensor.detach().cpu().to(torch.float32)
46
+ image = (image / 2 + 0.5).clamp(0, 1)
47
+ image = image.mul(255).round().to(dtype=torch.uint8)
48
+ image = image.permute(1, 2, 0).numpy()
49
+ return Image.fromarray(image, mode="RGB")
50
+
51
+ rec_image = tensor_to_pil(sampled_images[0])
52
+ rec_image.save("/path/to/output.jpg")
53
+ ```
54
+
55
+ ## Evaluation
56
+
57
+ ### Reconstruction Performance on ImageNet-1K 256×256
58
+
59
+ | Tokenizer | Latent Shape | PSNR ↑ | SSIM ↑ |
60
+ | ------------------------- | ------------ | --------- | -------- |
61
+ | **Discrete Tokenizers** | | | |
62
+ | SBER-MoVQGAN (270M) | 32×32 | 27.04 | 0.74 |
63
+ | LlamaGen | 32×32 | 24.44 | 0.77 |
64
+ | VAR | 680 | 22.12 | 0.62 |
65
+ | TiTok-S-128 | 128 | 17.52 | 0.44 |
66
+ | Sefltok | 1024 | 26.30 | 0.81 |
67
+ | **Continuous Tokenizers** | | | |
68
+ | Stable Diffusion 1.5 | 32×32×4 | 25.18 | 0.73 |
69
+ | Stable Diffusion XL | 32×32×4 | 26.22 | 0.77 |
70
+ | Stable Diffusion 3 Medium | 32×32×16 | 30.00 | 0.88 |
71
+ | Flux.1-dev | 32×32×16 | 31.64 | 0.91 |
72
+ | **NextStep-1** | **32×32×16** | **30.60** | **0.89** |
73
+
74
+ ### Robustness of NextStep-1-f8ch16-Tokenizer
75
+
76
+ Impact of Noise Perturbation on Image Tokenizer Performance. The top panel displays
77
+ quantitative metrics (rFID↓, PSNR↑, and SSIM↑) versus noise intensity. The bottom panel presents qualitative reconstruction examples at noise standard deviations of 0.2 and 0.5.
78
+
79
+ <div align='center'>
80
+ <img src="assets/robustness.png" class="interpolation-image" alt="arch." width="100%" />
81
+ </div>
assets/robustness.png ADDED

Git LFS Details

  • SHA256: bb814f6477b339a07c78296033fb81c134ce57d6e83d4fd061478ef7701f9fba
  • Pointer size: 132 Bytes
  • Size of remote file: 8.14 MB
autoencoder.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from diffusers import AutoencoderKL
7
+ from diffusers.configuration_utils import register_to_config
8
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
9
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
10
+ from diffusers.utils.accelerate_utils import apply_forward_hook
11
+
12
+
13
+ class AutoencoderKLNextStep(AutoencoderKL):
14
+ @register_to_config
15
+ def __init__(
16
+ self,
17
+ in_channels: int = 3,
18
+ out_channels: int = 3,
19
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
20
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
21
+ block_out_channels: Tuple[int] = (64,),
22
+ layers_per_block: int = 1,
23
+ act_fn: str = "silu",
24
+ latent_channels: int = 4,
25
+ norm_num_groups: int = 32,
26
+ sample_size: int = 32,
27
+ scaling_factor: float = 0.18215,
28
+ shift_factor: Optional[float] = None,
29
+ latents_mean: Optional[Tuple[float]] = None,
30
+ latents_std: Optional[Tuple[float]] = None,
31
+ force_upcast: bool = True,
32
+ use_quant_conv: bool = True,
33
+ use_post_quant_conv: bool = True,
34
+ mid_block_add_attention: bool = True,
35
+ deterministic: bool = False,
36
+ normalize_latents: bool = False,
37
+ patch_size: Optional[int] = None,
38
+ ):
39
+ super().__init__(
40
+ in_channels=in_channels,
41
+ out_channels=out_channels,
42
+ down_block_types=down_block_types,
43
+ up_block_types=up_block_types,
44
+ block_out_channels=block_out_channels,
45
+ layers_per_block=layers_per_block,
46
+ act_fn=act_fn,
47
+ latent_channels=latent_channels,
48
+ norm_num_groups=norm_num_groups,
49
+ sample_size=sample_size,
50
+ scaling_factor=scaling_factor,
51
+ shift_factor=shift_factor,
52
+ latents_mean=latents_mean,
53
+ latents_std=latents_std,
54
+ force_upcast=force_upcast,
55
+ use_quant_conv=use_quant_conv,
56
+ use_post_quant_conv=use_post_quant_conv,
57
+ mid_block_add_attention=mid_block_add_attention,
58
+ )
59
+ self.deterministic = deterministic
60
+ self.normalize_latents = normalize_latents
61
+ self.patch_size = patch_size
62
+
63
+ def patchify(self, x: torch.Tensor) -> torch.Tensor:
64
+ b, c, h, w = x.shape
65
+ p = self.patch_size
66
+ h_, w_ = h // p, w // p
67
+
68
+ x = x.reshape(b, c, h_, p, w_, p)
69
+ x = torch.einsum("bchpwq->bcpqhw", x)
70
+ x = x.reshape(b, c * p ** 2, h_, w_)
71
+ return x
72
+
73
+ def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
74
+ b, _, h_, w_ = x.shape
75
+ p = self.patch_size
76
+ c = x.shape[1] // (p ** 2)
77
+
78
+ x = x.reshape(b, c, p, p, h_, w_)
79
+ x = torch.einsum("bcpqhw->bchpwq", x)
80
+ x = x.reshape(b, c, h_ * p, w_ * p)
81
+ return x
82
+
83
+ @apply_forward_hook
84
+ def encode(
85
+ self, x: torch.Tensor, return_dict: bool = True
86
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
87
+ if self.use_slicing and x.shape[0] > 1:
88
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
89
+ h = torch.cat(encoded_slices)
90
+ else:
91
+ h = self._encode(x)
92
+
93
+ mean, logvar = torch.chunk(h, 2, dim=1)
94
+ if self.patch_size is not None:
95
+ mean = self.patchify(mean)
96
+ if self.normalize_latents:
97
+ mean = mean.permute(0, 2, 3, 1)
98
+ mean = F.layer_norm(mean, mean.shape[-1:], eps=1e-6)
99
+ mean = mean.permute(0, 3, 1, 2)
100
+ if self.patch_size is not None:
101
+ mean = self.unpatchify(mean)
102
+ h = torch.cat([mean, logvar], dim=1).contiguous()
103
+ posterior = DiagonalGaussianDistribution(h, deterministic=self.deterministic)
104
+
105
+ if not return_dict:
106
+ return (posterior,)
107
+
108
+ return AutoencoderKLOutput(latent_dist=posterior)
109
+
110
+ def forward(
111
+ self,
112
+ sample: torch.Tensor,
113
+ sample_posterior: bool = False,
114
+ return_dict: bool = True,
115
+ generator: Optional[torch.Generator] = None,
116
+ noise_strength: float = 0.0,
117
+ ) -> Union[DecoderOutput, torch.Tensor]:
118
+ x = sample
119
+ posterior = self.encode(x).latent_dist
120
+ if sample_posterior:
121
+ z = posterior.sample(generator=generator)
122
+ else:
123
+ z = posterior.mode()
124
+ if noise_strength > 0.0:
125
+ p = torch.distributions.Uniform(0, noise_strength)
126
+ z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
127
+ z.shape, device=z.device, dtype=z.dtype
128
+ )
129
+ dec = self.decode(z).sample
130
+
131
+ if not return_dict:
132
+ return (dec,)
133
+
134
+ return DecoderOutput(sample=dec)
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKLNextStep",
3
+ "_diffusers_version": "0.35.0.dev0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 16,
20
+ "latents_mean": null,
21
+ "latents_std": null,
22
+ "layers_per_block": 2,
23
+ "mid_block_add_attention": true,
24
+ "norm_num_groups": 32,
25
+ "out_channels": 3,
26
+ "sample_size": 512,
27
+ "scaling_factor": 1,
28
+ "shift_factor": 0,
29
+ "up_block_types": [
30
+ "UpDecoderBlock2D",
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D"
34
+ ],
35
+ "use_post_quant_conv": false,
36
+ "use_quant_conv": false,
37
+ "deterministic": true,
38
+ "normalize_latents": true,
39
+ "patch_size": 2
40
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d67ef6afe4ec377d53e99b270cf9a5f346f4c21dfe00732e2043b5b4c42ba394
3
+ size 335306212