sd-vae-ft-mse-spherical

A spherically-constrained fine-tune of stabilityai/sd-vae-ft-mse (AutoencoderKL, 4 latent channels).

The encoder output is constrained to lie on a hypersphere of radius sqrt(latent_channels) = 2.0 via a custom SphericalQuantConv layer trained during fine-tuning. The exported weights use a plain Conv2d (identical numerically) so the model loads directly via the standard diffusers AutoencoderKL API.

Model Details

Property Value
Base model stabilityai/sd-vae-ft-mse
Latent channels 4
Sphere radius 2.0 (= sqrt(4))
Training steps 10,000
Training resolution 256x256
Dataset FLUX-Reason-6M (256px subset)
Effective batch size 128 (batch_size=4 x grad_accum=1 x 4 GPUs x 8)

Training Configuration

See training_config.yaml for the full config. Key settings:

  • decoder_lr: 4e-5, encoder_lr: 4e-6
  • l1_weight: 1.0, lpips_weight: 1.0, kl_weight: 1e-6, adv_weight: 0.0
  • EMA decay: 0.9999 (exported weights are EMA)
  • No adversarial loss (adv_weight=0.0 throughout)

Baseline Metrics (sd-vae-ft-mse, before fine-tuning)

Computed on the FLUX-Reason-6M eval split at 256px:

Metric Value
rFID 242.0
LPIPS 0.762

Usage

Important: The decoder was fine-tuned with latents on a hypersphere of radius 2.0. You must apply sphere projection before decoding. Skipping this step produces over-saturated, incorrect reconstructions.

import math
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKL

SPHERE_RADIUS = math.sqrt(4)  # = 2.0

vae = AutoencoderKL.from_pretrained("tmeral/sd-vae-ft-mse-spherical")
vae.eval()

# Encode an image ([-1, 1] normalized, shape [B, 3, H, W])
with torch.no_grad():
    latents = vae.encode(image).latent_dist.mean
    # Required: project onto sphere before decoding
    latents = F.normalize(latents, p=2, dim=1) * SPHERE_RADIUS
    reconstruction = vae.decode(latents).sample

See inference.py for a complete example.

Reconstruction Comparison (256px)

Comparison

Related

Downloads last month
20
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for tmeral/sd-vae-ft-mse-spherical

Finetuned
(2)
this model