galaxy-diffusion / README.md
llapsus's picture
Upload README.md with huggingface_hub
73fd4ad verified
|
Raw
History Blame Contribute Delete
4.23 kB
---
license: cc0-1.0
tags:
- diffusion
- latent-diffusion
- classifier-free-guidance
- astronomy
- galaxies
- galaxy10
library_name: pytorch
---
# Galaxy Diffusion β€” latent diffusion weights (Galaxy10 DECaLS)
Conditional **latent diffusion model** (VAE + classifier-free guidance) for generating
galaxy images by morphology class, trained on **Galaxy10 DECaLS** (17,736 RGB images,
256Γ—256, 10 morphological classes).
These are the `.safetensors` weights. The model uses a **custom architecture** β€” it is
**not** a `transformers` / `diffusers` model and does not load via `AutoModel`. You need
the `galaxy_diffusion` package from the code repository to instantiate it.
- **Code:** https://github.com/LLapsus/galaxy-diffusion
- **License:** CC0 1.0 (public domain)
## Files
| File | Contents |
|---|---|
| `latent_diffusion_galaxy10_xattn_v1.model.safetensors` | UNet denoiser (`LatentUNetCA`, cross-attention conditioning), ~27.9M params |
| `latent_diffusion_galaxy10_xattn_v1.vae.safetensors` | VAE (image ↔ 4Γ—32Γ—32 latent), ~1.09M params |
| `latent_diffusion_galaxy10_xattn_v1.config.json` | constructor args (`vae_config`, `unet_config`, `unet_type`) + latent normalisation stats (`latents_mean`, `latents_std`) |
| `galaxy10_classifier.model.safetensors` | `GalaxyCNN` evaluation classifier, ~1.75M params (val acc 0.829) |
| `galaxy10_classifier.config.json` | classifier metadata (`val_acc`, `epoch`) |
## Installation
```bash
pip install "git+https://github.com/LLapsus/galaxy-diffusion.git"
pip install huggingface_hub safetensors
```
## Load the weights
```python
import json
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file
from galaxy_diffusion.models.vae import VAE
from galaxy_diffusion.models.unet import LatentUNetCA
path = snapshot_download("llapsus/galaxy-diffusion") # downloads all files
cfg = json.load(open(f"{path}/latent_diffusion_galaxy10_xattn_v1.config.json"))
vae = VAE(**cfg["vae_config"])
vae.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.vae.safetensors"))
vae.eval()
unet = LatentUNetCA(**cfg["unet_config"])
unet.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.model.safetensors"))
unet.eval()
```
## Generate images
```python
from galaxy_diffusion.diffusion.ddpm import cosine_schedule, sample_cfg
device = "cuda"
vae, unet = vae.to(device), unet.to(device)
_, alpha, alpha_bar = cosine_schedule(1000)
alpha, alpha_bar = alpha.to(device), alpha_bar.to(device)
latents_mean = torch.tensor(cfg["latents_mean"])
latents_std = torch.tensor(cfg["latents_std"])
images = sample_cfg(
unet, vae,
classes=list(range(10)), # one image per class
alpha=alpha, alpha_bar=alpha_bar,
latent_shape=(cfg["unet_config"]["latent_channels"], 32, 32),
latents_mean=latents_mean, latents_std=latents_std,
device=device,
guidance_scale=2.5, # see "Guidance scale" below
cfg_rescale=0.7, # CFG rescaling (Lin et al., 2023)
) # -> tensor (10, 3, 256, 256) in [-1, 1]
```
The classifier is loaded analogously with `GalaxyCNN` from
`galaxy_diffusion.models.classifier`.
## Model details
- **VAE:** 8Γ— spatial compression, 3Γ—256Γ—256 ↔ 4Γ—32Γ—32, KL-regularised.
- **UNet (`LatentUNetCA`):** time conditioning via AdaGN, class conditioning via a
cross-attention block after each encoder/decoder level + bottleneck; cosine noise
schedule (T=1000); trained with Min-SNR-weighted MSE and 10% CFG label dropout.
- **Classifier (`GalaxyCNN`):** trained on VAE-reconstructed images (to match the
distribution of diffusion outputs) for evaluating class fidelity of generated samples.
### Guidance scale
Classifier recall on generated images peaks around `w β‰ˆ 3`, but latent-space coverage
analysis shows `w β‰ˆ 2.5` is the better fidelity/diversity operating point (matched
within-class spread). Higher `w` over-extrapolates samples toward neighbouring classes.
See the coverage analysis in the code repository.
## Training data
**Galaxy10 DECaLS** β€” https://astronn.readthedocs.io/en/latest/galaxy10.html
(17,736 images; 10 classes; not redistributed here).