Jasmeet Singh commited on
Commit
056ab49
·
verified ·
1 Parent(s): 9ec6dd3

files upload

Browse files
Files changed (12) hide show
  1. app.py +53 -0
  2. attention.py +131 -0
  3. clip.py +97 -0
  4. decoder.py +99 -0
  5. encoder.py +102 -0
  6. generationPipeline.py +0 -0
  7. helperUNET.py +179 -0
  8. helperVAE.py +83 -0
  9. loadModel.py +0 -0
  10. model_converter.py +0 -0
  11. sampler.py +127 -0
  12. unet.py +183 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ from PIL import Image
4
+ from generationPipeline import generate
5
+ from transformers import CLIPTokenizer
6
+ from loadModel import preload_models_from_standard_weights
7
+ import gradio as gr
8
+
9
+
10
+ Device = 'cuda'
11
+
12
+ print(f"Using device: {Device}")
13
+
14
+
15
+ tokenizer = CLIPTokenizer("vocab.json", merges_file="merges.txt")
16
+ model_file = "weights2.ckpt"
17
+ models = preload_models_from_standard_weights(model_file, Device)
18
+
19
+
20
+ @spaces.GPU(duration = 242)
21
+ def generate_image(prompt, strength, seed):
22
+ # Your generate function adapted to accept parameters
23
+ output_image = generate(
24
+ prompt=prompt,
25
+ uncond_prompt="",
26
+ input_image=None,
27
+ strength=strength,
28
+ do_cfg=True,
29
+ cfg_scale=8,
30
+ sampler_name="ddpm",
31
+ n_inference_steps=50,
32
+ seed=seed,
33
+ models=models,
34
+ device=Device,
35
+ idle_device="cpu",
36
+ tokenizer=tokenizer,
37
+ )
38
+ return Image.fromarray(output_image)
39
+
40
+
41
+ iface = gr.Interface(
42
+ fn=generate_image,
43
+ inputs=[
44
+ gr.inputs.Textbox(label="Prompt"),
45
+ gr.inputs.Slider(0, 1, step=0.01, label="Strength (For Image-2-Image): Strength = 1 (Output further from input image), Strength = 0 (Output similar as Input image)"),
46
+ gr.inputs.Number(default=42, label="Seed"),
47
+ ],
48
+ outputs=gr.outputs.Image(label="Generated Image"),
49
+ title="Stable Diffusion Image Generator",
50
+ description="Generate images from text prompts using Stable Diffusion.",
51
+ )
52
+
53
+ iface.launch(debug = True)
attention.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ #Attention: softmax(q @ k.transpose / sqrt(dk)) @ w
7
+
8
+
9
+ class SelfAttention(nn.Module):
10
+ def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
11
+ super().__init__()
12
+ # This combines the Wq, Wk and Wv matrices into one matrix
13
+ self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
14
+ # This one represents the Wo matrix
15
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
16
+ self.n_heads = n_heads
17
+ self.d_head = d_embed // n_heads
18
+
19
+ def forward(self, x, causal_mask=False):
20
+ # (Batch_Size, Seq_Len, Dim)
21
+ input_shape = x.shape
22
+
23
+ # (Batch_Size, Seq_Len, Dim)
24
+ batch_size, sequence_length, d_embed = input_shape
25
+
26
+ # (Batch_Size, Seq_Len, H, Dim / H)
27
+ interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
28
+
29
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
30
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
31
+
32
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
33
+ q = q.view(interim_shape).transpose(1, 2)
34
+ k = k.view(interim_shape).transpose(1, 2)
35
+ v = v.view(interim_shape).transpose(1, 2)
36
+
37
+ # (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
38
+ weight = q @ k.transpose(-1, -2)
39
+
40
+ if causal_mask:
41
+ # It masks the token after the current tokens so that the future tokens are not accessible
42
+ # Mask where the upper triangle (above the principal diagonal) is 1
43
+ mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
44
+ # Fill the upper triangle with -inf
45
+ weight.masked_fill_(mask, -torch.inf)
46
+
47
+ # Divide by d_k (Dim / H).
48
+ # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
49
+ weight /= math.sqrt(self.d_head)
50
+
51
+ # (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
52
+ weight = F.softmax(weight, dim=-1)
53
+
54
+ # (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
55
+ output = weight @ v
56
+
57
+ # (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
58
+ output = output.transpose(1, 2)
59
+
60
+ # (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
61
+ output = output.reshape(input_shape)
62
+
63
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
64
+ output = self.out_proj(output)
65
+
66
+ # (Batch_Size, Seq_Len, Dim)
67
+ return output
68
+
69
+
70
+ # Calculate Attention between latent and prompt(context)
71
+
72
+
73
+ class CrossAttention(nn.Module):
74
+ def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
75
+ super().__init__()
76
+
77
+ self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
78
+ self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
79
+ self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
80
+ self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
81
+ self.n_heads = n_heads
82
+ self.d_head = d_embed // n_heads
83
+
84
+ def forward(self, x, y):
85
+ # x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
86
+ # y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
87
+
88
+ # Input shape: (b, h*w, c) -> (b, seq_legth, d_model) = (b, h/8*w/8, 512)
89
+ input_shape = x.shape
90
+ batch_size, sequence_length, d_embed = input_shape
91
+ # Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
92
+ interim_shape = (batch_size, -1, self.n_heads, self.d_head)
93
+
94
+ # In cross attention query is taken from one element (latent here) and key, values are taken from another element (context)
95
+ # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
96
+ q = self.q_proj(x)
97
+ # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
98
+ k = self.k_proj(y)
99
+ # (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
100
+ v = self.v_proj(y)
101
+
102
+ # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
103
+ q = q.view(interim_shape).transpose(1, 2)
104
+ # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
105
+ k = k.view(interim_shape).transpose(1, 2)
106
+ # (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
107
+ v = v.view(interim_shape).transpose(1, 2)
108
+
109
+ # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
110
+ weight = q @ k.transpose(-1, -2)
111
+
112
+ # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
113
+ weight /= math.sqrt(self.d_head)
114
+
115
+ # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
116
+ weight = F.softmax(weight, dim=-1)
117
+
118
+ # (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
119
+ output = weight @ v
120
+
121
+ # (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
122
+ output = output.transpose(1, 2).contiguous()
123
+
124
+ # (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
125
+ output = output.view(input_shape)
126
+
127
+ # (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
128
+ output = self.out_proj(output)
129
+
130
+ # (Batch_Size, Seq_Len, Dim) -> (b, h/8*w/8, 512) = (b, h*w, d_model)
131
+ return output
clip.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from attention import SelfAttention
4
+
5
+ class CLIPEmbedding(nn.Module):
6
+ def __init__(self, n_vocab: int, n_embd: int, n_token: int):
7
+ super().__init__()
8
+
9
+ self.token_embedding = nn.Embedding(n_vocab, n_embd) #(vocab_Size, embedding_dim)
10
+ # A learnable weight matrix encodes the position information for each token
11
+ self.position_embedding = nn.Parameter(torch.zeros((n_token, n_embd))) #(seq_legth, embedding_dim)
12
+
13
+ def forward(self, tokens):
14
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
15
+ x = self.token_embedding(tokens)
16
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
17
+ x += self.position_embedding
18
+
19
+ return x
20
+
21
+
22
+ class CLIPLayer(nn.Module):
23
+ def __init__(self, n_head: int, n_embd: int):
24
+ super().__init__()
25
+
26
+ # Pre-attention norm
27
+ self.layernorm_1 = nn.LayerNorm(n_embd)
28
+ # Self attention
29
+ self.attention = SelfAttention(n_head, n_embd)
30
+ # Pre-FNN norm
31
+ self.layernorm_2 = nn.LayerNorm(n_embd)
32
+ # Feedforward layer
33
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
34
+ self.linear_2 = nn.Linear(4 * n_embd, n_embd)
35
+
36
+ def forward(self, x):
37
+ # (Batch_Size, Seq_Len, Dim)
38
+ residue = x
39
+
40
+ ### SELF ATTENTION ###
41
+
42
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
43
+ x = self.layernorm_1(x)
44
+
45
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
46
+ x = self.attention(x, causal_mask=True)
47
+
48
+ # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
49
+ x += residue
50
+
51
+ ### FEEDFORWARD LAYER ###
52
+ # Apply a feedforward layer where the hidden dimension is 4 times the embedding dimension.
53
+
54
+ residue = x
55
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
56
+ x = self.layernorm_2(x)
57
+
58
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, 4 * Dim)
59
+ x = self.linear_1(x)
60
+
61
+ # (Batch_Size, Seq_Len, 4 * Dim) -> (Batch_Size, Seq_Len, 4 * Dim)
62
+ x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function
63
+
64
+ # (Batch_Size, Seq_Len, 4 * Dim) -> (Batch_Size, Seq_Len, Dim)
65
+ x = self.linear_2(x)
66
+
67
+ # (Batch_Size, Seq_Len, Dim) + (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
68
+ x += residue
69
+
70
+ return x
71
+
72
+
73
+ class CLIP(nn.Module):
74
+ def __init__(self):
75
+ super().__init__()
76
+ self.embedding = CLIPEmbedding(49408, 768, 77)
77
+
78
+ self.layers = nn.ModuleList([
79
+ CLIPLayer(12, 768) for i in range(12)
80
+ ])
81
+
82
+ self.layernorm = nn.LayerNorm(768)
83
+
84
+ def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
85
+ tokens = tokens.type(torch.long)
86
+
87
+ # (Batch_Size, Seq_Len) -> (Batch_Size, Seq_Len, Dim)
88
+ state = self.embedding(tokens)
89
+
90
+ # Apply encoder layers similar to the Transformer's encoder.
91
+ for layer in self.layers:
92
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
93
+ state = layer(state)
94
+ # (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
95
+ output = self.layernorm(state)
96
+
97
+ return output
decoder.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from helperVAE import VAE_AttentionBlock, VAE_ResidualBlock
5
+
6
+ class VAE_Decoder(nn.Sequential):
7
+ def __init__(self):
8
+ super().__init__(
9
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
10
+ nn.Conv2d(4, 4, kernel_size=1, padding=0),
11
+
12
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
13
+ nn.Conv2d(4, 512, kernel_size=3, padding=1),
14
+
15
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
16
+ VAE_ResidualBlock(512, 512),
17
+
18
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
19
+ VAE_AttentionBlock(512),
20
+
21
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
22
+ VAE_ResidualBlock(512, 512),
23
+
24
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
25
+ VAE_ResidualBlock(512, 512),
26
+
27
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
28
+ VAE_ResidualBlock(512, 512),
29
+
30
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
31
+ VAE_ResidualBlock(512, 512),
32
+
33
+ # Repeats the rows and columns of the data by scale_factor (like when you resize an image by doubling its size).
34
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 4, Width / 4)
35
+ nn.Upsample(scale_factor=2),
36
+
37
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
38
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
39
+
40
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
41
+ VAE_ResidualBlock(512, 512),
42
+
43
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
44
+ VAE_ResidualBlock(512, 512),
45
+
46
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
47
+ VAE_ResidualBlock(512, 512),
48
+
49
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
50
+ nn.Upsample(scale_factor=2),
51
+
52
+ # (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 512, Height / 2, Width / 2)
53
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
54
+
55
+ # (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
56
+ VAE_ResidualBlock(512, 256),
57
+
58
+ # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
59
+ VAE_ResidualBlock(256, 256),
60
+
61
+ # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
62
+ VAE_ResidualBlock(256, 256),
63
+
64
+ # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height, Width)
65
+ nn.Upsample(scale_factor=2),
66
+
67
+ # (Batch_Size, 256, Height, Width) -> (Batch_Size, 256, Height, Width)
68
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
69
+
70
+ # (Batch_Size, 256, Height, Width) -> (Batch_Size, 128, Height, Width)
71
+ VAE_ResidualBlock(256, 128),
72
+
73
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
74
+ VAE_ResidualBlock(128, 128),
75
+
76
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
77
+ VAE_ResidualBlock(128, 128),
78
+
79
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
80
+ nn.GroupNorm(32, 128),
81
+
82
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
83
+ nn.SiLU(),
84
+
85
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 3, Height, Width)
86
+ nn.Conv2d(128, 3, kernel_size=3, padding=1),
87
+ )
88
+
89
+ def forward(self, x):
90
+ # x: (Batch_Size, 4, Height / 8, Width / 8)
91
+
92
+ # Remove the scaling added by the Encoder.
93
+ x /= 0.18215
94
+
95
+ for module in self:
96
+ x = module(x)
97
+
98
+ # (Batch_Size, 3, Height, Width)
99
+ return x
encoder.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from helperVAE import VAE_ResidualBlock, VAE_AttentionBlock
5
+
6
+ class VAE_Encoder(nn.Sequential):
7
+ def __init__(self):
8
+ super().__init__(
9
+ # (Batch_Size, Channel, Height, Width) -> (Batch_Size, 128, Height, Width)
10
+ nn.Conv2d(3, 128, kernel_size=3, padding=1),
11
+
12
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
13
+ VAE_ResidualBlock(128, 128),
14
+
15
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
16
+ VAE_ResidualBlock(128, 128),
17
+
18
+ # (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height / 2, Width / 2)
19
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
20
+
21
+ # (Batch_Size, 128, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
22
+ VAE_ResidualBlock(128, 256),
23
+
24
+ # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
25
+ VAE_ResidualBlock(256, 256),
26
+
27
+ # (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 4, Width / 4)
28
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
29
+
30
+ # (Batch_Size, 256, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
31
+ VAE_ResidualBlock(256, 512),
32
+
33
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
34
+ VAE_ResidualBlock(512, 512),
35
+
36
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 8, Width / 8)
37
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
38
+
39
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
40
+ VAE_ResidualBlock(512, 512),
41
+
42
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
43
+ VAE_ResidualBlock(512, 512),
44
+
45
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
46
+ VAE_ResidualBlock(512, 512),
47
+
48
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
49
+ VAE_AttentionBlock(512),
50
+
51
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
52
+ VAE_ResidualBlock(512, 512),
53
+
54
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
55
+ nn.GroupNorm(32, 512),
56
+
57
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
58
+ nn.SiLU(),
59
+
60
+ # Because the padding=1, it means the width and height will increase by 2
61
+ # Out_Height = In_Height + Padding_Top + Padding_Bottom
62
+ # Out_Width = In_Width + Padding_Left + Padding_Right
63
+ # Since padding = 1 means Padding_Top = Padding_Bottom = Padding_Left = Padding_Right = 1,
64
+ # Since the Out_Width = In_Width + 2 (same for Out_Height), it will compensate for the Kernel size of 3
65
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8).
66
+ nn.Conv2d(512, 8, kernel_size=3, padding=1),
67
+
68
+ # (Batch_Size, 8, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8)
69
+ nn.Conv2d(8, 8, kernel_size=1, padding=0),
70
+ )
71
+
72
+ def forward(self, x, noise):
73
+ # x: (Batch_Size, Channel, Height, Width)
74
+ # noise: (Batch_Size, 4, Height / 8, Width / 8)
75
+
76
+ for module in self:
77
+
78
+ if getattr(module, 'stride', None) == (2, 2): # Padding at downsampling should be asymmetric (see #8)
79
+ # Pad: (Padding_Left, Padding_Right, Padding_Top, Padding_Bottom).
80
+ # Pad with zeros on the right and bottom.
81
+ # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Channel, Height + Padding_Top + Padding_Bottom, Width + Padding_Left + Padding_Right) = (Batch_Size, Channel, Height + 1, Width + 1)
82
+ x = F.pad(x, (0, 1, 0, 1))
83
+
84
+ x = module(x)
85
+ # (Batch_Size, 8, Height / 8, Width / 8) -> two tensors of shape (Batch_Size, 4, Height / 8, Width / 8)
86
+ mean, log_variance = torch.chunk(x, 2, dim=1)
87
+ # Clamp the log variance between -30 and 20, so that the variance is between (circa) 1e-14 and 1e8.
88
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
89
+ log_variance = torch.clamp(log_variance, -30, 20)
90
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
91
+ variance = log_variance.exp()
92
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
93
+ stdev = variance.sqrt()
94
+
95
+ # Transform N(0, 1) -> N(mean, stdev)
96
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
97
+ x = mean + stdev * noise
98
+
99
+ # Scale by a constant
100
+ x *= 0.18215
101
+
102
+ return x
generationPipeline.py ADDED
File without changes
helperUNET.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from attention import SelfAttention, CrossAttention
3
+ from torch.nn import functional as F
4
+
5
+ class UNET_AttentionBlock(nn.Module):
6
+ def __init__(self, n_head: int, n_embd: int, d_context=768):
7
+ super().__init__()
8
+ channels = n_head * n_embd
9
+
10
+ self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6)
11
+ self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
12
+
13
+ self.layernorm_1 = nn.LayerNorm(channels)
14
+ self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False)
15
+ self.layernorm_2 = nn.LayerNorm(channels)
16
+ self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False)
17
+ self.layernorm_3 = nn.LayerNorm(channels)
18
+ self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2)
19
+ self.linear_geglu_2 = nn.Linear(4 * channels, channels)
20
+
21
+ self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
22
+
23
+ def forward(self, x, context):
24
+ # x: (Batch_Size, Features, Height, Width)
25
+ # context: (Batch_Size, Seq_Len, Dim)
26
+
27
+ residue_long = x
28
+
29
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
30
+ x = self.groupnorm(x)
31
+
32
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
33
+ x = self.conv_input(x)
34
+
35
+ n, c, h, w = x.shape
36
+
37
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
38
+ x = x.view((n, c, h * w))
39
+
40
+ # (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features)
41
+ x = x.transpose(-1, -2)
42
+
43
+ # Normalization + Self-Attention with skip connection
44
+
45
+ # (Batch_Size, Height * Width, Features)
46
+ residue_short = x
47
+
48
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
49
+ x = self.layernorm_1(x)
50
+
51
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
52
+ x = self.attention_1(x)
53
+
54
+ # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
55
+ x += residue_short
56
+
57
+ # (Batch_Size, Height * Width, Features)
58
+ residue_short = x
59
+
60
+ # Normalization + Cross-Attention with skip connection
61
+
62
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
63
+ x = self.layernorm_2(x)
64
+
65
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
66
+ x = self.attention_2(x, context)
67
+
68
+ # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
69
+ x += residue_short
70
+
71
+ # (Batch_Size, Height * Width, Features)
72
+ residue_short = x
73
+
74
+ # Normalization + FFN with GeGLU and skip connection
75
+
76
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
77
+ x = self.layernorm_3(x)
78
+
79
+ # GeGLU as implemented in the original code: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/attention.py#L37C10-L37C10
80
+ # (Batch_Size, Height * Width, Features) -> two tensors of shape (Batch_Size, Height * Width, Features * 4)
81
+ x, gate = self.linear_geglu_1(x).chunk(2, dim=-1)
82
+
83
+ # Element-wise product: (Batch_Size, Height * Width, Features * 4) * (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features * 4)
84
+ x = x * F.gelu(gate)
85
+
86
+ # (Batch_Size, Height * Width, Features * 4) -> (Batch_Size, Height * Width, Features)
87
+ x = self.linear_geglu_2(x)
88
+
89
+ # (Batch_Size, Height * Width, Features) + (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
90
+ x += residue_short
91
+
92
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
93
+ x = x.transpose(-1, -2)
94
+
95
+ # (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
96
+ x = x.view((n, c, h, w))
97
+
98
+ # Final skip connection between initial input and output of the block
99
+ # (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
100
+ return self.conv_output(x) + residue_long
101
+
102
+
103
+
104
+
105
+ class Upsample(nn.Module):
106
+ def __init__(self, channels):
107
+ super().__init__()
108
+ self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
109
+
110
+ def forward(self, x):
111
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * 2, Width * 2)
112
+ x = F.interpolate(x, scale_factor=2, mode='nearest') #upsampling using nearest neighbor interpolation
113
+ return self.conv(x)
114
+
115
+
116
+ class UNET_ResidualBlock(nn.Module):
117
+ def __init__(self, in_channels, out_channels, n_time=1280):
118
+ super().__init__()
119
+ self.groupnorm_feature = nn.GroupNorm(32, in_channels)
120
+ self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
121
+ self.linear_time = nn.Linear(n_time, out_channels)
122
+
123
+ self.groupnorm_merged = nn.GroupNorm(32, out_channels)
124
+ self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
125
+
126
+ if in_channels == out_channels:
127
+ self.residual_layer = nn.Identity()
128
+ else:
129
+ self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
130
+
131
+ def forward(self, feature, time):
132
+ # feature: (Batch_Size, In_Channels, Height, Width)
133
+ # time: (1, 1280)
134
+
135
+ residue = feature
136
+
137
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
138
+ feature = self.groupnorm_feature(feature)
139
+
140
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
141
+ feature = F.silu(feature)
142
+
143
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
144
+ feature = self.conv_feature(feature)
145
+
146
+ # (1, 1280) -> (1, 1280)
147
+ time = F.silu(time)
148
+
149
+ # (1, 1280) -> (1, Out_Channels)
150
+ time = self.linear_time(time)
151
+
152
+ # Add width and height dimension to time.
153
+ # (Batch_Size, Out_Channels, Height, Width) + (1, Out_Channels, 1, 1) -> (Batch_Size, Out_Channels, Height, Width)
154
+ merged = feature + time.unsqueeze(-1).unsqueeze(-1)
155
+
156
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
157
+ merged = self.groupnorm_merged(merged)
158
+
159
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
160
+ merged = F.silu(merged)
161
+
162
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
163
+ merged = self.conv_merged(merged)
164
+
165
+ # (Batch_Size, Out_Channels, Height, Width) + (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
166
+ return merged + self.residual_layer(residue)
167
+
168
+
169
+ class SwitchSequential(nn.Sequential):
170
+ def forward(self, x, context, time):
171
+ for layer in self:
172
+ if isinstance(layer, UNET_AttentionBlock):
173
+ x = layer(x, context)
174
+ elif isinstance(layer, UNET_ResidualBlock):
175
+ x = layer(x, time)
176
+ else:
177
+ x = layer(x)
178
+ return x
179
+ #switch between attention and residual layer
helperVAE.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from attention import SelfAttention
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class VAE_AttentionBlock(nn.Module):
7
+ def __init__(self, channels):
8
+ super().__init__()
9
+ self.groupnorm = nn.GroupNorm(32, channels)
10
+ self.attention = SelfAttention(1, channels)
11
+
12
+ def forward(self, x):
13
+ # x: (Batch_Size, Features, Height, Width)
14
+
15
+ residue = x
16
+
17
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
18
+ x = self.groupnorm(x)
19
+
20
+ n, c, h, w = x.shape
21
+
22
+ # (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
23
+ x = x.view((n, c, h * w))
24
+
25
+ # (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features). Each pixel becomes a feature of size "Features", the sequence length is "Height * Width".
26
+ x = x.transpose(-1, -2)
27
+
28
+ # Perform self-attention WITHOUT mask
29
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
30
+ x = self.attention(x)
31
+
32
+ # (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
33
+ x = x.transpose(-1, -2)
34
+
35
+ # (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
36
+ x = x.view((n, c, h, w))
37
+
38
+ # (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
39
+ x += residue
40
+
41
+ # (Batch_Size, Features, Height, Width)
42
+ return x
43
+
44
+
45
+ class VAE_ResidualBlock(nn.Module):
46
+ def __init__(self, in_channels, out_channels):
47
+ super().__init__()
48
+ self.groupnorm_1 = nn.GroupNorm(32, in_channels)
49
+ self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
50
+
51
+ self.groupnorm_2 = nn.GroupNorm(32, out_channels)
52
+ self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
53
+
54
+ if in_channels == out_channels:
55
+ self.residual_layer = nn.Identity()
56
+ else:
57
+ self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
58
+
59
+ def forward(self, x):
60
+ # x: (Batch_Size, In_Channels, Height, Width)
61
+
62
+ residue = x
63
+
64
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
65
+ x = self.groupnorm_1(x)
66
+
67
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
68
+ x = F.silu(x)
69
+
70
+ # (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
71
+ x = self.conv_1(x)
72
+
73
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
74
+ x = self.groupnorm_2(x)
75
+
76
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
77
+ x = F.silu(x)
78
+
79
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
80
+ x = self.conv_2(x)
81
+
82
+ # (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
83
+ return x + self.residual_layer(residue)
loadModel.py ADDED
File without changes
model_converter.py ADDED
The diff for this file is too large to render. See raw diff
 
