Spaces:
Runtime error
Runtime error
added vae notebook
Browse files- notebooks/test_vae.ipynb +0 -0
- scripts/train_unconditional.py +6 -6
- scripts/train_vae.py +1 -1
notebooks/test_vae.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/train_unconditional.py
CHANGED
|
@@ -11,7 +11,7 @@ from accelerate import Accelerator
|
|
| 11 |
from accelerate.logging import get_logger
|
| 12 |
from datasets import load_from_disk, load_dataset
|
| 13 |
from diffusers import (DDPMPipeline, DDPMScheduler, UNet2DModel, LDMPipeline,
|
| 14 |
-
DDIMScheduler,
|
| 15 |
from diffusers.hub_utils import init_git_repo, push_to_hub
|
| 16 |
from diffusers.optimization import get_scheduler
|
| 17 |
from diffusers.training_utils import EMAModel
|
|
@@ -46,11 +46,11 @@ def main(args):
|
|
| 46 |
vqvae = pretrained.vqvae
|
| 47 |
model = pretrained.unet
|
| 48 |
else:
|
| 49 |
-
vqvae =
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
model = UNet2DModel(
|
| 55 |
sample_size=args.resolution,
|
| 56 |
in_channels=1,
|
|
|
|
| 11 |
from accelerate.logging import get_logger
|
| 12 |
from datasets import load_from_disk, load_dataset
|
| 13 |
from diffusers import (DDPMPipeline, DDPMScheduler, UNet2DModel, LDMPipeline,
|
| 14 |
+
DDIMScheduler, AutoencoderKL)
|
| 15 |
from diffusers.hub_utils import init_git_repo, push_to_hub
|
| 16 |
from diffusers.optimization import get_scheduler
|
| 17 |
from diffusers.training_utils import EMAModel
|
|
|
|
| 46 |
vqvae = pretrained.vqvae
|
| 47 |
model = pretrained.unet
|
| 48 |
else:
|
| 49 |
+
vqvae = AutoencoderKL(sample_size=args.resolution,
|
| 50 |
+
in_channels=1,
|
| 51 |
+
out_channels=1,
|
| 52 |
+
latent_channels=1,
|
| 53 |
+
layers_per_block=2)
|
| 54 |
model = UNet2DModel(
|
| 55 |
sample_size=args.resolution,
|
| 56 |
in_channels=1,
|
scripts/train_vae.py
CHANGED
|
@@ -152,7 +152,7 @@ if __name__ == "__main__":
|
|
| 152 |
trainer_opt,
|
| 153 |
resume_from_checkpoint=args.resume_from_checkpoint,
|
| 154 |
callbacks=[
|
| 155 |
-
ImageLogger(),
|
| 156 |
HFModelCheckpoint(ldm_config=config,
|
| 157 |
hf_checkpoint=args.hf_checkpoint_dir,
|
| 158 |
dirpath=args.ldm_checkpoint_dir,
|
|
|
|
| 152 |
trainer_opt,
|
| 153 |
resume_from_checkpoint=args.resume_from_checkpoint,
|
| 154 |
callbacks=[
|
| 155 |
+
ImageLogger(every=10),
|
| 156 |
HFModelCheckpoint(ldm_config=config,
|
| 157 |
hf_checkpoint=args.hf_checkpoint_dir,
|
| 158 |
dirpath=args.ldm_checkpoint_dir,
|