SkillForge45 commited on
Commit
67dc5db
·
verified ·
1 Parent(s): 54d9695

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +76 -0
train.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ from torch.utils.data import DataLoader
5
+ from torchvision import transforms
6
+ from model import DiffusionModel, UNet
7
+ from torchvision.datasets import CocoCaptions
8
+ import argparse
9
+ from tqdm import tqdm
10
+
11
+ # Config
12
+ IMAGE_SIZE = 256
13
+ BATCH_SIZE = 16
14
+ EPOCHS = 50
15
+ LR = 2e-5
16
+ TIMESTEPS = 1000
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ def load_coco_dataset():
20
+ transform = transforms.Compose([
21
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
24
+ ])
25
+
26
+ dataset = CocoCaptions(
27
+ root='./train2017',
28
+ annFile='./annotations/captions_train2017.json',
29
+ transform=transform
30
+ )
31
+
32
+ dataloader = DataLoader(
33
+ dataset,
34
+ batch_size=BATCH_SIZE,
35
+ shuffle=True,
36
+ num_workers=4,
37
+ collate_fn=lambda x: (torch.stack([item[0] for item in x]), [item[1] for item in x])
38
+ )
39
+ return dataloader
40
+
41
+ def train():
42
+ # Setup
43
+ model = UNet().to(DEVICE)
44
+ betas = torch.linspace(1e-4, 0.02, TIMESTEPS).to(DEVICE)
45
+ diffusion = DiffusionModel(model, betas, DEVICE)
46
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
47
+ dataloader = load_coco_dataset()
48
+
49
+ # Training loop
50
+ for epoch in range(EPOCHS):
51
+ pbar = tqdm(dataloader)
52
+ for images, captions in pbar:
53
+ images = images.to(DEVICE)
54
+
55
+ # Flatten captions (5 per image) and repeat images
56
+ captions = [cap for sublist in captions for cap in sublist]
57
+ images = images.repeat_interleave(5, dim=0)
58
+
59
+ # Sample timesteps
60
+ t = torch.randint(0, TIMESTEPS, (images.shape[0],), device=DEVICE).long()
61
+
62
+ # Compute loss
63
+ loss = diffusion.p_losses(images, captions, t)
64
+
65
+ # Optimize
66
+ optimizer.zero_grad()
67
+ loss.backward()
68
+ optimizer.step()
69
+
70
+ pbar.set_description(f"Epoch {epoch}, Loss: {loss.item():.4f}")
71
+
72
+ # Save checkpoint
73
+ torch.save(model.state_dict(), f"diffusion_model_epoch_{epoch}.pth")
74
+
75
+ if __name__ == "__main__":
76
+ train()