File size: 4,233 Bytes
87daef1 4b873f9 87daef1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | ---
license: cc0-1.0
tags:
- diffusion
- latent-diffusion
- classifier-free-guidance
- astronomy
- galaxies
- galaxy10
library_name: pytorch
---
# Galaxy Diffusion — latent diffusion weights (Galaxy10 DECaLS)
Conditional **latent diffusion model** (VAE + classifier-free guidance) for generating
galaxy images by morphology class, trained on **Galaxy10 DECaLS** (17,736 RGB images,
256×256, 10 morphological classes).
These are the `.safetensors` weights. The model uses a **custom architecture** — it is
**not** a `transformers` / `diffusers` model and does not load via `AutoModel`. You need
the `galaxy_diffusion` package from the code repository to instantiate it.
- **Code:** https://github.com/LLapsus/galaxy-diffusion
- **License:** CC0 1.0 (public domain)
## Files
| File | Contents |
|---|---|
| `latent_diffusion_galaxy10_xattn_v1.model.safetensors` | UNet denoiser (`LatentUNetCA`, cross-attention conditioning), ~27.9M params |
| `latent_diffusion_galaxy10_xattn_v1.vae.safetensors` | VAE (image ↔ 4×32×32 latent), ~1.09M params |
| `latent_diffusion_galaxy10_xattn_v1.config.json` | constructor args (`vae_config`, `unet_config`, `unet_type`) + latent normalisation stats (`latents_mean`, `latents_std`) |
| `galaxy10_classifier.model.safetensors` | `GalaxyCNN` evaluation classifier, ~1.75M params (val acc 0.829) |
| `galaxy10_classifier.config.json` | classifier metadata (`val_acc`, `epoch`) |
## Installation
```bash
pip install "git+https://github.com/LLapsus/galaxy-diffusion.git"
pip install huggingface_hub safetensors
```
## Load the weights
```python
import json
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from galaxy_diffusion.models.vae import VAE
from galaxy_diffusion.models.unet import LatentUNetCA
path = snapshot_download("llapsus/galaxy-diffusion") # downloads all files
cfg = json.load(open(f"{path}/latent_diffusion_galaxy10_xattn_v1.config.json"))
vae = VAE(**cfg["vae_config"])
vae.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.vae.safetensors"))
vae.eval()
unet = LatentUNetCA(**cfg["unet_config"])
unet.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.model.safetensors"))
unet.eval()
```
## Generate images
```python
from galaxy_diffusion.diffusion.ddpm import cosine_schedule, sample_cfg
device = "cuda"
vae, unet = vae.to(device), unet.to(device)
_, alpha, alpha_bar = cosine_schedule(1000)
alpha, alpha_bar = alpha.to(device), alpha_bar.to(device)
latents_mean = torch.tensor(cfg["latents_mean"])
latents_std = torch.tensor(cfg["latents_std"])
images = sample_cfg(
unet, vae,
classes=list(range(10)), # one image per class
alpha=alpha, alpha_bar=alpha_bar,
latent_shape=(cfg["unet_config"]["latent_channels"], 32, 32),
latents_mean=latents_mean, latents_std=latents_std,
device=device,
guidance_scale=2.5, # see "Guidance scale" below
cfg_rescale=0.7, # CFG rescaling (Lin et al., 2023)
) # -> tensor (10, 3, 256, 256) in [-1, 1]
```
The classifier is loaded analogously with `GalaxyCNN` from
`galaxy_diffusion.models.classifier`.
## Model details
- **VAE:** 8× spatial compression, 3×256×256 ↔ 4×32×32, KL-regularised.
- **UNet (`LatentUNetCA`):** time conditioning via AdaGN, class conditioning via a
cross-attention block after each encoder/decoder level + bottleneck; cosine noise
schedule (T=1000); trained with Min-SNR-weighted MSE and 10% CFG label dropout.
- **Classifier (`GalaxyCNN`):** trained on VAE-reconstructed images (to match the
distribution of diffusion outputs) for evaluating class fidelity of generated samples.
### Guidance scale
Classifier recall on generated images peaks around `w ≈ 3`, but latent-space coverage
analysis shows `w ≈ 2.5` is the better fidelity/diversity operating point (matched
within-class spread). Higher `w` over-extrapolates samples toward neighbouring classes.
See the coverage analysis in the code repository.
## Training data
**Galaxy10 DECaLS** — https://astronn.readthedocs.io/en/latest/galaxy10.html
(17,736 images; 10 classes; not redistributed here).
|