--- license: cc0-1.0 tags: - diffusion - latent-diffusion - classifier-free-guidance - astronomy - galaxies - galaxy10 library_name: pytorch --- # Galaxy Diffusion — latent diffusion weights (Galaxy10 DECaLS) Conditional **latent diffusion model** (VAE + classifier-free guidance) for generating galaxy images by morphology class, trained on **Galaxy10 DECaLS** (17,736 RGB images, 256×256, 10 morphological classes). These are the `.safetensors` weights. The model uses a **custom architecture** — it is **not** a `transformers` / `diffusers` model and does not load via `AutoModel`. You need the `galaxy_diffusion` package from the code repository to instantiate it. - **Code:** https://github.com/LLapsus/galaxy-diffusion - **License:** CC0 1.0 (public domain) ## Files | File | Contents | |---|---| | `latent_diffusion_galaxy10_xattn_v1.model.safetensors` | UNet denoiser (`LatentUNetCA`, cross-attention conditioning), ~27.9M params | | `latent_diffusion_galaxy10_xattn_v1.vae.safetensors` | VAE (image ↔ 4×32×32 latent), ~1.09M params | | `latent_diffusion_galaxy10_xattn_v1.config.json` | constructor args (`vae_config`, `unet_config`, `unet_type`) + latent normalisation stats (`latents_mean`, `latents_std`) | | `galaxy10_classifier.model.safetensors` | `GalaxyCNN` evaluation classifier, ~1.75M params (val acc 0.829) | | `galaxy10_classifier.config.json` | classifier metadata (`val_acc`, `epoch`) | ## Installation ```bash pip install "git+https://github.com/LLapsus/galaxy-diffusion.git" pip install huggingface_hub safetensors ``` ## Load the weights ```python import json import torch from huggingface_hub import snapshot_download from safetensors.torch import load_file from galaxy_diffusion.models.vae import VAE from galaxy_diffusion.models.unet import LatentUNetCA path = snapshot_download("llapsus/galaxy-diffusion") # downloads all files cfg = json.load(open(f"{path}/latent_diffusion_galaxy10_xattn_v1.config.json")) vae = VAE(**cfg["vae_config"]) vae.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.vae.safetensors")) vae.eval() unet = LatentUNetCA(**cfg["unet_config"]) unet.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.model.safetensors")) unet.eval() ``` ## Generate images ```python from galaxy_diffusion.diffusion.ddpm import cosine_schedule, sample_cfg device = "cuda" vae, unet = vae.to(device), unet.to(device) _, alpha, alpha_bar = cosine_schedule(1000) alpha, alpha_bar = alpha.to(device), alpha_bar.to(device) latents_mean = torch.tensor(cfg["latents_mean"]) latents_std = torch.tensor(cfg["latents_std"]) images = sample_cfg( unet, vae, classes=list(range(10)), # one image per class alpha=alpha, alpha_bar=alpha_bar, latent_shape=(cfg["unet_config"]["latent_channels"], 32, 32), latents_mean=latents_mean, latents_std=latents_std, device=device, guidance_scale=2.5, # see "Guidance scale" below cfg_rescale=0.7, # CFG rescaling (Lin et al., 2023) ) # -> tensor (10, 3, 256, 256) in [-1, 1] ``` The classifier is loaded analogously with `GalaxyCNN` from `galaxy_diffusion.models.classifier`. ## Model details - **VAE:** 8× spatial compression, 3×256×256 ↔ 4×32×32, KL-regularised. - **UNet (`LatentUNetCA`):** time conditioning via AdaGN, class conditioning via a cross-attention block after each encoder/decoder level + bottleneck; cosine noise schedule (T=1000); trained with Min-SNR-weighted MSE and 10% CFG label dropout. - **Classifier (`GalaxyCNN`):** trained on VAE-reconstructed images (to match the distribution of diffusion outputs) for evaluating class fidelity of generated samples. ### Guidance scale Classifier recall on generated images peaks around `w ≈ 3`, but latent-space coverage analysis shows `w ≈ 2.5` is the better fidelity/diversity operating point (matched within-class spread). Higher `w` over-extrapolates samples toward neighbouring classes. See the coverage analysis in the code repository. ## Training data **Galaxy10 DECaLS** — https://astronn.readthedocs.io/en/latest/galaxy10.html (17,736 images; 10 classes; not redistributed here).