image_generator / train_diff.py
Kyryll Kochkin
new frontend
ad9ba57
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
from tqdm import tqdm
def train_diffusion():
# Train and save a DDPM diffusion model on MNIST.
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
loader = DataLoader(train_ds, batch_size=128, shuffle=True)
# Conditional DDPM UNet for MNIST digits
unet = UNet2DModel(
sample_size=28,
in_channels=1,
out_channels=1,
block_out_channels=(32, 64, 128),
down_block_types=("DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
num_class_embeds=10,
).to(device)
scheduler = DDPMScheduler(num_train_timesteps=1000)
pipeline = DDPMPipeline(unet=unet, scheduler=scheduler).to(device)
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4, weight_decay=1e-4) # changed from Adam
epochs = 5
print(f"Training DDPM for {epochs} epochs...")
try:
for epoch in range(1, epochs + 1):
pbar = tqdm(loader, desc=f"Epoch {epoch}/{epochs}")
for images, labels in pbar:
images = images.to(device)
labels = labels.to(device)
noise = torch.randn_like(images)
timesteps = torch.randint(
0, scheduler.num_train_timesteps, (images.shape[0],), device=device
).long()
noisy = scheduler.add_noise(images, noise, timesteps)
# Conditional noise prediction
model_pred = unet(noisy, timesteps, class_labels=labels, return_dict=False)[0]
loss = F.mse_loss(model_pred, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_postfix(loss=f"{loss.item():.4f}")
except KeyboardInterrupt:
print("\nKeyboard interrupt, saving model...")
output_dir = "my_diffusion_model"
pipeline.save_pretrained(output_dir)
print(f"Model saved to {output_dir}/")
return pipeline
output_dir = "my_diffusion_model"
pipeline.save_pretrained(output_dir)
print(f"Training complete. Model saved to {output_dir}/")
return pipeline
if __name__ == "__main__":
train_diffusion()