File size: 4,003 Bytes
4aabce3
 
 
 
 
 
 
 
a625e96
 
 
 
 
 
 
 
 
 
 
 
 
 
4aabce3
 
 
 
 
a625e96
4aabce3
a625e96
4aabce3
a625e96
4aabce3
a625e96
 
4aabce3
 
a625e96
 
 
4aabce3
a625e96
 
 
4aabce3
 
a625e96
 
 
 
 
 
4aabce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a625e96
4aabce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import sys, os
sys.path.insert(0, os.path.dirname(__file__))
import torch
from torch import nn
import torch.nn.functional as F
from config import *
from attention import SelfCrossAttn

class VAEResBlock(nn.Module):
    def __init__(self, in_channels, out_channels=None):
        out_channels = out_channels or in_channels
        super().__init__()
        self.block = nn.Sequential(
            nn.GroupNorm(vae_group_size, in_channels), nn.SiLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
            nn.GroupNorm(vae_group_size, out_channels), nn.SiLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
        )
        self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) if in_channels != out_channels else nn.Identity()
    def forward(self, x):
        return self.block(x) + self.skip(x)

class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, padding=1, bias=False), # (B, 4, 256, 256)
            nn.GroupNorm(vae_group_size, 32), nn.SiLU(inplace=True),
            nn.Conv2d(32, 128, 3, 2, padding=1, bias=False), # (B, 32, 128, 128)
            nn.GroupNorm(vae_group_size, 128), nn.SiLU(inplace=True),
            nn.Conv2d(128, 256, 3, 2, padding=1, bias=False), # (B, 128, 64, 64)
            nn.GroupNorm(vae_group_size, 256), nn.SiLU(inplace=True),
            nn.Conv2d(256, 512, 3, 2, padding=1, bias=False), # (B, 128, 32, 32)
            nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True),
            # nn.Conv2d(256, 512, 3, 2, padding=1, bias=False), # (B, 128, 16, 16)
            # nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True),
            # SelfCrossAttn(512, heads=8, cross=False),
            VAEResBlock(512), SelfCrossAttn(512, heads=8, cross=False), VAEResBlock(512),
            nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True), 
        )
        # Channel‑wise μ and log σ², shape = (B, latent_channels, 4, 4, 4)
        self.to_latent = nn.Conv2d(512, 2 * vae_latent_channels, kernel_size=1)
        self.from_latent = nn.Conv2d(vae_latent_channels, 512, kernel_size=1)
        # Decoder
        self.decoder_conv = nn.Sequential(
            VAEResBlock(512), SelfCrossAttn(512, heads=8, cross=False), VAEResBlock(512),
            nn.GroupNorm(vae_group_size, 512), nn.SiLU(inplace=True), 
            # SelfCrossAttn(512, heads=8, cross=False),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(512, 256, 3, padding=1, bias=False),
            nn.GroupNorm(vae_group_size, 256), nn.SiLU(inplace=True),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(256, 128, 3, padding=1, bias=False),
            nn.GroupNorm(vae_group_size, 128), nn.SiLU(inplace=True),
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(128, 32, 3, padding=1, bias=False),
            nn.GroupNorm(vae_group_size, 32), nn.SiLU(inplace=True),
            nn.Conv2d(32, 3, 3, 1, padding=1),
            nn.Sigmoid()
            # nn.Tanh()
        )
    @staticmethod
    def reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        h = self.encoder_conv(x) # (B, C, D, H)
        h = self.to_latent(h)
        mu, logvar = torch.chunk(h, 2, dim=1)
        z = self.reparameterize(mu, logvar) # Latent (B, C, D, H)
        h = self.from_latent(z)
        recon = self.decoder_conv(h)
        return recon, mu, logvar

    def encode_img_to_latent(self, x):
        h = self.encoder_conv(x) # (B, C, D, H)
        h = self.to_latent(h)
        mu, logvar = torch.chunk(h, 2, dim=1)
        z = self.reparameterize(mu, logvar)
        return z

    def decode_latent_to_img(self, z):
        h = self.from_latent(z)
        return self.decoder_conv(h)