File size: 5,594 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
============================
Pretraining Variational AutoEncoder
============================

Variational Autoencoder (VAE) is a data compression technique that compresses high-resolution images into a lower-dimensional latent space, capturing essential features while reducing dimensionality. This process allows for efficient storage and processing of image data. VAE has been integral to training Stable Diffusion (SD) models, significantly reducing computational requirements. For instance, SDLX utilizes a VAE that reduces image dimensions by 8x, greatly optimizing the training and inference processes. In this repository, we provide training codes to pretrain the VAE from scratch, enabling users to achieve higher compression ratios in the spatial dimension, such as 16x or 32x.

Installation
============

Please pull the latest NeMo docker to get started, see details about NeMo docker `here <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo>`_.

Validation
========
We also provide a validation code for you to quickly evaluate our pretrained 16x VAE model on a 50K dataset. Once you start the docker, run the following script to start the testing.

.. code-block:: bash

   torchrun --nproc-per-node 8 nemo/collections/diffusion/vae/validate_vae.py --yes data.path=path/to/validation/data log.log_dir=/path/to/checkpoint

Configure the following variables:


1. ``data.path``: Set this to the directory containing your test data (e.g., `.jpg` or `.png` files). The original and VAE-reconstructed images will be logged side by side in Weights & Biases (wandb).

2. ``log.log_dir``: Set this to the directory containing the pretrained checkpoint. You can find our pretrained checkpoint at ``TODO by ethan``

Here are some sample images generated from our pretrained VAE.

``Left``: Original Image, ``Right``: 16x VAE Reconstructed Image

.. list-table::
   :align: center

   * - .. image:: https://github.com/user-attachments/assets/08122f5b-2e65-4d65-87d7-eceae9d158fb
         :width: 1400
         :align: center
     - .. image:: https://github.com/user-attachments/assets/6e805a0d-8783-4d24-a65b-d96a6ba1555d
         :width: 1400
         :align: center
     - .. image:: https://github.com/user-attachments/assets/aab1ef33-35da-444d-90ee-ba3ad58a6b2d
         :width: 1400
         :align: center

Data Preparation
========

1. we expect data to be in the form of WebDataset tar files. If you have a folder of images, you can use `tar` to convert them into WebDataset tar files:

    .. code-block:: bash

        000000.tar
        β”œβ”€β”€ 1.jpg
        β”œβ”€β”€ 2.jpg
        000001.tar
        β”œβ”€β”€ 3.jpg
        β”œβ”€β”€ 4.jpg

2. next we need to index the webdataset with `energon <https://nvidia.github.io/Megatron-Energon/>`_. navigate to the dataset directory and run the following command:

    .. code-block:: bash

        energon prepare . --num-workers 8 --shuffle-tars

3. then select dataset type `ImageWebdataset` and specify the type `jpg`. Below is an example of the interactive setup:

    .. code-block:: bash
        
        Found 2925 tar files in total. The first and last ones are:
        - 000000.tar
        - 002924.tar
        If you want to exclude some of them, cancel with ctrl+c and specify an exclude filter in the command line.
        Please enter a desired train/val/test split like "0.5, 0.2, 0.3" or "8,1,1": 99,1,0
        Indexing shards  [####################################]  2925/2925
        Sample 0, keys:
        - jpg
        Sample 1, keys:
        - jpg
        Found the following part types in the dataset: jpg
        Do you want to create a dataset.yaml interactively? [Y/n]:
        The following dataset classes are available:
        0. CaptioningWebdataset
        1. CrudeWebdataset
        2. ImageClassificationWebdataset
        3. ImageWebdataset
        4. InterleavedWebdataset
        5. MultiChoiceVQAWebdataset
        6. OCRWebdataset
        7. SimilarityInterleavedWebdataset
        8. TextWebdataset
        9. VQAOCRWebdataset
        10. VQAWebdataset
        11. VidQAWebdataset
        Please enter a number to choose a class: 3
        The dataset you selected uses the following sample type:

        @dataclass
        class ImageSample(Sample):
            """Sample type for an image, e.g. for image reconstruction."""

            #: The input image tensor in the shape (C, H, W)
            image: torch.Tensor

        Do you want to set a simple field_map[Y] (or write your own sample_loader [n])? [Y/n]:

        For each field, please specify the corresponding name in the WebDataset.
        Available types in WebDataset: jpg
        Leave empty for skipping optional field
        You may also access json fields e.g. by setting the field to: json[field][field]
        You may also specify alternative fields e.g. by setting to: jpg,png
        Please enter the field_map for ImageWebdataset:
        Please enter a webdataset field name for 'image' (<class 'torch.Tensor'>):
        That type doesn't exist in the WebDataset. Please try again.
        Please enter a webdataset field name for 'image' (<class 'torch.Tensor'>): jpg
        Done

4. finally, you can use the indexed dataset to train the VAE model. specify `data.path=/path/to/dataset` in the training script `train_vae.py`.

Training
========

We provide a sample training script for launching multi-node training. Simply configure ``data.path`` to point to your prepared dataset to get started.

.. code-block:: bash

   bash nemo/collections/diffusion/vae/train_vae.sh \
   data.path=xxx