| | --- |
| | tags: |
| | - pytorch |
| | - vae |
| | - diffusion |
| | - image-generation |
| | - cc3m |
| | license: mit |
| | datasets: |
| | - pixparse/cc3m-wds |
| | library_name: diffusers |
| | --- |
| | |
| | # UNet-Style VAE for 256x256 Image Reconstruction |
| |
|
| | This model is a UNet-style Variational Autoencoder (VAE) trained on the [CC3M](https://huggingface.co/datasets/pixparse/cc3m-wds) dataset for high-quality image reconstruction and generation. It integrates adversarial, perceptual, and identity-preserving loss terms to improve semantic and visual fidelity. |
| |
|
| | ## Architecture |
| |
|
| | - **Encoder/Decoder**: Multi-scale UNet architecture |
| | - **Latent Space**: 8-channel latent bottleneck with reparameterization (mu, logvar) |
| | - **Losses**: |
| | - L1 reconstruction loss |
| | - KL divergence with annealing |
| | - LPIPS perceptual loss (VGG backbone) |
| | - Identity loss via MoCo-v2 embeddings |
| | - Adversarial loss via Patch Discriminator w/ Spectral Norm |
| |
|
| | $$ |
| | \mathcal{L}_{total} = \mathcal{L}_{recon} + \mathcal{L}_{PIPS} + 0.5 * \mathcal{L}_{GAN} + 0.1 *\mathcal{L}_{ID} + 10^{-6} *\mathcal{L}_{KL} |
| | $$ |
| | ## Reconstructions |
| | | Input | Output | |
| | |-------|--------| |
| | |  |  | |
| | ## Training Config |
| | | Hyperparameter | Value | |
| | |-----------------------|----------------------------| |
| | | Dataset | CC3M (850k images) | |
| | | Image Resolution | 256 x 256 | |
| | | Batch Size | 16 | |
| | | Optimizer | AdamW | |
| | | Learning Rate | 5e-5 | |
| | | Precision | bf16 (mixed precision) | |
| | | Total Steps | 210,000 | |
| | | GAN Start Step | 50,000 | |
| | | KL Annealing | Yes (10% of training) | |
| | | Augmentations | Crop, flip, jitter, blur, rotation | |
| | |
| | Trained using a cosine learning rate schedule with gradient clipping and automatic mixed precision (`torch.cuda.amp`) |
| | |