File size: 4,233 Bytes
87daef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b873f9
87daef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
---
license: cc0-1.0
tags:
  - diffusion
  - latent-diffusion
  - classifier-free-guidance
  - astronomy
  - galaxies
  - galaxy10
library_name: pytorch
---

# Galaxy Diffusion — latent diffusion weights (Galaxy10 DECaLS)

Conditional **latent diffusion model** (VAE + classifier-free guidance) for generating
galaxy images by morphology class, trained on **Galaxy10 DECaLS** (17,736 RGB images,
256×256, 10 morphological classes).

These are the `.safetensors` weights. The model uses a **custom architecture** — it is
**not** a `transformers` / `diffusers` model and does not load via `AutoModel`. You need
the `galaxy_diffusion` package from the code repository to instantiate it.

- **Code:** https://github.com/LLapsus/galaxy-diffusion
- **License:** CC0 1.0 (public domain)

## Files

| File | Contents |
|---|---|
| `latent_diffusion_galaxy10_xattn_v1.model.safetensors` | UNet denoiser (`LatentUNetCA`, cross-attention conditioning), ~27.9M params |
| `latent_diffusion_galaxy10_xattn_v1.vae.safetensors` | VAE (image ↔ 4×32×32 latent), ~1.09M params |
| `latent_diffusion_galaxy10_xattn_v1.config.json` | constructor args (`vae_config`, `unet_config`, `unet_type`) + latent normalisation stats (`latents_mean`, `latents_std`) |
| `galaxy10_classifier.model.safetensors` | `GalaxyCNN` evaluation classifier, ~1.75M params (val acc 0.829) |
| `galaxy10_classifier.config.json` | classifier metadata (`val_acc`, `epoch`) |

## Installation

```bash
pip install "git+https://github.com/LLapsus/galaxy-diffusion.git"
pip install huggingface_hub safetensors
```

## Load the weights

```python
import json
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file

from galaxy_diffusion.models.vae import VAE
from galaxy_diffusion.models.unet import LatentUNetCA

path = snapshot_download("llapsus/galaxy-diffusion")  # downloads all files
cfg  = json.load(open(f"{path}/latent_diffusion_galaxy10_xattn_v1.config.json"))

vae = VAE(**cfg["vae_config"])
vae.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.vae.safetensors"))
vae.eval()

unet = LatentUNetCA(**cfg["unet_config"])
unet.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.model.safetensors"))
unet.eval()
```

## Generate images

```python
from galaxy_diffusion.diffusion.ddpm import cosine_schedule, sample_cfg

device = "cuda"
vae, unet = vae.to(device), unet.to(device)

_, alpha, alpha_bar = cosine_schedule(1000)
alpha, alpha_bar = alpha.to(device), alpha_bar.to(device)

latents_mean = torch.tensor(cfg["latents_mean"])
latents_std  = torch.tensor(cfg["latents_std"])

images = sample_cfg(
    unet, vae,
    classes=list(range(10)),                       # one image per class
    alpha=alpha, alpha_bar=alpha_bar,
    latent_shape=(cfg["unet_config"]["latent_channels"], 32, 32),
    latents_mean=latents_mean, latents_std=latents_std,
    device=device,
    guidance_scale=2.5,                            # see "Guidance scale" below
    cfg_rescale=0.7,                               # CFG rescaling (Lin et al., 2023)
)   # -> tensor (10, 3, 256, 256) in [-1, 1]
```

The classifier is loaded analogously with `GalaxyCNN` from
`galaxy_diffusion.models.classifier`.

## Model details

- **VAE:** 8× spatial compression, 3×256×256 ↔ 4×32×32, KL-regularised.
- **UNet (`LatentUNetCA`):** time conditioning via AdaGN, class conditioning via a
  cross-attention block after each encoder/decoder level + bottleneck; cosine noise
  schedule (T=1000); trained with Min-SNR-weighted MSE and 10% CFG label dropout.
- **Classifier (`GalaxyCNN`):** trained on VAE-reconstructed images (to match the
  distribution of diffusion outputs) for evaluating class fidelity of generated samples.

### Guidance scale

Classifier recall on generated images peaks around `w ≈ 3`, but latent-space coverage
analysis shows `w ≈ 2.5` is the better fidelity/diversity operating point (matched
within-class spread). Higher `w` over-extrapolates samples toward neighbouring classes.
See the coverage analysis in the code repository.

## Training data

**Galaxy10 DECaLS** — https://astronn.readthedocs.io/en/latest/galaxy10.html
(17,736 images; 10 classes; not redistributed here).