Spherical VAE (fine-tuned from FLUX.2-dev)

A variational autoencoder fine-tuned from black-forest-labs/FLUX.2-dev with a spherical latent space constraint. All encoder outputs are projected onto the hypersphere S^{d-1} with radius sqrt(latent_channels) = sqrt(32) ~ 5.66.

What is the spherical constraint?

Standard VAEs produce latents with unconstrained norms, leading to "holes" in latent space and unstable generation. This model replaces the standard quant_conv with a SphericalQuantConv that L2-normalizes encoder outputs and scales them to lie exactly on a hypersphere of radius sqrt(32). This ensures:

  • Uniform latent coverage: all latent vectors have the same L2 norm
  • No posterior collapse: the spherical constraint prevents mode collapse
  • Compatible with diffusers: exported weights use standard Conv2d (spherical projection can be applied at inference time)

Training Details

Parameter Value
Base model FLUX.2-dev VAE (AutoencoderKLFlux2)
Training steps 10,000
Dataset FLUX-Reason-6M (1024x1024)
Weights EMA (decay=0.9999)
Sphere radius sqrt(32) ~ 5.66
Latent channels 32
Decoder LR 4e-5
Encoder LR 4e-6
Loss L1 + LPIPS + KL + adversarial (after 10k steps)

Usage

from diffusers import AutoencoderKLFlux2
import torch

# Load the fine-tuned VAE
vae = AutoencoderKLFlux2.from_pretrained("tmeral/spherical-vae-flux")
vae = vae.to("cuda").eval()

# Encode an image (B, 3, H, W) -> latent
image = torch.randn(1, 3, 512, 512, device="cuda")  # replace with real image
with torch.no_grad():
    latent = vae.encode(image).latent_dist.sample()

# Apply spherical projection (optional, for strict sphere constraint)
import math
radius = math.sqrt(32)
latent_flat = latent.flatten(2)  # (B, C, H*W)
norms = latent_flat.norm(dim=1, keepdim=True)  # (B, 1, H*W)
latent_flat = latent_flat / norms * radius
latent = latent_flat.view_as(latent)

# Decode back to pixel space
with torch.no_grad():
    recon = vae.decode(latent).sample

print(f"Input: {image.shape} -> Latent: {latent.shape} -> Recon: {recon.shape}")

Qualitative Comparisons

Side-by-side comparisons of original FLUX.2-dev VAE vs this fine-tuned spherical VAE at different resolutions:

256x256

Comparison at 256x256

512x512

Comparison at 512x512

1024x1024

Comparison at 1024x1024

Files

File Description
diffusion_pytorch_model.safetensors Model weights in safetensors format
config.json Diffusers model config
spherical_vae_metadata.json Spherical VAE training metadata (radius, channels, base model)
inference.py Self-contained encode/decode inference script
training_config.yaml Full training configuration
comparison_*.png Qualitative comparison images at 256/512/1024

License

This is a research checkpoint for evaluation and collaboration purposes. The base model (FLUX.2-dev) has its own license terms.

Downloads last month
12
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support