import os import torch import torchvision from torch.utils.data import DataLoader from torchvision import transforms from model import DiffusionModel, UNet from torchvision.datasets import CocoCaptions import argparse from tqdm import tqdm # Config IMAGE_SIZE = 256 BATCH_SIZE = 16 EPOCHS = 50 LR = 2e-5 TIMESTEPS = 1000 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def load_coco_dataset(): transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) dataset = CocoCaptions( root='./train2017', annFile='./annotations/captions_train2017.json', transform=transform ) dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=lambda x: (torch.stack([item[0] for item in x]), [item[1] for item in x]) ) return dataloader def train(): # Setup model = UNet().to(DEVICE) betas = torch.linspace(1e-4, 0.02, TIMESTEPS).to(DEVICE) diffusion = DiffusionModel(model, betas, DEVICE) optimizer = torch.optim.AdamW(model.parameters(), lr=LR) dataloader = load_coco_dataset() # Training loop for epoch in range(EPOCHS): pbar = tqdm(dataloader) for images, captions in pbar: images = images.to(DEVICE) # Flatten captions (5 per image) and repeat images captions = [cap for sublist in captions for cap in sublist] images = images.repeat_interleave(5, dim=0) # Sample timesteps t = torch.randint(0, TIMESTEPS, (images.shape[0],), device=DEVICE).long() # Compute loss loss = diffusion.p_losses(images, captions, t) # Optimize optimizer.zero_grad() loss.backward() optimizer.step() pbar.set_description(f"Epoch {epoch}, Loss: {loss.item():.4f}") # Save checkpoint torch.save(model.state_dict(), f"diffusion_model_epoch_{epoch}.pth") if __name__ == "__main__": train()