YashNagraj75 commited on
Commit
9774d79
·
1 Parent(s): 82afd8e

Add dataloader and add logging for training script

Browse files
data/mnist_dataset.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torchvision
7
+ from PIL import Image
8
+ from torch.utils.data.dataset import Dataset
9
+ from tqdm import tqdm
10
+
11
+
12
+ class MnistDataset(Dataset):
13
+ r"""
14
+ Nothing special here. Just a simple dataset class for mnist images.
15
+ Created a dataset class rather using torchvision to allow
16
+ replacement with any other image dataset
17
+ """
18
+
19
+ def __init__(self, split, im_path, im_ext="png", im_size=28, return_hints=False):
20
+ r"""
21
+ Init method for initializing the dataset properties
22
+ :param split: train/test to locate the image files
23
+ :param im_path: root folder of images
24
+ :param im_ext: image extension. assumes all
25
+ images would be this type.
26
+ """
27
+ self.split = split
28
+ self.im_ext = im_ext
29
+ self.return_hints = return_hints
30
+ self.images = self.load_images(im_path)
31
+
32
+ def load_images(self, im_path):
33
+ r"""
34
+ Gets all images from the path specified
35
+ and stacks them all up
36
+ :param im_path:
37
+ :return:
38
+ """
39
+ assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
40
+ ims = []
41
+ labels = []
42
+ for d_name in tqdm(os.listdir(im_path)):
43
+ for fname in glob.glob(
44
+ os.path.join(im_path, d_name, "*.{}".format(self.im_ext))
45
+ ):
46
+ ims.append(fname)
47
+ print("Found {} images for split {}".format(len(ims), self.split))
48
+ return ims
49
+
50
+ def __len__(self):
51
+ return len(self.images)
52
+
53
+ def __getitem__(self, index):
54
+ im = Image.open(self.images[index])
55
+ im_tensor = torchvision.transforms.ToTensor()(im)
56
+
57
+ # Convert input to -1 to 1 range.
58
+ im_tensor = (2 * im_tensor) - 1
59
+
60
+ if self.return_hints:
61
+ canny_image = Image.open(self.images[index])
62
+ canny_image = np.array(canny_image)
63
+ canny_image = cv2.Canny(canny_image, 100, 200)
64
+ canny_image = canny_image[:, :, None]
65
+ canny_image = np.concatenate(
66
+ [canny_image, canny_image, canny_image], axis=2
67
+ )
68
+ canny_image_tensor = torchvision.transforms.ToTensor()(canny_image)
69
+ return im_tensor, canny_image_tensor
70
+ else:
71
+ return im_tensor
model_config/mnist.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_params:
2
+ im_path: 'data/mnist/train/images'
3
+ im_test_path: 'data/mnist/test/images'
4
+ canny_im_size: 28
5
+
6
+ diffusion_params:
7
+ num_timesteps : 1000
8
+ beta_start : 0.0001
9
+ beta_end : 0.02
10
+
11
+ model_params:
12
+ im_channels : 1
13
+ im_size : 28
14
+ hint_channels : 3
15
+ down_channels : [32, 64, 128, 256]
16
+ mid_channels : [256, 256, 128]
17
+ down_sample : [True, True, False]
18
+ time_emb_dim : 128
19
+ num_down_layers : 2
20
+ num_mid_layers : 2
21
+ num_up_layers : 2
22
+ num_heads : 4
23
+
24
+ train_params:
25
+ task_name: 'mnist'
26
+ batch_size: 64
27
+ num_epochs: 40
28
+ controlnet_epochs : 1
29
+ num_samples : 25
30
+ num_grid_rows : 5
31
+ save_epoch: 2
32
+ ddpm_lr: 0.0001
33
+ controlnet_lr: 0.0001
34
+ ddpm_ckpt_name: 'ddpm_ckpt.pth'
35
+ controlnet_ckpt_name: 'ddpm_controlnet_ckpt.pth'
training_scripts/train_ddpm.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import wandb
7
+ import yaml
8
+ from torch.utils.data import DataLoader
9
+ from tqdm import tqdm
10
+
11
+ from data import mnist_dataset
12
+ from data.mnist_dataset import MnistDataset
13
+ from model_blocks.unet_base import UNet
14
+ from scheduler.linear_scheduler import LinearNoiseScheduler
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ logger = logging.getLogger(__name__)
18
+ wandb.login()
19
+
20
+
21
+ def init_wandb(config):
22
+ """
23
+ Initialize a new wandb run
24
+ """
25
+ run = wandb.init(
26
+ project="controlnet-ddpm-mnist",
27
+ config=config,
28
+ resume="allow", # Allows resuming if run was interrupted
29
+ )
30
+ return run
31
+
32
+
33
+ def load_checkpoint(model, optimizer, scheduler, checkpoint_path):
34
+ """
35
+ Load model checkpoint from local file
36
+ """
37
+ checkpoint = torch.load(checkpoint_path, map_location=device)
38
+
39
+ model.load_state_dict(checkpoint["model_state_dict"])
40
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
41
+ if (
42
+ scheduler
43
+ and "scheduler_state_dict" in checkpoint
44
+ and checkpoint["scheduler_state_dict"]
45
+ ):
46
+ scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
47
+
48
+ start_epoch = checkpoint["epoch"] + 1 # Start from the next epoch
49
+ step = checkpoint["step"]
50
+
51
+ print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
52
+ return start_epoch, step
53
+
54
+
55
+ def save_checkpoint(
56
+ model, optimizer, scheduler, epoch, loss, step, run, checkpoint_path
57
+ ):
58
+ """
59
+ Save model checkpoint locally and to wandb
60
+ """
61
+ # Create checkpoint dictionary
62
+ checkpoint = {
63
+ "epoch": epoch,
64
+ "model_state_dict": model.state_dict(),
65
+ "optimizer_state_dict": optimizer.state_dict(),
66
+ "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
67
+ "loss": loss,
68
+ "step": step,
69
+ }
70
+
71
+ # Save locally
72
+ torch.save(checkpoint, checkpoint_path)
73
+
74
+ # Log to wandb
75
+ artifact = wandb.Artifact(f"model-checkpoint-epoch-{epoch}", type="model")
76
+ artifact.add_file(checkpoint_path)
77
+ run.log_artifact(artifact)
78
+
79
+ print(f"Checkpoint saved at epoch {epoch}")
80
+ return checkpoint_path
81
+
82
+
83
+ def train(args):
84
+ with open(args.config_path, "r") as file:
85
+ try:
86
+ config = yaml.safe_load(file)
87
+ except yaml.YAMLError as exc:
88
+ print(exc)
89
+ print(config)
90
+ run = init_wandb(config["train_params"])
91
+
92
+ diffusion_config = config["diffusion_params"]
93
+ dataset_config = config["dataset_params"]
94
+ model_config = config["model_params"]
95
+ train_config = config["train_params"]
96
+ scheduler = LinearNoiseScheduler(
97
+ num_timesteps=diffusion_config["num_timesteps"],
98
+ beta_start=diffusion_config["beta_start"],
99
+ beta_end=diffusion_config["beta_end"],
100
+ )
101
+
102
+ mnist = MnistDataset("train", dataset_config["im_path"])
103
+ mnist_loader = DataLoader(
104
+ mnist, batch_size=train_config["batch_size"], shuffle=True, num_workers=4
105
+ )
106
+
107
+ model = UNet(model_config).to(device)
108
+ model.train()
109
+ logger.debug(f"Initialized model and set to train")
110
+ optimizer = torch.optim.AdamW(model.parameters(), lr=train_config["ddpm_lr"])
111
+ criterion = torch.nn.MSELoss()
112
+
113
+ # Create the output directories
114
+ if not os.path.exists(
115
+ os.path.join(train_config["task_name"], train_config["ddpm_ckpt_name"])
116
+ ):
117
+ os.mkdir(train_config["task_name"])
118
+
119
+ ckpt_path = os.path.join(train_config["task_name"], train_config["ddpm_ckpt_name"])
120
+ # Load checkpoint if there
121
+ if os.path.exists(ckpt_path):
122
+ start_epoch, step = load_checkpoint(
123
+ model, optimizer=optimizer, scheduler=scheduler, checkpoint_path=ckpt_path
124
+ )
125
+ else:
126
+ start_epoch = 0
127
+ step = 0
128
+
129
+ # Log model architecture as a Table
130
+ model_table = wandb.Table(columns=["Layer", "Parameters"])
131
+ total_params = 0
132
+ for name, param in model.named_parameters():
133
+ if param.requires_grad:
134
+ params = param.numel()
135
+ total_params += params
136
+ model_table.add_data(name, params)
137
+
138
+ wandb.log({"model_architecture": model_table})
139
+ wandb.log({"total_parameters": total_params})
140
+
141
+ # Watch model gradients and parameters
142
+ wandb.watch(model, log="all", log_freq=100)
143
+
144
+ for epoch in range(start_epoch, train_config["num_epochs"]):
145
+ losses = []
146
+ batch_idx = 0
147
+ progress_bar = tqdm(
148
+ mnist_loader, desc=f"Epoch {epoch + 1}/{train_config['num_epochs']}"
149
+ )
150
+
151
+ for im in progress_bar:
152
+ batch_idx += 1
153
+ optimizer.zero_grad()
154
+ im = im.float().to(device)
155
+
156
+ # Sample noise
157
+ noise = torch.randn_like(im).to(device)
158
+ logger.debug(f"Sampled noise epoch {epoch} : {noise.shape}")
159
+
160
+ # Sample timestep
161
+ t = torch.randint(0, diffusion_config["num_timesteps"], (im.shape[0],)).to(
162
+ device
163
+ )
164
+
165
+ noisy_im = scheduler.add_noise(im, noise, t)
166
+ noise_pred = model(noisy_im, t)
167
+
168
+ loss = criterion(noise_pred, noise)
169
+ losses.append(loss.item())
170
+ loss.backward()
171
+ optimizer.step()
172
+
173
+ # Calculate gradient norm for monitoring
174
+ total_norm = 0
175
+ for p in model.parameters():
176
+ if p.grad is not None:
177
+ param_norm = p.grad.data.norm(2)
178
+ total_norm += param_norm.item() ** 2
179
+ total_norm = total_norm**0.5
180
+
181
+ # Update progress bar
182
+ progress_bar.set_postfix({"loss": loss.item(), "avg_loss": np.mean(losses)})
183
+
184
+ wandb.log(
185
+ {
186
+ "train/batch_loss": loss.item(),
187
+ "train/step": step,
188
+ "train/epoch": epoch + batch_idx / len(mnist_loader),
189
+ "train/gradient_norm": total_norm,
190
+ "train/learning_rate": optimizer.param_groups[0]["lr"],
191
+ }
192
+ )
193
+
194
+ step += 1
195
+
196
+ avg_loss = np.mean(losses)
197
+
198
+ # Log epoch-level metrics
199
+ wandb.log(
200
+ {
201
+ "train/epoch_loss": avg_loss,
202
+ "train/epoch_completed": epoch,
203
+ }
204
+ )
205
+
206
+ print(f"Finished epoch: {epoch} | Loss: {np.mean(losses):.4f}")
207
+ if epoch % train_config["save_epoch"]:
208
+ visualize_samples(
209
+ model, scheduler, epoch, diffusion_config["num_timesteps"], device
210
+ )
211
+
212
+ save_checkpoint(
213
+ model, optimizer, scheduler, epoch, np.mean(losses), step, run, ckpt_path
214
+ )
215
+
216
+ # Log final model as artifact
217
+ logging.info("Finished training and starting to save model")
218
+ final_model_path = os.path.join(
219
+ train_config["task_name"], f"final_{train_config['ddpm_ckpt_name']}"
220
+ )
221
+ save_checkpoint(
222
+ model,
223
+ optimizer,
224
+ scheduler,
225
+ train_config["num_epochs"] - 1,
226
+ avg_loss,
227
+ step,
228
+ run,
229
+ final_model_path,
230
+ )
231
+ logging.info("Saved Model to Wandb and local")
232
+
233
+ # Log a summary table of training
234
+ summary_table = wandb.Table(columns=["Metric", "Value"])
235
+ summary_table.add_data("Final Loss", avg_loss)
236
+ summary_table.add_data("Best Loss", best_loss)
237
+ summary_table.add_data("Best Epoch", epoch)
238
+ summary_table.add_data("Total Steps", step)
239
+ summary_table.add_data("Training Time (hours)", wandb.run.duration / 3600)
240
+
241
+ wandb.log({"training_summary": summary_table})
242
+
243
+ # Finish the run
244
+ wandb.finish()
245
+
246
+
247
+ def visualize_samples(model, scheduler, epoch, num_timesteps, device, num_samples=4):
248
+ """
249
+ Generate sample images from noise and log to wandb
250
+ """
251
+ model.eval()
252
+ with torch.no_grad():
253
+ # Start with random noise
254
+ samples = torch.randn(num_samples, 1, 28, 28).to(device)
255
+
256
+ # Store the denoising process
257
+ sample_images = []
258
+
259
+ # Record more frequently at the beginning of sampling
260
+ log_steps = set([0, 20, 50, 100, 200, 400, 600, 800, num_timesteps - 1])
261
+
262
+ # Denoise gradually
263
+ for i in tqdm(reversed(range(num_timesteps)), desc="Sampling"):
264
+ t = torch.full((num_samples,), i, device=device, dtype=torch.long)
265
+
266
+ # Get model prediction and update sample
267
+ predicted_noise = model(samples, t)
268
+ samples = scheduler.step(predicted_noise, i, samples)
269
+
270
+ # Save images at specified timesteps
271
+ if i in log_steps:
272
+ # Denormalize and convert to numpy for logging
273
+ denorm_samples = samples.clamp(-1, 1).cpu().numpy()
274
+ denorm_samples = (
275
+ denorm_samples + 1
276
+ ) / 2.0 # scale from [-1, 1] to [0, 1]
277
+ sample_images.append((i, denorm_samples))
278
+
279
+ # Create a grid to show denoising process
280
+ images_to_log = {}
281
+
282
+ # Log individual samples
283
+ for i, sample in enumerate(samples):
284
+ sample_np = sample.clamp(-1, 1).cpu().numpy()
285
+ sample_np = (sample_np + 1) / 2.0 # scale from [-1, 1] to [0, 1]
286
+ images_to_log[f"sample_{i}_epoch_{epoch}"] = wandb.Image(
287
+ sample_np[0], caption=f"Sample {i}, Epoch {epoch}"
288
+ )
289
+
290
+ # Log denoising process for first sample
291
+ denoising_steps = []
292
+ for step_idx, samples_np in sample_images:
293
+ denoising_steps.append(
294
+ wandb.Image(
295
+ samples_np[0][0],
296
+ caption=f"Step {num_timesteps - step_idx}/{num_timesteps}",
297
+ )
298
+ )
299
+
300
+ images_to_log["denoising_process_epoch_" + str(epoch)] = denoising_steps
301
+
302
+ # Log all images
303
+ wandb.log(images_to_log)
304
+
305
+ model.train()