Spaces:
Sleeping
Sleeping
File size: 2,624 Bytes
ad9ba57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
from tqdm import tqdm
def train_diffusion():
# Train and save a DDPM diffusion model on MNIST.
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
loader = DataLoader(train_ds, batch_size=128, shuffle=True)
# Conditional DDPM UNet for MNIST digits
unet = UNet2DModel(
sample_size=28,
in_channels=1,
out_channels=1,
block_out_channels=(32, 64, 128),
down_block_types=("DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
num_class_embeds=10,
).to(device)
scheduler = DDPMScheduler(num_train_timesteps=1000)
pipeline = DDPMPipeline(unet=unet, scheduler=scheduler).to(device)
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4, weight_decay=1e-4) # changed from Adam
epochs = 5
print(f"Training DDPM for {epochs} epochs...")
try:
for epoch in range(1, epochs + 1):
pbar = tqdm(loader, desc=f"Epoch {epoch}/{epochs}")
for images, labels in pbar:
images = images.to(device)
labels = labels.to(device)
noise = torch.randn_like(images)
timesteps = torch.randint(
0, scheduler.num_train_timesteps, (images.shape[0],), device=device
).long()
noisy = scheduler.add_noise(images, noise, timesteps)
# Conditional noise prediction
model_pred = unet(noisy, timesteps, class_labels=labels, return_dict=False)[0]
loss = F.mse_loss(model_pred, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_postfix(loss=f"{loss.item():.4f}")
except KeyboardInterrupt:
print("\nKeyboard interrupt, saving model...")
output_dir = "my_diffusion_model"
pipeline.save_pretrained(output_dir)
print(f"Model saved to {output_dir}/")
return pipeline
output_dir = "my_diffusion_model"
pipeline.save_pretrained(output_dir)
print(f"Training complete. Model saved to {output_dir}/")
return pipeline
if __name__ == "__main__":
train_diffusion() |