File size: 2,241 Bytes
67dc5db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from model import DiffusionModel, UNet
from torchvision.datasets import CocoCaptions
import argparse
from tqdm import tqdm

# Config
IMAGE_SIZE = 256
BATCH_SIZE = 16
EPOCHS = 50
LR = 2e-5
TIMESTEPS = 1000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def load_coco_dataset():
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    dataset = CocoCaptions(
        root='./train2017',
        annFile='./annotations/captions_train2017.json',
        transform=transform
    )
    
    dataloader = DataLoader(
        dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=4,
        collate_fn=lambda x: (torch.stack([item[0] for item in x]), [item[1] for item in x])
    )
    return dataloader

def train():
    # Setup
    model = UNet().to(DEVICE)
    betas = torch.linspace(1e-4, 0.02, TIMESTEPS).to(DEVICE)
    diffusion = DiffusionModel(model, betas, DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    dataloader = load_coco_dataset()
    
    # Training loop
    for epoch in range(EPOCHS):
        pbar = tqdm(dataloader)
        for images, captions in pbar:
            images = images.to(DEVICE)
            
            # Flatten captions (5 per image) and repeat images
            captions = [cap for sublist in captions for cap in sublist]
            images = images.repeat_interleave(5, dim=0)
            
            # Sample timesteps
            t = torch.randint(0, TIMESTEPS, (images.shape[0],), device=DEVICE).long()
            
            # Compute loss
            loss = diffusion.p_losses(images, captions, t)
            
            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            pbar.set_description(f"Epoch {epoch}, Loss: {loss.item():.4f}")
        
        # Save checkpoint
        torch.save(model.state_dict(), f"diffusion_model_epoch_{epoch}.pth")

if __name__ == "__main__":
    train()