sd-vae-ft-mse-spherical
A spherically-constrained fine-tune of stabilityai/sd-vae-ft-mse (AutoencoderKL, 4 latent channels).
The encoder output is constrained to lie on a hypersphere of radius sqrt(latent_channels) = 2.0 via a custom SphericalQuantConv layer trained during fine-tuning. The exported weights use a plain Conv2d (identical numerically) so the model loads directly via the standard diffusers AutoencoderKL API.
Model Details
| Property | Value |
|---|---|
| Base model | stabilityai/sd-vae-ft-mse |
| Latent channels | 4 |
| Sphere radius | 2.0 (= sqrt(4)) |
| Training steps | 10,000 |
| Training resolution | 256x256 |
| Dataset | FLUX-Reason-6M (256px subset) |
| Effective batch size | 128 (batch_size=4 x grad_accum=1 x 4 GPUs x 8) |
Training Configuration
See training_config.yaml for the full config. Key settings:
decoder_lr: 4e-5,encoder_lr: 4e-6l1_weight: 1.0,lpips_weight: 1.0,kl_weight: 1e-6,adv_weight: 0.0- EMA decay: 0.9999 (exported weights are EMA)
- No adversarial loss (adv_weight=0.0 throughout)
Baseline Metrics (sd-vae-ft-mse, before fine-tuning)
Computed on the FLUX-Reason-6M eval split at 256px:
| Metric | Value |
|---|---|
| rFID | 242.0 |
| LPIPS | 0.762 |
Usage
Important: The decoder was fine-tuned with latents on a hypersphere of radius 2.0. You must apply sphere projection before decoding. Skipping this step produces over-saturated, incorrect reconstructions.
import math
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKL
SPHERE_RADIUS = math.sqrt(4) # = 2.0
vae = AutoencoderKL.from_pretrained("tmeral/sd-vae-ft-mse-spherical")
vae.eval()
# Encode an image ([-1, 1] normalized, shape [B, 3, H, W])
with torch.no_grad():
latents = vae.encode(image).latent_dist.mean
# Required: project onto sphere before decoding
latents = F.normalize(latents, p=2, dim=1) * SPHERE_RADIUS
reconstruction = vae.decode(latents).sample
See inference.py for a complete example.
Reconstruction Comparison (256px)
Related
- tmeral/spherical-vae-flux -- same approach applied to FLUX.2-dev (32-channel VAE)
- Downloads last month
- 20
Model tree for tmeral/sd-vae-ft-mse-spherical
Base model
stabilityai/sd-vae-ft-mse