Diffusers
Safetensors
English

You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

By clicking "Agree", you acknowledge that these models are released solely for academic research purposes. The models are initialized from Stable Diffusion v3.5 medium (stabilityai-ai-community License) and further trained on a subset of the Re-LAION-5B (research-safe) dataset. You agree to review and comply with the terms and licenses of both the pretrained model and training dataset, and you bear responsibility for any use of this model.

Log in or Sign Up to review the conditions and access this model content.

SD v3-5-medium, covariance mismatch experiments

This repository contains versions of the transformer of Stable Diffusion v3.5 medium trained under various settings for the article "Covariance Mismatch in Diffusion Models". The weights are initialized from the pretrained model (Stable Diffusion v3.5 medium), and training was done on a 100,000-sample subset of the Re-LAION-5B research-safe dataset at resolution 512×512.

These models are intended for academic research use only and are not suitable for production deployment.

Training settings:

  • Original: White noise on original data (original): typical training setting with covariance mismatch.
  • Colored noise: Colored noise on original data (colorednoise): The covariance of the noise is modified to align with the covariance of the data. The colored noise is obtained through DCT or DFT domain (colorednoiseDCT, colorednoiseDFT).

The models are trained during 100,000 steps.

Example usage 1: inference on smaller range of noise levels

In this example, we perform inference over a smaller range of noise levels, $SNR \in [0.0123, 0.1837]$, instead of the original one used in Stable Diffusion v3.5 medium ($SNR \in [0, 998001]$). For ease of implementation, we keep the original noise schedule and only use the range of timesteps $t \in [700, 900]$ (instead of $t \in [1, 1000]$).

Smaller range of noise levels

from diffusers import StableDiffusion3Pipeline, SD3Transformer2DModel

subfolder = "transformer_colorednoiseDCT" # choose among "transformer_original", "transformer_colorednoiseDCT", "transformer_colorednoiseDFT"

# Load model
transformer = SD3Transformer2DModel.from_pretrained(
    "EPFL-IVRL/sd3.5-medium-covariance-mismatch",
    subfolder=subfolder,
)    
pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3.5-medium",
    transformer=transformer,
)    
pipe.enable_model_cpu_offload()
from torch_dct import dct_2d, idct_2d
import torch
import numpy as np
from diffusers.utils import _get_model_file
from safetensors.torch import load_file

stats = load_file(_get_model_file("EPFL-IVRL/sd3.5-medium-covariance-mismatch", weights_name="stats.safetensors", subfolder="relaion2B-en-research-safe-subset100000/dft_stats"))
variance_spectrum_dft = stats["variance_spectrum_vae64"].to("cuda")
stats = load_file(_get_model_file("EPFL-IVRL/sd3.5-medium-covariance-mismatch", weights_name="stats.safetensors", subfolder="relaion2B-en-research-safe-subset100000/dct_stats"))
variance_spectrum_dct = stats["variance_spectrum_vae64"].to("cuda")

# Generate
prompt = "A colorful castle, vibrant colors, detailed."
generator = torch.manual_seed(123456)
initial_noise = torch.randn((1, 16, 64, 64), generator=generator).to("cuda")

# Color initial noise if necessary (if the model is trained with colored noise)
if subfolder == "transformer_colorednoiseDCT":
    dct = dct_2d(initial_noise, norm='ortho')
    dct *= torch.sqrt(variance_spectrum_dct)
    initial_noise = idct_2d(dct, norm='ortho')
elif subfolder == "transformer_colorednoiseDFT":
    ft = torch.fft.fftshift(torch.fft.fftn(initial_noise, dim=(-2, -1), norm="ortho"), dim=(-2, -1))
    ft *= torch.sqrt(variance_spectrum_dft)
    initial_noise = torch.real(torch.fft.ifftn(torch.fft.ifftshift(ft, dim=(-2, -1)), dim=(-2, -1), norm="ortho"))

# Inference timesteps
start = 900.0
stop = 700.0
num_steps = 50

