YashNagraj75 commited on
Commit
edc370c
·
1 Parent(s): 6a1e886

Add the training script for dit

Browse files
Files changed (2) hide show
  1. celeba/config.yaml +2 -2
  2. train_dit.py +128 -0
celeba/config.yaml CHANGED
@@ -11,10 +11,10 @@ diffusion_params:
11
  dit_params:
12
  patch_size: 2
13
  num_layers: 12
14
- hidden_size: 768
15
  num_heads: 12
16
  head_dim: 64
17
- timestep_emb_dim: 768
18
 
19
  autoencoder_params:
20
  z_channels: 4
 
11
  dit_params:
12
  patch_size: 2
13
  num_layers: 12
14
+ hidden_dim: 768
15
  num_heads: 12
16
  head_dim: 64
17
+ temb_dim: 768
18
 
19
  autoencoder_params:
20
  z_channels: 4
train_dit.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import yaml
8
+ from torch.optim import AdamW
9
+ from tqdm import tqdm
10
+
11
+ from celeba import create_dataloader
12
+ from model.transformer import DIT
13
+ from model.vae import VAE
14
+ from scheduler.linear_scheduler import LinearNoiseScheduler
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+
19
+ def train(args):
20
+ with open(args.config_path, "r") as file:
21
+ try:
22
+ config = yaml.safe_load(file)
23
+ except yaml.YAMLError as e:
24
+ print(f"Error in loading yaml: {e}")
25
+
26
+ train_config = config["train_params"]
27
+ dit_config = config["dit_params"]
28
+ dataset_config = config["dataset_params"]
29
+ diffusion_params = config["diffusion_params"]
30
+ vae_config = config["autoencoder_params"]
31
+
32
+ dataloader = create_dataloader(dataset_config["im_path"])
33
+
34
+ scheduler = LinearNoiseScheduler(
35
+ diffusion_params["num_timesteps"],
36
+ diffusion_params["beta_start"],
37
+ diffusion_params["beta_end"],
38
+ )
39
+
40
+ im_size = dataset_config["im_size"] // 2 ** sum(vae_config["down_sample"])
41
+ model = DIT(
42
+ im_size=im_size, im_channels=dataset_config["im_channels"], config=dit_config
43
+ ).to(device)
44
+ model.train()
45
+
46
+ if os.path.exists(
47
+ os.path.join(train_config["task_name"], train_config["dit_ckpt_name"])
48
+ ):
49
+ checkpoint = torch.load(
50
+ os.path.join(train_config["task_name"], train_config["dit_ckpt_name"]),
51
+ map_location=device,
52
+ )
53
+
54
+ model.load_state_dict(checkpoint["dit"])
55
+ start_epoch = checkpoint["epoch"]
56
+ step_count = checkpoint["step_count"]
57
+ else:
58
+ step_count = 0
59
+ start_epoch = 0
60
+
61
+ if not os.path.exists(
62
+ os.path.join(
63
+ train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
64
+ )
65
+ ):
66
+ print("No VAE checkpoint found, VAE checkpoint needed")
67
+ return
68
+ else:
69
+ vae = VAE(dataset_config["im_channels"], vae_config).to(device)
70
+ vae.load_state_dict(
71
+ torch.load(
72
+ os.path.join(
73
+ train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
74
+ ),
75
+ map_location=device,
76
+ )
77
+ )
78
+ vae.eval()
79
+ for param in vae.parameters():
80
+ param.requires_grad = False
81
+ print("VAE checkpoint loaded")
82
+
83
+ num_epochs = train_config["dit_epochs"]
84
+ optimizer = AdamW(model.parameters(), lr=train_config["dit_lr"])
85
+ accu_steps = train_config["dit_acc_steps"]
86
+ criterion = nn.MSELoss()
87
+
88
+ for epoch in range(start_epoch, num_epochs):
89
+ losses = []
90
+ for im in tqdm(dataloader):
91
+ im = im.float().to(device)
92
+ step_count += 1
93
+ with torch.no_grad():
94
+ im, _ = vae.encode(im)
95
+
96
+ noise = torch.randn_like(im).to(device)
97
+
98
+ t = torch.randint(0, diffusion_params["num_time_steps"], (im.shape[0],)).to(
99
+ device
100
+ )
101
+ noisy_im = scheduler.add_noise(im, noise, t)
102
+ pred = model(noisy_im, t)
103
+ loss = criterion(pred, noise)
104
+ losses.append(loss.item())
105
+ loss = loss / accu_steps
106
+ loss.backward()
107
+ if step_count % accu_steps == 0:
108
+ optimizer.step()
109
+ optimizer.zero_grad()
110
+ optimizer.step()
111
+ optimizer.zero_grad()
112
+ print(f"Epoch {epoch}: Loss: {np.mean(losses)}")
113
+ torch.save(
114
+ {"dit": model.state_dict(), "epoch": epoch + 1, "step": step_count},
115
+ os.path.join(
116
+ train_config["task_name"], train_config["vae_autoencoder_ckpt_name"]
117
+ ),
118
+ )
119
+ print("Done Training")
120
+
121
+
122
+ if __name__ == "__main__":
123
+ parser = argparse.ArgumentParser(description="Arguments for dit training")
124
+ parser.add_argument(
125
+ "--config", dest="config_path", default="celeba/config.yaml", type=str
126
+ )
127
+ args = parser.parse_args()
128
+ train(args)