File size: 2,699 Bytes
66a2b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""HSIGene AutoencoderKL - nn.Module, no Lightning. Loss = Identity."""

import torch
import torch.nn as nn

from .vae_blocks import Encoder, Decoder, DiagonalGaussianDistribution


class AutoencoderKL(nn.Module):
    """
    AutoencoderKL - nn.Module (not Lightning).
    Uses Encoder, Decoder, quant_conv, post_quant_conv.
    encode() returns posterior, decode() takes z.
    Loss = Identity (no-op).
    """

    def __init__(
        self,
        ddconfig,
        embed_dim=4,
        lossconfig=None,
        **kwargs,
    ):
        super().__init__()
        self.encoder = Encoder(**ddconfig)
        self.decoder = Decoder(**ddconfig)
        assert ddconfig.get("double_z", True)
        z_channels = ddconfig["z_channels"]
        self.quant_conv = nn.Conv2d(2 * z_channels, 2 * embed_dim, 1)
        self.post_quant_conv = nn.Conv2d(embed_dim, z_channels, 1)
        self.embed_dim = embed_dim
        self.loss = nn.Identity()

    def encode(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments, deterministic=True)
        return posterior

    def decode(self, z):
        z = self.post_quant_conv(z)
        return self.decoder(z)

    def forward(self, input, sample_posterior=True):
        posterior = self.encode(input)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        dec = self.decode(z)
        return dec, posterior


class HSIGeneAutoencoderKL(AutoencoderKL):
    """
    HSIGene VAE with diffusers-style config.
    Accepts in_channels, out_channels, latent_channels, block_out_channels.
    """

    def __init__(
        self,
        in_channels: int = 48,
        out_channels: int = 48,
        latent_channels: int = 96,
        embed_dim: int = 4,
        block_out_channels: tuple = (64, 128, 256),
        num_res_blocks: int = 4,
        attn_resolutions: tuple = (16, 32, 64),
        dropout: float = 0.0,
        double_z: bool = True,
        resolution: int = 256,
        **kwargs,
    ):
        ch = block_out_channels[0]
        ch_mult = tuple(
            block_out_channels[i] // ch for i in range(len(block_out_channels))
        )
        ddconfig = dict(
            double_z=double_z,
            z_channels=latent_channels,
            resolution=resolution,
            in_channels=in_channels,
            out_ch=out_channels,
            ch=ch,
            ch_mult=list(ch_mult),
            num_res_blocks=num_res_blocks,
            attn_resolutions=list(attn_resolutions),
            dropout=dropout,
        )
        super().__init__(ddconfig=ddconfig, embed_dim=embed_dim, **kwargs)