llapsus commited on
Commit
87daef1
·
verified ·
1 Parent(s): 09ffd40

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc0-1.0
3
+ tags:
4
+ - diffusion
5
+ - latent-diffusion
6
+ - classifier-free-guidance
7
+ - astronomy
8
+ - galaxies
9
+ - galaxy10
10
+ library_name: pytorch
11
+ ---
12
+
13
+ # Galaxy Diffusion — latent diffusion weights (Galaxy10 DECaLS)
14
+
15
+ Conditional **latent diffusion model** (VAE + classifier-free guidance) for generating
16
+ galaxy images by morphology class, trained on **Galaxy10 DECaLS** (17,736 RGB images,
17
+ 256×256, 10 morphological classes).
18
+
19
+ These are the `.safetensors` weights. The model uses a **custom architecture** — it is
20
+ **not** a `transformers` / `diffusers` model and does not load via `AutoModel`. You need
21
+ the `galaxy_diffusion` package from the code repository to instantiate it.
22
+
23
+ - **Code:** https://github.com/LLapsus/galaxy-diffusion
24
+ - **License:** CC0 1.0 (public domain)
25
+
26
+ ## Files
27
+
28
+ | File | Contents |
29
+ |---|---|
30
+ | `latent_diffusion_galaxy10_xattn_v1.model.safetensors` | UNet denoiser (`LatentUNetCA`, cross-attention conditioning), ~27.9M params |
31
+ | `latent_diffusion_galaxy10_xattn_v1.vae.safetensors` | VAE (image ↔ 4×32×32 latent), ~1.09M params |
32
+ | `latent_diffusion_galaxy10_xattn_v1.config.json` | constructor args (`vae_config`, `unet_config`, `unet_type`) + latent normalisation stats (`latents_mean`, `latents_std`) |
33
+ | `galaxy10_classifier.model.safetensors` | `GalaxyCNN` evaluation classifier, ~1.75M params (val acc 0.829) |
34
+ | `galaxy10_classifier.config.json` | classifier metadata (`val_acc`, `epoch`) |
35
+
36
+ ## Installation
37
+
38
+ ```bash
39
+ pip install "git+https://github.com/LLapsus/galaxy-diffusion.git"
40
+ pip install huggingface_hub safetensors
41
+ ```
42
+
43
+ ## Load the weights
44
+
45
+ ```python
46
+ import json
47
+ import torch
48
+ from huggingface_hub import snapshot_download
49
+ from safetensors.torch import load_file
50
+
51
+ from galaxy_diffusion.models.vae import VAE
52
+ from galaxy_diffusion.models.unet import LatentUNetCA
53
+
54
+ path = snapshot_download("LLapsus/galaxy-diffusion") # downloads all files
55
+ cfg = json.load(open(f"{path}/latent_diffusion_galaxy10_xattn_v1.config.json"))
56
+
57
+ vae = VAE(**cfg["vae_config"])
58
+ vae.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.vae.safetensors"))
59
+ vae.eval()
60
+
61
+ unet = LatentUNetCA(**cfg["unet_config"])
62
+ unet.load_state_dict(load_file(f"{path}/latent_diffusion_galaxy10_xattn_v1.model.safetensors"))
63
+ unet.eval()
64
+ ```
65
+
66
+ ## Generate images
67
+
68
+ ```python
69
+ from galaxy_diffusion.diffusion.ddpm import cosine_schedule, sample_cfg
70
+
71
+ device = "cuda"
72
+ vae, unet = vae.to(device), unet.to(device)
73
+
74
+ _, alpha, alpha_bar = cosine_schedule(1000)
75
+ alpha, alpha_bar = alpha.to(device), alpha_bar.to(device)
76
+
77
+ latents_mean = torch.tensor(cfg["latents_mean"])
78
+ latents_std = torch.tensor(cfg["latents_std"])
79
+
80
+ images = sample_cfg(
81
+ unet, vae,
82
+ classes=list(range(10)), # one image per class
83
+ alpha=alpha, alpha_bar=alpha_bar,
84
+ latent_shape=(cfg["unet_config"]["latent_channels"], 32, 32),
85
+ latents_mean=latents_mean, latents_std=latents_std,
86
+ device=device,
87
+ guidance_scale=2.5, # see "Guidance scale" below
88
+ cfg_rescale=0.7, # CFG rescaling (Lin et al., 2023)
89
+ ) # -> tensor (10, 3, 256, 256) in [-1, 1]
90
+ ```
91
+
92
+ The classifier is loaded analogously with `GalaxyCNN` from
93
+ `galaxy_diffusion.models.classifier`.
94
+
95
+ ## Model details
96
+
97
+ - **VAE:** 8× spatial compression, 3×256×256 ↔ 4×32×32, KL-regularised.
98
+ - **UNet (`LatentUNetCA`):** time conditioning via AdaGN, class conditioning via a
99
+ cross-attention block after each encoder/decoder level + bottleneck; cosine noise
100
+ schedule (T=1000); trained with Min-SNR-weighted MSE and 10% CFG label dropout.
101
+ - **Classifier (`GalaxyCNN`):** trained on VAE-reconstructed images (to match the
102
+ distribution of diffusion outputs) for evaluating class fidelity of generated samples.
103
+
104
+ ### Guidance scale
105
+
106
+ Classifier recall on generated images peaks around `w ≈ 3`, but latent-space coverage
107
+ analysis shows `w ≈ 2.5` is the better fidelity/diversity operating point (matched
108
+ within-class spread). Higher `w` over-extrapolates samples toward neighbouring classes.
109
+ See the coverage analysis in the code repository.
110
+
111
+ ## Training data
112
+
113
+ **Galaxy10 DECaLS** — https://astronn.readthedocs.io/en/latest/galaxy10.html
114
+ (17,736 images; 10 classes; not redistributed here).
115
+
116
+ ## Citation
117
+
118
+ TODO(pavel): citation / blog post link.
galaxy10_classifier.config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "val_acc": 0.8291925465838509,
3
+ "epoch": 84
4
+ }
galaxy10_classifier.model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbed56693211bafd98ab136c5935f84e3dfe6ca3faaf870b02dc67afe0baf16a
3
+ size 7022264
latent_diffusion_galaxy10_xattn_v1.config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "latents_mean": -0.03181709349155426,
3
+ "latents_std": 0.8395193219184875,
4
+ "vae_config": {
5
+ "in_channels": 3,
6
+ "latent_channels": 4,
7
+ "base_channels": 32,
8
+ "num_downsamples": 3
9
+ },
10
+ "unet_config": {
11
+ "latent_channels": 4,
12
+ "base_channels": 128,
13
+ "channel_mult": [
14
+ 1,
15
+ 2,
16
+ 4
17
+ ],
18
+ "num_res_blocks": 2,
19
+ "time_emb_dim": 256,
20
+ "class_emb_dim": 256,
21
+ "num_classes": 10,
22
+ "attn_levels": [
23
+ 1
24
+ ],
25
+ "cross_attn_heads": 4
26
+ },
27
+ "unet_type": "LatentUNetCA"
28
+ }
latent_diffusion_galaxy10_xattn_v1.model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:744d5491c1f6318e795b0aaa4bbba9eb9be6c40a96d7e110039fc0050c782663
3
+ size 111486200
latent_diffusion_galaxy10_xattn_v1.vae.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea48697060ddb9ea816876d90c6090d3383c3803f30643a38384c7f694844917
3
+ size 4367516