YashNagraj75 commited on
Commit
d62b4c3
·
1 Parent(s): 76a0a2e

Add ControlNet and Scheduler

Browse files
model_blocks/controlnet.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import logging
3
+ import os
4
+ from re import UNICODE
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from unet_base import UNet, get_time_embedding
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def make_zero_module(module):
14
+ for p in module.parameters():
15
+ p.detach().zero_()
16
+ return module
17
+
18
+
19
+ class ControlNet(nn.Module):
20
+ r"""
21
+ ControlNet for trained DDPM
22
+ """
23
+
24
+ def __init__(
25
+ self, device, model_config, trained_ckpt_path=None, model_locked=True
26
+ ) -> None:
27
+ super().__init__()
28
+
29
+ # Trained DDPM
30
+ self.model = UNet(model_config)
31
+ self.model_locked = model_locked
32
+
33
+ if trained_ckpt_path is not None:
34
+ print("Loading Checkpoint")
35
+ self.model = torch.load(trained_ckpt_path).to(device)
36
+
37
+ # False the upblocks (Decoder blocks) from the DDPM and uses only the encoder
38
+ self.control_copy = UNet(model_config, use_up=False)
39
+ if trained_ckpt_path is not None:
40
+ self.control_copy.load_state_dict(self.model.state_dict(), strict=False)
41
+
42
+ # Hint Block for ControlNet
43
+ # Stack of Conv Activation and Zero Convolution at the end
44
+ self.hint_block = nn.Sequential(
45
+ nn.Conv2d(model_config["hint_channels"], 64, kernel_size=3, padding=1),
46
+ nn.SiLU(),
47
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
48
+ nn.SiLU(),
49
+ nn.Conv2d(128, self.model.down_channels[0], kernel_size=3, padding=1),
50
+ nn.SiLU(),
51
+ make_zero_module(
52
+ nn.Conv2d(
53
+ self.model.down_channels[0],
54
+ self.model.down_channels[0],
55
+ kernel_size=1,
56
+ padding=0,
57
+ )
58
+ ),
59
+ )
60
+
61
+ self.control_copy_down_blocks = nn.ModuleList(
62
+ [
63
+ make_zero_module(
64
+ nn.Conv2d(
65
+ self.model.down_channels[i],
66
+ self.model.down_channels[i],
67
+ kernel_size=1,
68
+ padding=0,
69
+ )
70
+ )
71
+ for i in range(len(self.model.down_channels) - 1)
72
+ ]
73
+ )
74
+
75
+ self.control_copy_mid_blocks = nn.ModuleList(
76
+ [
77
+ make_zero_module(
78
+ nn.Conv2d(
79
+ self.model.mid_channels[i],
80
+ self.model.mid_channels[i],
81
+ kernel_size=1,
82
+ padding=0,
83
+ )
84
+ )
85
+ for i in range(1, len(self.model.mid_channels) - 1)
86
+ ]
87
+ )
88
+
89
+ def get_params(self):
90
+ # Get all the control_net params
91
+ params = list(self.control_copy.parameters())
92
+ params += list(self.hint_block.parameters())
93
+ params += list(self.control_copy_down_blocks.parameters())
94
+ params += list(self.control_copy_mid_blocks.parameters())
95
+
96
+ return params
97
+
98
+ def forward(self, x, t, hint):
99
+ time_embedding = get_time_embedding(
100
+ torch.as_tensor(t).long(), self.model.t_emb_dim
101
+ )
102
+ time_embedding = self.model.t_proj(time_embedding)
103
+ logger.debug(f"Got Time embeddings for Original Copy : {time_embedding.shape}")
104
+
105
+ model_down_outs = []
106
+
107
+ with torch.no_grad():
108
+ model_out = self.model.conv_in(x)
109
+ for idx, down in enumerate(self.model.downs):
110
+ model_down_outs.append(model_in)
111
+ model_out = down(model_out, time_embedding)
112
+ logger.debug(
113
+ f"Getting output of Down Layer {idx} from the original copy : {model_out.shape}"
114
+ )
115
+
116
+ logger.debug("Passing into ControlNet")
117
+
118
+ controlnet_time_embedding = get_time_embedding(
119
+ torch.as_tensor(t).long(), self.control_copy.t_emb_dim
120
+ )
121
+ controlnet_time_embedding = self.control_copy.t_proj(controlnet_time_embedding)
122
+ logger.debug(
123
+ f"Got Time embedding for ControlNet : {controlnet_time_embedding.shape}"
124
+ )
125
+
126
+ # Hint layer output here
127
+ controlnet_hint_output = self.hint_block(hint)
128
+ logger.debug(
129
+ f"Getting output of the Hint Block into the ControlNet : {controlnet_hint_output.shape}"
130
+ )
131
+
132
+ controlnet_out = self.control_copy.conv_in(x)
133
+ logger.debug(
134
+ f"Getting output of the Input Conv of ControlNet: {controlnet_out.shape}"
135
+ )
136
+
137
+ controlnet_out += controlnet_hint_output
138
+ logger.debug(f"Added Hint to the Conv Input: {controlnet_out.shape}")
139
+
140
+ controlnet_down_outs = []
141
+ # Get all the outputs of the controlnet down blocks
142
+ for idx, down in enumerate(self.control_copy.downs):
143
+ down_out = self.control_copy_down_blocks[idx](controlnet_out)
144
+ controlnet_down_outs.append(down_out)
145
+ logger.debug(
146
+ f"Got output of the {idx} Down Block of the ControlNet: {down_out.shape}"
147
+ )
148
+
149
+ # Now get the midblocks and then give to original copy
150
+ for idx in range(len(self.control_copy.mids)):
151
+ controlnet_out = self.control_copy.mids[idx](
152
+ controlnet_out, controlnet_time_embedding
153
+ )
154
+ logger.debug(
155
+ f"Got the output of the mid block {idx} in controlnet : {controlnet_out.shape}"
156
+ )
157
+
158
+ model_out = self.model.mids[idx](model_out, time_embedding)
159
+ logger.debug(
160
+ f"Got the output of Mid Block {idx} from original model : {model_out.shape}"
161
+ )
162
+
163
+ model_out += self.control_copy_mid_blocks[idx](controlnet_out)
164
+ logger.debug(
165
+ f"Concatinating the ControlNet Mid Block {idx} output :{model_out.shape} to original copy"
166
+ )
167
+
168
+ # Call the upblocks now
169
+ for idx, up in enumerate(self.model.ups):
170
+ model_down_out = model_down_outs.pop()
171
+ logger.debug(
172
+ f"Got the output from the down blocks from original model : {model_down_out.shape}"
173
+ )
174
+ controlnet_down_out = controlnet_down_outs.pop()
175
+ logger.debug(
176
+ f"Got the output from the down blocks from controlnet copy : {controlnet_down_out.shape}"
177
+ )
178
+
179
+ model_out = up(
180
+ model_out, controlnet_down_out + model_down_out, time_embedding
181
+ )
182
+
183
+ model_out = self.model.norm_out(model_out)
184
+ model_out = nn.SiLU()(model_out)
185
+ model_out = self.model.conv_out(model_out)
186
+
187
+ return model_out
scheduler/linear_scheduler.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class LinearNoiseScheduler:
5
+ r"""
6
+ Class for the linear noise scheduler that is used in DDPM.
7
+ """
8
+
9
+ def __init__(self, num_timesteps, beta_start, beta_end, ldm_scheduler=False):
10
+ self.num_timesteps = num_timesteps
11
+ self.beta_start = beta_start
12
+ self.beta_end = beta_end
13
+
14
+ if ldm_scheduler:
15
+ # Mimicking how compvis repo creates schedule
16
+ self.betas = (
17
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2
18
+ )
19
+ else:
20
+ self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
21
+ self.alphas = 1.0 - self.betas
22
+ self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
23
+ self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
24
+ self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
25
+
26
+ def add_noise(self, original, noise, t):
27
+ r"""
28
+ Forward method for diffusion
29
+ :param original: Image on which noise is to be applied
30
+ :param noise: Random Noise Tensor (from normal dist)
31
+ :param t: timestep of the forward process of shape -> (B,)
32
+ :return:
33
+ """
34
+ original_shape = original.shape
35
+ batch_size = original_shape[0]
36
+
37
+ sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(
38
+ batch_size
39
+ )
40
+ sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(
41
+ original.device
42
+ )[t].reshape(batch_size)
43
+
44
+ # Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
45
+ for _ in range(len(original_shape) - 1):
46
+ sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
47
+ for _ in range(len(original_shape) - 1):
48
+ sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
49
+
50
+ # Apply and Return Forward process equation
51
+ return (
52
+ sqrt_alpha_cum_prod.to(original.device) * original
53
+ + sqrt_one_minus_alpha_cum_prod.to(original.device) * noise
54
+ )
55
+
56
+ def sample_prev_timestep(self, xt, noise_pred, t):
57
+ r"""
58
+ Use the noise prediction by model to get
59
+ xt-1 using xt and the noise predicted
60
+ :param xt: current timestep sample
61
+ :param noise_pred: model noise prediction
62
+ :param t: current timestep we are at
63
+ :return:
64
+ """
65
+ x0 = (
66
+ xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)
67
+ ) / torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])
68
+ x0 = torch.clamp(x0, -1.0, 1.0)
69
+
70
+ mean = (
71
+ xt
72
+ - ((self.betas.to(xt.device)[t]) * noise_pred)
73
+ / (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
74
+ )
75
+ mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
76
+
77
+ if t == 0:
78
+ return mean, x0
79
+ else:
80
+ variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (
81
+ 1.0 - self.alpha_cum_prod.to(xt.device)[t]
82
+ )
83
+ variance = variance * self.betas.to(xt.device)[t]
84
+ sigma = variance**0.5
85
+ z = torch.randn(xt.shape).to(xt.device)
86
+
87
+ # OR
88
+ # variance = self.betas[t]
89
+ # sigma = variance ** 0.5
90
+ # z = torch.randn(xt.shape).to(xt.device)
91
+ return mean + sigma * z, x0