Spherical VAE (fine-tuned from FLUX.2-dev)
A variational autoencoder fine-tuned from black-forest-labs/FLUX.2-dev with a spherical latent space constraint. All encoder outputs are projected onto the hypersphere S^{d-1} with radius sqrt(latent_channels) = sqrt(32) ~ 5.66.
What is the spherical constraint?
Standard VAEs produce latents with unconstrained norms, leading to "holes" in latent space and unstable generation. This model replaces the standard quant_conv with a SphericalQuantConv that L2-normalizes encoder outputs and scales them to lie exactly on a hypersphere of radius sqrt(32). This ensures:
- Uniform latent coverage: all latent vectors have the same L2 norm
- No posterior collapse: the spherical constraint prevents mode collapse
- Compatible with diffusers: exported weights use standard Conv2d (spherical projection can be applied at inference time)
Training Details
| Parameter | Value |
|---|---|
| Base model | FLUX.2-dev VAE (AutoencoderKLFlux2) |
| Training steps | 10,000 |
| Dataset | FLUX-Reason-6M (1024x1024) |
| Weights | EMA (decay=0.9999) |
| Sphere radius | sqrt(32) ~ 5.66 |
| Latent channels | 32 |
| Decoder LR | 4e-5 |
| Encoder LR | 4e-6 |
| Loss | L1 + LPIPS + KL + adversarial (after 10k steps) |
Usage
from diffusers import AutoencoderKLFlux2
import torch
# Load the fine-tuned VAE
vae = AutoencoderKLFlux2.from_pretrained("tmeral/spherical-vae-flux")
vae = vae.to("cuda").eval()
# Encode an image (B, 3, H, W) -> latent
image = torch.randn(1, 3, 512, 512, device="cuda") # replace with real image
with torch.no_grad():
latent = vae.encode(image).latent_dist.sample()
# Apply spherical projection (optional, for strict sphere constraint)
import math
radius = math.sqrt(32)
latent_flat = latent.flatten(2) # (B, C, H*W)
norms = latent_flat.norm(dim=1, keepdim=True) # (B, 1, H*W)
latent_flat = latent_flat / norms * radius
latent = latent_flat.view_as(latent)
# Decode back to pixel space
with torch.no_grad():
recon = vae.decode(latent).sample
print(f"Input: {image.shape} -> Latent: {latent.shape} -> Recon: {recon.shape}")
Qualitative Comparisons
Side-by-side comparisons of original FLUX.2-dev VAE vs this fine-tuned spherical VAE at different resolutions:
256x256
512x512
1024x1024
Files
| File | Description |
|---|---|
diffusion_pytorch_model.safetensors |
Model weights in safetensors format |
config.json |
Diffusers model config |
spherical_vae_metadata.json |
Spherical VAE training metadata (radius, channels, base model) |
inference.py |
Self-contained encode/decode inference script |
training_config.yaml |
Full training configuration |
comparison_*.png |
Qualitative comparison images at 256/512/1024 |
License
This is a research checkpoint for evaluation and collaboration purposes. The base model (FLUX.2-dev) has its own license terms.
- Downloads last month
- 12


