Spaces:
Runtime error
Runtime error
update train_unconditional for latent diffusion
Browse files- README.md +2 -2
- scripts/train_unconditional.py +16 -19
- scripts/train_vae.py +0 -2
README.md
CHANGED
|
@@ -89,7 +89,7 @@ accelerate launch --config_file config/accelerate_local.yaml \
|
|
| 89 |
scripts/train_unconditional.py \
|
| 90 |
--dataset_name teticio/audio-diffusion-256 \
|
| 91 |
--resolution 256 \
|
| 92 |
-
--output_dir
|
| 93 |
--num_epochs 100 \
|
| 94 |
--train_batch_size 2 \
|
| 95 |
--eval_batch_size 2 \
|
|
@@ -98,7 +98,7 @@ accelerate launch --config_file config/accelerate_local.yaml \
|
|
| 98 |
--lr_warmup_steps 500 \
|
| 99 |
--mixed_precision no \
|
| 100 |
--push_to_hub True \
|
| 101 |
-
--hub_model_id
|
| 102 |
--hub_token $(cat $HOME/.huggingface/token)
|
| 103 |
```
|
| 104 |
#### Run training on SageMaker.
|
|
|
|
| 89 |
scripts/train_unconditional.py \
|
| 90 |
--dataset_name teticio/audio-diffusion-256 \
|
| 91 |
--resolution 256 \
|
| 92 |
+
--output_dir audio-diffusion-256 \
|
| 93 |
--num_epochs 100 \
|
| 94 |
--train_batch_size 2 \
|
| 95 |
--eval_batch_size 2 \
|
|
|
|
| 98 |
--lr_warmup_steps 500 \
|
| 99 |
--mixed_precision no \
|
| 100 |
--push_to_hub True \
|
| 101 |
+
--hub_model_id audio-diffusion-256 \
|
| 102 |
--hub_token $(cat $HOME/.huggingface/token)
|
| 103 |
```
|
| 104 |
#### Run training on SageMaker.
|
scripts/train_unconditional.py
CHANGED
|
@@ -48,8 +48,9 @@ def main(args):
|
|
| 48 |
model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
|
| 49 |
else:
|
| 50 |
model = UNet2DModel(
|
| 51 |
-
|
| 52 |
-
|
|
|
|
| 53 |
layers_per_block=2,
|
| 54 |
block_out_channels=(128, 128, 256, 256, 512, 512),
|
| 55 |
down_block_types=(
|
|
@@ -114,7 +115,7 @@ def main(args):
|
|
| 114 |
def transforms(examples):
|
| 115 |
if args.vae is not None:
|
| 116 |
images = [
|
| 117 |
-
augmentations(image
|
| 118 |
for image in examples["image"]
|
| 119 |
]
|
| 120 |
else:
|
|
@@ -173,6 +174,13 @@ def main(args):
|
|
| 173 |
model.train()
|
| 174 |
for step, batch in enumerate(train_dataloader):
|
| 175 |
clean_images = batch["input"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
# Sample noise that we'll add to the images
|
| 177 |
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
| 178 |
bsz = clean_images.shape[0]
|
|
@@ -184,11 +192,6 @@ def main(args):
|
|
| 184 |
device=clean_images.device,
|
| 185 |
).long()
|
| 186 |
|
| 187 |
-
if args.vae is not None:
|
| 188 |
-
with torch.no_grad():
|
| 189 |
-
clean_images = vqvae.encode(
|
| 190 |
-
clean_images).latent_dist.sample()
|
| 191 |
-
|
| 192 |
# Add noise to the clean images according to the noise magnitude at each timestep
|
| 193 |
# (this is the forward diffusion process)
|
| 194 |
noisy_images = noise_scheduler.add_noise(clean_images, noise,
|
|
@@ -196,8 +199,7 @@ def main(args):
|
|
| 196 |
|
| 197 |
with accelerator.accumulate(model):
|
| 198 |
# Predict the noise residual
|
| 199 |
-
|
| 200 |
-
noise_pred = vqvae.decode(images)["sample"]
|
| 201 |
loss = F.mse_loss(noise_pred, noise)
|
| 202 |
accelerator.backward(loss)
|
| 203 |
|
|
@@ -209,13 +211,6 @@ def main(args):
|
|
| 209 |
ema_model.step(model)
|
| 210 |
optimizer.zero_grad()
|
| 211 |
|
| 212 |
-
if args.vae is not None:
|
| 213 |
-
with torch.no_grad():
|
| 214 |
-
images = [
|
| 215 |
-
image.convert('L')
|
| 216 |
-
for image in vqvae.decode(images)["sample"]
|
| 217 |
-
]
|
| 218 |
-
|
| 219 |
if accelerator.sync_gradients:
|
| 220 |
progress_bar.update(1)
|
| 221 |
global_step += 1
|
|
@@ -239,14 +234,16 @@ def main(args):
|
|
| 239 |
if args.vae is not None:
|
| 240 |
pipeline = LDMPipeline(
|
| 241 |
unet=accelerator.unwrap_model(
|
| 242 |
-
ema_model.averaged_model if args.use_ema else model
|
|
|
|
| 243 |
vqvae=vqvae,
|
| 244 |
scheduler=noise_scheduler,
|
| 245 |
)
|
| 246 |
else:
|
| 247 |
pipeline = DDPMPipeline(
|
| 248 |
unet=accelerator.unwrap_model(
|
| 249 |
-
ema_model.averaged_model if args.use_ema else model
|
|
|
|
| 250 |
scheduler=noise_scheduler,
|
| 251 |
)
|
| 252 |
|
|
|
|
| 48 |
model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
|
| 49 |
else:
|
| 50 |
model = UNet2DModel(
|
| 51 |
+
sample_size=args.resolution if args.vae is None else 64,
|
| 52 |
+
in_channels=1 if args.vae is None else 3,
|
| 53 |
+
out_channels=1 if args.vae is None else 3,
|
| 54 |
layers_per_block=2,
|
| 55 |
block_out_channels=(128, 128, 256, 256, 512, 512),
|
| 56 |
down_block_types=(
|
|
|
|
| 115 |
def transforms(examples):
|
| 116 |
if args.vae is not None:
|
| 117 |
images = [
|
| 118 |
+
augmentations(image.convert("RGB"))
|
| 119 |
for image in examples["image"]
|
| 120 |
]
|
| 121 |
else:
|
|
|
|
| 174 |
model.train()
|
| 175 |
for step, batch in enumerate(train_dataloader):
|
| 176 |
clean_images = batch["input"]
|
| 177 |
+
|
| 178 |
+
if args.vae is not None:
|
| 179 |
+
vqvae.to(clean_images.device)
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
clean_images = vqvae.encode(
|
| 182 |
+
clean_images).latent_dist.sample()
|
| 183 |
+
|
| 184 |
# Sample noise that we'll add to the images
|
| 185 |
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
| 186 |
bsz = clean_images.shape[0]
|
|
|
|
| 192 |
device=clean_images.device,
|
| 193 |
).long()
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
# Add noise to the clean images according to the noise magnitude at each timestep
|
| 196 |
# (this is the forward diffusion process)
|
| 197 |
noisy_images = noise_scheduler.add_noise(clean_images, noise,
|
|
|
|
| 199 |
|
| 200 |
with accelerator.accumulate(model):
|
| 201 |
# Predict the noise residual
|
| 202 |
+
noise_pred = model(noisy_images, timesteps)["sample"]
|
|
|
|
| 203 |
loss = F.mse_loss(noise_pred, noise)
|
| 204 |
accelerator.backward(loss)
|
| 205 |
|
|
|
|
| 211 |
ema_model.step(model)
|
| 212 |
optimizer.zero_grad()
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
if accelerator.sync_gradients:
|
| 215 |
progress_bar.update(1)
|
| 216 |
global_step += 1
|
|
|
|
| 234 |
if args.vae is not None:
|
| 235 |
pipeline = LDMPipeline(
|
| 236 |
unet=accelerator.unwrap_model(
|
| 237 |
+
ema_model.averaged_model if args.use_ema else model
|
| 238 |
+
),
|
| 239 |
vqvae=vqvae,
|
| 240 |
scheduler=noise_scheduler,
|
| 241 |
)
|
| 242 |
else:
|
| 243 |
pipeline = DDPMPipeline(
|
| 244 |
unet=accelerator.unwrap_model(
|
| 245 |
+
ema_model.averaged_model if args.use_ema else model
|
| 246 |
+
),
|
| 247 |
scheduler=noise_scheduler,
|
| 248 |
)
|
| 249 |
|
scripts/train_vae.py
CHANGED
|
@@ -4,9 +4,7 @@
|
|
| 4 |
|
| 5 |
# TODO
|
| 6 |
# grayscale
|
| 7 |
-
# add vae to train_uncond (no_grad)
|
| 8 |
# update README
|
| 9 |
-
# merge in changes to train_unconditional
|
| 10 |
|
| 11 |
import os
|
| 12 |
import argparse
|
|
|
|
| 4 |
|
| 5 |
# TODO
|
| 6 |
# grayscale
|
|
|
|
| 7 |
# update README
|
|
|
|
| 8 |
|
| 9 |
import os
|
| 10 |
import argparse
|