start_sigma = start / pipe.scheduler.config.num_train_timesteps
stop_sigma = stop / pipe.scheduler.config.num_train_timesteps
start_u = start_sigma / (pipe.scheduler.config.shift - start_sigma * (pipe.scheduler.config.shift - 1))
stop_u = stop_sigma / (pipe.scheduler.config.shift - stop_sigma * (pipe.scheduler.config.shift - 1))
us = np.linspace(start_u, stop_u, num_steps)
sigmas = pipe.scheduler.config.shift * us / (1 + (pipe.scheduler.config.shift - 1) * us)   
timesteps = (sigmas*pipe.scheduler.config.num_train_timesteps).tolist()
sigmas = sigmas.tolist()

# Generate
init_noise_sigma = ((1-sigmas[0])**2 + (sigmas[0])**2) ** 0.5
generated = pipe(
    prompt=prompt,
    sigmas=us.tolist(),
    latents=init_noise_sigma*initial_noise,
    output_type="latent",
).images
    
image = pipe.vae.decode(generated / pipe.vae.config.scaling_factor + pipe.vae.config.shift_factor).sample[0]                
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().detach().permute(1, 2, 0).numpy()               
image_pil = pipe.numpy_to_pil(image)[0]
image_pil.show()
Training setting Generated image
transformer_original Generated image
transformer_colorednoiseDCT Generated image
transformer_colorednoiseDFT Generated image

Example usage 2: inference on a single noise level

In this example, we perform inference over a single noise level, $SNR = 0.045$, instead of the original range used in Stable Diffusion v3.5 medium ($SNR \in [0, 998001]$). For ease of implementation, we keep the original noise schedule and only use the timestep $t = 825$ (instead of $t \in [1, 1000]$).

from diffusers import StableDiffusion3Pipeline, SD3Transformer2DModel

subfolder = "transformer_colorednoiseDCT" # choose among "transformer_original", "transformer_colorednoiseDCT", "transformer_colorednoiseDFT"

# Load model
transformer = SD3Transformer2DModel.from_pretrained(
    "EPFL-IVRL/sd3.5-medium-covariance-mismatch",
    subfolder=subfolder,
)    
pipe = StableDiffusion3Pipeline.from_pretrained(
    "stabilityai/stable-diffusion-3.5-medium",
    transformer=transformer,
)    
pipe.enable_model_cpu_offload()
from torch_dct import dct_2d, idct_2d
import torch
import numpy as np
from diffusers.utils import _get_model_file
from safetensors.torch import load_file

stats = load_file(_get_model_file("EPFL-IVRL/sd3.5-medium-covariance-mismatch", weights_name="stats.safetensors", subfolder="relaion2B-en-research-safe-subset100000/dft_stats"))
variance_spectrum_dft = stats["variance_spectrum_vae64"].to("cuda")
stats = load_file(_get_model_file("EPFL-IVRL/models_share/sd3.5-medium-covariance-mismatch", weights_name="stats.safetensors", subfolder="relaion2B-en-research-safe-subset100000/dct_stats"))
variance_spectrum_dct = stats["variance_spectrum_vae64"].to("cuda")

# Generate
prompt = "A colorful castle, vibrant colors, detailed."
generator = torch.manual_seed(123456)
initial_noise = torch.randn((1, 16, 64, 64), generator=generator).to("cuda")

# Color initial noise if necessary (if the model is trained with colored noise)
if subfolder == "transformer_colorednoiseDCT":
    dct = dct_2d(initial_noise, norm='ortho')
    dct *= torch.sqrt(variance_spectrum_dct)
    initial_noise = idct_2d(dct, norm='ortho')
elif subfolder == "transformer_colorednoiseDFT":
    ft = torch.fft.fftshift(torch.fft.fftn(initial_noise, dim=(-2, -1), norm="ortho"), dim=(-2, -1))
    ft *= torch.sqrt(variance_spectrum_dft)
    initial_noise = torch.real(torch.fft.ifftn(torch.fft.ifftshift(ft, dim=(-2, -1)), dim=(-2, -1), norm="ortho"))

# Inference timesteps
timesteps = [825, 825, 825]
fixed_delta_t = 275

