Spaces:
Runtime error
Runtime error
merge with diffusers latest version
Browse files- scripts/train_unconditional.py +61 -34
- scripts/train_vae.py +1 -0
scripts/train_unconditional.py
CHANGED
|
@@ -35,24 +35,19 @@ def main(args):
|
|
| 35 |
output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
|
| 36 |
logging_dir = os.path.join(output_dir, args.logging_dir)
|
| 37 |
accelerator = Accelerator(
|
|
|
|
| 38 |
mixed_precision=args.mixed_precision,
|
| 39 |
log_with="tensorboard",
|
| 40 |
logging_dir=logging_dir,
|
| 41 |
)
|
| 42 |
|
|
|
|
|
|
|
|
|
|
| 43 |
if args.from_pretrained is not None:
|
| 44 |
-
|
| 45 |
-
pretrained = LDMPipeline.from_pretrained(args.from_pretrained)
|
| 46 |
-
vqvae = pretrained.vqvae
|
| 47 |
-
model = pretrained.unet
|
| 48 |
else:
|
| 49 |
-
vqvae = AutoencoderKL(sample_size=args.resolution,
|
| 50 |
-
in_channels=1,
|
| 51 |
-
out_channels=1,
|
| 52 |
-
latent_channels=1,
|
| 53 |
-
layers_per_block=2)
|
| 54 |
model = UNet2DModel(
|
| 55 |
-
sample_size=args.resolution,
|
| 56 |
in_channels=1,
|
| 57 |
out_channels=1,
|
| 58 |
layers_per_block=2,
|
|
@@ -75,10 +70,12 @@ def main(args):
|
|
| 75 |
),
|
| 76 |
)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
optimizer = torch.optim.AdamW(
|
| 83 |
model.parameters(),
|
| 84 |
lr=args.learning_rate,
|
|
@@ -115,7 +112,13 @@ def main(args):
|
|
| 115 |
)
|
| 116 |
|
| 117 |
def transforms(examples):
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
return {"input": images}
|
| 120 |
|
| 121 |
dataset.set_transform(transforms)
|
|
@@ -181,27 +184,42 @@ def main(args):
|
|
| 181 |
device=clean_images.device,
|
| 182 |
).long()
|
| 183 |
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
# Add noise to the clean images according to the noise magnitude at each timestep
|
| 186 |
# (this is the forward diffusion process)
|
| 187 |
-
|
| 188 |
-
|
| 189 |
|
| 190 |
with accelerator.accumulate(model):
|
| 191 |
# Predict the noise residual
|
| 192 |
-
|
| 193 |
-
noise_pred = vqvae.decode(
|
| 194 |
loss = F.mse_loss(noise_pred, noise)
|
| 195 |
accelerator.backward(loss)
|
| 196 |
|
| 197 |
-
accelerator.
|
|
|
|
| 198 |
optimizer.step()
|
| 199 |
lr_scheduler.step()
|
| 200 |
if args.use_ema:
|
| 201 |
ema_model.step(model)
|
| 202 |
optimizer.zero_grad()
|
| 203 |
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
logs = {
|
| 206 |
"loss": loss.detach().item(),
|
| 207 |
"lr": lr_scheduler.get_last_lr()[0],
|
|
@@ -211,7 +229,6 @@ def main(args):
|
|
| 211 |
logs["ema_decay"] = ema_model.decay
|
| 212 |
progress_bar.set_postfix(**logs)
|
| 213 |
accelerator.log(logs, step=global_step)
|
| 214 |
-
global_step += 1
|
| 215 |
progress_bar.close()
|
| 216 |
|
| 217 |
accelerator.wait_for_everyone()
|
|
@@ -219,17 +236,19 @@ def main(args):
|
|
| 219 |
# Generate sample images for visual inspection
|
| 220 |
if accelerator.is_main_process:
|
| 221 |
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
| 233 |
|
| 234 |
# save the model
|
| 235 |
if args.push_to_hub:
|
|
@@ -325,6 +344,14 @@ if __name__ == "__main__":
|
|
| 325 |
parser.add_argument("--hop_length", type=int, default=512)
|
| 326 |
parser.add_argument("--from_pretrained", type=str, default=None)
|
| 327 |
parser.add_argument("--start_epoch", type=int, default=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
args = parser.parse_args()
|
| 330 |
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
|
|
|
| 35 |
output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
|
| 36 |
logging_dir = os.path.join(output_dir, args.logging_dir)
|
| 37 |
accelerator = Accelerator(
|
| 38 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 39 |
mixed_precision=args.mixed_precision,
|
| 40 |
log_with="tensorboard",
|
| 41 |
logging_dir=logging_dir,
|
| 42 |
)
|
| 43 |
|
| 44 |
+
if args.vae is not None:
|
| 45 |
+
vqvae = AutoencoderKL.from_pretrained(args.vae)
|
| 46 |
+
|
| 47 |
if args.from_pretrained is not None:
|
| 48 |
+
model = DDPMPipeline.from_pretrained(args.from_pretrained).unet
|
|
|
|
|
|
|
|
|
|
| 49 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
model = UNet2DModel(
|
|
|
|
| 51 |
in_channels=1,
|
| 52 |
out_channels=1,
|
| 53 |
layers_per_block=2,
|
|
|
|
| 70 |
),
|
| 71 |
)
|
| 72 |
|
| 73 |
+
if args.scheduler == "ddpm":
|
| 74 |
+
noise_scheduler = DDPMScheduler(num_train_timesteps=1000,
|
| 75 |
+
tensor_format="pt")
|
| 76 |
+
else:
|
| 77 |
+
noise_scheduler = DDIMScheduler(num_train_timesteps=1000,
|
| 78 |
+
tensor_format="pt")
|
| 79 |
optimizer = torch.optim.AdamW(
|
| 80 |
model.parameters(),
|
| 81 |
lr=args.learning_rate,
|
|
|
|
| 112 |
)
|
| 113 |
|
| 114 |
def transforms(examples):
|
| 115 |
+
if args.vae is not None:
|
| 116 |
+
images = [
|
| 117 |
+
augmentations(image).convert("RGB")
|
| 118 |
+
for image in examples["image"]
|
| 119 |
+
]
|
| 120 |
+
else:
|
| 121 |
+
images = [augmentations(image) for image in examples["image"]]
|
| 122 |
return {"input": images}
|
| 123 |
|
| 124 |
dataset.set_transform(transforms)
|
|
|
|
| 184 |
device=clean_images.device,
|
| 185 |
).long()
|
| 186 |
|
| 187 |
+
if args.vae is not None:
|
| 188 |
+
with torch.no_grad():
|
| 189 |
+
clean_images = vqvae.encode(
|
| 190 |
+
clean_images).latent_dist.sample()
|
| 191 |
+
|
| 192 |
# Add noise to the clean images according to the noise magnitude at each timestep
|
| 193 |
# (this is the forward diffusion process)
|
| 194 |
+
noisy_images = noise_scheduler.add_noise(clean_images, noise,
|
| 195 |
+
timesteps)
|
| 196 |
|
| 197 |
with accelerator.accumulate(model):
|
| 198 |
# Predict the noise residual
|
| 199 |
+
images = model(noisy_images, timesteps)["sample"]
|
| 200 |
+
noise_pred = vqvae.decode(images)["sample"]
|
| 201 |
loss = F.mse_loss(noise_pred, noise)
|
| 202 |
accelerator.backward(loss)
|
| 203 |
|
| 204 |
+
if accelerator.sync_gradients:
|
| 205 |
+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
| 206 |
optimizer.step()
|
| 207 |
lr_scheduler.step()
|
| 208 |
if args.use_ema:
|
| 209 |
ema_model.step(model)
|
| 210 |
optimizer.zero_grad()
|
| 211 |
|
| 212 |
+
if args.vae is not None:
|
| 213 |
+
with torch.no_grad():
|
| 214 |
+
images = [
|
| 215 |
+
image.convert('L')
|
| 216 |
+
for image in vqvae.decode(images)["sample"]
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
if accelerator.sync_gradients:
|
| 220 |
+
progress_bar.update(1)
|
| 221 |
+
global_step += 1
|
| 222 |
+
|
| 223 |
logs = {
|
| 224 |
"loss": loss.detach().item(),
|
| 225 |
"lr": lr_scheduler.get_last_lr()[0],
|
|
|
|
| 229 |
logs["ema_decay"] = ema_model.decay
|
| 230 |
progress_bar.set_postfix(**logs)
|
| 231 |
accelerator.log(logs, step=global_step)
|
|
|
|
| 232 |
progress_bar.close()
|
| 233 |
|
| 234 |
accelerator.wait_for_everyone()
|
|
|
|
| 236 |
# Generate sample images for visual inspection
|
| 237 |
if accelerator.is_main_process:
|
| 238 |
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
| 239 |
+
if args.vae is not None:
|
| 240 |
+
pipeline = LDMPipeline(
|
| 241 |
+
unet=accelerator.unwrap_model(
|
| 242 |
+
ema_model.averaged_model if args.use_ema else model),
|
| 243 |
+
vqvae=vqvae,
|
| 244 |
+
scheduler=noise_scheduler,
|
| 245 |
+
)
|
| 246 |
+
else:
|
| 247 |
+
pipeline = DDPMPipeline(
|
| 248 |
+
unet=accelerator.unwrap_model(
|
| 249 |
+
ema_model.averaged_model if args.use_ema else model),
|
| 250 |
+
scheduler=noise_scheduler,
|
| 251 |
+
)
|
| 252 |
|
| 253 |
# save the model
|
| 254 |
if args.push_to_hub:
|
|
|
|
| 344 |
parser.add_argument("--hop_length", type=int, default=512)
|
| 345 |
parser.add_argument("--from_pretrained", type=str, default=None)
|
| 346 |
parser.add_argument("--start_epoch", type=int, default=0)
|
| 347 |
+
parser.add_argument("--scheduler",
|
| 348 |
+
type=str,
|
| 349 |
+
default="ddpm",
|
| 350 |
+
help="ddpm or ddim")
|
| 351 |
+
parser.add_argument("--vae",
|
| 352 |
+
type=str,
|
| 353 |
+
default=None,
|
| 354 |
+
help="pretrained VAE model for latent diffusion")
|
| 355 |
|
| 356 |
args = parser.parse_args()
|
| 357 |
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
scripts/train_vae.py
CHANGED
|
@@ -6,6 +6,7 @@
|
|
| 6 |
# grayscale
|
| 7 |
# add vae to train_uncond (no_grad)
|
| 8 |
# update README
|
|
|
|
| 9 |
|
| 10 |
import os
|
| 11 |
import argparse
|
|
|
|
| 6 |
# grayscale
|
| 7 |
# add vae to train_uncond (no_grad)
|
| 8 |
# update README
|
| 9 |
+
# merge in changes to train_unconditional
|
| 10 |
|
| 11 |
import os
|
| 12 |
import argparse
|