|
|
--- |
|
|
license: mit |
|
|
--- |
|
|
|
|
|
# 🍰 Tiny AutoEncoder for FLUX.2 |
|
|
|
|
|
[TAEF2](https://github.com/madebyollin/taesd) is very tiny autoencoder which uses the same "latent API" as FLUX.2's VAE. |
|
|
FLUX.2 is useful for real-time previewing of the FLUX.2 generation process, as well as general resource-constrained encoding/decoding. |
|
|
|
|
|
This repo contains `.safetensors` versions of the TAEF2 weights. |
|
|
|
|
|
## Using in 🧨 diffusers |
|
|
|
|
|
**NOTE**: Unlike TAEF1, TAEF2's architecture [isn't properly integrated](https://github.com/madebyollin/taesd/issues/35#issuecomment-3765620926) into Diffusers yet. |
|
|
So for now you'll want some wrapper code: |
|
|
|
|
|
```sh |
|
|
pip install git+https://www.github.com/huggingface/diffusers # needed for Klein support as of 2026-01-18 |
|
|
wget -nc -nv https://raw.githubusercontent.com/madebyollin/taesd/refs/heads/main/taesd.py -O taesd.py |
|
|
wget -nc -nv https://huggingface.co/madebyollin/taef2/resolve/main/taef2.safetensors -O taef2.safetensors |
|
|
``` |
|
|
|
|
|
```python |
|
|
# Construction |
|
|
from taesd import TAESD |
|
|
import torch |
|
|
import safetensors.torch as stt |
|
|
from diffusers.utils.accelerate_utils import apply_forward_hook |
|
|
|
|
|
def convert_diffusers_sd_to_taesd(sd): |
|
|
out = {} |
|
|
for k, v in sd.items(): |
|
|
encdec, _layers, index, *suffix = k.split(".") |
|
|
offset = 0 |
|
|
if encdec == "decoder": |
|
|
offset = +1 |
|
|
out[".".join([encdec, str(int(index)+offset), *suffix])] = v |
|
|
return out |
|
|
|
|
|
class DotDict(dict): |
|
|
__getattr__ = dict.__getitem__ |
|
|
__setattr__ = dict.__setitem__ |
|
|
|
|
|
class DiffusersTAEF2Wrapper(torch.nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.dtype = torch.bfloat16 |
|
|
self.taesd = TAESD(encoder_path=None, decoder_path=None, latent_channels=32, arch_variant="flux_2").to(self.dtype) |
|
|
self.taesd.load_state_dict(convert_diffusers_sd_to_taesd(stt.load_file("taef2.safetensors"))) |
|
|
self.bn = torch.nn.BatchNorm2d(128, affine=False, eps=0.0) # default bn |
|
|
self.config = DotDict(batch_norm_eps=self.bn.eps) |
|
|
|
|
|
@apply_forward_hook |
|
|
def encode(self, x): |
|
|
return DotDict(latent_dist=DotDict(sample=lambda : self.taesd.encoder(x.to(self.dtype).mul(0.5).add_(0.5)).to(x.dtype))) |
|
|
|
|
|
@apply_forward_hook |
|
|
def decode(self, x, return_dict=True): |
|
|
x = self.taesd.decoder(x.to(self.dtype)).mul(2).sub_(1).clamp_(-1, 1).to(x.dtype) |
|
|
return dict(sample=x) if return_dict else x, |
|
|
|
|
|
taef2_diffusers = DiffusersTAEF2Wrapper().eval().requires_grad_(False) |
|
|
|
|
|
# Usage |
|
|
from diffusers import Flux2KleinPipeline |
|
|
|
|
|
device = "cuda" |
|
|
dtype = torch.bfloat16 |
|
|
|
|
|
pipe = Flux2KleinPipeline.from_pretrained("black-forest-labs/FLUX.2-klein-4B", torch_dtype=dtype) |
|
|
pipe.vae = taef2_diffusers |
|
|
pipe.enable_sequential_cpu_offload() # pipe.enable_model_cpu_offload() # pipe = pipe.to(device) |
|
|
|
|
|
prompt = "A slice of delicious New York-style berry cheesecake" |
|
|
image = pipe( |
|
|
prompt=prompt, |
|
|
height=1024, |
|
|
width=1024, |
|
|
guidance_scale=1.0, |
|
|
num_inference_steps=4, |
|
|
generator=torch.Generator(device="cpu").manual_seed(0) |
|
|
).images[0] |
|
|
image.save("flux-klein.png") |
|
|
image |
|
|
``` |
|
|
|
|
|
|
|
|
 |
|
|
|
|
|
## Quality Comparisons |
|
|
|
|
|
These compare TAEF2, the [full FLUX.2 VAE](https://huggingface.co/black-forest-labs/FLUX.2-klein-4B/blob/main/vae/config.json), and the alternate [FAL FLUX.2 Tiny](https://huggingface.co/fal/FLUX.2-Tiny-AutoEncoder) AE. |
|
|
|
|
|
 |
|
|
|
|
|
 |
|
|
|
|
|
 |
|
|
|
|
|
 |