| --- |
| license: cc0-1.0 |
| tags: |
| - diffusion |
| - latent-diffusion |
| - classifier-free-guidance |
| - astronomy |
| - galaxies |
| - galaxy10 |
| library_name: pytorch |
| --- |
| |
| # Galaxy Diffusion β latent diffusion weights (Galaxy10 DECaLS) |
|
|
| Conditional **latent diffusion model** (VAE + classifier-free guidance) for generating |
| galaxy images by morphology class, trained on **Galaxy10 DECaLS** (17,736 RGB images, |
| 256Γ256, 10 morphological classes). |
|
|
| These are the `.safetensors` weights. The model uses a **custom architecture** β it is |
| **not** a `transformers` / `diffusers` model and does not load via `AutoModel`. You need |
| the `galaxy_diffusion` package from the code repository to instantiate it. |
|
|
| - **Code:** https://github.com/LLapsus/galaxy-diffusion |
| - **License:** CC0 1.0 (public domain) |
|
|
| ## Files |
|
|
| | File | Contents | |
| |---|---| |
| | `latent_diffusion_galaxy10_xattn_v1.model.safetensors` | UNet denoiser (`LatentUNetCA`, cross-attention conditioning), ~27.9M params | |
| | `latent_diffusion_galaxy10_xattn_v1.vae.safetensors` | VAE (image β 4Γ32Γ32 latent), ~1.09M params | |
| | `latent_diffusion_galaxy10_xattn_v1.config.json` | constructor args (`vae_config`, `unet_config`, `unet_type`) + latent normalisation stats (`latents_mean`, `latents_std`) | |
| | `galaxy10_classifier.model.safetensors` | `GalaxyCNN` evaluation classifier, ~1.75M params (val acc 0.829) | |
| | `galaxy10_classifier.config.json` | classifier metadata (`val_acc`, `epoch`) | |
|
|
| ## Installation |
|
|
| ```bash |
| pip install "git+https://github.com/LLapsus/galaxy-diffusion.git" |
| pip install huggingface_hub safetensors |
| ``` |
|
|
| ## Load the weights |
|
|
| ```python |
| import json |
| import torch |
| from huggingface_hub import snapshot_download |
| from safetensors.torch import load_file |
| |
| from galaxy_diffusion.models.vae import VAE |
| from galaxy_diffusion.models.unet import LatentUNetCA |
| |
| path = snapshot_download("llapsus/galaxy-diffusion") # downloads all files |
| cfg = json.load(open(f"{path}/latent_diffusion_galaxy10_xattn_v1.config.json")) |
| |
| vae = VAE(**cfg["vae_config"]) |
| vae.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.vae.safetensors")) |
| vae.eval() |
| |
| unet = LatentUNetCA(**cfg["unet_config"]) |
| unet.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.model.safetensors")) |
| unet.eval() |
| ``` |
|
|
| ## Generate images |
|
|
| ```python |
| from galaxy_diffusion.diffusion.ddpm import cosine_schedule, sample_cfg |
| |
| device = "cuda" |
| vae, unet = vae.to(device), unet.to(device) |
| |
| _, alpha, alpha_bar = cosine_schedule(1000) |
| alpha, alpha_bar = alpha.to(device), alpha_bar.to(device) |
| |
| latents_mean = torch.tensor(cfg["latents_mean"]) |
| latents_std = torch.tensor(cfg["latents_std"]) |
| |
| images = sample_cfg( |
| unet, vae, |
| classes=list(range(10)), # one image per class |
| alpha=alpha, alpha_bar=alpha_bar, |
| latent_shape=(cfg["unet_config"]["latent_channels"], 32, 32), |
| latents_mean=latents_mean, latents_std=latents_std, |
| device=device, |
| guidance_scale=2.5, # see "Guidance scale" below |
| cfg_rescale=0.7, # CFG rescaling (Lin et al., 2023) |
| ) # -> tensor (10, 3, 256, 256) in [-1, 1] |
| ``` |
|
|
| The classifier is loaded analogously with `GalaxyCNN` from |
| `galaxy_diffusion.models.classifier`. |
|
|
| ## Model details |
|
|
| - **VAE:** 8Γ spatial compression, 3Γ256Γ256 β 4Γ32Γ32, KL-regularised. |
| - **UNet (`LatentUNetCA`):** time conditioning via AdaGN, class conditioning via a |
| cross-attention block after each encoder/decoder level + bottleneck; cosine noise |
| schedule (T=1000); trained with Min-SNR-weighted MSE and 10% CFG label dropout. |
| - **Classifier (`GalaxyCNN`):** trained on VAE-reconstructed images (to match the |
| distribution of diffusion outputs) for evaluating class fidelity of generated samples. |
|
|
| ### Guidance scale |
|
|
| Classifier recall on generated images peaks around `w β 3`, but latent-space coverage |
| analysis shows `w β 2.5` is the better fidelity/diversity operating point (matched |
| within-class spread). Higher `w` over-extrapolates samples toward neighbouring classes. |
| See the coverage analysis in the code repository. |
|
|
| ## Training data |
|
|
| **Galaxy10 DECaLS** β https://astronn.readthedocs.io/en/latest/galaxy10.html |
| (17,736 images; 10 classes; not redistributed here). |
|
|