nitesh501 commited on
Commit
4691adb
·
verified ·
1 Parent(s): e359f57

Delete vae.py

Browse files
Files changed (1) hide show
  1. vae.py +0 -166
vae.py DELETED
@@ -1,166 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import math
5
-
6
- class SelfAttention(nn.Module):
7
- def __init__(self, n_heads, embd_dim, in_proj_bias=True, out_proj_bias=True):
8
- super().__init__()
9
- self.n_heads = n_heads
10
- self.in_proj = nn.Linear(embd_dim, 3 * embd_dim, bias=in_proj_bias)
11
- self.out_proj = nn.Linear(embd_dim, embd_dim, bias=out_proj_bias)
12
-
13
- self.d_heads = embd_dim // n_heads
14
- assert self.d_heads * n_heads == embd_dim, "embed_dim must be divisible by num_heads"
15
-
16
- def forward(self, x, casual_mask=False):
17
- batch_size, seq_len, embd_dim = x.shape
18
- interim_shape = (batch_size, seq_len, self.n_heads, self.d_heads)
19
- q, k, v = self.in_proj(x).chunk(3, dim=-1)
20
- q = q.view(interim_shape)
21
- k = k.view(interim_shape)
22
- v = v.view(interim_shape)
23
-
24
- q = q.transpose(1, 2)
25
- k = k.transpose(1, 2)
26
- v = v.transpose(1, 2)
27
-
28
- weight = q @ k.transpose(-1, -2)
29
- if casual_mask:
30
- mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
31
- weight.masked_fill_(mask, -torch.inf)
32
- weight /= math.sqrt(self.d_heads)
33
- weight = F.softmax(weight, dim=-1)
34
- output = weight @ v
35
- output = output.transpose(1, 2)
36
- output = output.reshape((batch_size, seq_len, embd_dim))
37
- output = self.out_proj(output)
38
- return output
39
-
40
- class AttentionBlock(nn.Module):
41
- def __init__(self, channels):
42
- super().__init__()
43
- self.groupnorm = nn.GroupNorm(num_groups=32, num_channels=channels)
44
- self.attention = SelfAttention(n_heads=1, embd_dim=channels)
45
-
46
- def forward(self, x):
47
- residual = x
48
- x = self.groupnorm(x)
49
- n, c, h, w = x.shape
50
- x = x.view((n, c, h * w)).transpose(-1, -2)
51
- x = self.attention(x)
52
- x = x.transpose(-1, -2).view((n, c, h, w))
53
- x = x + residual
54
- return x
55
-
56
- class Residual(nn.Module):
57
- def __init__(self, in_channels, out_channels):
58
- super().__init__()
59
- self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
60
- self.gn1 = nn.GroupNorm(32, out_channels)
61
- self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)
62
- self.gn2 = nn.GroupNorm(32, out_channels)
63
- self.silu = nn.SiLU()
64
- if in_channels != out_channels:
65
- self.residual_layer = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
66
- else:
67
- self.residual_layer = nn.Identity()
68
-
69
- def forward(self, x):
70
- x_residual = x.clone()
71
- x = self.gn1(x)
72
- x = self.silu(x)
73
- x = self.conv1(x)
74
- x = self.gn2(x)
75
- x = self.conv2(x)
76
- x += self.residual_layer(x_residual)
77
- return x
78
-
79
- class Encoder(nn.Module):
80
- def __init__(self, latent_channels):
81
- super().__init__()
82
- self.net = nn.Sequential(
83
- nn.Conv2d(3, 64, 3, padding=1),
84
- nn.SiLU(),
85
-
86
- Residual(64, 64),
87
- Residual(64, 64),
88
-
89
- nn.Conv2d(64, 128, 3, 2, 1),
90
- Residual(128, 128),
91
- Residual(128, 128),
92
-
93
- nn.Conv2d(128, 256, 3, 2, 1),
94
- Residual(256, 256),
95
- Residual(256, 256),
96
-
97
- nn.Conv2d(256, 256, 3, 2, 1),
98
- Residual(256, 256),
99
- AttentionBlock(channels=256),
100
- Residual(256, 256),
101
-
102
- nn.GroupNorm(32, 256),
103
- nn.SiLU(),
104
- )
105
- self.mu = nn.Conv2d(256, latent_channels, 3, padding=1)
106
- self.logvar = nn.Conv2d(256, latent_channels, 3, padding=1)
107
- self.latent_channels = latent_channels
108
-
109
- def forward(self, x):
110
- x = self.net(x)
111
- mu = self.mu(x)
112
- logvar = self.logvar(x)
113
- return mu, logvar
114
-
115
- class Decoder(nn.Module):
116
- def __init__(self, latent_channels):
117
- super().__init__()
118
- self.net = nn.Sequential(
119
- nn.Conv2d(latent_channels, 256, 3, padding=1),
120
- Residual(256, 256),
121
- AttentionBlock(channels=256),
122
- Residual(256, 256),
123
-
124
- nn.Upsample(scale_factor=2, mode='nearest'),
125
- nn.Conv2d(256, 256, 3, padding=1),
126
- Residual(256, 256),
127
- Residual(256, 256),
128
-
129
- nn.Upsample(scale_factor=2, mode='nearest'),
130
- nn.Conv2d(256, 128, 3, padding=1),
131
- Residual(128, 128),
132
- Residual(128, 128),
133
-
134
- nn.Upsample(scale_factor=2, mode='nearest'),
135
- nn.Conv2d(128, 64, 3, padding=1),
136
- Residual(64, 64),
137
- Residual(64, 64),
138
-
139
- nn.GroupNorm(32, 64),
140
- nn.SiLU(),
141
- nn.Conv2d(64, 3, 3, padding=1),
142
- nn.Tanh(),
143
- )
144
- self.latent_channels = latent_channels
145
-
146
- def forward(self, x):
147
- return self.net(x)
148
-
149
- class Vae(nn.Module):
150
- def __init__(self, latent_channels):
151
- super().__init__()
152
- self.encoder = Encoder(latent_channels)
153
- self.decoder = Decoder(latent_channels)
154
- self.latent_channels = latent_channels
155
-
156
- def reparametrize(self, mu, logvar):
157
- logvar = torch.clamp(logvar, -30, 20)
158
- std = torch.exp(0.5 * logvar)
159
- eps = torch.randn_like(std)
160
- return mu + eps * std
161
-
162
- def forward(self, x):
163
- mu, logvar = self.encoder(x)
164
- z = self.reparametrize(mu, logvar)
165
- return self.decoder(z), mu, logvar
166
-