sampler.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ class DDPMSampler:
5
+
6
+ def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
7
+ # Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
8
+ # For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
9
+ self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2 #beta
10
+ self.alphas = 1.0 - self.betas
11
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # alpha bar
12
+ self.one = torch.tensor(1.0)
13
+
14
+ self.generator = generator
15
+
16
+ self.num_train_timesteps = num_training_steps
17
+ self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy()) ##[999, 998, ...0]
18
+
19
+ def set_inference_timesteps(self, num_inference_steps=50):
20
+ # num_inference_steps = 50
21
+ # step ratio = num_training_steps // inference_steps = 20
22
+ self.num_inference_steps = num_inference_steps
23
+ step_ratio = self.num_train_timesteps // self.num_inference_steps # 1000/50 = 20
24
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) #[980, 960, ..0]
25
+ self.timesteps = torch.from_numpy(timesteps)
26
+
27
+ def _get_previous_timestep(self, timestep: int) -> int:
28
+ prev_t = timestep - self.num_train_timesteps // self.num_inference_steps #eg: t = 960, t-1 = 960-20 = 940
29
+ return prev_t
30
+
31
+ def _get_variance(self, timestep: int) -> torch.Tensor:
32
+ prev_t = self._get_previous_timestep(timestep) #t-1
33
+
34
+ alpha_prod_t = self.alphas_cumprod[timestep] #alpha bar t
35
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one #alpha bar t minus 1
36
+ current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev #beta t
37
+
38
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
39
+ # and sample from it to get previous sample
40
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
41
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t #variance#
42
+
43
+ # we always take the log of variance, so clamp it to ensure it's not 0
44
+ variance = torch.clamp(variance, min=1e-20)
45
+
46
+ return variance
47
+
48
+ def set_strength(self, strength=1):
49
+ """
50
+ Set how much noise to add to the input image.
51
+ More noise (strength ~ 1) means that the output will be further from the input image.
52
+ Less noise (strength ~ 0) means that the output will be closer to the input image.
53
+ """
54
+ # more strength -> start step is approximately 0 that is model starts from pure noise and generates the image from it, strength = 1, start step = 50 - (50 * 1) = 0
55
+ # less strenght -> start step is skipped till 50 so model has the less noisified image a time step 50, model reconstructs the image from the less noisified image, strength = 0, start_step = 50
56
+
57
+ # start_step is the number of noise levels to skip
58
+ #eg inf_steps = 50, strength = 1, start step = 50 - (50 * 1) = 0, strength = 0, start_step = 50
59
+ start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
60
+ self.timesteps = self.timesteps[start_step:] #skip time_steps, if start_step = 50 8#
61
+ self.start_step = start_step #50, in this case
62
+
63
+ def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
64
+ t = timestep #t
65
+ prev_t = self._get_previous_timestep(t) #t-1
66
+
67
+ # 1. compute alphas, betas
68
+ alpha_prod_t = self.alphas_cumprod[t] #alpha_bar_t
69
+ alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one #alpha_bar_t-1
70
+ beta_prod_t = 1 - alpha_prod_t #beta_bar_t
71
+ beta_prod_t_prev = 1 - alpha_prod_t_prev #beta_bar_t-1
72
+ current_alpha_t = alpha_prod_t / alpha_prod_t_prev #alpha_t
73
+ current_beta_t = 1 - current_alpha_t #beta_t
74
+
75
+ # 2. compute predicted original sample from predicted noise also called
76
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
77
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) #x_0 - gaussian noise
78
+
79
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
80
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
81
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t #coeff_x_0
82
+ current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t #coff_x_t
83
+
84
+ # 5. Compute predicted previous sample µ_t
85
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
86
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents #
87
+
88
+ # 6. Add noise
89
+ variance = 0
90
+ if t > 0:
91
+ device = model_output.device
92
+ noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
93
+ # Compute the variance as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
94
+ variance = (self._get_variance(t) ** 0.5) * noise
95
+
96
+ # sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
97
+ # the variable "variance" is already multiplied by the noise N(0, 1)
98
+ pred_prev_sample = pred_prev_sample + variance #predicted xt-1
99
+
100
+ return pred_prev_sample
101
+
102
+ def add_noise(
103
+ self,
104
+ original_samples: torch.FloatTensor,
105
+ timesteps: torch.IntTensor,
106
+ ) -> torch.FloatTensor:
107
+ #forward noisification
108
+ #q(xt | x_not) = N(xt; sqrt(alpha_cumprod); (1 - alpha_cumprod)I)
109
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) #alpha_bar
110
+ timesteps = timesteps.to(original_samples.device)
111
+
112
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 #sqrt(alpha_bar_t)
113
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten() #flatten
114
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape): #for boardcasting
115
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
116
+
117
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 #sqrt(1 - alpha_bar_t)
118
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
119
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
120
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
121
+
122
+ # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
123
+ # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
124
+ # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
125
+ noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype) #noise
126
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise #noisy samples
127
+ return noisy_samples
unet.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from helperUNET import SwitchSequential, UNET_AttentionBlock, UNET_ResidualBlock, Upsample
5
+
6
+
7
+ class UNET(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.encoders = nn.ModuleList([
11
+ # (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
12
+ SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)),
13
+
14
+ # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
15
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
16
+
17
+ # (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
18
+ SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
19
+
20
+ # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)
21
+ SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
22
+
23
+ # (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
24
+ SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
25
+
26
+ # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
27
+ SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
28
+
29
+ # (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)
30
+ SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
31
+
32
+ # (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
33
+ SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
34
+
35
+ # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
36
+ SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
37
+
38
+ # (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)
39
+ SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
40
+
41
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
42
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
43
+
44
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
45
+ SwitchSequential(UNET_ResidualBlock(1280, 1280)),
46
+ ])
47
+
48
+ self.bottleneck = SwitchSequential(
49
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
50
+ UNET_ResidualBlock(1280, 1280),
51
+
52
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
53
+ UNET_AttentionBlock(8, 160),
54
+
55
+ # (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
56
+ UNET_ResidualBlock(1280, 1280),
57
+ )
58
+
59
+ self.decoders = nn.ModuleList([
60
+ # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
61
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
62
+
63
+ # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
64
+ SwitchSequential(UNET_ResidualBlock(2560, 1280)),
65
+
66
+ # (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32)
67
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
68
+
69
+ # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
70
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
71
+
72
+ # (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
73
+ SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
74
+
75
+ # (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)
76
+ SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
77
+
78
+ # (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
79
+ SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
80
+
81
+ # (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
82
+ SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
83
+
84
+ # (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)
85
+ SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
86
+
87
+ # (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
88
+ SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
89
+
90
+ # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
91
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
92
+
93
+ # (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
94
+ SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
95
+ ])
96
+
97
+ def forward(self, x, context, time):
98
+ # x: (Batch_Size, 4, Height / 8, Width / 8)
99
+ # context: (Batch_Size, Seq_Len, Dim)
100
+ # time: (1, 1280)
101
+
102
+ skip_connections = []
103
+ for layers in self.encoders:
104
+ x = layers(x, context, time)
105
+ skip_connections.append(x)
106
+
107
+ x = self.bottleneck(x, context, time)
108
+
109
+ for layers in self.decoders:
110
+ # Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
111
+ x = torch.cat((x, skip_connections.pop()), dim=1)
112
+ x = layers(x, context, time)
113
+
114
+ return x
115
+
116
+
117
+
118
+ class UNET_OutputLayer(nn.Module):
119
+ def __init__(self, in_channels, out_channels):
120
+ super().__init__()
121
+ self.groupnorm = nn.GroupNorm(32, in_channels)
122
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
123
+
124
+ def forward(self, x):
125
+ # x: (Batch_Size, 320, Height / 8, Width / 8)
126
+
127
+ # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
128
+ x = self.groupnorm(x)
129
+
130
+ # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
131
+ x = F.silu(x)
132
+
133
+ # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
134
+ x = self.conv(x)
135
+
136
+ # (Batch_Size, 4, Height / 8, Width / 8)
137
+ return x
138
+
139
+
140
+ class TimeEmbedding(nn.Module):
141
+ def __init__(self, n_embd):
142
+ super().__init__()
143
+ self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
144
+ self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd)
145
+
146
+ def forward(self, x):
147
+ # x: (1, 320)
148
+
149
+ # (1, 320) -> (1, 1280)
150
+ x = self.linear_1(x)
151
+
152
+ # (1, 1280) -> (1, 1280)
153
+ x = F.silu(x)
154
+
155
+ # (1, 1280) -> (1, 1280)
156
+ x = self.linear_2(x)
157
+
158
+ return x
159
+
160
+
161
+ class Diffusion(nn.Module):
162
+ def __init__(self):
163
+ super().__init__()
164
+ self.time_embedding = TimeEmbedding(320)
165
+ self.unet = UNET()
166
+ self.final = UNET_OutputLayer(320, 4)
167
+
168
+ def forward(self, latent, context, time):
169
+ # latent: (Batch_Size, 4, Height / 8, Width / 8)
170
+ # context: (Batch_Size, Seq_Len, Dim)
171
+ # time: (1, 320)
172
+
173
+ # (1, 320) -> (1, 1280)
174
+ time = self.time_embedding(time)
175
+
176
+ # (Batch, 4, Height / 8, Width / 8) -> (Batch, 320, Height / 8, Width / 8)
177
+ output = self.unet(latent, context, time)
178
+
179
+ # (Batch, 320, Height / 8, Width / 8) -> (Batch, 4, Height / 8, Width / 8)
180
+ output = self.final(output)
181
+
182
+ # (Batch, 4, Height / 8, Width / 8)
183
+ return output