Yash Nagraj commited on
Commit
cb6bd3a
·
1 Parent(s): 3cb348b

Add attention to Down Blocks

Browse files
Files changed (1) hide show
  1. models/blocks.py +77 -0
models/blocks.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def get_time_embedding(time_steps, temb_dim):
6
+ assert time_steps % 2 == 0, "time embedding dimension must be divisible by 2"
7
+
8
+ factor = 10000 ** ((torch.arange(
9
+ start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
10
+ )
11
+
12
+ # pos / factor
13
+ # time_steps B -> B, 1 -> B, temb_dim
14
+ t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
15
+ t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
16
+ return t_emb
17
+
18
+
19
+ class DownBlock(nn.Module):
20
+ """
21
+ Down Block that down samples the image, flows like this:
22
+ 1) Resnet block with time embedding
23
+ 2) Self Attention block
24
+ 3) Down Sample
25
+ """
26
+
27
+ def __init__(self, in_channels, out_channels, t_emd_dim, down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False,
28
+ context_dim=None):
29
+ super().__init__()
30
+ self.down_sample = down_sample
31
+ self.cross_attn = cross_attn
32
+ self.context_dim = context_dim
33
+ self.cross_attn = cross_attn
34
+ self.t_emb_dim = t_emd_dim
35
+ self.attn = attn
36
+ self.resnet_conv_first = nn.ModuleList([
37
+ nn.Sequential(
38
+ nn.GroupNorm(norm_channels, in_channels if i ==
39
+ 0 else out_channels),
40
+ nn.SiLU(),
41
+ nn.Conv2d(in_channels=in_channels if i == 0 else out_channels,
42
+ out_channels=out_channels, kernel_size=3, stride=1, padding=1)
43
+
44
+ ) for i in range(num_layers)
45
+ ])
46
+ if self.t_emb_dim is not None:
47
+ self.time_embd_layers = nn.ModuleList([
48
+ nn.Sequential(
49
+ nn.SiLU(),
50
+ nn.Linear(self.t_emb_dim, out_channels)
51
+ )
52
+ for _ in range(num_layers)
53
+ ])
54
+
55
+ self.resnet_conv_second = nn.ModuleList([
56
+ nn.Sequential(
57
+ nn.GroupNorm(norm_channels, out_channels),
58
+ nn.SiLU(),
59
+ nn.Conv2d(in_channels, out_channels,
60
+ kernel_size=3, stride=1, padding=1),
61
+ )
62
+ for _ in range(num_layers)
63
+ ])
64
+
65
+ if self.attn:
66
+ self.attention_norms = nn.ModuleList(
67
+ [nn.GroupNorm(norm_channels, out_channels)
68
+ for _ in range(num_layers)]
69
+ )
70
+
71
+ self.attention = nn.ModuleList(
72
+ [nn.MultiheadAttention(
73
+ out_channels, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]
74
+ )
75
+
76
+ if self.cross_attn:
77
+ assert context_dim is not None, "Context Dimension must be passed to croo attention"