File size: 2,143 Bytes
582b238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from attention import ViTBlock

# Global Parameters
image_shape = 256
emb_dim = 768
patch_size = 16

class Encoder(nn.Module):
  def __init__(self, latent_dim, image_shape=image_shape, emb_dim=emb_dim, patch_size=patch_size, n_heads=8, dropout=0.1, layers=6, gaussian=False):
    super().__init__()
    self.patchifier = nn.Conv2d(3, emb_dim, kernel_size=patch_size, stride=patch_size)
    self.Blocks = nn.ModuleList([ViTBlock(image_shape // patch_size, image_shape // patch_size, emb_dim, n_heads=8, dropout=dropout) for _ in range(layers)])
    self.ln = nn.LayerNorm(emb_dim)
    self.compress_latent = nn.Linear(emb_dim, latent_dim)

  def forward(self,x):
    x = self.patchifier(x)
    x = rearrange(x, "B D H W -> B (H W) D") # Flatten to B, N, D
    for vitBlock in self.Blocks:
      x = vitBlock(x)
    x = self.ln(x)
    x = self.compress_latent(x)
    return x

class Decoder(nn.Module):
  def __init__(self, latent_dim, image_shape=image_shape, emb_dim=emb_dim, patch_size=patch_size, n_heads=8, dropout=0.1, layers=6, gaussian=False):
    super().__init__()
    self.hw = image_shape // patch_size
    self.patch_size = patch_size
    self.decompress_latent = nn.Linear(latent_dim, emb_dim)
    self.ln = nn.LayerNorm(emb_dim)
    self.emb_to_patch = nn.Linear(emb_dim, 3*(patch_size**2))
    self.Blocks = nn.ModuleList([ViTBlock(image_shape // patch_size, image_shape // patch_size, emb_dim, n_heads=8, dropout=dropout) for _ in range(layers)])

  def forward(self,x):
    x = self.decompress_latent(x)
    for vitBlock in self.Blocks:
      x = vitBlock(x)
    self.ln(x)
    #shape is [B HW/p**2 (3 p p)]
    x = self.emb_to_patch(x)
    assert x.shape == torch.Size([x.shape[0], self.hw**2, 3*(self.patch_size**2)]), f"Expected shape {torch.Size([x.shape[0], self.hw**2, 3*(self.patch_size**2)])} got {x.shape}"
    x = rearrange(x, "B (H W) (D p1 p2) -> B D (H p1) (W p2)", H=self.hw, W=self.hw, p1=self.patch_size, p2=self.patch_size) # Expand to B, H, W, D
    return F.tanh(x)