File size: 1,940 Bytes
68f7fb8
 
 
 
 
 
 
 
 
 
 
 
6b7bd51
68f7fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775aae5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
---
tags:
- pytorch
- vae
- diffusion
- image-generation
- cc3m
license: mit
datasets:
- pixparse/cc3m-wds
library_name: diffusers
---

# UNet-Style VAE for 256x256 Image Reconstruction

This model is a UNet-style Variational Autoencoder (VAE) trained on the [CC3M](https://huggingface.co/datasets/pixparse/cc3m-wds) dataset for high-quality image reconstruction and generation. It integrates adversarial, perceptual, and identity-preserving loss terms to improve semantic and visual fidelity.

## Architecture

- **Encoder/Decoder**: Multi-scale UNet architecture
- **Latent Space**: 8-channel latent bottleneck with reparameterization (mu, logvar)
- **Losses**:
  - L1 reconstruction loss
  - KL divergence with annealing
  - LPIPS perceptual loss (VGG backbone)
  - Identity loss via MoCo-v2 embeddings
  - Adversarial loss via Patch Discriminator w/ Spectral Norm

$$
\mathcal{L}_{total} = \mathcal{L}_{recon} + \mathcal{L}_{PIPS} + 0.5 * \mathcal{L}_{GAN} + 0.1 *\mathcal{L}_{ID} + 10^{-6} *\mathcal{L}_{KL}
$$
## Reconstructions
| Input | Output |
|-------|--------|
| ![input](./input_grid.png) | ![output](./recon_grid.png) |
## Training Config
| Hyperparameter        | Value                      |
|-----------------------|----------------------------|
| Dataset               | CC3M (850k images)         |
| Image Resolution      | 256 x 256                  |
| Batch Size            | 16                         |
| Optimizer             | AdamW                      |
| Learning Rate         | 5e-5                       |
| Precision             | bf16 (mixed precision)     |
| Total Steps           | 210,000                    |
| GAN Start Step        | 50,000                     |
| KL Annealing          | Yes (10% of training)      |
| Augmentations         | Crop, flip, jitter, blur, rotation |

Trained using a cosine learning rate schedule with gradient clipping and automatic mixed precision (`torch.cuda.amp`)