# Generate
with torch.no_grad():
    guidance_scale = 2.8
    timesteps = torch.tensor(timesteps).to("cuda")
    num_train_timesteps = pipe.scheduler.config.num_train_timesteps

    sigma_start = timesteps[0]/num_train_timesteps
    latents = initial_noise * ((1-sigma_start)**2 + sigma_start**2 ) ** 0.5
    #latents = latents.half()

    (
        prompt_embeds,
        negative_prompt_embeds,
        pooled_prompt_embeds,
        negative_pooled_prompt_embeds,
    ) = pipe.encode_prompt(
        prompt=prompt,
        prompt_2=prompt,
        prompt_3=prompt,
        do_classifier_free_guidance=True,
        device=pipe.device,
        num_images_per_prompt=1,
    )
    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
    pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
        
    for i, t in enumerate(timesteps):    
        sigma = t/num_train_timesteps
        if i<len(timesteps)-1:
            sigma_nextinferencestep = timesteps[i+1]/num_train_timesteps
            if fixed_delta_t != None:
                sigma_prev = (t-fixed_delta_t)/num_train_timesteps
            else:
                sigma_prev = timesteps[i+1]/num_train_timesteps
        else:
            sigma_prev = 0
            sigma_nextinferencestep = 0                        

        latent_model_input = torch.cat([latents] * 2) # expand the latents for classifier free guidance
        timestep =  t.expand(latent_model_input.shape[0])

        # predict the noise residual
        noise_pred = pipe.transformer(
            hidden_states=latent_model_input,
            timestep = timestep,
            encoder_hidden_states=prompt_embeds,
            pooled_projections=pooled_prompt_embeds,
            return_dict=False,
        )[0]  
        
        # perform classifier free guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        
        # 1 step of ODE
        dt = sigma_prev - sigma            
        pred_original_sample = latents - sigma * noise_pred            
        prev_sample = latents + dt * noise_pred
        if fixed_delta_t != None:
            latents = prev_sample / ((1-sigma_prev)**2 + sigma_prev**2 ) ** 0.5
            latents = latents * ((1-sigma_nextinferencestep)**2 + sigma_nextinferencestep**2 ) ** 0.5
        else:
            latents = prev_sample
        
    generated = pred_original_sample          
    image = pipe.vae.decode(generated / pipe.vae.config.scaling_factor + pipe.vae.config.shift_factor).sample[0]                
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(1, 2, 0).numpy()               
    image_pil = pipe.numpy_to_pil(image)[0]
    
image_pil.show()
Training setting Generated image
transformer_original Generated image
transformer_colorednoiseDCT Generated image
transformer_colorednoiseDFT Generated image

Model Description

Citation

@article{everaert2024covariancemismatch,
    author   = {Everaert, Martin Nicolas and Süsstrunk, Sabine and Achanta, Radhakrishna},
    title    = {{C}ovariance {M}ismatch in {D}iffusion {M}odels}, 
    journal  = {Infoscience preprint Infoscience:20.500.14299/242173},
    month    = {November},
    year     = {2024},
}

Training details

  • Dataset size: 100k image-caption pairs from Re-LAION-5B research-safe
  • Hardware: 1 × NVIDIA-H100-80GB-HBM3
  • Pretrained model: Stable Diffusion v3.5 medium. This Stability AI Model is licensed under the Stability AI Community License, Copyright © Stability AI Ltd. All Rights Reserved. Powered by Stability AI.
  • Optimizer: AdamW (32-bit, no quantization)
    • betas: (0.9, 0.999)
    • weight_decay: 0.01
    • eps: 1e-08
    • lr: 5e-06,
    • lr_scheduler: get_cosine_schedule_with_warmup, num_warmup_steps: 500, num_training_steps: 100000
  • Batch size: 8 (no gradient accumulation)
  • Caption dropout: 10%
  • Exponential Moving Average (EMA): not used
  • Training steps: 100,000
  • Training Time:
    • transformer_original: 15h53min
    • transformer_colorednoiseDCT: 16h04min
    • transformer_colorednoiseDFT: 16h05min
  • Covariance realignment method:
    • transformer_original: no covariance realignment, original data (not whitened), white noise (not colored)
    • transformer_colorednoiseDCT: original data, colored noise (DCT)
    • transformer_colorednoiseDFT: original data, colored noise (DFT)
  • Training range of noise levels:
    • full original range $SNR \in [0, 998001]$
  • Training loss:
    • unet_original
      • Training loss
    • unet_colorednoiseDCT
      • Training loss
    • unet_colorednoiseDFT
      • Training loss
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including EPFL-IVRL/sd3.5-medium-covariance-mismatch